manager.go 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. tierIDPrefix = "ti_"
  18. tierIDLength = 8
  19. syncTopicPrefix = "st_"
  20. syncTopicLength = 16
  21. userIDPrefix = "u_"
  22. userIDLength = 12
  23. userPasswordBcryptCost = 10
  24. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match userPasswordBcryptCost
  25. userStatsQueueWriterInterval = 33 * time.Second
  26. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  27. tokenPrefix = "tk_"
  28. tokenLength = 32
  29. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  30. )
  31. var (
  32. errNoTokenProvided = errors.New("no token provided")
  33. errTopicOwnedByOthers = errors.New("topic owned by others")
  34. errNoRows = errors.New("no rows found")
  35. )
  36. // Manager-related queries
  37. const (
  38. createTablesQueriesNoTx = `
  39. CREATE TABLE IF NOT EXISTS tier (
  40. id TEXT PRIMARY KEY,
  41. code TEXT NOT NULL,
  42. name TEXT NOT NULL,
  43. messages_limit INT NOT NULL,
  44. messages_expiry_duration INT NOT NULL,
  45. emails_limit INT NOT NULL,
  46. reservations_limit INT NOT NULL,
  47. attachment_file_size_limit INT NOT NULL,
  48. attachment_total_size_limit INT NOT NULL,
  49. attachment_expiry_duration INT NOT NULL,
  50. attachment_bandwidth_limit INT NOT NULL,
  51. stripe_price_id TEXT
  52. );
  53. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  54. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  55. CREATE TABLE IF NOT EXISTS user (
  56. id TEXT PRIMARY KEY,
  57. tier_id TEXT,
  58. user TEXT NOT NULL,
  59. pass TEXT NOT NULL,
  60. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  61. prefs JSON NOT NULL DEFAULT '{}',
  62. sync_topic TEXT NOT NULL,
  63. stats_messages INT NOT NULL DEFAULT (0),
  64. stats_emails INT NOT NULL DEFAULT (0),
  65. stripe_customer_id TEXT,
  66. stripe_subscription_id TEXT,
  67. stripe_subscription_status TEXT,
  68. stripe_subscription_paid_until INT,
  69. stripe_subscription_cancel_at INT,
  70. created INT NOT NULL,
  71. deleted INT,
  72. FOREIGN KEY (tier_id) REFERENCES tier (id)
  73. );
  74. CREATE UNIQUE INDEX idx_user ON user (user);
  75. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  76. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  77. CREATE TABLE IF NOT EXISTS user_access (
  78. user_id TEXT NOT NULL,
  79. topic TEXT NOT NULL,
  80. read INT NOT NULL,
  81. write INT NOT NULL,
  82. owner_user_id INT,
  83. PRIMARY KEY (user_id, topic),
  84. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  85. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  86. );
  87. CREATE TABLE IF NOT EXISTS user_token (
  88. user_id TEXT NOT NULL,
  89. token TEXT NOT NULL,
  90. label TEXT NOT NULL,
  91. expires INT NOT NULL,
  92. PRIMARY KEY (user_id, token),
  93. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  94. );
  95. CREATE TABLE IF NOT EXISTS schemaVersion (
  96. id INT PRIMARY KEY,
  97. version INT NOT NULL
  98. );
  99. INSERT INTO user (id, user, pass, role, sync_topic, created)
  100. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
  101. ON CONFLICT (id) DO NOTHING;
  102. `
  103. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  104. builtinStartupQueries = `
  105. PRAGMA foreign_keys = ON;
  106. `
  107. selectUserByIDQuery = `
  108. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  109. FROM user u
  110. LEFT JOIN tier t on t.id = u.tier_id
  111. WHERE u.id = ?
  112. `
  113. selectUserByNameQuery = `
  114. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  115. FROM user u
  116. LEFT JOIN tier t on t.id = u.tier_id
  117. WHERE user = ?
  118. `
  119. selectUserByTokenQuery = `
  120. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  121. FROM user u
  122. JOIN user_token t on u.id = t.user_id
  123. LEFT JOIN tier t on t.id = u.tier_id
  124. WHERE t.token = ? AND (t.expires = 0 OR t.expires >= ?)
  125. `
  126. selectUserByStripeCustomerIDQuery = `
  127. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  128. FROM user u
  129. LEFT JOIN tier t on t.id = u.tier_id
  130. WHERE u.stripe_customer_id = ?
  131. `
  132. selectTopicPermsQuery = `
  133. SELECT read, write
  134. FROM user_access a
  135. JOIN user u ON u.id = a.user_id
  136. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  137. ORDER BY u.user DESC
  138. `
  139. insertUserQuery = `
  140. INSERT INTO user (id, user, pass, role, sync_topic, created)
  141. VALUES (?, ?, ?, ?, ?, ?)
  142. `
  143. selectUsernamesQuery = `
  144. SELECT user
  145. FROM user
  146. ORDER BY
  147. CASE role
  148. WHEN 'admin' THEN 1
  149. WHEN 'anonymous' THEN 3
  150. ELSE 2
  151. END, user
  152. `
  153. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  154. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  155. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  156. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
  157. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  158. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  159. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  160. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  161. upsertUserAccessQuery = `
  162. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  163. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  164. ON CONFLICT (user_id, topic)
  165. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  166. `
  167. selectUserAccessQuery = `
  168. SELECT topic, read, write
  169. FROM user_access
  170. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  171. ORDER BY write DESC, read DESC, topic
  172. `
  173. selectUserReservationsQuery = `
  174. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  175. FROM user_access a_user
  176. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  177. WHERE a_user.user_id = a_user.owner_user_id
  178. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  179. ORDER BY a_user.topic
  180. `
  181. selectUserReservationsCountQuery = `
  182. SELECT COUNT(*)
  183. FROM user_access
  184. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  185. `
  186. selectUserHasReservationQuery = `
  187. SELECT COUNT(*)
  188. FROM user_access
  189. WHERE user_id = owner_user_id
  190. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  191. AND topic = ?
  192. `
  193. selectOtherAccessCountQuery = `
  194. SELECT COUNT(*)
  195. FROM user_access
  196. WHERE (topic = ? OR ? LIKE topic)
  197. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  198. `
  199. deleteAllAccessQuery = `DELETE FROM user_access`
  200. deleteUserAccessQuery = `
  201. DELETE FROM user_access
  202. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  203. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  204. `
  205. deleteTopicAccessQuery = `
  206. DELETE FROM user_access
  207. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  208. AND topic = ?
  209. `
  210. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  211. selectTokensQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ?`
  212. selectTokenQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ? AND token = ?`
  213. insertTokenQuery = `INSERT INTO user_token (user_id, token, label, expires) VALUES (?, ?, ?, ?)`
  214. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  215. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  216. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  217. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  218. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  219. deleteExcessTokensQuery = `
  220. DELETE FROM user_token
  221. WHERE (user_id, token) NOT IN (
  222. SELECT user_id, token
  223. FROM user_token
  224. WHERE user_id = ?
  225. ORDER BY expires DESC
  226. LIMIT ?
  227. )
  228. `
  229. insertTierQuery = `
  230. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
  231. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  232. `
  233. selectTiersQuery = `
  234. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  235. FROM tier
  236. `
  237. selectTierByCodeQuery = `
  238. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  239. FROM tier
  240. WHERE code = ?
  241. `
  242. selectTierByPriceIDQuery = `
  243. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  244. FROM tier
  245. WHERE stripe_price_id = ?
  246. `
  247. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  248. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  249. updateBillingQuery = `
  250. UPDATE user
  251. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  252. WHERE user = ?
  253. `
  254. )
  255. // Schema management queries
  256. const (
  257. currentSchemaVersion = 2
  258. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  259. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  260. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  261. // 1 -> 2 (complex migration!)
  262. migrate1To2RenameUserTableQueryNoTx = `
  263. ALTER TABLE user RENAME TO user_old;
  264. `
  265. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  266. migrate1To2InsertUserNoTx = `
  267. INSERT INTO user (id, user, pass, role, sync_topic, created)
  268. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  269. `
  270. migrate1To2InsertFromOldTablesAndDropNoTx = `
  271. INSERT INTO user_access (user_id, topic, read, write)
  272. SELECT u.id, a.topic, a.read, a.write
  273. FROM user u
  274. JOIN access a ON u.user = a.user;
  275. DROP TABLE access;
  276. DROP TABLE user_old;
  277. `
  278. )
  279. // Manager is an implementation of Manager. It stores users and access control list
  280. // in a SQLite database.
  281. type Manager struct {
  282. db *sql.DB
  283. defaultAccess Permission // Default permission if no ACL matches
  284. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  285. mu sync.Mutex
  286. }
  287. var _ Auther = (*Manager)(nil)
  288. // NewManager creates a new Manager instance
  289. func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
  290. return NewManagerWithStatsInterval(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
  291. }
  292. // NewManagerWithStatsInterval creates a new Manager instance
  293. func NewManagerWithStatsInterval(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
  294. db, err := sql.Open("sqlite3", filename)
  295. if err != nil {
  296. return nil, err
  297. }
  298. if err := setupDB(db); err != nil {
  299. return nil, err
  300. }
  301. if err := runStartupQueries(db, startupQueries); err != nil {
  302. return nil, err
  303. }
  304. manager := &Manager{
  305. db: db,
  306. defaultAccess: defaultAccess,
  307. statsQueue: make(map[string]*Stats),
  308. }
  309. go manager.userStatsQueueWriter(statsWriterInterval)
  310. return manager, nil
  311. }
  312. // Authenticate checks username and password and returns a User if correct, and the user has not been
  313. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  314. // the password is correct or incorrect.
  315. func (a *Manager) Authenticate(username, password string) (*User, error) {
  316. if username == Everyone {
  317. return nil, ErrUnauthenticated
  318. }
  319. user, err := a.User(username)
  320. if err != nil {
  321. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  322. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  323. return nil, ErrUnauthenticated
  324. } else if user.Deleted {
  325. log.Trace("authentication of user %s failed (2): user marked deleted", username)
  326. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  327. return nil, ErrUnauthenticated
  328. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  329. log.Trace("authentication of user %s failed (3): %s", username, err.Error())
  330. return nil, ErrUnauthenticated
  331. }
  332. return user, nil
  333. }
  334. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  335. // The method sets the User.Token value to the token that was used for authentication.
  336. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  337. if len(token) != tokenLength {
  338. return nil, ErrUnauthenticated
  339. }
  340. user, err := a.userByToken(token)
  341. if err != nil {
  342. return nil, ErrUnauthenticated
  343. }
  344. user.Token = token
  345. return user, nil
  346. }
  347. // CreateToken generates a random token for the given user and returns it. The token expires
  348. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  349. // given user, if there are too many of them.
  350. func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token, error) {
  351. token := util.RandomStringPrefix(tokenPrefix, tokenLength)
  352. tx, err := a.db.Begin()
  353. if err != nil {
  354. return nil, err
  355. }
  356. defer tx.Rollback()
  357. if _, err := tx.Exec(insertTokenQuery, userID, token, label, expires.Unix()); err != nil {
  358. return nil, err
  359. }
  360. rows, err := tx.Query(selectTokenCountQuery, userID)
  361. if err != nil {
  362. return nil, err
  363. }
  364. defer rows.Close()
  365. if !rows.Next() {
  366. return nil, errNoRows
  367. }
  368. var tokenCount int
  369. if err := rows.Scan(&tokenCount); err != nil {
  370. return nil, err
  371. }
  372. if tokenCount >= tokenMaxCount {
  373. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  374. // on two indices, whereas the query below is a full table scan.
  375. if _, err := tx.Exec(deleteExcessTokensQuery, userID, tokenMaxCount); err != nil {
  376. return nil, err
  377. }
  378. }
  379. if err := tx.Commit(); err != nil {
  380. return nil, err
  381. }
  382. return &Token{
  383. Value: token,
  384. Label: label,
  385. Expires: expires,
  386. }, nil
  387. }
  388. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  389. rows, err := a.db.Query(selectTokensQuery, userID)
  390. if err != nil {
  391. return nil, err
  392. }
  393. defer rows.Close()
  394. tokens := make([]*Token, 0)
  395. for {
  396. token, err := a.readToken(rows)
  397. if err == ErrTokenNotFound {
  398. break
  399. } else if err != nil {
  400. return nil, err
  401. }
  402. tokens = append(tokens, token)
  403. }
  404. return tokens, nil
  405. }
  406. func (a *Manager) Token(userID, token string) (*Token, error) {
  407. rows, err := a.db.Query(selectTokenQuery, userID, token)
  408. if err != nil {
  409. return nil, err
  410. }
  411. defer rows.Close()
  412. return a.readToken(rows)
  413. }
  414. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  415. var token, label string
  416. var expires int64
  417. if !rows.Next() {
  418. return nil, ErrTokenNotFound
  419. }
  420. if err := rows.Scan(&token, &label, &expires); err != nil {
  421. return nil, err
  422. } else if err := rows.Err(); err != nil {
  423. return nil, err
  424. }
  425. return &Token{
  426. Value: token,
  427. Label: label,
  428. Expires: time.Unix(expires, 0),
  429. }, nil
  430. }
  431. // ChangeToken updates a token's label and/or expiry date
  432. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  433. if token == "" {
  434. return nil, errNoTokenProvided
  435. }
  436. tx, err := a.db.Begin()
  437. if err != nil {
  438. return nil, err
  439. }
  440. defer tx.Rollback()
  441. if label != nil {
  442. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  443. return nil, err
  444. }
  445. }
  446. if expires != nil {
  447. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  448. return nil, err
  449. }
  450. }
  451. if err := tx.Commit(); err != nil {
  452. return nil, err
  453. }
  454. return a.Token(userID, token)
  455. }
  456. // RemoveToken deletes the token defined in User.Token
  457. func (a *Manager) RemoveToken(userID, token string) error {
  458. if token == "" {
  459. return errNoTokenProvided
  460. }
  461. if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil {
  462. return err
  463. }
  464. return nil
  465. }
  466. // RemoveExpiredTokens deletes all expired tokens from the database
  467. func (a *Manager) RemoveExpiredTokens() error {
  468. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  469. return err
  470. }
  471. return nil
  472. }
  473. // RemoveDeletedUsers deletes all users that have been marked deleted for
  474. func (a *Manager) RemoveDeletedUsers() error {
  475. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  476. return err
  477. }
  478. return nil
  479. }
  480. // ChangeSettings persists the user settings
  481. func (a *Manager) ChangeSettings(user *User) error {
  482. prefs, err := json.Marshal(user.Prefs)
  483. if err != nil {
  484. return err
  485. }
  486. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  487. return err
  488. }
  489. return nil
  490. }
  491. // ResetStats resets all user stats in the user database. This touches all users.
  492. func (a *Manager) ResetStats() error {
  493. a.mu.Lock()
  494. defer a.mu.Unlock()
  495. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  496. return err
  497. }
  498. a.statsQueue = make(map[string]*Stats)
  499. return nil
  500. }
  501. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  502. // batches at a regular interval
  503. func (a *Manager) EnqueueStats(userID string, stats *Stats) {
  504. a.mu.Lock()
  505. defer a.mu.Unlock()
  506. a.statsQueue[userID] = stats
  507. }
  508. func (a *Manager) userStatsQueueWriter(interval time.Duration) {
  509. ticker := time.NewTicker(interval)
  510. for range ticker.C {
  511. if err := a.writeUserStatsQueue(); err != nil {
  512. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  513. }
  514. }
  515. }
  516. func (a *Manager) writeUserStatsQueue() error {
  517. a.mu.Lock()
  518. if len(a.statsQueue) == 0 {
  519. a.mu.Unlock()
  520. log.Trace("User Manager: No user stats updates to commit")
  521. return nil
  522. }
  523. statsQueue := a.statsQueue
  524. a.statsQueue = make(map[string]*Stats)
  525. a.mu.Unlock()
  526. tx, err := a.db.Begin()
  527. if err != nil {
  528. return err
  529. }
  530. defer tx.Rollback()
  531. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  532. for userID, update := range statsQueue {
  533. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", userID, update.Messages, update.Emails)
  534. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil {
  535. return err
  536. }
  537. }
  538. return tx.Commit()
  539. }
  540. // Authorize returns nil if the given user has access to the given topic using the desired
  541. // permission. The user param may be nil to signal an anonymous user.
  542. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  543. if user != nil && user.Role == RoleAdmin {
  544. return nil // Admin can do everything
  545. }
  546. username := Everyone
  547. if user != nil {
  548. username = user.Name
  549. }
  550. // Select the read/write permissions for this user/topic combo. The query may return two
  551. // rows (one for everyone, and one for the user), but prioritizes the user.
  552. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  553. if err != nil {
  554. return err
  555. }
  556. defer rows.Close()
  557. if !rows.Next() {
  558. return a.resolvePerms(a.defaultAccess, perm)
  559. }
  560. var read, write bool
  561. if err := rows.Scan(&read, &write); err != nil {
  562. return err
  563. } else if err := rows.Err(); err != nil {
  564. return err
  565. }
  566. return a.resolvePerms(NewPermission(read, write), perm)
  567. }
  568. func (a *Manager) resolvePerms(base, perm Permission) error {
  569. if perm == PermissionRead && base.IsRead() {
  570. return nil
  571. } else if perm == PermissionWrite && base.IsWrite() {
  572. return nil
  573. }
  574. return ErrUnauthorized
  575. }
  576. // AddUser adds a user with the given username, password and role
  577. func (a *Manager) AddUser(username, password string, role Role) error {
  578. if !AllowedUsername(username) || !AllowedRole(role) {
  579. return ErrInvalidArgument
  580. }
  581. hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
  582. if err != nil {
  583. return err
  584. }
  585. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  586. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  587. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
  588. return err
  589. }
  590. return nil
  591. }
  592. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  593. // if the user did not exist in the first place.
  594. func (a *Manager) RemoveUser(username string) error {
  595. if !AllowedUsername(username) {
  596. return ErrInvalidArgument
  597. }
  598. // Rows in user_access, user_token, etc. are deleted via foreign keys
  599. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  600. return err
  601. }
  602. return nil
  603. }
  604. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  605. // successful auth via Authenticate. A background process will delete the user at a later date.
  606. func (a *Manager) MarkUserRemoved(user *User) error {
  607. if !AllowedUsername(user.Name) {
  608. return ErrInvalidArgument
  609. }
  610. tx, err := a.db.Begin()
  611. if err != nil {
  612. return err
  613. }
  614. defer tx.Rollback()
  615. if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  616. return err
  617. }
  618. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  619. return err
  620. }
  621. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  622. return err
  623. }
  624. return tx.Commit()
  625. }
  626. // Users returns a list of users. It always also returns the Everyone user ("*").
  627. func (a *Manager) Users() ([]*User, error) {
  628. rows, err := a.db.Query(selectUsernamesQuery)
  629. if err != nil {
  630. return nil, err
  631. }
  632. defer rows.Close()
  633. usernames := make([]string, 0)
  634. for rows.Next() {
  635. var username string
  636. if err := rows.Scan(&username); err != nil {
  637. return nil, err
  638. } else if err := rows.Err(); err != nil {
  639. return nil, err
  640. }
  641. usernames = append(usernames, username)
  642. }
  643. rows.Close()
  644. users := make([]*User, 0)
  645. for _, username := range usernames {
  646. user, err := a.User(username)
  647. if err != nil {
  648. return nil, err
  649. }
  650. users = append(users, user)
  651. }
  652. return users, nil
  653. }
  654. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  655. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  656. func (a *Manager) User(username string) (*User, error) {
  657. rows, err := a.db.Query(selectUserByNameQuery, username)
  658. if err != nil {
  659. return nil, err
  660. }
  661. return a.readUser(rows)
  662. }
  663. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  664. func (a *Manager) UserByID(id string) (*User, error) {
  665. rows, err := a.db.Query(selectUserByIDQuery, id)
  666. if err != nil {
  667. return nil, err
  668. }
  669. return a.readUser(rows)
  670. }
  671. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  672. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  673. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  674. if err != nil {
  675. return nil, err
  676. }
  677. return a.readUser(rows)
  678. }
  679. func (a *Manager) userByToken(token string) (*User, error) {
  680. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  681. if err != nil {
  682. return nil, err
  683. }
  684. return a.readUser(rows)
  685. }
  686. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  687. defer rows.Close()
  688. var id, username, hash, role, prefs, syncTopic string
  689. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
  690. var messages, emails int64
  691. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  692. if !rows.Next() {
  693. return nil, ErrUserNotFound
  694. }
  695. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  696. return nil, err
  697. } else if err := rows.Err(); err != nil {
  698. return nil, err
  699. }
  700. user := &User{
  701. ID: id,
  702. Name: username,
  703. Hash: hash,
  704. Role: Role(role),
  705. Prefs: &Prefs{},
  706. SyncTopic: syncTopic,
  707. Stats: &Stats{
  708. Messages: messages,
  709. Emails: emails,
  710. },
  711. Billing: &Billing{
  712. StripeCustomerID: stripeCustomerID.String, // May be empty
  713. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  714. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  715. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  716. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  717. },
  718. Deleted: deleted.Valid,
  719. }
  720. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  721. return nil, err
  722. }
  723. if tierCode.Valid {
  724. // See readTier() when this is changed!
  725. user.Tier = &Tier{
  726. ID: tierID.String,
  727. Code: tierCode.String,
  728. Name: tierName.String,
  729. MessageLimit: messagesLimit.Int64,
  730. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  731. EmailLimit: emailsLimit.Int64,
  732. ReservationLimit: reservationsLimit.Int64,
  733. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  734. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  735. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  736. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  737. StripePriceID: stripePriceID.String, // May be empty
  738. }
  739. }
  740. return user, nil
  741. }
  742. // Grants returns all user-specific access control entries
  743. func (a *Manager) Grants(username string) ([]Grant, error) {
  744. rows, err := a.db.Query(selectUserAccessQuery, username)
  745. if err != nil {
  746. return nil, err
  747. }
  748. defer rows.Close()
  749. grants := make([]Grant, 0)
  750. for rows.Next() {
  751. var topic string
  752. var read, write bool
  753. if err := rows.Scan(&topic, &read, &write); err != nil {
  754. return nil, err
  755. } else if err := rows.Err(); err != nil {
  756. return nil, err
  757. }
  758. grants = append(grants, Grant{
  759. TopicPattern: fromSQLWildcard(topic),
  760. Allow: NewPermission(read, write),
  761. })
  762. }
  763. return grants, nil
  764. }
  765. // Reservations returns all user-owned topics, and the associated everyone-access
  766. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  767. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  768. if err != nil {
  769. return nil, err
  770. }
  771. defer rows.Close()
  772. reservations := make([]Reservation, 0)
  773. for rows.Next() {
  774. var topic string
  775. var ownerRead, ownerWrite bool
  776. var everyoneRead, everyoneWrite sql.NullBool
  777. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  778. return nil, err
  779. } else if err := rows.Err(); err != nil {
  780. return nil, err
  781. }
  782. reservations = append(reservations, Reservation{
  783. Topic: topic,
  784. Owner: NewPermission(ownerRead, ownerWrite),
  785. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  786. })
  787. }
  788. return reservations, nil
  789. }
  790. // HasReservation returns true if the given topic access is owned by the user
  791. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  792. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  793. if err != nil {
  794. return false, err
  795. }
  796. defer rows.Close()
  797. if !rows.Next() {
  798. return false, errNoRows
  799. }
  800. var count int64
  801. if err := rows.Scan(&count); err != nil {
  802. return false, err
  803. }
  804. return count > 0, nil
  805. }
  806. // ReservationsCount returns the number of reservations owned by this user
  807. func (a *Manager) ReservationsCount(username string) (int64, error) {
  808. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  809. if err != nil {
  810. return 0, err
  811. }
  812. defer rows.Close()
  813. if !rows.Next() {
  814. return 0, errNoRows
  815. }
  816. var count int64
  817. if err := rows.Scan(&count); err != nil {
  818. return 0, err
  819. }
  820. return count, nil
  821. }
  822. // ChangePassword changes a user's password
  823. func (a *Manager) ChangePassword(username, password string) error {
  824. hash, err := bcrypt.GenerateFromPassword([]byte(password), userPasswordBcryptCost)
  825. if err != nil {
  826. return err
  827. }
  828. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  829. return err
  830. }
  831. return nil
  832. }
  833. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  834. // all existing access control entries (Grant) are removed, since they are no longer needed.
  835. func (a *Manager) ChangeRole(username string, role Role) error {
  836. if !AllowedUsername(username) || !AllowedRole(role) {
  837. return ErrInvalidArgument
  838. }
  839. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  840. return err
  841. }
  842. if role == RoleAdmin {
  843. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  844. return err
  845. }
  846. }
  847. return nil
  848. }
  849. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  850. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  851. func (a *Manager) ChangeTier(username, tier string) error {
  852. if !AllowedUsername(username) {
  853. return ErrInvalidArgument
  854. }
  855. t, err := a.Tier(tier)
  856. if err != nil {
  857. return err
  858. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  859. return err
  860. }
  861. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  862. return err
  863. }
  864. return nil
  865. }
  866. // ResetTier removes the tier from the given user
  867. func (a *Manager) ResetTier(username string) error {
  868. if !AllowedUsername(username) && username != Everyone && username != "" {
  869. return ErrInvalidArgument
  870. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  871. return err
  872. }
  873. _, err := a.db.Exec(deleteUserTierQuery, username)
  874. return err
  875. }
  876. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  877. u, err := a.User(username)
  878. if err != nil {
  879. return err
  880. }
  881. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  882. reservations, err := a.Reservations(username)
  883. if err != nil {
  884. return err
  885. } else if int64(len(reservations)) > reservationsLimit {
  886. return ErrTooManyReservations
  887. }
  888. }
  889. return nil
  890. }
  891. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  892. // If there are any ACL entries that are not owned by the user, an error is returned.
  893. // FIXME is this the same as HasReservation?
  894. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  895. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  896. return ErrInvalidArgument
  897. }
  898. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  899. if err != nil {
  900. return err
  901. }
  902. defer rows.Close()
  903. if !rows.Next() {
  904. return errNoRows
  905. }
  906. var otherCount int
  907. if err := rows.Scan(&otherCount); err != nil {
  908. return err
  909. }
  910. if otherCount > 0 {
  911. return errTopicOwnedByOthers
  912. }
  913. return nil
  914. }
  915. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  916. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  917. // owner may either be a user (username), or the system (empty).
  918. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  919. if !AllowedUsername(username) && username != Everyone {
  920. return ErrInvalidArgument
  921. } else if !AllowedTopicPattern(topicPattern) {
  922. return ErrInvalidArgument
  923. }
  924. owner := ""
  925. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  926. return err
  927. }
  928. return nil
  929. }
  930. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  931. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  932. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  933. if !AllowedUsername(username) && username != Everyone && username != "" {
  934. return ErrInvalidArgument
  935. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  936. return ErrInvalidArgument
  937. }
  938. if username == "" && topicPattern == "" {
  939. _, err := a.db.Exec(deleteAllAccessQuery, username)
  940. return err
  941. } else if topicPattern == "" {
  942. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  943. return err
  944. }
  945. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  946. return err
  947. }
  948. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  949. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  950. // can modify or delete them.
  951. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  952. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  953. return ErrInvalidArgument
  954. }
  955. tx, err := a.db.Begin()
  956. if err != nil {
  957. return err
  958. }
  959. defer tx.Rollback()
  960. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  961. return err
  962. }
  963. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  964. return err
  965. }
  966. return tx.Commit()
  967. }
  968. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  969. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  970. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  971. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  972. return ErrInvalidArgument
  973. }
  974. for _, topic := range topics {
  975. if !AllowedTopic(topic) {
  976. return ErrInvalidArgument
  977. }
  978. }
  979. tx, err := a.db.Begin()
  980. if err != nil {
  981. return err
  982. }
  983. defer tx.Rollback()
  984. for _, topic := range topics {
  985. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  986. return err
  987. }
  988. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  989. return err
  990. }
  991. }
  992. return tx.Commit()
  993. }
  994. // DefaultAccess returns the default read/write access if no access control entry matches
  995. func (a *Manager) DefaultAccess() Permission {
  996. return a.defaultAccess
  997. }
  998. // CreateTier creates a new tier in the database
  999. func (a *Manager) CreateTier(tier *Tier) error {
  1000. if tier.ID == "" {
  1001. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1002. }
  1003. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
  1004. return err
  1005. }
  1006. return nil
  1007. }
  1008. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1009. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1010. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1011. return err
  1012. }
  1013. return nil
  1014. }
  1015. // Tiers returns a list of all Tier structs
  1016. func (a *Manager) Tiers() ([]*Tier, error) {
  1017. rows, err := a.db.Query(selectTiersQuery)
  1018. if err != nil {
  1019. return nil, err
  1020. }
  1021. defer rows.Close()
  1022. tiers := make([]*Tier, 0)
  1023. for {
  1024. tier, err := a.readTier(rows)
  1025. if err == ErrTierNotFound {
  1026. break
  1027. } else if err != nil {
  1028. return nil, err
  1029. }
  1030. tiers = append(tiers, tier)
  1031. }
  1032. return tiers, nil
  1033. }
  1034. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1035. func (a *Manager) Tier(code string) (*Tier, error) {
  1036. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1037. if err != nil {
  1038. return nil, err
  1039. }
  1040. defer rows.Close()
  1041. return a.readTier(rows)
  1042. }
  1043. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1044. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1045. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  1046. if err != nil {
  1047. return nil, err
  1048. }
  1049. defer rows.Close()
  1050. return a.readTier(rows)
  1051. }
  1052. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1053. var id, code, name string
  1054. var stripePriceID sql.NullString
  1055. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1056. if !rows.Next() {
  1057. return nil, ErrTierNotFound
  1058. }
  1059. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  1060. return nil, err
  1061. } else if err := rows.Err(); err != nil {
  1062. return nil, err
  1063. }
  1064. // When changed, note readUser() as well
  1065. return &Tier{
  1066. ID: id,
  1067. Code: code,
  1068. Name: name,
  1069. MessageLimit: messagesLimit.Int64,
  1070. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1071. EmailLimit: emailsLimit.Int64,
  1072. ReservationLimit: reservationsLimit.Int64,
  1073. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1074. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1075. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1076. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1077. StripePriceID: stripePriceID.String, // May be empty
  1078. }, nil
  1079. }
  1080. func toSQLWildcard(s string) string {
  1081. return strings.ReplaceAll(s, "*", "%")
  1082. }
  1083. func fromSQLWildcard(s string) string {
  1084. return strings.ReplaceAll(s, "%", "*")
  1085. }
  1086. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1087. if _, err := db.Exec(startupQueries); err != nil {
  1088. return err
  1089. }
  1090. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1091. return err
  1092. }
  1093. return nil
  1094. }
  1095. func setupDB(db *sql.DB) error {
  1096. // If 'schemaVersion' table does not exist, this must be a new database
  1097. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1098. if err != nil {
  1099. return setupNewDB(db)
  1100. }
  1101. defer rowsSV.Close()
  1102. // If 'schemaVersion' table exists, read version and potentially upgrade
  1103. schemaVersion := 0
  1104. if !rowsSV.Next() {
  1105. return errors.New("cannot determine schema version: database file may be corrupt")
  1106. }
  1107. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1108. return err
  1109. }
  1110. rowsSV.Close()
  1111. // Do migrations
  1112. if schemaVersion == currentSchemaVersion {
  1113. return nil
  1114. } else if schemaVersion == 1 {
  1115. return migrateFrom1(db)
  1116. }
  1117. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  1118. }
  1119. func setupNewDB(db *sql.DB) error {
  1120. if _, err := db.Exec(createTablesQueries); err != nil {
  1121. return err
  1122. }
  1123. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1124. return err
  1125. }
  1126. return nil
  1127. }
  1128. func migrateFrom1(db *sql.DB) error {
  1129. log.Info("Migrating user database schema: from 1 to 2")
  1130. tx, err := db.Begin()
  1131. if err != nil {
  1132. return err
  1133. }
  1134. defer tx.Rollback()
  1135. // Rename user -> user_old, and create new tables
  1136. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  1137. return err
  1138. }
  1139. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  1140. return err
  1141. }
  1142. // Insert users from user_old into new user table, with ID and sync_topic
  1143. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1144. if err != nil {
  1145. return err
  1146. }
  1147. defer rows.Close()
  1148. usernames := make([]string, 0)
  1149. for rows.Next() {
  1150. var username string
  1151. if err := rows.Scan(&username); err != nil {
  1152. return err
  1153. }
  1154. usernames = append(usernames, username)
  1155. }
  1156. if err := rows.Close(); err != nil {
  1157. return err
  1158. }
  1159. for _, username := range usernames {
  1160. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1161. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1162. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1163. return err
  1164. }
  1165. }
  1166. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1167. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1168. return err
  1169. }
  1170. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1171. return err
  1172. }
  1173. if err := tx.Commit(); err != nil {
  1174. return err
  1175. }
  1176. return nil // Update this when a new version is added
  1177. }
  1178. func nullString(s string) sql.NullString {
  1179. if s == "" {
  1180. return sql.NullString{}
  1181. }
  1182. return sql.NullString{String: s, Valid: true}
  1183. }
  1184. func nullInt64(v int64) sql.NullInt64 {
  1185. if v == 0 {
  1186. return sql.NullInt64{}
  1187. }
  1188. return sql.NullInt64{Int64: v, Valid: true}
  1189. }