manager.go 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "net/netip"
  13. "strings"
  14. "sync"
  15. "time"
  16. )
  17. const (
  18. tierIDPrefix = "ti_"
  19. tierIDLength = 8
  20. syncTopicPrefix = "st_"
  21. syncTopicLength = 16
  22. userIDPrefix = "u_"
  23. userIDLength = 12
  24. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
  25. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  26. tokenPrefix = "tk_"
  27. tokenLength = 32
  28. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  29. )
  30. // Default constants that may be overridden by configs
  31. const (
  32. DefaultUserStatsQueueWriterInterval = 33 * time.Second
  33. DefaultUserPasswordBcryptCost = 10
  34. )
  35. var (
  36. errNoTokenProvided = errors.New("no token provided")
  37. errTopicOwnedByOthers = errors.New("topic owned by others")
  38. errNoRows = errors.New("no rows found")
  39. )
  40. // Manager-related queries
  41. const (
  42. createTablesQueriesNoTx = `
  43. CREATE TABLE IF NOT EXISTS tier (
  44. id TEXT PRIMARY KEY,
  45. code TEXT NOT NULL,
  46. name TEXT NOT NULL,
  47. messages_limit INT NOT NULL,
  48. messages_expiry_duration INT NOT NULL,
  49. emails_limit INT NOT NULL,
  50. reservations_limit INT NOT NULL,
  51. attachment_file_size_limit INT NOT NULL,
  52. attachment_total_size_limit INT NOT NULL,
  53. attachment_expiry_duration INT NOT NULL,
  54. attachment_bandwidth_limit INT NOT NULL,
  55. stripe_price_id TEXT
  56. );
  57. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  58. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  59. CREATE TABLE IF NOT EXISTS user (
  60. id TEXT PRIMARY KEY,
  61. tier_id TEXT,
  62. user TEXT NOT NULL,
  63. pass TEXT NOT NULL,
  64. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  65. prefs JSON NOT NULL DEFAULT '{}',
  66. sync_topic TEXT NOT NULL,
  67. stats_messages INT NOT NULL DEFAULT (0),
  68. stats_emails INT NOT NULL DEFAULT (0),
  69. stripe_customer_id TEXT,
  70. stripe_subscription_id TEXT,
  71. stripe_subscription_status TEXT,
  72. stripe_subscription_paid_until INT,
  73. stripe_subscription_cancel_at INT,
  74. created INT NOT NULL,
  75. deleted INT,
  76. FOREIGN KEY (tier_id) REFERENCES tier (id)
  77. );
  78. CREATE UNIQUE INDEX idx_user ON user (user);
  79. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  80. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  81. CREATE TABLE IF NOT EXISTS user_access (
  82. user_id TEXT NOT NULL,
  83. topic TEXT NOT NULL,
  84. read INT NOT NULL,
  85. write INT NOT NULL,
  86. owner_user_id INT,
  87. PRIMARY KEY (user_id, topic),
  88. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  89. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  90. );
  91. CREATE TABLE IF NOT EXISTS user_token (
  92. user_id TEXT NOT NULL,
  93. token TEXT NOT NULL,
  94. label TEXT NOT NULL,
  95. last_access INT NOT NULL,
  96. last_origin TEXT NOT NULL,
  97. expires INT NOT NULL,
  98. PRIMARY KEY (user_id, token),
  99. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  100. );
  101. CREATE TABLE IF NOT EXISTS schemaVersion (
  102. id INT PRIMARY KEY,
  103. version INT NOT NULL
  104. );
  105. INSERT INTO user (id, user, pass, role, sync_topic, created)
  106. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
  107. ON CONFLICT (id) DO NOTHING;
  108. `
  109. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  110. builtinStartupQueries = `
  111. PRAGMA foreign_keys = ON;
  112. `
  113. selectUserByIDQuery = `
  114. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  115. FROM user u
  116. LEFT JOIN tier t on t.id = u.tier_id
  117. WHERE u.id = ?
  118. `
  119. selectUserByNameQuery = `
  120. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  121. FROM user u
  122. LEFT JOIN tier t on t.id = u.tier_id
  123. WHERE user = ?
  124. `
  125. selectUserByTokenQuery = `
  126. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  127. FROM user u
  128. JOIN user_token tk on u.id = tk.user_id
  129. LEFT JOIN tier t on t.id = u.tier_id
  130. WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
  131. `
  132. selectUserByStripeCustomerIDQuery = `
  133. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  134. FROM user u
  135. LEFT JOIN tier t on t.id = u.tier_id
  136. WHERE u.stripe_customer_id = ?
  137. `
  138. selectTopicPermsQuery = `
  139. SELECT read, write
  140. FROM user_access a
  141. JOIN user u ON u.id = a.user_id
  142. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  143. ORDER BY u.user DESC
  144. `
  145. insertUserQuery = `
  146. INSERT INTO user (id, user, pass, role, sync_topic, created)
  147. VALUES (?, ?, ?, ?, ?, ?)
  148. `
  149. selectUsernamesQuery = `
  150. SELECT user
  151. FROM user
  152. ORDER BY
  153. CASE role
  154. WHEN 'admin' THEN 1
  155. WHEN 'anonymous' THEN 3
  156. ELSE 2
  157. END, user
  158. `
  159. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  160. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  161. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  162. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
  163. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  164. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  165. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  166. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  167. upsertUserAccessQuery = `
  168. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  169. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  170. ON CONFLICT (user_id, topic)
  171. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  172. `
  173. selectUserAccessQuery = `
  174. SELECT topic, read, write
  175. FROM user_access
  176. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  177. ORDER BY write DESC, read DESC, topic
  178. `
  179. selectUserReservationsQuery = `
  180. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  181. FROM user_access a_user
  182. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  183. WHERE a_user.user_id = a_user.owner_user_id
  184. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  185. ORDER BY a_user.topic
  186. `
  187. selectUserReservationsCountQuery = `
  188. SELECT COUNT(*)
  189. FROM user_access
  190. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  191. `
  192. selectUserHasReservationQuery = `
  193. SELECT COUNT(*)
  194. FROM user_access
  195. WHERE user_id = owner_user_id
  196. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  197. AND topic = ?
  198. `
  199. selectOtherAccessCountQuery = `
  200. SELECT COUNT(*)
  201. FROM user_access
  202. WHERE (topic = ? OR ? LIKE topic)
  203. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  204. `
  205. deleteAllAccessQuery = `DELETE FROM user_access`
  206. deleteUserAccessQuery = `
  207. DELETE FROM user_access
  208. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  209. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  210. `
  211. deleteTopicAccessQuery = `
  212. DELETE FROM user_access
  213. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  214. AND topic = ?
  215. `
  216. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  217. selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
  218. selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
  219. insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
  220. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  221. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  222. updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
  223. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  224. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  225. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  226. deleteExcessTokensQuery = `
  227. DELETE FROM user_token
  228. WHERE (user_id, token) NOT IN (
  229. SELECT user_id, token
  230. FROM user_token
  231. WHERE user_id = ?
  232. ORDER BY expires DESC
  233. LIMIT ?
  234. )
  235. `
  236. insertTierQuery = `
  237. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
  238. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  239. `
  240. selectTiersQuery = `
  241. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  242. FROM tier
  243. `
  244. selectTierByCodeQuery = `
  245. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  246. FROM tier
  247. WHERE code = ?
  248. `
  249. selectTierByPriceIDQuery = `
  250. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  251. FROM tier
  252. WHERE stripe_price_id = ?
  253. `
  254. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  255. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  256. updateBillingQuery = `
  257. UPDATE user
  258. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  259. WHERE user = ?
  260. `
  261. )
  262. // Schema management queries
  263. const (
  264. currentSchemaVersion = 2
  265. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  266. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  267. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  268. // 1 -> 2 (complex migration!)
  269. migrate1To2RenameUserTableQueryNoTx = `
  270. ALTER TABLE user RENAME TO user_old;
  271. `
  272. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  273. migrate1To2InsertUserNoTx = `
  274. INSERT INTO user (id, user, pass, role, sync_topic, created)
  275. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  276. `
  277. migrate1To2InsertFromOldTablesAndDropNoTx = `
  278. INSERT INTO user_access (user_id, topic, read, write)
  279. SELECT u.id, a.topic, a.read, a.write
  280. FROM user u
  281. JOIN access a ON u.user = a.user;
  282. DROP TABLE access;
  283. DROP TABLE user_old;
  284. `
  285. )
  286. // Manager is an implementation of Manager. It stores users and access control list
  287. // in a SQLite database.
  288. type Manager struct {
  289. db *sql.DB
  290. defaultAccess Permission // Default permission if no ACL matches
  291. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  292. tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
  293. bcryptCost int // Makes testing easier
  294. mu sync.Mutex
  295. }
  296. var _ Auther = (*Manager)(nil)
  297. // NewManager creates a new Manager instance
  298. func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
  299. db, err := sql.Open("sqlite3", filename)
  300. if err != nil {
  301. return nil, err
  302. }
  303. if err := setupDB(db); err != nil {
  304. return nil, err
  305. }
  306. if err := runStartupQueries(db, startupQueries); err != nil {
  307. return nil, err
  308. }
  309. manager := &Manager{
  310. db: db,
  311. defaultAccess: defaultAccess,
  312. statsQueue: make(map[string]*Stats),
  313. tokenQueue: make(map[string]*TokenUpdate),
  314. bcryptCost: bcryptCost,
  315. }
  316. go manager.asyncQueueWriter(queueWriterInterval)
  317. return manager, nil
  318. }
  319. // Authenticate checks username and password and returns a User if correct, and the user has not been
  320. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  321. // the password is correct or incorrect.
  322. func (a *Manager) Authenticate(username, password string) (*User, error) {
  323. if username == Everyone {
  324. return nil, ErrUnauthenticated
  325. }
  326. user, err := a.User(username)
  327. if err != nil {
  328. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  329. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  330. return nil, ErrUnauthenticated
  331. } else if user.Deleted {
  332. log.Trace("authentication of user %s failed (2): user marked deleted", username)
  333. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  334. return nil, ErrUnauthenticated
  335. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  336. log.Trace("authentication of user %s failed (3): %s", username, err.Error())
  337. return nil, ErrUnauthenticated
  338. }
  339. return user, nil
  340. }
  341. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  342. // The method sets the User.Token value to the token that was used for authentication.
  343. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  344. if len(token) != tokenLength {
  345. return nil, ErrUnauthenticated
  346. }
  347. user, err := a.userByToken(token)
  348. if err != nil {
  349. return nil, ErrUnauthenticated
  350. }
  351. user.Token = token
  352. return user, nil
  353. }
  354. // CreateToken generates a random token for the given user and returns it. The token expires
  355. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  356. // given user, if there are too many of them.
  357. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
  358. token := util.RandomStringPrefix(tokenPrefix, tokenLength)
  359. tx, err := a.db.Begin()
  360. if err != nil {
  361. return nil, err
  362. }
  363. defer tx.Rollback()
  364. access := time.Now()
  365. if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
  366. return nil, err
  367. }
  368. rows, err := tx.Query(selectTokenCountQuery, userID)
  369. if err != nil {
  370. return nil, err
  371. }
  372. defer rows.Close()
  373. if !rows.Next() {
  374. return nil, errNoRows
  375. }
  376. var tokenCount int
  377. if err := rows.Scan(&tokenCount); err != nil {
  378. return nil, err
  379. }
  380. if tokenCount >= tokenMaxCount {
  381. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  382. // on two indices, whereas the query below is a full table scan.
  383. if _, err := tx.Exec(deleteExcessTokensQuery, userID, tokenMaxCount); err != nil {
  384. return nil, err
  385. }
  386. }
  387. if err := tx.Commit(); err != nil {
  388. return nil, err
  389. }
  390. return &Token{
  391. Value: token,
  392. Label: label,
  393. LastAccess: access,
  394. LastOrigin: origin,
  395. Expires: expires,
  396. }, nil
  397. }
  398. // Tokens returns all existing tokens for the user with the given user ID
  399. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  400. rows, err := a.db.Query(selectTokensQuery, userID)
  401. if err != nil {
  402. return nil, err
  403. }
  404. defer rows.Close()
  405. tokens := make([]*Token, 0)
  406. for {
  407. token, err := a.readToken(rows)
  408. if err == ErrTokenNotFound {
  409. break
  410. } else if err != nil {
  411. return nil, err
  412. }
  413. tokens = append(tokens, token)
  414. }
  415. return tokens, nil
  416. }
  417. // Token returns a specific token for a user
  418. func (a *Manager) Token(userID, token string) (*Token, error) {
  419. rows, err := a.db.Query(selectTokenQuery, userID, token)
  420. if err != nil {
  421. return nil, err
  422. }
  423. defer rows.Close()
  424. return a.readToken(rows)
  425. }
  426. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  427. var token, label, lastOrigin string
  428. var lastAccess, expires int64
  429. if !rows.Next() {
  430. return nil, ErrTokenNotFound
  431. }
  432. if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
  433. return nil, err
  434. } else if err := rows.Err(); err != nil {
  435. return nil, err
  436. }
  437. lastOriginIP, err := netip.ParseAddr(lastOrigin)
  438. if err != nil {
  439. lastOriginIP = netip.IPv4Unspecified()
  440. }
  441. return &Token{
  442. Value: token,
  443. Label: label,
  444. LastAccess: time.Unix(lastAccess, 0),
  445. LastOrigin: lastOriginIP,
  446. Expires: time.Unix(expires, 0),
  447. }, nil
  448. }
  449. // ChangeToken updates a token's label and/or expiry date
  450. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  451. if token == "" {
  452. return nil, errNoTokenProvided
  453. }
  454. tx, err := a.db.Begin()
  455. if err != nil {
  456. return nil, err
  457. }
  458. defer tx.Rollback()
  459. if label != nil {
  460. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  461. return nil, err
  462. }
  463. }
  464. if expires != nil {
  465. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  466. return nil, err
  467. }
  468. }
  469. if err := tx.Commit(); err != nil {
  470. return nil, err
  471. }
  472. return a.Token(userID, token)
  473. }
  474. // RemoveToken deletes the token defined in User.Token
  475. func (a *Manager) RemoveToken(userID, token string) error {
  476. if token == "" {
  477. return errNoTokenProvided
  478. }
  479. if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil {
  480. return err
  481. }
  482. return nil
  483. }
  484. // RemoveExpiredTokens deletes all expired tokens from the database
  485. func (a *Manager) RemoveExpiredTokens() error {
  486. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  487. return err
  488. }
  489. return nil
  490. }
  491. // RemoveDeletedUsers deletes all users that have been marked deleted for
  492. func (a *Manager) RemoveDeletedUsers() error {
  493. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  494. return err
  495. }
  496. return nil
  497. }
  498. // ChangeSettings persists the user settings
  499. func (a *Manager) ChangeSettings(user *User) error {
  500. prefs, err := json.Marshal(user.Prefs)
  501. if err != nil {
  502. return err
  503. }
  504. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  505. return err
  506. }
  507. return nil
  508. }
  509. // ResetStats resets all user stats in the user database. This touches all users.
  510. func (a *Manager) ResetStats() error {
  511. a.mu.Lock() // Includes database query to avoid races!
  512. defer a.mu.Unlock()
  513. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  514. return err
  515. }
  516. a.statsQueue = make(map[string]*Stats)
  517. return nil
  518. }
  519. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  520. // batches at a regular interval
  521. func (a *Manager) EnqueueStats(userID string, stats *Stats) {
  522. a.mu.Lock()
  523. defer a.mu.Unlock()
  524. a.statsQueue[userID] = stats
  525. }
  526. // EnqueueTokenUpdate adds the token update to a queue which writes out token access times
  527. // in batches at a regular interval
  528. func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
  529. a.mu.Lock()
  530. defer a.mu.Unlock()
  531. a.tokenQueue[tokenID] = update
  532. }
  533. func (a *Manager) asyncQueueWriter(interval time.Duration) {
  534. ticker := time.NewTicker(interval)
  535. for range ticker.C {
  536. if err := a.writeUserStatsQueue(); err != nil {
  537. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  538. }
  539. if err := a.writeTokenUpdateQueue(); err != nil {
  540. log.Warn("User Manager: Writing token update queue failed: %s", err.Error())
  541. }
  542. }
  543. }
  544. func (a *Manager) writeUserStatsQueue() error {
  545. a.mu.Lock()
  546. if len(a.statsQueue) == 0 {
  547. a.mu.Unlock()
  548. log.Trace("User Manager: No user stats updates to commit")
  549. return nil
  550. }
  551. statsQueue := a.statsQueue
  552. a.statsQueue = make(map[string]*Stats)
  553. a.mu.Unlock()
  554. tx, err := a.db.Begin()
  555. if err != nil {
  556. return err
  557. }
  558. defer tx.Rollback()
  559. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  560. for userID, update := range statsQueue {
  561. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, update.Messages, update.Emails)
  562. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil {
  563. return err
  564. }
  565. }
  566. return tx.Commit()
  567. }
  568. func (a *Manager) writeTokenUpdateQueue() error {
  569. a.mu.Lock()
  570. if len(a.tokenQueue) == 0 {
  571. a.mu.Unlock()
  572. log.Trace("User Manager: No token updates to commit")
  573. return nil
  574. }
  575. tokenQueue := a.tokenQueue
  576. a.tokenQueue = make(map[string]*TokenUpdate)
  577. a.mu.Unlock()
  578. tx, err := a.db.Begin()
  579. if err != nil {
  580. return err
  581. }
  582. defer tx.Rollback()
  583. log.Debug("User Manager: Writing token update queue for %d token(s)", len(tokenQueue))
  584. for tokenID, update := range tokenQueue {
  585. log.Trace("User Manager: Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
  586. if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
  587. return err
  588. }
  589. }
  590. return tx.Commit()
  591. }
  592. // Authorize returns nil if the given user has access to the given topic using the desired
  593. // permission. The user param may be nil to signal an anonymous user.
  594. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  595. if user != nil && user.Role == RoleAdmin {
  596. return nil // Admin can do everything
  597. }
  598. username := Everyone
  599. if user != nil {
  600. username = user.Name
  601. }
  602. // Select the read/write permissions for this user/topic combo. The query may return two
  603. // rows (one for everyone, and one for the user), but prioritizes the user.
  604. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  605. if err != nil {
  606. return err
  607. }
  608. defer rows.Close()
  609. if !rows.Next() {
  610. return a.resolvePerms(a.defaultAccess, perm)
  611. }
  612. var read, write bool
  613. if err := rows.Scan(&read, &write); err != nil {
  614. return err
  615. } else if err := rows.Err(); err != nil {
  616. return err
  617. }
  618. return a.resolvePerms(NewPermission(read, write), perm)
  619. }
  620. func (a *Manager) resolvePerms(base, perm Permission) error {
  621. if perm == PermissionRead && base.IsRead() {
  622. return nil
  623. } else if perm == PermissionWrite && base.IsWrite() {
  624. return nil
  625. }
  626. return ErrUnauthorized
  627. }
  628. // AddUser adds a user with the given username, password and role
  629. func (a *Manager) AddUser(username, password string, role Role) error {
  630. if !AllowedUsername(username) || !AllowedRole(role) {
  631. return ErrInvalidArgument
  632. }
  633. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  634. if err != nil {
  635. return err
  636. }
  637. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  638. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  639. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
  640. return err
  641. }
  642. return nil
  643. }
  644. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  645. // if the user did not exist in the first place.
  646. func (a *Manager) RemoveUser(username string) error {
  647. if !AllowedUsername(username) {
  648. return ErrInvalidArgument
  649. }
  650. // Rows in user_access, user_token, etc. are deleted via foreign keys
  651. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  652. return err
  653. }
  654. return nil
  655. }
  656. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  657. // successful auth via Authenticate. A background process will delete the user at a later date.
  658. func (a *Manager) MarkUserRemoved(user *User) error {
  659. if !AllowedUsername(user.Name) {
  660. return ErrInvalidArgument
  661. }
  662. tx, err := a.db.Begin()
  663. if err != nil {
  664. return err
  665. }
  666. defer tx.Rollback()
  667. if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  668. return err
  669. }
  670. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  671. return err
  672. }
  673. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  674. return err
  675. }
  676. return tx.Commit()
  677. }
  678. // Users returns a list of users. It always also returns the Everyone user ("*").
  679. func (a *Manager) Users() ([]*User, error) {
  680. rows, err := a.db.Query(selectUsernamesQuery)
  681. if err != nil {
  682. return nil, err
  683. }
  684. defer rows.Close()
  685. usernames := make([]string, 0)
  686. for rows.Next() {
  687. var username string
  688. if err := rows.Scan(&username); err != nil {
  689. return nil, err
  690. } else if err := rows.Err(); err != nil {
  691. return nil, err
  692. }
  693. usernames = append(usernames, username)
  694. }
  695. rows.Close()
  696. users := make([]*User, 0)
  697. for _, username := range usernames {
  698. user, err := a.User(username)
  699. if err != nil {
  700. return nil, err
  701. }
  702. users = append(users, user)
  703. }
  704. return users, nil
  705. }
  706. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  707. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  708. func (a *Manager) User(username string) (*User, error) {
  709. rows, err := a.db.Query(selectUserByNameQuery, username)
  710. if err != nil {
  711. return nil, err
  712. }
  713. return a.readUser(rows)
  714. }
  715. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  716. func (a *Manager) UserByID(id string) (*User, error) {
  717. rows, err := a.db.Query(selectUserByIDQuery, id)
  718. if err != nil {
  719. return nil, err
  720. }
  721. return a.readUser(rows)
  722. }
  723. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  724. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  725. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  726. if err != nil {
  727. return nil, err
  728. }
  729. return a.readUser(rows)
  730. }
  731. func (a *Manager) userByToken(token string) (*User, error) {
  732. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  733. if err != nil {
  734. return nil, err
  735. }
  736. return a.readUser(rows)
  737. }
  738. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  739. defer rows.Close()
  740. var id, username, hash, role, prefs, syncTopic string
  741. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
  742. var messages, emails int64
  743. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  744. if !rows.Next() {
  745. return nil, ErrUserNotFound
  746. }
  747. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  748. return nil, err
  749. } else if err := rows.Err(); err != nil {
  750. return nil, err
  751. }
  752. user := &User{
  753. ID: id,
  754. Name: username,
  755. Hash: hash,
  756. Role: Role(role),
  757. Prefs: &Prefs{},
  758. SyncTopic: syncTopic,
  759. Stats: &Stats{
  760. Messages: messages,
  761. Emails: emails,
  762. },
  763. Billing: &Billing{
  764. StripeCustomerID: stripeCustomerID.String, // May be empty
  765. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  766. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  767. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  768. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  769. },
  770. Deleted: deleted.Valid,
  771. }
  772. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  773. return nil, err
  774. }
  775. if tierCode.Valid {
  776. // See readTier() when this is changed!
  777. user.Tier = &Tier{
  778. ID: tierID.String,
  779. Code: tierCode.String,
  780. Name: tierName.String,
  781. MessageLimit: messagesLimit.Int64,
  782. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  783. EmailLimit: emailsLimit.Int64,
  784. ReservationLimit: reservationsLimit.Int64,
  785. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  786. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  787. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  788. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  789. StripePriceID: stripePriceID.String, // May be empty
  790. }
  791. }
  792. return user, nil
  793. }
  794. // Grants returns all user-specific access control entries
  795. func (a *Manager) Grants(username string) ([]Grant, error) {
  796. rows, err := a.db.Query(selectUserAccessQuery, username)
  797. if err != nil {
  798. return nil, err
  799. }
  800. defer rows.Close()
  801. grants := make([]Grant, 0)
  802. for rows.Next() {
  803. var topic string
  804. var read, write bool
  805. if err := rows.Scan(&topic, &read, &write); err != nil {
  806. return nil, err
  807. } else if err := rows.Err(); err != nil {
  808. return nil, err
  809. }
  810. grants = append(grants, Grant{
  811. TopicPattern: fromSQLWildcard(topic),
  812. Allow: NewPermission(read, write),
  813. })
  814. }
  815. return grants, nil
  816. }
  817. // Reservations returns all user-owned topics, and the associated everyone-access
  818. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  819. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  820. if err != nil {
  821. return nil, err
  822. }
  823. defer rows.Close()
  824. reservations := make([]Reservation, 0)
  825. for rows.Next() {
  826. var topic string
  827. var ownerRead, ownerWrite bool
  828. var everyoneRead, everyoneWrite sql.NullBool
  829. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  830. return nil, err
  831. } else if err := rows.Err(); err != nil {
  832. return nil, err
  833. }
  834. reservations = append(reservations, Reservation{
  835. Topic: topic,
  836. Owner: NewPermission(ownerRead, ownerWrite),
  837. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  838. })
  839. }
  840. return reservations, nil
  841. }
  842. // HasReservation returns true if the given topic access is owned by the user
  843. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  844. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  845. if err != nil {
  846. return false, err
  847. }
  848. defer rows.Close()
  849. if !rows.Next() {
  850. return false, errNoRows
  851. }
  852. var count int64
  853. if err := rows.Scan(&count); err != nil {
  854. return false, err
  855. }
  856. return count > 0, nil
  857. }
  858. // ReservationsCount returns the number of reservations owned by this user
  859. func (a *Manager) ReservationsCount(username string) (int64, error) {
  860. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  861. if err != nil {
  862. return 0, err
  863. }
  864. defer rows.Close()
  865. if !rows.Next() {
  866. return 0, errNoRows
  867. }
  868. var count int64
  869. if err := rows.Scan(&count); err != nil {
  870. return 0, err
  871. }
  872. return count, nil
  873. }
  874. // ChangePassword changes a user's password
  875. func (a *Manager) ChangePassword(username, password string) error {
  876. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  877. if err != nil {
  878. return err
  879. }
  880. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  881. return err
  882. }
  883. return nil
  884. }
  885. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  886. // all existing access control entries (Grant) are removed, since they are no longer needed.
  887. func (a *Manager) ChangeRole(username string, role Role) error {
  888. if !AllowedUsername(username) || !AllowedRole(role) {
  889. return ErrInvalidArgument
  890. }
  891. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  892. return err
  893. }
  894. if role == RoleAdmin {
  895. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  896. return err
  897. }
  898. }
  899. return nil
  900. }
  901. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  902. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  903. func (a *Manager) ChangeTier(username, tier string) error {
  904. if !AllowedUsername(username) {
  905. return ErrInvalidArgument
  906. }
  907. t, err := a.Tier(tier)
  908. if err != nil {
  909. return err
  910. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  911. return err
  912. }
  913. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  914. return err
  915. }
  916. return nil
  917. }
  918. // ResetTier removes the tier from the given user
  919. func (a *Manager) ResetTier(username string) error {
  920. if !AllowedUsername(username) && username != Everyone && username != "" {
  921. return ErrInvalidArgument
  922. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  923. return err
  924. }
  925. _, err := a.db.Exec(deleteUserTierQuery, username)
  926. return err
  927. }
  928. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  929. u, err := a.User(username)
  930. if err != nil {
  931. return err
  932. }
  933. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  934. reservations, err := a.Reservations(username)
  935. if err != nil {
  936. return err
  937. } else if int64(len(reservations)) > reservationsLimit {
  938. return ErrTooManyReservations
  939. }
  940. }
  941. return nil
  942. }
  943. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  944. // If there are any ACL entries that are not owned by the user, an error is returned.
  945. // FIXME is this the same as HasReservation?
  946. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  947. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  948. return ErrInvalidArgument
  949. }
  950. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  951. if err != nil {
  952. return err
  953. }
  954. defer rows.Close()
  955. if !rows.Next() {
  956. return errNoRows
  957. }
  958. var otherCount int
  959. if err := rows.Scan(&otherCount); err != nil {
  960. return err
  961. }
  962. if otherCount > 0 {
  963. return errTopicOwnedByOthers
  964. }
  965. return nil
  966. }
  967. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  968. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  969. // owner may either be a user (username), or the system (empty).
  970. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  971. if !AllowedUsername(username) && username != Everyone {
  972. return ErrInvalidArgument
  973. } else if !AllowedTopicPattern(topicPattern) {
  974. return ErrInvalidArgument
  975. }
  976. owner := ""
  977. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  978. return err
  979. }
  980. return nil
  981. }
  982. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  983. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  984. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  985. if !AllowedUsername(username) && username != Everyone && username != "" {
  986. return ErrInvalidArgument
  987. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  988. return ErrInvalidArgument
  989. }
  990. if username == "" && topicPattern == "" {
  991. _, err := a.db.Exec(deleteAllAccessQuery, username)
  992. return err
  993. } else if topicPattern == "" {
  994. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  995. return err
  996. }
  997. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  998. return err
  999. }
  1000. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  1001. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  1002. // can modify or delete them.
  1003. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  1004. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  1005. return ErrInvalidArgument
  1006. }
  1007. tx, err := a.db.Begin()
  1008. if err != nil {
  1009. return err
  1010. }
  1011. defer tx.Rollback()
  1012. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  1013. return err
  1014. }
  1015. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  1016. return err
  1017. }
  1018. return tx.Commit()
  1019. }
  1020. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  1021. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  1022. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  1023. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  1024. return ErrInvalidArgument
  1025. }
  1026. for _, topic := range topics {
  1027. if !AllowedTopic(topic) {
  1028. return ErrInvalidArgument
  1029. }
  1030. }
  1031. tx, err := a.db.Begin()
  1032. if err != nil {
  1033. return err
  1034. }
  1035. defer tx.Rollback()
  1036. for _, topic := range topics {
  1037. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  1038. return err
  1039. }
  1040. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  1041. return err
  1042. }
  1043. }
  1044. return tx.Commit()
  1045. }
  1046. // DefaultAccess returns the default read/write access if no access control entry matches
  1047. func (a *Manager) DefaultAccess() Permission {
  1048. return a.defaultAccess
  1049. }
  1050. // CreateTier creates a new tier in the database
  1051. func (a *Manager) CreateTier(tier *Tier) error {
  1052. if tier.ID == "" {
  1053. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1054. }
  1055. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID)); err != nil {
  1056. return err
  1057. }
  1058. return nil
  1059. }
  1060. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1061. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1062. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1063. return err
  1064. }
  1065. return nil
  1066. }
  1067. // Tiers returns a list of all Tier structs
  1068. func (a *Manager) Tiers() ([]*Tier, error) {
  1069. rows, err := a.db.Query(selectTiersQuery)
  1070. if err != nil {
  1071. return nil, err
  1072. }
  1073. defer rows.Close()
  1074. tiers := make([]*Tier, 0)
  1075. for {
  1076. tier, err := a.readTier(rows)
  1077. if err == ErrTierNotFound {
  1078. break
  1079. } else if err != nil {
  1080. return nil, err
  1081. }
  1082. tiers = append(tiers, tier)
  1083. }
  1084. return tiers, nil
  1085. }
  1086. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1087. func (a *Manager) Tier(code string) (*Tier, error) {
  1088. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1089. if err != nil {
  1090. return nil, err
  1091. }
  1092. defer rows.Close()
  1093. return a.readTier(rows)
  1094. }
  1095. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1096. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1097. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  1098. if err != nil {
  1099. return nil, err
  1100. }
  1101. defer rows.Close()
  1102. return a.readTier(rows)
  1103. }
  1104. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1105. var id, code, name string
  1106. var stripePriceID sql.NullString
  1107. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1108. if !rows.Next() {
  1109. return nil, ErrTierNotFound
  1110. }
  1111. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  1112. return nil, err
  1113. } else if err := rows.Err(); err != nil {
  1114. return nil, err
  1115. }
  1116. // When changed, note readUser() as well
  1117. return &Tier{
  1118. ID: id,
  1119. Code: code,
  1120. Name: name,
  1121. MessageLimit: messagesLimit.Int64,
  1122. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1123. EmailLimit: emailsLimit.Int64,
  1124. ReservationLimit: reservationsLimit.Int64,
  1125. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1126. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1127. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1128. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1129. StripePriceID: stripePriceID.String, // May be empty
  1130. }, nil
  1131. }
  1132. // Close closes the underlying database
  1133. func (a *Manager) Close() error {
  1134. return a.db.Close()
  1135. }
  1136. func toSQLWildcard(s string) string {
  1137. return strings.ReplaceAll(s, "*", "%")
  1138. }
  1139. func fromSQLWildcard(s string) string {
  1140. return strings.ReplaceAll(s, "%", "*")
  1141. }
  1142. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1143. if _, err := db.Exec(startupQueries); err != nil {
  1144. return err
  1145. }
  1146. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1147. return err
  1148. }
  1149. return nil
  1150. }
  1151. func setupDB(db *sql.DB) error {
  1152. // If 'schemaVersion' table does not exist, this must be a new database
  1153. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1154. if err != nil {
  1155. return setupNewDB(db)
  1156. }
  1157. defer rowsSV.Close()
  1158. // If 'schemaVersion' table exists, read version and potentially upgrade
  1159. schemaVersion := 0
  1160. if !rowsSV.Next() {
  1161. return errors.New("cannot determine schema version: database file may be corrupt")
  1162. }
  1163. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1164. return err
  1165. }
  1166. rowsSV.Close()
  1167. // Do migrations
  1168. if schemaVersion == currentSchemaVersion {
  1169. return nil
  1170. } else if schemaVersion == 1 {
  1171. return migrateFrom1(db)
  1172. }
  1173. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  1174. }
  1175. func setupNewDB(db *sql.DB) error {
  1176. if _, err := db.Exec(createTablesQueries); err != nil {
  1177. return err
  1178. }
  1179. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1180. return err
  1181. }
  1182. return nil
  1183. }
  1184. func migrateFrom1(db *sql.DB) error {
  1185. log.Info("Migrating user database schema: from 1 to 2")
  1186. tx, err := db.Begin()
  1187. if err != nil {
  1188. return err
  1189. }
  1190. defer tx.Rollback()
  1191. // Rename user -> user_old, and create new tables
  1192. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  1193. return err
  1194. }
  1195. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  1196. return err
  1197. }
  1198. // Insert users from user_old into new user table, with ID and sync_topic
  1199. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1200. if err != nil {
  1201. return err
  1202. }
  1203. defer rows.Close()
  1204. usernames := make([]string, 0)
  1205. for rows.Next() {
  1206. var username string
  1207. if err := rows.Scan(&username); err != nil {
  1208. return err
  1209. }
  1210. usernames = append(usernames, username)
  1211. }
  1212. if err := rows.Close(); err != nil {
  1213. return err
  1214. }
  1215. for _, username := range usernames {
  1216. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1217. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1218. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1219. return err
  1220. }
  1221. }
  1222. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1223. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1224. return err
  1225. }
  1226. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1227. return err
  1228. }
  1229. if err := tx.Commit(); err != nil {
  1230. return err
  1231. }
  1232. return nil // Update this when a new version is added
  1233. }
  1234. func nullString(s string) sql.NullString {
  1235. if s == "" {
  1236. return sql.NullString{}
  1237. }
  1238. return sql.NullString{String: s, Valid: true}
  1239. }
  1240. func nullInt64(v int64) sql.NullInt64 {
  1241. if v == 0 {
  1242. return sql.NullInt64{}
  1243. }
  1244. return sql.NullInt64{Int64: v, Valid: true}
  1245. }