1
0

manager.go 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. tierIDPrefix = "ti_"
  18. tierIDLength = 8
  19. syncTopicPrefix = "st_"
  20. syncTopicLength = 16
  21. userIDPrefix = "u_"
  22. userIDLength = 12
  23. userPasswordBcryptCost = 10
  24. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
  25. userStatsQueueWriterInterval = 33 * time.Second
  26. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  27. tokenPrefix = "tk_"
  28. tokenLength = 32
  29. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  30. tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
  31. )
  32. var (
  33. errNoTokenProvided = errors.New("no token provided")
  34. errTopicOwnedByOthers = errors.New("topic owned by others")
  35. errNoRows = errors.New("no rows found")
  36. )
  37. // Manager-related queries
  38. const (
  39. createTablesQueriesNoTx = `
  40. CREATE TABLE IF NOT EXISTS tier (
  41. id TEXT PRIMARY KEY,
  42. code TEXT NOT NULL,
  43. name TEXT NOT NULL,
  44. messages_limit INT NOT NULL,
  45. messages_expiry_duration INT NOT NULL,
  46. emails_limit INT NOT NULL,
  47. reservations_limit INT NOT NULL,
  48. attachment_file_size_limit INT NOT NULL,
  49. attachment_total_size_limit INT NOT NULL,
  50. attachment_expiry_duration INT NOT NULL,
  51. stripe_price_id TEXT
  52. );
  53. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  54. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  55. CREATE TABLE IF NOT EXISTS user (
  56. id TEXT PRIMARY KEY,
  57. tier_id TEXT,
  58. user TEXT NOT NULL,
  59. pass TEXT NOT NULL,
  60. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  61. prefs JSON NOT NULL DEFAULT '{}',
  62. sync_topic TEXT NOT NULL,
  63. stats_messages INT NOT NULL DEFAULT (0),
  64. stats_emails INT NOT NULL DEFAULT (0),
  65. stripe_customer_id TEXT,
  66. stripe_subscription_id TEXT,
  67. stripe_subscription_status TEXT,
  68. stripe_subscription_paid_until INT,
  69. stripe_subscription_cancel_at INT,
  70. created INT NOT NULL,
  71. deleted INT,
  72. FOREIGN KEY (tier_id) REFERENCES tier (id)
  73. );
  74. CREATE UNIQUE INDEX idx_user ON user (user);
  75. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  76. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  77. CREATE TABLE IF NOT EXISTS user_access (
  78. user_id TEXT NOT NULL,
  79. topic TEXT NOT NULL,
  80. read INT NOT NULL,
  81. write INT NOT NULL,
  82. owner_user_id INT,
  83. PRIMARY KEY (user_id, topic),
  84. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  85. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  86. );
  87. CREATE TABLE IF NOT EXISTS user_token (
  88. user_id TEXT NOT NULL,
  89. token TEXT NOT NULL,
  90. expires INT NOT NULL,
  91. PRIMARY KEY (user_id, token),
  92. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  93. );
  94. CREATE TABLE IF NOT EXISTS schemaVersion (
  95. id INT PRIMARY KEY,
  96. version INT NOT NULL
  97. );
  98. INSERT INTO user (id, user, pass, role, sync_topic, created)
  99. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
  100. ON CONFLICT (id) DO NOTHING;
  101. `
  102. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  103. builtinStartupQueries = `
  104. PRAGMA foreign_keys = ON;
  105. `
  106. selectUserByIDQuery = `
  107. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  108. FROM user u
  109. LEFT JOIN tier t on t.id = u.tier_id
  110. WHERE u.id = ?
  111. `
  112. selectUserByNameQuery = `
  113. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  114. FROM user u
  115. LEFT JOIN tier t on t.id = u.tier_id
  116. WHERE user = ?
  117. `
  118. selectUserByTokenQuery = `
  119. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  120. FROM user u
  121. JOIN user_token t on u.id = t.user_id
  122. LEFT JOIN tier t on t.id = u.tier_id
  123. WHERE t.token = ? AND t.expires >= ?
  124. `
  125. selectUserByStripeCustomerIDQuery = `
  126. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  127. FROM user u
  128. LEFT JOIN tier t on t.id = u.tier_id
  129. WHERE u.stripe_customer_id = ?
  130. `
  131. selectTopicPermsQuery = `
  132. SELECT read, write
  133. FROM user_access a
  134. JOIN user u ON u.id = a.user_id
  135. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  136. ORDER BY u.user DESC
  137. `
  138. insertUserQuery = `
  139. INSERT INTO user (id, user, pass, role, sync_topic, created)
  140. VALUES (?, ?, ?, ?, ?, ?)
  141. `
  142. selectUsernamesQuery = `
  143. SELECT user
  144. FROM user
  145. ORDER BY
  146. CASE role
  147. WHEN 'admin' THEN 1
  148. WHEN 'anonymous' THEN 3
  149. ELSE 2
  150. END, user
  151. `
  152. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  153. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  154. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  155. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
  156. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  157. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  158. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  159. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  160. upsertUserAccessQuery = `
  161. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  162. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  163. ON CONFLICT (user_id, topic)
  164. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  165. `
  166. selectUserAccessQuery = `
  167. SELECT topic, read, write
  168. FROM user_access
  169. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  170. ORDER BY write DESC, read DESC, topic
  171. `
  172. selectUserReservationsQuery = `
  173. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  174. FROM user_access a_user
  175. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  176. WHERE a_user.user_id = a_user.owner_user_id
  177. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  178. ORDER BY a_user.topic
  179. `
  180. selectUserReservationsCountQuery = `
  181. SELECT COUNT(*)
  182. FROM user_access
  183. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  184. `
  185. selectUserHasReservationQuery = `
  186. SELECT COUNT(*)
  187. FROM user_access
  188. WHERE user_id = owner_user_id
  189. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  190. AND topic = ?
  191. `
  192. selectOtherAccessCountQuery = `
  193. SELECT COUNT(*)
  194. FROM user_access
  195. WHERE (topic = ? OR ? LIKE topic)
  196. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  197. `
  198. deleteAllAccessQuery = `DELETE FROM user_access`
  199. deleteUserAccessQuery = `
  200. DELETE FROM user_access
  201. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  202. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  203. `
  204. deleteTopicAccessQuery = `
  205. DELETE FROM user_access
  206. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  207. AND topic = ?
  208. `
  209. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  210. insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES (?, ?, ?)`
  211. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  212. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  213. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  214. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
  215. deleteExcessTokensQuery = `
  216. DELETE FROM user_token
  217. WHERE (user_id, token) NOT IN (
  218. SELECT user_id, token
  219. FROM user_token
  220. WHERE user_id = ?
  221. ORDER BY expires DESC
  222. LIMIT ?
  223. )
  224. `
  225. insertTierQuery = `
  226. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
  227. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  228. `
  229. selectTiersQuery = `
  230. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  231. FROM tier
  232. `
  233. selectTierByCodeQuery = `
  234. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  235. FROM tier
  236. WHERE code = ?
  237. `
  238. selectTierByPriceIDQuery = `
  239. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  240. FROM tier
  241. WHERE stripe_price_id = ?
  242. `
  243. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  244. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  245. updateBillingQuery = `
  246. UPDATE user
  247. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  248. WHERE user = ?
  249. `
  250. )
  251. // Schema management queries
  252. const (
  253. currentSchemaVersion = 2
  254. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  255. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  256. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  257. // 1 -> 2 (complex migration!)
  258. migrate1To2RenameUserTableQueryNoTx = `
  259. ALTER TABLE user RENAME TO user_old;
  260. `
  261. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  262. migrate1To2InsertUserNoTx = `
  263. INSERT INTO user (id, user, pass, role, sync_topic, created)
  264. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  265. `
  266. migrate1To2InsertFromOldTablesAndDropNoTx = `
  267. INSERT INTO user_access (user_id, topic, read, write)
  268. SELECT u.id, a.topic, a.read, a.write
  269. FROM user u
  270. JOIN access a ON u.user = a.user;
  271. DROP TABLE access;
  272. DROP TABLE user_old;
  273. `
  274. migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
  275. )
  276. // Manager is an implementation of Manager. It stores users and access control list
  277. // in a SQLite database.
  278. type Manager struct {
  279. db *sql.DB
  280. defaultAccess Permission // Default permission if no ACL matches
  281. statsQueue map[string]*User // Username -> User, for "unimportant" user updates
  282. mu sync.Mutex
  283. }
  284. var _ Auther = (*Manager)(nil)
  285. // NewManager creates a new Manager instance
  286. func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
  287. return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
  288. }
  289. // NewManager creates a new Manager instance
  290. func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
  291. db, err := sql.Open("sqlite3", filename)
  292. if err != nil {
  293. return nil, err
  294. }
  295. if err := setupDB(db); err != nil {
  296. return nil, err
  297. }
  298. if err := runStartupQueries(db, startupQueries); err != nil {
  299. return nil, err
  300. }
  301. manager := &Manager{
  302. db: db,
  303. defaultAccess: defaultAccess,
  304. statsQueue: make(map[string]*User),
  305. }
  306. go manager.userStatsQueueWriter(statsWriterInterval)
  307. return manager, nil
  308. }
  309. // Authenticate checks username and password and returns a User if correct, and the user has not been
  310. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  311. // the password is correct or incorrect.
  312. func (a *Manager) Authenticate(username, password string) (*User, error) {
  313. if username == Everyone {
  314. return nil, ErrUnauthenticated
  315. }
  316. user, err := a.User(username)
  317. if err != nil {
  318. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  319. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  320. return nil, ErrUnauthenticated
  321. } else if user.Deleted {
  322. log.Trace("authentication of user %s failed (2): user marked deleted", username)
  323. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  324. return nil, ErrUnauthenticated
  325. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  326. log.Trace("authentication of user %s failed (3): %s", username, err.Error())
  327. return nil, ErrUnauthenticated
  328. }
  329. return user, nil
  330. }
  331. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  332. // The method sets the User.Token value to the token that was used for authentication.
  333. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  334. if len(token) != tokenLength {
  335. return nil, ErrUnauthenticated
  336. }
  337. user, err := a.userByToken(token)
  338. if err != nil {
  339. return nil, ErrUnauthenticated
  340. }
  341. user.Token = token
  342. return user, nil
  343. }
  344. // CreateToken generates a random token for the given user and returns it. The token expires
  345. // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
  346. // given user, if there are too many of them.
  347. func (a *Manager) CreateToken(user *User) (*Token, error) {
  348. token, expires := util.RandomStringPrefix(tokenPrefix, tokenLength), time.Now().Add(tokenExpiryDuration)
  349. tx, err := a.db.Begin()
  350. if err != nil {
  351. return nil, err
  352. }
  353. defer tx.Rollback()
  354. if _, err := tx.Exec(insertTokenQuery, user.ID, token, expires.Unix()); err != nil {
  355. return nil, err
  356. }
  357. rows, err := tx.Query(selectTokenCountQuery, user.ID)
  358. if err != nil {
  359. return nil, err
  360. }
  361. defer rows.Close()
  362. if !rows.Next() {
  363. return nil, errNoRows
  364. }
  365. var tokenCount int
  366. if err := rows.Scan(&tokenCount); err != nil {
  367. return nil, err
  368. }
  369. if tokenCount >= tokenMaxCount {
  370. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  371. // on two indices, whereas the query below is a full table scan.
  372. if _, err := tx.Exec(deleteExcessTokensQuery, user.ID, tokenMaxCount); err != nil {
  373. return nil, err
  374. }
  375. }
  376. if err := tx.Commit(); err != nil {
  377. return nil, err
  378. }
  379. return &Token{
  380. Value: token,
  381. Expires: expires,
  382. }, nil
  383. }
  384. // ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
  385. func (a *Manager) ExtendToken(user *User) (*Token, error) {
  386. if user.Token == "" {
  387. return nil, errNoTokenProvided
  388. }
  389. newExpires := time.Now().Add(tokenExpiryDuration)
  390. if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
  391. return nil, err
  392. }
  393. return &Token{
  394. Value: user.Token,
  395. Expires: newExpires,
  396. }, nil
  397. }
  398. // RemoveToken deletes the token defined in User.Token
  399. func (a *Manager) RemoveToken(user *User) error {
  400. if user.Token == "" {
  401. return ErrUnauthorized
  402. }
  403. if _, err := a.db.Exec(deleteTokenQuery, user.ID, user.Token); err != nil {
  404. return err
  405. }
  406. return nil
  407. }
  408. // RemoveExpiredTokens deletes all expired tokens from the database
  409. func (a *Manager) RemoveExpiredTokens() error {
  410. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  411. return err
  412. }
  413. return nil
  414. }
  415. // RemoveDeletedUsers deletes all users that have been marked deleted for
  416. func (a *Manager) RemoveDeletedUsers() error {
  417. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  418. return err
  419. }
  420. return nil
  421. }
  422. // ChangeSettings persists the user settings
  423. func (a *Manager) ChangeSettings(user *User) error {
  424. prefs, err := json.Marshal(user.Prefs)
  425. if err != nil {
  426. return err
  427. }
  428. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  429. return err
  430. }
  431. return nil
  432. }
  433. // ResetStats resets all user stats in the user database. This touches all users.
  434. func (a *Manager) ResetStats() error {
  435. a.mu.Lock()
  436. defer a.mu.Unlock()
  437. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  438. return err
  439. }
  440. a.statsQueue = make(map[string]*User)
  441. return nil
  442. }
  443. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  444. // batches at a regular interval
  445. func (a *Manager) EnqueueStats(user *User) {
  446. a.mu.Lock()
  447. defer a.mu.Unlock()
  448. a.statsQueue[user.ID] = user
  449. }
  450. func (a *Manager) userStatsQueueWriter(interval time.Duration) {
  451. ticker := time.NewTicker(interval)
  452. for range ticker.C {
  453. if err := a.writeUserStatsQueue(); err != nil {
  454. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  455. }
  456. }
  457. }
  458. func (a *Manager) writeUserStatsQueue() error {
  459. a.mu.Lock()
  460. if len(a.statsQueue) == 0 {
  461. a.mu.Unlock()
  462. log.Trace("User Manager: No user stats updates to commit")
  463. return nil
  464. }
  465. statsQueue := a.statsQueue
  466. a.statsQueue = make(map[string]*User)
  467. a.mu.Unlock()
  468. tx, err := a.db.Begin()
  469. if err != nil {
  470. return err
  471. }
  472. defer tx.Rollback()
  473. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  474. for userID, u := range statsQueue {
  475. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, u.Stats.Messages, u.Stats.Emails)
  476. if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, userID); err != nil {
  477. return err
  478. }
  479. }
  480. return tx.Commit()
  481. }
  482. // Authorize returns nil if the given user has access to the given topic using the desired
  483. // permission. The user param may be nil to signal an anonymous user.
  484. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  485. if user != nil && user.Role == RoleAdmin {
  486. return nil // Admin can do everything
  487. }
  488. username := Everyone
  489. if user != nil {
  490. username = user.Name
  491. }
  492. // Select the read/write permissions for this user/topic combo. The query may return two
  493. // rows (one for everyone, and one for the user), but prioritizes the user.
  494. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  495. if err != nil {
  496. return err
  497. }
  498. defer rows.Close()
  499. if !rows.Next() {
  500. return a.resolvePerms(a.defaultAccess, perm)
  501. }
  502. var read, write bool
  503. if err := rows.Scan(&read, &write); err != nil {
  504. return err
  505. } else if err := rows.Err(); err != nil {
  506. return err
  507. }
  508. return a.resolvePerms(NewPermission(read, write), perm)
  509. }
  510. func (a *Manager) resolvePerms(base, perm Permission) error {
  511. if perm == PermissionRead && base.IsRead() {
  512. return nil
  513. } else if perm == PermissionWrite && base.IsWrite() {
  514. return nil
  515. }
  516. return ErrUnauthorized
  517. }
  518. // AddUser adds a user with the given username, password and role
  519. func (a *Manager) AddUser(username, password string, role Role) error {
  520. if !AllowedUsername(username) || !AllowedRole(role) {
  521. return ErrInvalidArgument
  522. }
  523. hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
  524. if err != nil {
  525. return err
  526. }
  527. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  528. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  529. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
  530. return err
  531. }
  532. return nil
  533. }
  534. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  535. // if the user did not exist in the first place.
  536. func (a *Manager) RemoveUser(username string) error {
  537. if !AllowedUsername(username) {
  538. return ErrInvalidArgument
  539. }
  540. // Rows in user_access, user_token, etc. are deleted via foreign keys
  541. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  542. return err
  543. }
  544. return nil
  545. }
  546. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  547. // successful auth via Authenticate. A background process will delete the user at a later date.
  548. func (a *Manager) MarkUserRemoved(user *User) error {
  549. if !AllowedUsername(user.Name) {
  550. return ErrInvalidArgument
  551. }
  552. tx, err := a.db.Begin()
  553. if err != nil {
  554. return err
  555. }
  556. defer tx.Rollback()
  557. if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  558. return err
  559. }
  560. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  561. return err
  562. }
  563. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  564. return err
  565. }
  566. return tx.Commit()
  567. }
  568. // Users returns a list of users. It always also returns the Everyone user ("*").
  569. func (a *Manager) Users() ([]*User, error) {
  570. rows, err := a.db.Query(selectUsernamesQuery)
  571. if err != nil {
  572. return nil, err
  573. }
  574. defer rows.Close()
  575. usernames := make([]string, 0)
  576. for rows.Next() {
  577. var username string
  578. if err := rows.Scan(&username); err != nil {
  579. return nil, err
  580. } else if err := rows.Err(); err != nil {
  581. return nil, err
  582. }
  583. usernames = append(usernames, username)
  584. }
  585. rows.Close()
  586. users := make([]*User, 0)
  587. for _, username := range usernames {
  588. user, err := a.User(username)
  589. if err != nil {
  590. return nil, err
  591. }
  592. users = append(users, user)
  593. }
  594. return users, nil
  595. }
  596. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  597. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  598. func (a *Manager) User(username string) (*User, error) {
  599. rows, err := a.db.Query(selectUserByNameQuery, username)
  600. if err != nil {
  601. return nil, err
  602. }
  603. return a.readUser(rows)
  604. }
  605. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  606. func (a *Manager) UserByID(id string) (*User, error) {
  607. rows, err := a.db.Query(selectUserByIDQuery, id)
  608. if err != nil {
  609. return nil, err
  610. }
  611. return a.readUser(rows)
  612. }
  613. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  614. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  615. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  616. if err != nil {
  617. return nil, err
  618. }
  619. return a.readUser(rows)
  620. }
  621. func (a *Manager) userByToken(token string) (*User, error) {
  622. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  623. if err != nil {
  624. return nil, err
  625. }
  626. return a.readUser(rows)
  627. }
  628. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  629. defer rows.Close()
  630. var id, username, hash, role, prefs, syncTopic string
  631. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
  632. var messages, emails int64
  633. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  634. if !rows.Next() {
  635. return nil, ErrUserNotFound
  636. }
  637. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  638. return nil, err
  639. } else if err := rows.Err(); err != nil {
  640. return nil, err
  641. }
  642. user := &User{
  643. ID: id,
  644. Name: username,
  645. Hash: hash,
  646. Role: Role(role),
  647. Prefs: &Prefs{},
  648. SyncTopic: syncTopic,
  649. Stats: &Stats{
  650. Messages: messages,
  651. Emails: emails,
  652. },
  653. Billing: &Billing{
  654. StripeCustomerID: stripeCustomerID.String, // May be empty
  655. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  656. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  657. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  658. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  659. },
  660. Deleted: deleted.Valid,
  661. }
  662. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  663. return nil, err
  664. }
  665. if tierCode.Valid {
  666. // See readTier() when this is changed!
  667. user.Tier = &Tier{
  668. Code: tierCode.String,
  669. Name: tierName.String,
  670. MessagesLimit: messagesLimit.Int64,
  671. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  672. EmailsLimit: emailsLimit.Int64,
  673. ReservationsLimit: reservationsLimit.Int64,
  674. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  675. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  676. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  677. StripePriceID: stripePriceID.String, // May be empty
  678. }
  679. }
  680. return user, nil
  681. }
  682. // Grants returns all user-specific access control entries
  683. func (a *Manager) Grants(username string) ([]Grant, error) {
  684. rows, err := a.db.Query(selectUserAccessQuery, username)
  685. if err != nil {
  686. return nil, err
  687. }
  688. defer rows.Close()
  689. grants := make([]Grant, 0)
  690. for rows.Next() {
  691. var topic string
  692. var read, write bool
  693. if err := rows.Scan(&topic, &read, &write); err != nil {
  694. return nil, err
  695. } else if err := rows.Err(); err != nil {
  696. return nil, err
  697. }
  698. grants = append(grants, Grant{
  699. TopicPattern: fromSQLWildcard(topic),
  700. Allow: NewPermission(read, write),
  701. })
  702. }
  703. return grants, nil
  704. }
  705. // Reservations returns all user-owned topics, and the associated everyone-access
  706. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  707. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  708. if err != nil {
  709. return nil, err
  710. }
  711. defer rows.Close()
  712. reservations := make([]Reservation, 0)
  713. for rows.Next() {
  714. var topic string
  715. var ownerRead, ownerWrite bool
  716. var everyoneRead, everyoneWrite sql.NullBool
  717. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  718. return nil, err
  719. } else if err := rows.Err(); err != nil {
  720. return nil, err
  721. }
  722. reservations = append(reservations, Reservation{
  723. Topic: topic,
  724. Owner: NewPermission(ownerRead, ownerWrite),
  725. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  726. })
  727. }
  728. return reservations, nil
  729. }
  730. // HasReservation returns true if the given topic access is owned by the user
  731. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  732. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  733. if err != nil {
  734. return false, err
  735. }
  736. defer rows.Close()
  737. if !rows.Next() {
  738. return false, errNoRows
  739. }
  740. var count int64
  741. if err := rows.Scan(&count); err != nil {
  742. return false, err
  743. }
  744. return count > 0, nil
  745. }
  746. // ReservationsCount returns the number of reservations owned by this user
  747. func (a *Manager) ReservationsCount(username string) (int64, error) {
  748. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  749. if err != nil {
  750. return 0, err
  751. }
  752. defer rows.Close()
  753. if !rows.Next() {
  754. return 0, errNoRows
  755. }
  756. var count int64
  757. if err := rows.Scan(&count); err != nil {
  758. return 0, err
  759. }
  760. return count, nil
  761. }
  762. // ChangePassword changes a user's password
  763. func (a *Manager) ChangePassword(username, password string) error {
  764. hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
  765. if err != nil {
  766. return err
  767. }
  768. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  769. return err
  770. }
  771. return nil
  772. }
  773. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  774. // all existing access control entries (Grant) are removed, since they are no longer needed.
  775. func (a *Manager) ChangeRole(username string, role Role) error {
  776. if !AllowedUsername(username) || !AllowedRole(role) {
  777. return ErrInvalidArgument
  778. }
  779. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  780. return err
  781. }
  782. if role == RoleAdmin {
  783. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  784. return err
  785. }
  786. }
  787. return nil
  788. }
  789. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  790. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  791. func (a *Manager) ChangeTier(username, tier string) error {
  792. if !AllowedUsername(username) {
  793. return ErrInvalidArgument
  794. }
  795. t, err := a.Tier(tier)
  796. if err != nil {
  797. return err
  798. } else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
  799. return err
  800. }
  801. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  802. return err
  803. }
  804. return nil
  805. }
  806. // ResetTier removes the tier from the given user
  807. func (a *Manager) ResetTier(username string) error {
  808. if !AllowedUsername(username) && username != Everyone && username != "" {
  809. return ErrInvalidArgument
  810. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  811. return err
  812. }
  813. _, err := a.db.Exec(deleteUserTierQuery, username)
  814. return err
  815. }
  816. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  817. u, err := a.User(username)
  818. if err != nil {
  819. return err
  820. }
  821. if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
  822. reservations, err := a.Reservations(username)
  823. if err != nil {
  824. return err
  825. } else if int64(len(reservations)) > reservationsLimit {
  826. return ErrTooManyReservations
  827. }
  828. }
  829. return nil
  830. }
  831. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  832. // If there are any ACL entries that are not owned by the user, an error is returned.
  833. // FIXME is this the same as HasReservation?
  834. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  835. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  836. return ErrInvalidArgument
  837. }
  838. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  839. if err != nil {
  840. return err
  841. }
  842. defer rows.Close()
  843. if !rows.Next() {
  844. return errNoRows
  845. }
  846. var otherCount int
  847. if err := rows.Scan(&otherCount); err != nil {
  848. return err
  849. }
  850. if otherCount > 0 {
  851. return errTopicOwnedByOthers
  852. }
  853. return nil
  854. }
  855. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  856. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  857. // owner may either be a user (username), or the system (empty).
  858. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  859. if !AllowedUsername(username) && username != Everyone {
  860. return ErrInvalidArgument
  861. } else if !AllowedTopicPattern(topicPattern) {
  862. return ErrInvalidArgument
  863. }
  864. owner := ""
  865. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  866. return err
  867. }
  868. return nil
  869. }
  870. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  871. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  872. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  873. if !AllowedUsername(username) && username != Everyone && username != "" {
  874. return ErrInvalidArgument
  875. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  876. return ErrInvalidArgument
  877. }
  878. if username == "" && topicPattern == "" {
  879. _, err := a.db.Exec(deleteAllAccessQuery, username)
  880. return err
  881. } else if topicPattern == "" {
  882. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  883. return err
  884. }
  885. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  886. return err
  887. }
  888. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  889. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  890. // can modify or delete them.
  891. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  892. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  893. return ErrInvalidArgument
  894. }
  895. tx, err := a.db.Begin()
  896. if err != nil {
  897. return err
  898. }
  899. defer tx.Rollback()
  900. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  901. return err
  902. }
  903. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  904. return err
  905. }
  906. return tx.Commit()
  907. }
  908. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  909. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  910. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  911. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  912. return ErrInvalidArgument
  913. }
  914. for _, topic := range topics {
  915. if !AllowedTopic(topic) {
  916. return ErrInvalidArgument
  917. }
  918. }
  919. tx, err := a.db.Begin()
  920. if err != nil {
  921. return err
  922. }
  923. defer tx.Rollback()
  924. for _, topic := range topics {
  925. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  926. return err
  927. }
  928. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  929. return err
  930. }
  931. }
  932. return tx.Commit()
  933. }
  934. // DefaultAccess returns the default read/write access if no access control entry matches
  935. func (a *Manager) DefaultAccess() Permission {
  936. return a.defaultAccess
  937. }
  938. // CreateTier creates a new tier in the database
  939. func (a *Manager) CreateTier(tier *Tier) error {
  940. tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  941. if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
  942. return err
  943. }
  944. return nil
  945. }
  946. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  947. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  948. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  949. return err
  950. }
  951. return nil
  952. }
  953. // Tiers returns a list of all Tier structs
  954. func (a *Manager) Tiers() ([]*Tier, error) {
  955. rows, err := a.db.Query(selectTiersQuery)
  956. if err != nil {
  957. return nil, err
  958. }
  959. defer rows.Close()
  960. tiers := make([]*Tier, 0)
  961. for {
  962. tier, err := a.readTier(rows)
  963. if err == ErrTierNotFound {
  964. break
  965. } else if err != nil {
  966. return nil, err
  967. }
  968. tiers = append(tiers, tier)
  969. }
  970. return tiers, nil
  971. }
  972. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  973. func (a *Manager) Tier(code string) (*Tier, error) {
  974. rows, err := a.db.Query(selectTierByCodeQuery, code)
  975. if err != nil {
  976. return nil, err
  977. }
  978. defer rows.Close()
  979. return a.readTier(rows)
  980. }
  981. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  982. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  983. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  984. if err != nil {
  985. return nil, err
  986. }
  987. defer rows.Close()
  988. return a.readTier(rows)
  989. }
  990. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  991. var id, code, name string
  992. var stripePriceID sql.NullString
  993. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
  994. if !rows.Next() {
  995. return nil, ErrTierNotFound
  996. }
  997. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  998. return nil, err
  999. } else if err := rows.Err(); err != nil {
  1000. return nil, err
  1001. }
  1002. // When changed, note readUser() as well
  1003. return &Tier{
  1004. ID: id,
  1005. Code: code,
  1006. Name: name,
  1007. MessagesLimit: messagesLimit.Int64,
  1008. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1009. EmailsLimit: emailsLimit.Int64,
  1010. ReservationsLimit: reservationsLimit.Int64,
  1011. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1012. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1013. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1014. StripePriceID: stripePriceID.String, // May be empty
  1015. }, nil
  1016. }
  1017. func toSQLWildcard(s string) string {
  1018. return strings.ReplaceAll(s, "*", "%")
  1019. }
  1020. func fromSQLWildcard(s string) string {
  1021. return strings.ReplaceAll(s, "%", "*")
  1022. }
  1023. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1024. if _, err := db.Exec(startupQueries); err != nil {
  1025. return err
  1026. }
  1027. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1028. return err
  1029. }
  1030. return nil
  1031. }
  1032. func setupDB(db *sql.DB) error {
  1033. // If 'schemaVersion' table does not exist, this must be a new database
  1034. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1035. if err != nil {
  1036. return setupNewDB(db)
  1037. }
  1038. defer rowsSV.Close()
  1039. // If 'schemaVersion' table exists, read version and potentially upgrade
  1040. schemaVersion := 0
  1041. if !rowsSV.Next() {
  1042. return errors.New("cannot determine schema version: database file may be corrupt")
  1043. }
  1044. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1045. return err
  1046. }
  1047. rowsSV.Close()
  1048. // Do migrations
  1049. if schemaVersion == currentSchemaVersion {
  1050. return nil
  1051. } else if schemaVersion == 1 {
  1052. return migrateFrom1(db)
  1053. }
  1054. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  1055. }
  1056. func setupNewDB(db *sql.DB) error {
  1057. if _, err := db.Exec(createTablesQueries); err != nil {
  1058. return err
  1059. }
  1060. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1061. return err
  1062. }
  1063. return nil
  1064. }
  1065. func migrateFrom1(db *sql.DB) error {
  1066. log.Info("Migrating user database schema: from 1 to 2")
  1067. tx, err := db.Begin()
  1068. if err != nil {
  1069. return err
  1070. }
  1071. defer tx.Rollback()
  1072. // Rename user -> user_old, and create new tables
  1073. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  1074. return err
  1075. }
  1076. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  1077. return err
  1078. }
  1079. // Insert users from user_old into new user table, with ID and sync_topic
  1080. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1081. if err != nil {
  1082. return err
  1083. }
  1084. defer rows.Close()
  1085. usernames := make([]string, 0)
  1086. for rows.Next() {
  1087. var username string
  1088. if err := rows.Scan(&username); err != nil {
  1089. return err
  1090. }
  1091. usernames = append(usernames, username)
  1092. }
  1093. if err := rows.Close(); err != nil {
  1094. return err
  1095. }
  1096. for _, username := range usernames {
  1097. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1098. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1099. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1100. return err
  1101. }
  1102. }
  1103. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1104. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1105. return err
  1106. }
  1107. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1108. return err
  1109. }
  1110. if err := tx.Commit(); err != nil {
  1111. return err
  1112. }
  1113. return nil // Update this when a new version is added
  1114. }
  1115. func nullString(s string) sql.NullString {
  1116. if s == "" {
  1117. return sql.NullString{}
  1118. }
  1119. return sql.NullString{String: s, Valid: true}
  1120. }
  1121. func nullInt64(v int64) sql.NullInt64 {
  1122. if v == 0 {
  1123. return sql.NullInt64{}
  1124. }
  1125. return sql.NullInt64{Int64: v, Valid: true}
  1126. }