manager.go 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. bcryptCost = 10
  18. intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
  19. userStatsQueueWriterInterval = 33 * time.Second
  20. tokenLength = 32
  21. tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
  22. syncTopicLength = 16
  23. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  24. )
  25. var (
  26. errNoTokenProvided = errors.New("no token provided")
  27. errTopicOwnedByOthers = errors.New("topic owned by others")
  28. errNoRows = errors.New("no rows found")
  29. )
  30. // Manager-related queries
  31. const (
  32. createTablesQueriesNoTx = `
  33. CREATE TABLE IF NOT EXISTS tier (
  34. id INTEGER PRIMARY KEY AUTOINCREMENT,
  35. code TEXT NOT NULL,
  36. name TEXT NOT NULL,
  37. messages_limit INT NOT NULL,
  38. messages_expiry_duration INT NOT NULL,
  39. emails_limit INT NOT NULL,
  40. reservations_limit INT NOT NULL,
  41. attachment_file_size_limit INT NOT NULL,
  42. attachment_total_size_limit INT NOT NULL,
  43. attachment_expiry_duration INT NOT NULL,
  44. features TEXT,
  45. stripe_price_id TEXT
  46. );
  47. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  48. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  49. CREATE TABLE IF NOT EXISTS user (
  50. id INTEGER PRIMARY KEY AUTOINCREMENT,
  51. tier_id INT,
  52. user TEXT NOT NULL,
  53. pass TEXT NOT NULL,
  54. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  55. prefs JSON NOT NULL DEFAULT '{}',
  56. sync_topic TEXT NOT NULL,
  57. stats_messages INT NOT NULL DEFAULT (0),
  58. stats_emails INT NOT NULL DEFAULT (0),
  59. stripe_customer_id TEXT,
  60. stripe_subscription_id TEXT,
  61. stripe_subscription_status TEXT,
  62. stripe_subscription_paid_until INT,
  63. stripe_subscription_cancel_at INT,
  64. created_by TEXT NOT NULL,
  65. created_at INT NOT NULL,
  66. last_seen INT NOT NULL,
  67. FOREIGN KEY (tier_id) REFERENCES tier (id)
  68. );
  69. CREATE UNIQUE INDEX idx_user ON user (user);
  70. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  71. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  72. CREATE TABLE IF NOT EXISTS user_access (
  73. user_id INT NOT NULL,
  74. topic TEXT NOT NULL,
  75. read INT NOT NULL,
  76. write INT NOT NULL,
  77. owner_user_id INT,
  78. PRIMARY KEY (user_id, topic),
  79. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  80. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  81. );
  82. CREATE TABLE IF NOT EXISTS user_token (
  83. user_id INT NOT NULL,
  84. token TEXT NOT NULL,
  85. expires INT NOT NULL,
  86. PRIMARY KEY (user_id, token),
  87. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  88. );
  89. CREATE TABLE IF NOT EXISTS schemaVersion (
  90. id INT PRIMARY KEY,
  91. version INT NOT NULL
  92. );
  93. INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at, last_seen)
  94. VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH(), 0)
  95. ON CONFLICT (id) DO NOTHING;
  96. `
  97. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  98. builtinStartupQueries = `
  99. PRAGMA foreign_keys = ON;
  100. `
  101. selectUserByNameQuery = `
  102. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
  103. FROM user u
  104. LEFT JOIN tier t on t.id = u.tier_id
  105. WHERE user = ?
  106. `
  107. selectUserByTokenQuery = `
  108. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
  109. FROM user u
  110. JOIN user_token t on u.id = t.user_id
  111. LEFT JOIN tier t on t.id = u.tier_id
  112. WHERE t.token = ? AND t.expires >= ?
  113. `
  114. selectUserByStripeCustomerIDQuery = `
  115. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.features, t.stripe_price_id
  116. FROM user u
  117. LEFT JOIN tier t on t.id = u.tier_id
  118. WHERE u.stripe_customer_id = ?
  119. `
  120. selectTopicPermsQuery = `
  121. SELECT read, write
  122. FROM user_access a
  123. JOIN user u ON u.id = a.user_id
  124. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  125. ORDER BY u.user DESC
  126. `
  127. insertUserQuery = `
  128. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
  129. VALUES (?, ?, ?, ?, ?, ?, ?)
  130. `
  131. selectUsernamesQuery = `
  132. SELECT user
  133. FROM user
  134. ORDER BY
  135. CASE role
  136. WHEN 'admin' THEN 1
  137. WHEN 'anonymous' THEN 3
  138. ELSE 2
  139. END, user
  140. `
  141. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  142. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  143. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  144. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE user = ?`
  145. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  146. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  147. upsertUserAccessQuery = `
  148. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  149. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  150. ON CONFLICT (user_id, topic)
  151. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  152. `
  153. selectUserAccessQuery = `
  154. SELECT topic, read, write
  155. FROM user_access
  156. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  157. ORDER BY write DESC, read DESC, topic
  158. `
  159. selectUserReservationsQuery = `
  160. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  161. FROM user_access a_user
  162. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  163. WHERE a_user.user_id = a_user.owner_user_id
  164. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  165. ORDER BY a_user.topic
  166. `
  167. selectUserReservationsCountQuery = `
  168. SELECT COUNT(*)
  169. FROM user_access
  170. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  171. `
  172. selectUserHasReservationQuery = `
  173. SELECT COUNT(*)
  174. FROM user_access
  175. WHERE user_id = owner_user_id
  176. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  177. AND topic = ?
  178. `
  179. selectOtherAccessCountQuery = `
  180. SELECT COUNT(*)
  181. FROM user_access
  182. WHERE (topic = ? OR ? LIKE topic)
  183. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  184. `
  185. deleteAllAccessQuery = `DELETE FROM user_access`
  186. deleteUserAccessQuery = `
  187. DELETE FROM user_access
  188. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  189. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  190. `
  191. deleteTopicAccessQuery = `
  192. DELETE FROM user_access
  193. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  194. AND topic = ?
  195. `
  196. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
  197. insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
  198. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  199. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  200. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
  201. deleteExcessTokensQuery = `
  202. DELETE FROM user_token
  203. WHERE (user_id, token) NOT IN (
  204. SELECT user_id, token
  205. FROM user_token
  206. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  207. ORDER BY expires DESC
  208. LIMIT ?
  209. )
  210. `
  211. insertTierQuery = `
  212. INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
  213. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
  214. `
  215. selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
  216. selectTiersQuery = `
  217. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
  218. FROM tier
  219. `
  220. selectTierByCodeQuery = `
  221. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
  222. FROM tier
  223. WHERE code = ?
  224. `
  225. selectTierByPriceIDQuery = `
  226. SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, features, stripe_price_id
  227. FROM tier
  228. WHERE stripe_price_id = ?
  229. `
  230. updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
  231. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  232. updateBillingQuery = `
  233. UPDATE user
  234. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  235. WHERE user = ?
  236. `
  237. )
  238. // Schema management queries
  239. const (
  240. currentSchemaVersion = 2
  241. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  242. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  243. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  244. // 1 -> 2 (complex migration!)
  245. migrate1To2RenameUserTableQueryNoTx = `
  246. ALTER TABLE user RENAME TO user_old;
  247. `
  248. migrate1To2InsertFromOldTablesAndDropNoTx = `
  249. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
  250. SELECT user, pass, role, '', 'admin', UNIXEPOCH(), UNIXEPOCH() FROM user_old;
  251. INSERT INTO user_access (user_id, topic, read, write)
  252. SELECT u.id, a.topic, a.read, a.write
  253. FROM user u
  254. JOIN access a ON u.user = a.user;
  255. DROP TABLE access;
  256. DROP TABLE user_old;
  257. `
  258. migrate1To2SelectAllUsersIDsNoTx = `SELECT id FROM user`
  259. migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
  260. )
  261. // Manager is an implementation of Manager. It stores users and access control list
  262. // in a SQLite database.
  263. type Manager struct {
  264. db *sql.DB
  265. defaultAccess Permission // Default permission if no ACL matches
  266. statsQueue map[string]*User // Username -> User, for "unimportant" user updates
  267. mu sync.Mutex
  268. }
  269. var _ Auther = (*Manager)(nil)
  270. // NewManager creates a new Manager instance
  271. func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
  272. return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
  273. }
  274. // NewManager creates a new Manager instance
  275. func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
  276. db, err := sql.Open("sqlite3", filename)
  277. if err != nil {
  278. return nil, err
  279. }
  280. if err := setupDB(db); err != nil {
  281. return nil, err
  282. }
  283. if err := runStartupQueries(db, startupQueries); err != nil {
  284. return nil, err
  285. }
  286. manager := &Manager{
  287. db: db,
  288. defaultAccess: defaultAccess,
  289. statsQueue: make(map[string]*User),
  290. }
  291. go manager.userStatsQueueWriter(statsWriterInterval)
  292. return manager, nil
  293. }
  294. // Authenticate checks username and password and returns a User if correct. The method
  295. // returns in constant-ish time, regardless of whether the user exists or the password is
  296. // correct or incorrect.
  297. func (a *Manager) Authenticate(username, password string) (*User, error) {
  298. if username == Everyone {
  299. return nil, ErrUnauthenticated
  300. }
  301. user, err := a.User(username)
  302. if err != nil {
  303. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  304. bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  305. return nil, ErrUnauthenticated
  306. }
  307. if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  308. log.Trace("authentication of user %s failed (2): %s", username, err.Error())
  309. return nil, ErrUnauthenticated
  310. }
  311. return user, nil
  312. }
  313. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  314. // The method sets the User.Token value to the token that was used for authentication.
  315. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  316. if len(token) != tokenLength {
  317. return nil, ErrUnauthenticated
  318. }
  319. user, err := a.userByToken(token)
  320. if err != nil {
  321. return nil, ErrUnauthenticated
  322. }
  323. user.Token = token
  324. return user, nil
  325. }
  326. // CreateToken generates a random token for the given user and returns it. The token expires
  327. // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
  328. // given user, if there are too many of them.
  329. func (a *Manager) CreateToken(user *User) (*Token, error) {
  330. token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
  331. tx, err := a.db.Begin()
  332. if err != nil {
  333. return nil, err
  334. }
  335. defer tx.Rollback()
  336. if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
  337. return nil, err
  338. }
  339. rows, err := tx.Query(selectTokenCountQuery, user.Name)
  340. if err != nil {
  341. return nil, err
  342. }
  343. defer rows.Close()
  344. if !rows.Next() {
  345. return nil, errNoRows
  346. }
  347. var tokenCount int
  348. if err := rows.Scan(&tokenCount); err != nil {
  349. return nil, err
  350. }
  351. if tokenCount >= tokenMaxCount {
  352. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  353. // on two indices, whereas the query below is a full table scan.
  354. if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
  355. return nil, err
  356. }
  357. }
  358. if err := tx.Commit(); err != nil {
  359. return nil, err
  360. }
  361. return &Token{
  362. Value: token,
  363. Expires: expires,
  364. }, nil
  365. }
  366. // ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
  367. func (a *Manager) ExtendToken(user *User) (*Token, error) {
  368. if user.Token == "" {
  369. return nil, errNoTokenProvided
  370. }
  371. newExpires := time.Now().Add(tokenExpiryDuration)
  372. if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
  373. return nil, err
  374. }
  375. return &Token{
  376. Value: user.Token,
  377. Expires: newExpires,
  378. }, nil
  379. }
  380. // RemoveToken deletes the token defined in User.Token
  381. func (a *Manager) RemoveToken(user *User) error {
  382. if user.Token == "" {
  383. return ErrUnauthorized
  384. }
  385. if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil {
  386. return err
  387. }
  388. return nil
  389. }
  390. // RemoveExpiredTokens deletes all expired tokens from the database
  391. func (a *Manager) RemoveExpiredTokens() error {
  392. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  393. return err
  394. }
  395. return nil
  396. }
  397. // ChangeSettings persists the user settings
  398. func (a *Manager) ChangeSettings(user *User) error {
  399. prefs, err := json.Marshal(user.Prefs)
  400. if err != nil {
  401. return err
  402. }
  403. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  404. return err
  405. }
  406. return nil
  407. }
  408. // ResetStats resets all user stats in the user database. This touches all users.
  409. func (a *Manager) ResetStats() error {
  410. a.mu.Lock()
  411. defer a.mu.Unlock()
  412. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  413. return err
  414. }
  415. a.statsQueue = make(map[string]*User)
  416. return nil
  417. }
  418. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  419. // batches at a regular interval
  420. func (a *Manager) EnqueueStats(user *User) {
  421. a.mu.Lock()
  422. defer a.mu.Unlock()
  423. a.statsQueue[user.Name] = user
  424. }
  425. func (a *Manager) userStatsQueueWriter(interval time.Duration) {
  426. ticker := time.NewTicker(interval)
  427. for range ticker.C {
  428. if err := a.writeUserStatsQueue(); err != nil {
  429. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  430. }
  431. }
  432. }
  433. func (a *Manager) writeUserStatsQueue() error {
  434. a.mu.Lock()
  435. if len(a.statsQueue) == 0 {
  436. a.mu.Unlock()
  437. log.Trace("User Manager: No user stats updates to commit")
  438. return nil
  439. }
  440. statsQueue := a.statsQueue
  441. a.statsQueue = make(map[string]*User)
  442. a.mu.Unlock()
  443. tx, err := a.db.Begin()
  444. if err != nil {
  445. return err
  446. }
  447. defer tx.Rollback()
  448. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  449. for username, u := range statsQueue {
  450. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails)
  451. if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil {
  452. return err
  453. }
  454. }
  455. return tx.Commit()
  456. }
  457. // Authorize returns nil if the given user has access to the given topic using the desired
  458. // permission. The user param may be nil to signal an anonymous user.
  459. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  460. if user != nil && user.Role == RoleAdmin {
  461. return nil // Admin can do everything
  462. }
  463. username := Everyone
  464. if user != nil {
  465. username = user.Name
  466. }
  467. // Select the read/write permissions for this user/topic combo. The query may return two
  468. // rows (one for everyone, and one for the user), but prioritizes the user.
  469. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  470. if err != nil {
  471. return err
  472. }
  473. defer rows.Close()
  474. if !rows.Next() {
  475. return a.resolvePerms(a.defaultAccess, perm)
  476. }
  477. var read, write bool
  478. if err := rows.Scan(&read, &write); err != nil {
  479. return err
  480. } else if err := rows.Err(); err != nil {
  481. return err
  482. }
  483. return a.resolvePerms(NewPermission(read, write), perm)
  484. }
  485. func (a *Manager) resolvePerms(base, perm Permission) error {
  486. if perm == PermissionRead && base.IsRead() {
  487. return nil
  488. } else if perm == PermissionWrite && base.IsWrite() {
  489. return nil
  490. }
  491. return ErrUnauthorized
  492. }
  493. // AddUser adds a user with the given username, password and role
  494. func (a *Manager) AddUser(username, password string, role Role, createdBy string) error {
  495. if !AllowedUsername(username) || !AllowedRole(role) {
  496. return ErrInvalidArgument
  497. }
  498. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  499. if err != nil {
  500. return err
  501. }
  502. syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix()
  503. if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now, now); err != nil {
  504. return err
  505. }
  506. return nil
  507. }
  508. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  509. // if the user did not exist in the first place.
  510. func (a *Manager) RemoveUser(username string) error {
  511. if !AllowedUsername(username) {
  512. return ErrInvalidArgument
  513. }
  514. // Rows in user_access, user_token, etc. are deleted via foreign keys
  515. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  516. return err
  517. }
  518. return nil
  519. }
  520. // Users returns a list of users. It always also returns the Everyone user ("*").
  521. func (a *Manager) Users() ([]*User, error) {
  522. rows, err := a.db.Query(selectUsernamesQuery)
  523. if err != nil {
  524. return nil, err
  525. }
  526. defer rows.Close()
  527. usernames := make([]string, 0)
  528. for rows.Next() {
  529. var username string
  530. if err := rows.Scan(&username); err != nil {
  531. return nil, err
  532. } else if err := rows.Err(); err != nil {
  533. return nil, err
  534. }
  535. usernames = append(usernames, username)
  536. }
  537. rows.Close()
  538. users := make([]*User, 0)
  539. for _, username := range usernames {
  540. user, err := a.User(username)
  541. if err != nil {
  542. return nil, err
  543. }
  544. users = append(users, user)
  545. }
  546. return users, nil
  547. }
  548. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  549. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  550. func (a *Manager) User(username string) (*User, error) {
  551. rows, err := a.db.Query(selectUserByNameQuery, username)
  552. if err != nil {
  553. return nil, err
  554. }
  555. return a.readUser(rows)
  556. }
  557. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  558. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  559. if err != nil {
  560. return nil, err
  561. }
  562. return a.readUser(rows)
  563. }
  564. func (a *Manager) userByToken(token string) (*User, error) {
  565. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  566. if err != nil {
  567. return nil, err
  568. }
  569. return a.readUser(rows)
  570. }
  571. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  572. defer rows.Close()
  573. var username, hash, role, prefs, syncTopic string
  574. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName, tierFeatures sql.NullString
  575. var messages, emails int64
  576. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
  577. if !rows.Next() {
  578. return nil, ErrUserNotFound
  579. }
  580. if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &tierFeatures, &stripePriceID); err != nil {
  581. return nil, err
  582. } else if err := rows.Err(); err != nil {
  583. return nil, err
  584. }
  585. user := &User{
  586. Name: username,
  587. Hash: hash,
  588. Role: Role(role),
  589. Prefs: &Prefs{},
  590. SyncTopic: syncTopic,
  591. Stats: &Stats{
  592. Messages: messages,
  593. Emails: emails,
  594. },
  595. Billing: &Billing{
  596. StripeCustomerID: stripeCustomerID.String, // May be empty
  597. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  598. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  599. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  600. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  601. },
  602. }
  603. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  604. return nil, err
  605. }
  606. if tierCode.Valid {
  607. // See readTier() when this is changed!
  608. user.Tier = &Tier{
  609. Code: tierCode.String,
  610. Name: tierName.String,
  611. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  612. MessagesLimit: messagesLimit.Int64,
  613. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  614. EmailsLimit: emailsLimit.Int64,
  615. ReservationsLimit: reservationsLimit.Int64,
  616. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  617. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  618. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  619. Features: tierFeatures.String, // May be empty
  620. StripePriceID: stripePriceID.String, // May be empty
  621. }
  622. }
  623. return user, nil
  624. }
  625. // Grants returns all user-specific access control entries
  626. func (a *Manager) Grants(username string) ([]Grant, error) {
  627. rows, err := a.db.Query(selectUserAccessQuery, username)
  628. if err != nil {
  629. return nil, err
  630. }
  631. defer rows.Close()
  632. grants := make([]Grant, 0)
  633. for rows.Next() {
  634. var topic string
  635. var read, write bool
  636. if err := rows.Scan(&topic, &read, &write); err != nil {
  637. return nil, err
  638. } else if err := rows.Err(); err != nil {
  639. return nil, err
  640. }
  641. grants = append(grants, Grant{
  642. TopicPattern: fromSQLWildcard(topic),
  643. Allow: NewPermission(read, write),
  644. })
  645. }
  646. return grants, nil
  647. }
  648. // Reservations returns all user-owned topics, and the associated everyone-access
  649. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  650. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  651. if err != nil {
  652. return nil, err
  653. }
  654. defer rows.Close()
  655. reservations := make([]Reservation, 0)
  656. for rows.Next() {
  657. var topic string
  658. var ownerRead, ownerWrite bool
  659. var everyoneRead, everyoneWrite sql.NullBool
  660. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  661. return nil, err
  662. } else if err := rows.Err(); err != nil {
  663. return nil, err
  664. }
  665. reservations = append(reservations, Reservation{
  666. Topic: topic,
  667. Owner: NewPermission(ownerRead, ownerWrite),
  668. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  669. })
  670. }
  671. return reservations, nil
  672. }
  673. // HasReservation returns true if the given topic access is owned by the user
  674. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  675. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  676. if err != nil {
  677. return false, err
  678. }
  679. defer rows.Close()
  680. if !rows.Next() {
  681. return false, errNoRows
  682. }
  683. var count int64
  684. if err := rows.Scan(&count); err != nil {
  685. return false, err
  686. }
  687. return count > 0, nil
  688. }
  689. // ReservationsCount returns the number of reservations owned by this user
  690. func (a *Manager) ReservationsCount(username string) (int64, error) {
  691. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  692. if err != nil {
  693. return 0, err
  694. }
  695. defer rows.Close()
  696. if !rows.Next() {
  697. return 0, errNoRows
  698. }
  699. var count int64
  700. if err := rows.Scan(&count); err != nil {
  701. return 0, err
  702. }
  703. return count, nil
  704. }
  705. // ChangePassword changes a user's password
  706. func (a *Manager) ChangePassword(username, password string) error {
  707. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  708. if err != nil {
  709. return err
  710. }
  711. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  712. return err
  713. }
  714. return nil
  715. }
  716. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  717. // all existing access control entries (Grant) are removed, since they are no longer needed.
  718. func (a *Manager) ChangeRole(username string, role Role) error {
  719. if !AllowedUsername(username) || !AllowedRole(role) {
  720. return ErrInvalidArgument
  721. }
  722. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  723. return err
  724. }
  725. if role == RoleAdmin {
  726. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  727. return err
  728. }
  729. }
  730. return nil
  731. }
  732. // ChangeTier changes a user's tier using the tier code
  733. func (a *Manager) ChangeTier(username, tier string) error {
  734. if !AllowedUsername(username) {
  735. return ErrInvalidArgument
  736. }
  737. rows, err := a.db.Query(selectTierIDQuery, tier)
  738. if err != nil {
  739. return err
  740. }
  741. defer rows.Close()
  742. if !rows.Next() {
  743. return ErrInvalidArgument
  744. }
  745. var tierID int64
  746. if err := rows.Scan(&tierID); err != nil {
  747. return err
  748. }
  749. rows.Close()
  750. if _, err := a.db.Exec(updateUserTierQuery, tierID, username); err != nil {
  751. return err
  752. }
  753. return nil
  754. }
  755. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  756. // If there are any ACL entries that are not owned by the user, an error is returned.
  757. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  758. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  759. return ErrInvalidArgument
  760. }
  761. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  762. if err != nil {
  763. return err
  764. }
  765. defer rows.Close()
  766. if !rows.Next() {
  767. return errNoRows
  768. }
  769. var otherCount int
  770. if err := rows.Scan(&otherCount); err != nil {
  771. return err
  772. }
  773. if otherCount > 0 {
  774. return errTopicOwnedByOthers
  775. }
  776. return nil
  777. }
  778. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  779. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  780. // owner may either be a user (username), or the system (empty).
  781. func (a *Manager) AllowAccess(owner, username string, topicPattern string, read bool, write bool) error {
  782. if !AllowedUsername(username) && username != Everyone {
  783. return ErrInvalidArgument
  784. } else if owner != "" && !AllowedUsername(owner) {
  785. return ErrInvalidArgument
  786. } else if !AllowedTopicPattern(topicPattern) {
  787. return ErrInvalidArgument
  788. }
  789. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), read, write, owner, owner); err != nil {
  790. return err
  791. }
  792. return nil
  793. }
  794. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  795. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  796. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  797. if !AllowedUsername(username) && username != Everyone && username != "" {
  798. return ErrInvalidArgument
  799. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  800. return ErrInvalidArgument
  801. }
  802. if username == "" && topicPattern == "" {
  803. _, err := a.db.Exec(deleteAllAccessQuery, username)
  804. return err
  805. } else if topicPattern == "" {
  806. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  807. return err
  808. }
  809. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  810. return err
  811. }
  812. // ResetTier removes the tier from the given user
  813. func (a *Manager) ResetTier(username string) error {
  814. if !AllowedUsername(username) && username != Everyone && username != "" {
  815. return ErrInvalidArgument
  816. }
  817. _, err := a.db.Exec(deleteUserTierQuery, username)
  818. return err
  819. }
  820. // DefaultAccess returns the default read/write access if no access control entry matches
  821. func (a *Manager) DefaultAccess() Permission {
  822. return a.defaultAccess
  823. }
  824. // CreateTier creates a new tier in the database
  825. func (a *Manager) CreateTier(tier *Tier) error {
  826. if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil {
  827. return err
  828. }
  829. return nil
  830. }
  831. func (a *Manager) ChangeBilling(user *User) error {
  832. if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(user.Billing.StripeSubscriptionCancelAt.Unix()), user.Name); err != nil {
  833. return err
  834. }
  835. return nil
  836. }
  837. func (a *Manager) Tiers() ([]*Tier, error) {
  838. rows, err := a.db.Query(selectTiersQuery)
  839. if err != nil {
  840. return nil, err
  841. }
  842. defer rows.Close()
  843. tiers := make([]*Tier, 0)
  844. for {
  845. tier, err := a.readTier(rows)
  846. if err == ErrTierNotFound {
  847. break
  848. } else if err != nil {
  849. return nil, err
  850. }
  851. tiers = append(tiers, tier)
  852. }
  853. return tiers, nil
  854. }
  855. func (a *Manager) Tier(code string) (*Tier, error) {
  856. rows, err := a.db.Query(selectTierByCodeQuery, code)
  857. if err != nil {
  858. return nil, err
  859. }
  860. defer rows.Close()
  861. return a.readTier(rows)
  862. }
  863. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  864. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  865. if err != nil {
  866. return nil, err
  867. }
  868. defer rows.Close()
  869. return a.readTier(rows)
  870. }
  871. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  872. var code, name string
  873. var features, stripePriceID sql.NullString
  874. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
  875. if !rows.Next() {
  876. return nil, ErrTierNotFound
  877. }
  878. if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &features, &stripePriceID); err != nil {
  879. return nil, err
  880. } else if err := rows.Err(); err != nil {
  881. return nil, err
  882. }
  883. // When changed, note readUser() as well
  884. return &Tier{
  885. Code: code,
  886. Name: name,
  887. Paid: stripePriceID.Valid, // If there is a price, it's a paid tier
  888. MessagesLimit: messagesLimit.Int64,
  889. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  890. EmailsLimit: emailsLimit.Int64,
  891. ReservationsLimit: reservationsLimit.Int64,
  892. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  893. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  894. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  895. Features: features.String, // May be empty
  896. StripePriceID: stripePriceID.String, // May be empty
  897. }, nil
  898. }
  899. func toSQLWildcard(s string) string {
  900. return strings.ReplaceAll(s, "*", "%")
  901. }
  902. func fromSQLWildcard(s string) string {
  903. return strings.ReplaceAll(s, "%", "*")
  904. }
  905. func runStartupQueries(db *sql.DB, startupQueries string) error {
  906. if _, err := db.Exec(startupQueries); err != nil {
  907. return err
  908. }
  909. if _, err := db.Exec(builtinStartupQueries); err != nil {
  910. return err
  911. }
  912. return nil
  913. }
  914. func setupDB(db *sql.DB) error {
  915. // If 'schemaVersion' table does not exist, this must be a new database
  916. rowsSV, err := db.Query(selectSchemaVersionQuery)
  917. if err != nil {
  918. return setupNewDB(db)
  919. }
  920. defer rowsSV.Close()
  921. // If 'schemaVersion' table exists, read version and potentially upgrade
  922. schemaVersion := 0
  923. if !rowsSV.Next() {
  924. return errors.New("cannot determine schema version: database file may be corrupt")
  925. }
  926. if err := rowsSV.Scan(&schemaVersion); err != nil {
  927. return err
  928. }
  929. rowsSV.Close()
  930. // Do migrations
  931. if schemaVersion == currentSchemaVersion {
  932. return nil
  933. } else if schemaVersion == 1 {
  934. return migrateFrom1(db)
  935. }
  936. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  937. }
  938. func setupNewDB(db *sql.DB) error {
  939. if _, err := db.Exec(createTablesQueries); err != nil {
  940. return err
  941. }
  942. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  943. return err
  944. }
  945. return nil
  946. }
  947. func migrateFrom1(db *sql.DB) error {
  948. log.Info("Migrating user database schema: from 1 to 2")
  949. tx, err := db.Begin()
  950. if err != nil {
  951. return err
  952. }
  953. defer tx.Rollback()
  954. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  955. return err
  956. }
  957. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  958. return err
  959. }
  960. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  961. return err
  962. }
  963. rows, err := tx.Query(migrate1To2SelectAllUsersIDsNoTx)
  964. if err != nil {
  965. return err
  966. }
  967. defer rows.Close()
  968. syncTopics := make(map[int]string)
  969. for rows.Next() {
  970. var userID int
  971. if err := rows.Scan(&userID); err != nil {
  972. return err
  973. }
  974. syncTopics[userID] = util.RandomString(syncTopicLength)
  975. }
  976. if err := rows.Close(); err != nil {
  977. return err
  978. }
  979. for userID, syncTopic := range syncTopics {
  980. if _, err := tx.Exec(migrate1To2UpdateSyncTopicNoTx, syncTopic, userID); err != nil {
  981. return err
  982. }
  983. }
  984. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  985. return err
  986. }
  987. if err := tx.Commit(); err != nil {
  988. return err
  989. }
  990. return nil // Update this when a new version is added
  991. }
  992. func nullString(s string) sql.NullString {
  993. if s == "" {
  994. return sql.NullString{}
  995. }
  996. return sql.NullString{String: s, Valid: true}
  997. }
  998. func nullInt64(v int64) sql.NullInt64 {
  999. if v == 0 {
  1000. return sql.NullInt64{}
  1001. }
  1002. return sql.NullInt64{Int64: v, Valid: true}
  1003. }