manager.go 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. bcryptCost = 10
  18. intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
  19. userStatsQueueWriterInterval = 33 * time.Second
  20. tokenLength = 32
  21. tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
  22. syncTopicLength = 16
  23. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  24. )
  25. var (
  26. errNoTokenProvided = errors.New("no token provided")
  27. errTopicOwnedByOthers = errors.New("topic owned by others")
  28. errNoRows = errors.New("no rows found")
  29. )
  30. // Manager-related queries
  31. const (
  32. createTablesQueriesNoTx = `
  33. CREATE TABLE IF NOT EXISTS tier (
  34. id INTEGER PRIMARY KEY AUTOINCREMENT,
  35. code TEXT NOT NULL,
  36. name TEXT NOT NULL,
  37. messages_limit INT NOT NULL,
  38. messages_expiry_duration INT NOT NULL,
  39. emails_limit INT NOT NULL,
  40. reservations_limit INT NOT NULL,
  41. attachment_file_size_limit INT NOT NULL,
  42. attachment_total_size_limit INT NOT NULL,
  43. attachment_expiry_duration INT NOT NULL,
  44. stripe_price_id TEXT
  45. );
  46. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  47. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  48. CREATE TABLE IF NOT EXISTS user (
  49. id INTEGER PRIMARY KEY AUTOINCREMENT,
  50. tier_id INT,
  51. user TEXT NOT NULL,
  52. pass TEXT NOT NULL,
  53. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  54. prefs JSON NOT NULL DEFAULT '{}',
  55. sync_topic TEXT NOT NULL,
  56. stats_messages INT NOT NULL DEFAULT (0),
  57. stats_emails INT NOT NULL DEFAULT (0),
  58. stripe_customer_id TEXT,
  59. stripe_subscription_id TEXT,
  60. stripe_subscription_status TEXT,
  61. stripe_subscription_paid_until INT,
  62. stripe_subscription_cancel_at INT,
  63. created_by TEXT NOT NULL,
  64. created_at INT NOT NULL,
  65. last_seen INT NOT NULL,
  66. FOREIGN KEY (tier_id) REFERENCES tier (id)
  67. );
  68. CREATE UNIQUE INDEX idx_user ON user (user);
  69. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  70. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  71. CREATE TABLE IF NOT EXISTS user_access (
  72. user_id INT NOT NULL,
  73. topic TEXT NOT NULL,
  74. read INT NOT NULL,
  75. write INT NOT NULL,
  76. owner_user_id INT,
  77. PRIMARY KEY (user_id, topic),
  78. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  79. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  80. );
  81. CREATE TABLE IF NOT EXISTS user_token (
  82. user_id INT NOT NULL,
  83. token TEXT NOT NULL,
  84. expires INT NOT NULL,
  85. PRIMARY KEY (user_id, token),
  86. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  87. );
  88. CREATE TABLE IF NOT EXISTS schemaVersion (
  89. id INT PRIMARY KEY,
  90. version INT NOT NULL
  91. );
  92. INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at, last_seen)
  93. VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH(), 0)
  94. ON CONFLICT (id) DO NOTHING;
  95. `
  96. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  97. builtinStartupQueries = `
  98. PRAGMA foreign_keys = ON;
  99. `
  100. selectUserByNameQuery = `
  101. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  102. FROM user u
  103. LEFT JOIN tier t on t.id = u.tier_id
  104. WHERE user = ?
  105. `
  106. selectUserByTokenQuery = `
  107. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  108. FROM user u
  109. JOIN user_token t on u.id = t.user_id
  110. LEFT JOIN tier t on t.id = u.tier_id
  111. WHERE t.token = ? AND t.expires >= ?
  112. `
  113. selectUserByStripeCustomerIDQuery = `
  114. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.stripe_price_id
  115. FROM user u
  116. LEFT JOIN tier t on t.id = u.tier_id
  117. WHERE u.stripe_customer_id = ?
  118. `
  119. selectTopicPermsQuery = `
  120. SELECT read, write
  121. FROM user_access a
  122. JOIN user u ON u.id = a.user_id
  123. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  124. ORDER BY u.user DESC
  125. `
  126. insertUserQuery = `
  127. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
  128. VALUES (?, ?, ?, ?, ?, ?, ?)
  129. `
  130. selectUsernamesQuery = `
  131. SELECT user
  132. FROM user
  133. ORDER BY
  134. CASE role
  135. WHEN 'admin' THEN 1
  136. WHEN 'anonymous' THEN 3
  137. ELSE 2
  138. END, user
  139. `
  140. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  141. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  142. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  143. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE user = ?`
  144. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  145. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  146. upsertUserAccessQuery = `
  147. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  148. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  149. ON CONFLICT (user_id, topic)
  150. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  151. `
  152. selectUserAccessQuery = `
  153. SELECT topic, read, write
  154. FROM user_access
  155. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  156. ORDER BY write DESC, read DESC, topic
  157. `
  158. selectUserReservationsQuery = `
  159. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  160. FROM user_access a_user
  161. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  162. WHERE a_user.user_id = a_user.owner_user_id
  163. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  164. ORDER BY a_user.topic
  165. `
  166. selectUserReservationsCountQuery = `
  167. SELECT COUNT(*)
  168. FROM user_access
  169. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  170. `
  171. selectUserHasReservationQuery = `
  172. SELECT COUNT(*)
  173. FROM user_access
  174. WHERE user_id = owner_user_id
  175. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  176. AND topic = ?
  177. `
  178. selectOtherAccessCountQuery = `
  179. SELECT COUNT(*)
  180. FROM user_access
  181. WHERE (topic = ? OR ? LIKE topic)
  182. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  183. `
  184. deleteAllAccessQuery = `DELETE FROM user_access`
  185. deleteUserAccessQuery = `
  186. DELETE FROM user_access
  187. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  188. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  189. `
  190. deleteTopicAccessQuery = `
  191. DELETE FROM user_access
  192. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  193. AND topic = ?
  194. `
  195. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
  196. insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
  197. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  198. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  199. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
  200. deleteExcessTokensQuery = `
  201. DELETE FROM user_token
  202. WHERE (user_id, token) NOT IN (
  203. SELECT user_id, token
  204. FROM user_token
  205. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  206. ORDER BY expires DESC
  207. LIMIT ?
  208. )
  209. `
  210. insertTierQuery = `
  211. INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id)
  212. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  213. `
  214. selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
  215. selectTiersQuery = `
  216. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  217. FROM tier
  218. `
  219. selectTierByCodeQuery = `
  220. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  221. FROM tier
  222. WHERE code = ?
  223. `
  224. selectTierByPriceIDQuery = `
  225. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  226. FROM tier
  227. WHERE stripe_price_id = ?
  228. `
  229. updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
  230. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  231. updateBillingQuery = `
  232. UPDATE user
  233. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  234. WHERE user = ?
  235. `
  236. )
  237. // Schema management queries
  238. const (
  239. currentSchemaVersion = 2
  240. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  241. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  242. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  243. // 1 -> 2 (complex migration!)
  244. migrate1To2RenameUserTableQueryNoTx = `
  245. ALTER TABLE user RENAME TO user_old;
  246. `
  247. migrate1To2InsertFromOldTablesAndDropNoTx = `
  248. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
  249. SELECT user, pass, role, '', 'admin', UNIXEPOCH(), UNIXEPOCH() FROM user_old;
  250. INSERT INTO user_access (user_id, topic, read, write)
  251. SELECT u.id, a.topic, a.read, a.write
  252. FROM user u
  253. JOIN access a ON u.user = a.user;
  254. DROP TABLE access;
  255. DROP TABLE user_old;
  256. `
  257. migrate1To2SelectAllUsersIDsNoTx = `SELECT id FROM user`
  258. migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
  259. )
  260. // Manager is an implementation of Manager. It stores users and access control list
  261. // in a SQLite database.
  262. type Manager struct {
  263. db *sql.DB
  264. defaultAccess Permission // Default permission if no ACL matches
  265. statsQueue map[string]*User // Username -> User, for "unimportant" user updates
  266. mu sync.Mutex
  267. }
  268. var _ Auther = (*Manager)(nil)
  269. // NewManager creates a new Manager instance
  270. func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
  271. return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
  272. }
  273. // NewManager creates a new Manager instance
  274. func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
  275. db, err := sql.Open("sqlite3", filename)
  276. if err != nil {
  277. return nil, err
  278. }
  279. if err := setupDB(db); err != nil {
  280. return nil, err
  281. }
  282. if err := runStartupQueries(db, startupQueries); err != nil {
  283. return nil, err
  284. }
  285. manager := &Manager{
  286. db: db,
  287. defaultAccess: defaultAccess,
  288. statsQueue: make(map[string]*User),
  289. }
  290. go manager.userStatsQueueWriter(statsWriterInterval)
  291. return manager, nil
  292. }
  293. // Authenticate checks username and password and returns a User if correct. The method
  294. // returns in constant-ish time, regardless of whether the user exists or the password is
  295. // correct or incorrect.
  296. func (a *Manager) Authenticate(username, password string) (*User, error) {
  297. if username == Everyone {
  298. return nil, ErrUnauthenticated
  299. }
  300. user, err := a.User(username)
  301. if err != nil {
  302. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  303. bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  304. return nil, ErrUnauthenticated
  305. }
  306. if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  307. log.Trace("authentication of user %s failed (2): %s", username, err.Error())
  308. return nil, ErrUnauthenticated
  309. }
  310. return user, nil
  311. }
  312. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  313. // The method sets the User.Token value to the token that was used for authentication.
  314. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  315. if len(token) != tokenLength {
  316. return nil, ErrUnauthenticated
  317. }
  318. user, err := a.userByToken(token)
  319. if err != nil {
  320. return nil, ErrUnauthenticated
  321. }
  322. user.Token = token
  323. return user, nil
  324. }
  325. // CreateToken generates a random token for the given user and returns it. The token expires
  326. // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
  327. // given user, if there are too many of them.
  328. func (a *Manager) CreateToken(user *User) (*Token, error) {
  329. token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
  330. tx, err := a.db.Begin()
  331. if err != nil {
  332. return nil, err
  333. }
  334. defer tx.Rollback()
  335. if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
  336. return nil, err
  337. }
  338. rows, err := tx.Query(selectTokenCountQuery, user.Name)
  339. if err != nil {
  340. return nil, err
  341. }
  342. defer rows.Close()
  343. if !rows.Next() {
  344. return nil, errNoRows
  345. }
  346. var tokenCount int
  347. if err := rows.Scan(&tokenCount); err != nil {
  348. return nil, err
  349. }
  350. if tokenCount >= tokenMaxCount {
  351. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  352. // on two indices, whereas the query below is a full table scan.
  353. if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
  354. return nil, err
  355. }
  356. }
  357. if err := tx.Commit(); err != nil {
  358. return nil, err
  359. }
  360. return &Token{
  361. Value: token,
  362. Expires: expires,
  363. }, nil
  364. }
  365. // ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
  366. func (a *Manager) ExtendToken(user *User) (*Token, error) {
  367. if user.Token == "" {
  368. return nil, errNoTokenProvided
  369. }
  370. newExpires := time.Now().Add(tokenExpiryDuration)
  371. if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
  372. return nil, err
  373. }
  374. return &Token{
  375. Value: user.Token,
  376. Expires: newExpires,
  377. }, nil
  378. }
  379. // RemoveToken deletes the token defined in User.Token
  380. func (a *Manager) RemoveToken(user *User) error {
  381. if user.Token == "" {
  382. return ErrUnauthorized
  383. }
  384. if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil {
  385. return err
  386. }
  387. return nil
  388. }
  389. // RemoveExpiredTokens deletes all expired tokens from the database
  390. func (a *Manager) RemoveExpiredTokens() error {
  391. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  392. return err
  393. }
  394. return nil
  395. }
  396. // ChangeSettings persists the user settings
  397. func (a *Manager) ChangeSettings(user *User) error {
  398. prefs, err := json.Marshal(user.Prefs)
  399. if err != nil {
  400. return err
  401. }
  402. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  403. return err
  404. }
  405. return nil
  406. }
  407. // ResetStats resets all user stats in the user database. This touches all users.
  408. func (a *Manager) ResetStats() error {
  409. a.mu.Lock()
  410. defer a.mu.Unlock()
  411. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  412. return err
  413. }
  414. a.statsQueue = make(map[string]*User)
  415. return nil
  416. }
  417. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  418. // batches at a regular interval
  419. func (a *Manager) EnqueueStats(user *User) {
  420. a.mu.Lock()
  421. defer a.mu.Unlock()
  422. a.statsQueue[user.Name] = user
  423. }
  424. func (a *Manager) userStatsQueueWriter(interval time.Duration) {
  425. ticker := time.NewTicker(interval)
  426. for range ticker.C {
  427. if err := a.writeUserStatsQueue(); err != nil {
  428. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  429. }
  430. }
  431. }
  432. func (a *Manager) writeUserStatsQueue() error {
  433. a.mu.Lock()
  434. if len(a.statsQueue) == 0 {
  435. a.mu.Unlock()
  436. log.Trace("User Manager: No user stats updates to commit")
  437. return nil
  438. }
  439. statsQueue := a.statsQueue
  440. a.statsQueue = make(map[string]*User)
  441. a.mu.Unlock()
  442. tx, err := a.db.Begin()
  443. if err != nil {
  444. return err
  445. }
  446. defer tx.Rollback()
  447. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  448. for username, u := range statsQueue {
  449. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails)
  450. if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil {
  451. return err
  452. }
  453. }
  454. return tx.Commit()
  455. }
  456. // Authorize returns nil if the given user has access to the given topic using the desired
  457. // permission. The user param may be nil to signal an anonymous user.
  458. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  459. if user != nil && user.Role == RoleAdmin {
  460. return nil // Admin can do everything
  461. }
  462. username := Everyone
  463. if user != nil {
  464. username = user.Name
  465. }
  466. // Select the read/write permissions for this user/topic combo. The query may return two
  467. // rows (one for everyone, and one for the user), but prioritizes the user.
  468. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  469. if err != nil {
  470. return err
  471. }
  472. defer rows.Close()
  473. if !rows.Next() {
  474. return a.resolvePerms(a.defaultAccess, perm)
  475. }
  476. var read, write bool
  477. if err := rows.Scan(&read, &write); err != nil {
  478. return err
  479. } else if err := rows.Err(); err != nil {
  480. return err
  481. }
  482. return a.resolvePerms(NewPermission(read, write), perm)
  483. }
  484. func (a *Manager) resolvePerms(base, perm Permission) error {
  485. if perm == PermissionRead && base.IsRead() {
  486. return nil
  487. } else if perm == PermissionWrite && base.IsWrite() {
  488. return nil
  489. }
  490. return ErrUnauthorized
  491. }
  492. // AddUser adds a user with the given username, password and role
  493. func (a *Manager) AddUser(username, password string, role Role, createdBy string) error {
  494. if !AllowedUsername(username) || !AllowedRole(role) {
  495. return ErrInvalidArgument
  496. }
  497. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  498. if err != nil {
  499. return err
  500. }
  501. syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix()
  502. if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now, now); err != nil {
  503. return err
  504. }
  505. return nil
  506. }
  507. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  508. // if the user did not exist in the first place.
  509. func (a *Manager) RemoveUser(username string) error {
  510. if !AllowedUsername(username) {
  511. return ErrInvalidArgument
  512. }
  513. // Rows in user_access, user_token, etc. are deleted via foreign keys
  514. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  515. return err
  516. }
  517. return nil
  518. }
  519. // Users returns a list of users. It always also returns the Everyone user ("*").
  520. func (a *Manager) Users() ([]*User, error) {
  521. rows, err := a.db.Query(selectUsernamesQuery)
  522. if err != nil {
  523. return nil, err
  524. }
  525. defer rows.Close()
  526. usernames := make([]string, 0)
  527. for rows.Next() {
  528. var username string
  529. if err := rows.Scan(&username); err != nil {
  530. return nil, err
  531. } else if err := rows.Err(); err != nil {
  532. return nil, err
  533. }
  534. usernames = append(usernames, username)
  535. }
  536. rows.Close()
  537. users := make([]*User, 0)
  538. for _, username := range usernames {
  539. user, err := a.User(username)
  540. if err != nil {
  541. return nil, err
  542. }
  543. users = append(users, user)
  544. }
  545. return users, nil
  546. }
  547. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  548. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  549. func (a *Manager) User(username string) (*User, error) {
  550. rows, err := a.db.Query(selectUserByNameQuery, username)
  551. if err != nil {
  552. return nil, err
  553. }
  554. return a.readUser(rows)
  555. }
  556. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  557. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  558. if err != nil {
  559. return nil, err
  560. }
  561. return a.readUser(rows)
  562. }
  563. func (a *Manager) userByToken(token string) (*User, error) {
  564. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  565. if err != nil {
  566. return nil, err
  567. }
  568. return a.readUser(rows)
  569. }
  570. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  571. defer rows.Close()
  572. var username, hash, role, prefs, syncTopic string
  573. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
  574. var messages, emails int64
  575. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
  576. if !rows.Next() {
  577. return nil, ErrUserNotFound
  578. }
  579. if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  580. return nil, err
  581. } else if err := rows.Err(); err != nil {
  582. return nil, err
  583. }
  584. user := &User{
  585. Name: username,
  586. Hash: hash,
  587. Role: Role(role),
  588. Prefs: &Prefs{},
  589. SyncTopic: syncTopic,
  590. Stats: &Stats{
  591. Messages: messages,
  592. Emails: emails,
  593. },
  594. Billing: &Billing{
  595. StripeCustomerID: stripeCustomerID.String, // May be empty
  596. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  597. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  598. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  599. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  600. },
  601. }
  602. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  603. return nil, err
  604. }
  605. if tierCode.Valid {
  606. // See readTier() when this is changed!
  607. user.Tier = &Tier{
  608. Code: tierCode.String,
  609. Name: tierName.String,
  610. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  611. MessagesLimit: messagesLimit.Int64,
  612. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  613. EmailsLimit: emailsLimit.Int64,
  614. ReservationsLimit: reservationsLimit.Int64,
  615. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  616. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  617. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  618. StripePriceID: stripePriceID.String, // May be empty
  619. }
  620. }
  621. return user, nil
  622. }
  623. // Grants returns all user-specific access control entries
  624. func (a *Manager) Grants(username string) ([]Grant, error) {
  625. rows, err := a.db.Query(selectUserAccessQuery, username)
  626. if err != nil {
  627. return nil, err
  628. }
  629. defer rows.Close()
  630. grants := make([]Grant, 0)
  631. for rows.Next() {
  632. var topic string
  633. var read, write bool
  634. if err := rows.Scan(&topic, &read, &write); err != nil {
  635. return nil, err
  636. } else if err := rows.Err(); err != nil {
  637. return nil, err
  638. }
  639. grants = append(grants, Grant{
  640. TopicPattern: fromSQLWildcard(topic),
  641. Allow: NewPermission(read, write),
  642. })
  643. }
  644. return grants, nil
  645. }
  646. // Reservations returns all user-owned topics, and the associated everyone-access
  647. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  648. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  649. if err != nil {
  650. return nil, err
  651. }
  652. defer rows.Close()
  653. reservations := make([]Reservation, 0)
  654. for rows.Next() {
  655. var topic string
  656. var ownerRead, ownerWrite bool
  657. var everyoneRead, everyoneWrite sql.NullBool
  658. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  659. return nil, err
  660. } else if err := rows.Err(); err != nil {
  661. return nil, err
  662. }
  663. reservations = append(reservations, Reservation{
  664. Topic: topic,
  665. Owner: NewPermission(ownerRead, ownerWrite),
  666. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  667. })
  668. }
  669. return reservations, nil
  670. }
  671. // HasReservation returns true if the given topic access is owned by the user
  672. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  673. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  674. if err != nil {
  675. return false, err
  676. }
  677. defer rows.Close()
  678. if !rows.Next() {
  679. return false, errNoRows
  680. }
  681. var count int64
  682. if err := rows.Scan(&count); err != nil {
  683. return false, err
  684. }
  685. return count > 0, nil
  686. }
  687. // ReservationsCount returns the number of reservations owned by this user
  688. func (a *Manager) ReservationsCount(username string) (int64, error) {
  689. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  690. if err != nil {
  691. return 0, err
  692. }
  693. defer rows.Close()
  694. if !rows.Next() {
  695. return 0, errNoRows
  696. }
  697. var count int64
  698. if err := rows.Scan(&count); err != nil {
  699. return 0, err
  700. }
  701. return count, nil
  702. }
  703. // ChangePassword changes a user's password
  704. func (a *Manager) ChangePassword(username, password string) error {
  705. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  706. if err != nil {
  707. return err
  708. }
  709. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  710. return err
  711. }
  712. return nil
  713. }
  714. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  715. // all existing access control entries (Grant) are removed, since they are no longer needed.
  716. func (a *Manager) ChangeRole(username string, role Role) error {
  717. if !AllowedUsername(username) || !AllowedRole(role) {
  718. return ErrInvalidArgument
  719. }
  720. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  721. return err
  722. }
  723. if role == RoleAdmin {
  724. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  725. return err
  726. }
  727. }
  728. return nil
  729. }
  730. // ChangeTier changes a user's tier using the tier code
  731. func (a *Manager) ChangeTier(username, tier string) error {
  732. if !AllowedUsername(username) {
  733. return ErrInvalidArgument
  734. }
  735. rows, err := a.db.Query(selectTierIDQuery, tier)
  736. if err != nil {
  737. return err
  738. }
  739. defer rows.Close()
  740. if !rows.Next() {
  741. return ErrInvalidArgument
  742. }
  743. var tierID int64
  744. if err := rows.Scan(&tierID); err != nil {
  745. return err
  746. }
  747. rows.Close()
  748. if _, err := a.db.Exec(updateUserTierQuery, tierID, username); err != nil {
  749. return err
  750. }
  751. return nil
  752. }
  753. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  754. // If there are any ACL entries that are not owned by the user, an error is returned.
  755. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  756. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  757. return ErrInvalidArgument
  758. }
  759. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  760. if err != nil {
  761. return err
  762. }
  763. defer rows.Close()
  764. if !rows.Next() {
  765. return errNoRows
  766. }
  767. var otherCount int
  768. if err := rows.Scan(&otherCount); err != nil {
  769. return err
  770. }
  771. if otherCount > 0 {
  772. return errTopicOwnedByOthers
  773. }
  774. return nil
  775. }
  776. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  777. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  778. // owner may either be a user (username), or the system (empty).
  779. func (a *Manager) AllowAccess(owner, username string, topicPattern string, read bool, write bool) error {
  780. if !AllowedUsername(username) && username != Everyone {
  781. return ErrInvalidArgument
  782. } else if owner != "" && !AllowedUsername(owner) {
  783. return ErrInvalidArgument
  784. } else if !AllowedTopicPattern(topicPattern) {
  785. return ErrInvalidArgument
  786. }
  787. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), read, write, owner, owner); err != nil {
  788. return err
  789. }
  790. return nil
  791. }
  792. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  793. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  794. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  795. if !AllowedUsername(username) && username != Everyone && username != "" {
  796. return ErrInvalidArgument
  797. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  798. return ErrInvalidArgument
  799. }
  800. if username == "" && topicPattern == "" {
  801. _, err := a.db.Exec(deleteAllAccessQuery, username)
  802. return err
  803. } else if topicPattern == "" {
  804. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  805. return err
  806. }
  807. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  808. return err
  809. }
  810. // ResetTier removes the tier from the given user
  811. func (a *Manager) ResetTier(username string) error {
  812. if !AllowedUsername(username) && username != Everyone && username != "" {
  813. return ErrInvalidArgument
  814. }
  815. _, err := a.db.Exec(deleteUserTierQuery, username)
  816. return err
  817. }
  818. // DefaultAccess returns the default read/write access if no access control entry matches
  819. func (a *Manager) DefaultAccess() Permission {
  820. return a.defaultAccess
  821. }
  822. // CreateTier creates a new tier in the database
  823. func (a *Manager) CreateTier(tier *Tier) error {
  824. if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.StripePriceID); err != nil {
  825. return err
  826. }
  827. return nil
  828. }
  829. func (a *Manager) ChangeBilling(user *User) error {
  830. if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil {
  831. return err
  832. }
  833. return nil
  834. }
  835. func (a *Manager) Tiers() ([]*Tier, error) {
  836. rows, err := a.db.Query(selectTiersQuery)
  837. if err != nil {
  838. return nil, err
  839. }
  840. defer rows.Close()
  841. tiers := make([]*Tier, 0)
  842. for {
  843. tier, err := a.readTier(rows)
  844. if err == ErrTierNotFound {
  845. break
  846. } else if err != nil {
  847. return nil, err
  848. }
  849. tiers = append(tiers, tier)
  850. }
  851. return tiers, nil
  852. }
  853. func (a *Manager) Tier(code string) (*Tier, error) {
  854. rows, err := a.db.Query(selectTierByCodeQuery, code)
  855. if err != nil {
  856. return nil, err
  857. }
  858. defer rows.Close()
  859. return a.readTier(rows)
  860. }
  861. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  862. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  863. if err != nil {
  864. return nil, err
  865. }
  866. defer rows.Close()
  867. return a.readTier(rows)
  868. }
  869. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  870. var code, name string
  871. var stripePriceID sql.NullString
  872. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
  873. if !rows.Next() {
  874. return nil, ErrTierNotFound
  875. }
  876. if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  877. return nil, err
  878. } else if err := rows.Err(); err != nil {
  879. return nil, err
  880. }
  881. // When changed, note readUser() as well
  882. return &Tier{
  883. Code: code,
  884. Name: name,
  885. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  886. MessagesLimit: messagesLimit.Int64,
  887. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  888. EmailsLimit: emailsLimit.Int64,
  889. ReservationsLimit: reservationsLimit.Int64,
  890. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  891. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  892. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  893. StripePriceID: stripePriceID.String, // May be empty
  894. }, nil
  895. }
  896. func toSQLWildcard(s string) string {
  897. return strings.ReplaceAll(s, "*", "%")
  898. }
  899. func fromSQLWildcard(s string) string {
  900. return strings.ReplaceAll(s, "%", "*")
  901. }
  902. func runStartupQueries(db *sql.DB, startupQueries string) error {
  903. if _, err := db.Exec(startupQueries); err != nil {
  904. return err
  905. }
  906. if _, err := db.Exec(builtinStartupQueries); err != nil {
  907. return err
  908. }
  909. return nil
  910. }
  911. func setupDB(db *sql.DB) error {
  912. // If 'schemaVersion' table does not exist, this must be a new database
  913. rowsSV, err := db.Query(selectSchemaVersionQuery)
  914. if err != nil {
  915. return setupNewDB(db)
  916. }
  917. defer rowsSV.Close()
  918. // If 'schemaVersion' table exists, read version and potentially upgrade
  919. schemaVersion := 0
  920. if !rowsSV.Next() {
  921. return errors.New("cannot determine schema version: database file may be corrupt")
  922. }
  923. if err := rowsSV.Scan(&schemaVersion); err != nil {
  924. return err
  925. }
  926. rowsSV.Close()
  927. // Do migrations
  928. if schemaVersion == currentSchemaVersion {
  929. return nil
  930. } else if schemaVersion == 1 {
  931. return migrateFrom1(db)
  932. }
  933. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  934. }
  935. func setupNewDB(db *sql.DB) error {
  936. if _, err := db.Exec(createTablesQueries); err != nil {
  937. return err
  938. }
  939. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  940. return err
  941. }
  942. return nil
  943. }
  944. func migrateFrom1(db *sql.DB) error {
  945. log.Info("Migrating user database schema: from 1 to 2")
  946. tx, err := db.Begin()
  947. if err != nil {
  948. return err
  949. }
  950. defer tx.Rollback()
  951. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  952. return err
  953. }
  954. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  955. return err
  956. }
  957. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  958. return err
  959. }
  960. rows, err := tx.Query(migrate1To2SelectAllUsersIDsNoTx)
  961. if err != nil {
  962. return err
  963. }
  964. defer rows.Close()
  965. syncTopics := make(map[int]string)
  966. for rows.Next() {
  967. var userID int
  968. if err := rows.Scan(&userID); err != nil {
  969. return err
  970. }
  971. syncTopics[userID] = util.RandomString(syncTopicLength)
  972. }
  973. if err := rows.Close(); err != nil {
  974. return err
  975. }
  976. for userID, syncTopic := range syncTopics {
  977. if _, err := tx.Exec(migrate1To2UpdateSyncTopicNoTx, syncTopic, userID); err != nil {
  978. return err
  979. }
  980. }
  981. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  982. return err
  983. }
  984. if err := tx.Commit(); err != nil {
  985. return err
  986. }
  987. return nil // Update this when a new version is added
  988. }
  989. func nullString(s string) sql.NullString {
  990. if s == "" {
  991. return sql.NullString{}
  992. }
  993. return sql.NullString{String: s, Valid: true}
  994. }
  995. func nullInt64(v int64) sql.NullInt64 {
  996. if v == 0 {
  997. return sql.NullInt64{}
  998. }
  999. return sql.NullInt64{Int64: v, Valid: true}
  1000. }