manager.go 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "golang.org/x/crypto/bcrypt"
  9. "heckel.io/ntfy/log"
  10. "heckel.io/ntfy/util"
  11. "strings"
  12. "sync"
  13. "time"
  14. )
  15. const (
  16. bcryptCost = 10
  17. intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
  18. userStatsQueueWriterInterval = 33 * time.Second
  19. tokenLength = 32
  20. tokenExpiryDuration = 72 * time.Hour // Extend tokens by this much
  21. syncTopicLength = 16
  22. tokenMaxCount = 10 // Only keep this many tokens in the table per user
  23. )
  24. var (
  25. errNoTokenProvided = errors.New("no token provided")
  26. errTopicOwnedByOthers = errors.New("topic owned by others")
  27. errNoRows = errors.New("no rows found")
  28. )
  29. // Manager-related queries
  30. const (
  31. createTablesQueriesNoTx = `
  32. CREATE TABLE IF NOT EXISTS tier (
  33. id INTEGER PRIMARY KEY AUTOINCREMENT,
  34. code TEXT NOT NULL,
  35. name TEXT NOT NULL,
  36. paid INT NOT NULL,
  37. messages_limit INT NOT NULL,
  38. messages_expiry_duration INT NOT NULL,
  39. emails_limit INT NOT NULL,
  40. reservations_limit INT NOT NULL,
  41. attachment_file_size_limit INT NOT NULL,
  42. attachment_total_size_limit INT NOT NULL,
  43. attachment_expiry_duration INT NOT NULL,
  44. stripe_price_id TEXT
  45. );
  46. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  47. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  48. CREATE TABLE IF NOT EXISTS user (
  49. id INTEGER PRIMARY KEY AUTOINCREMENT,
  50. tier_id INT,
  51. user TEXT NOT NULL,
  52. pass TEXT NOT NULL,
  53. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  54. prefs JSON NOT NULL DEFAULT '{}',
  55. sync_topic TEXT NOT NULL,
  56. stats_messages INT NOT NULL DEFAULT (0),
  57. stats_emails INT NOT NULL DEFAULT (0),
  58. stripe_customer_id TEXT,
  59. stripe_subscription_id TEXT,
  60. created_by TEXT NOT NULL,
  61. created_at INT NOT NULL,
  62. last_seen INT NOT NULL,
  63. FOREIGN KEY (tier_id) REFERENCES tier (id)
  64. );
  65. CREATE UNIQUE INDEX idx_user ON user (user);
  66. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  67. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  68. CREATE TABLE IF NOT EXISTS user_access (
  69. user_id INT NOT NULL,
  70. topic TEXT NOT NULL,
  71. read INT NOT NULL,
  72. write INT NOT NULL,
  73. owner_user_id INT,
  74. PRIMARY KEY (user_id, topic),
  75. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  76. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  77. );
  78. CREATE TABLE IF NOT EXISTS user_token (
  79. user_id INT NOT NULL,
  80. token TEXT NOT NULL,
  81. expires INT NOT NULL,
  82. PRIMARY KEY (user_id, token),
  83. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  84. );
  85. CREATE TABLE IF NOT EXISTS schemaVersion (
  86. id INT PRIMARY KEY,
  87. version INT NOT NULL
  88. );
  89. INSERT INTO user (id, user, pass, role, sync_topic, created_by, created_at, last_seen)
  90. VALUES (1, '*', '', 'anonymous', '', 'system', UNIXEPOCH(), 0)
  91. ON CONFLICT (id) DO NOTHING;
  92. `
  93. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  94. builtinStartupQueries = `
  95. PRAGMA foreign_keys = ON;
  96. `
  97. selectUserByNameQuery = `
  98. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
  99. FROM user u
  100. LEFT JOIN tier p on p.id = u.tier_id
  101. WHERE user = ?
  102. `
  103. selectUserByTokenQuery = `
  104. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
  105. FROM user u
  106. JOIN user_token t on u.id = t.user_id
  107. LEFT JOIN tier p on p.id = u.tier_id
  108. WHERE t.token = ? AND t.expires >= ?
  109. `
  110. selectUserByStripeCustomerIDQuery = `
  111. SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
  112. FROM user u
  113. LEFT JOIN tier p on p.id = u.tier_id
  114. WHERE u.stripe_customer_id = ?
  115. `
  116. selectTopicPermsQuery = `
  117. SELECT read, write
  118. FROM user_access a
  119. JOIN user u ON u.id = a.user_id
  120. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  121. ORDER BY u.user DESC
  122. `
  123. insertUserQuery = `
  124. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
  125. VALUES (?, ?, ?, ?, ?, ?, ?)
  126. `
  127. selectUsernamesQuery = `
  128. SELECT user
  129. FROM user
  130. ORDER BY
  131. CASE role
  132. WHEN 'admin' THEN 1
  133. WHEN 'anonymous' THEN 3
  134. ELSE 2
  135. END, user
  136. `
  137. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  138. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  139. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  140. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE user = ?`
  141. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  142. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  143. upsertUserAccessQuery = `
  144. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  145. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  146. ON CONFLICT (user_id, topic)
  147. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  148. `
  149. selectUserAccessQuery = `
  150. SELECT topic, read, write
  151. FROM user_access
  152. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  153. ORDER BY write DESC, read DESC, topic
  154. `
  155. selectUserReservationsQuery = `
  156. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  157. FROM user_access a_user
  158. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  159. WHERE a_user.user_id = a_user.owner_user_id
  160. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  161. ORDER BY a_user.topic
  162. `
  163. selectUserReservationsCountQuery = `
  164. SELECT COUNT(*)
  165. FROM user_access
  166. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  167. `
  168. selectUserHasReservationQuery = `
  169. SELECT COUNT(*)
  170. FROM user_access
  171. WHERE user_id = owner_user_id
  172. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  173. AND topic = ?
  174. `
  175. selectOtherAccessCountQuery = `
  176. SELECT COUNT(*)
  177. FROM user_access
  178. WHERE (topic = ? OR ? LIKE topic)
  179. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  180. `
  181. deleteAllAccessQuery = `DELETE FROM user_access`
  182. deleteUserAccessQuery = `
  183. DELETE FROM user_access
  184. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  185. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  186. `
  187. deleteTopicAccessQuery = `
  188. DELETE FROM user_access
  189. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  190. AND topic = ?
  191. `
  192. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE (SELECT id FROM user WHERE user = ?)`
  193. insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
  194. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  195. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = (SELECT id FROM user WHERE user = ?) AND token = ?`
  196. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires < ?`
  197. deleteExcessTokensQuery = `
  198. DELETE FROM user_token
  199. WHERE (user_id, token) NOT IN (
  200. SELECT user_id, token
  201. FROM user_token
  202. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  203. ORDER BY expires DESC
  204. LIMIT ?
  205. )
  206. `
  207. insertTierQuery = `
  208. INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
  209. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  210. `
  211. selectTierIDQuery = `SELECT id FROM tier WHERE code = ?`
  212. selectTierByCodeQuery = `
  213. SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  214. FROM tier
  215. WHERE code = ?
  216. `
  217. selectTierByPriceIDQuery = `
  218. SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
  219. FROM tier
  220. WHERE stripe_price_id = ?
  221. `
  222. updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
  223. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  224. updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
  225. )
  226. // Schema management queries
  227. const (
  228. currentSchemaVersion = 2
  229. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  230. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  231. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  232. // 1 -> 2 (complex migration!)
  233. migrate1To2RenameUserTableQueryNoTx = `
  234. ALTER TABLE user RENAME TO user_old;
  235. `
  236. migrate1To2InsertFromOldTablesAndDropNoTx = `
  237. INSERT INTO user (user, pass, role, sync_topic, created_by, created_at, last_seen)
  238. SELECT user, pass, role, '', 'admin', UNIXEPOCH(), UNIXEPOCH() FROM user_old;
  239. INSERT INTO user_access (user_id, topic, read, write)
  240. SELECT u.id, a.topic, a.read, a.write
  241. FROM user u
  242. JOIN access a ON u.user = a.user;
  243. DROP TABLE access;
  244. DROP TABLE user_old;
  245. `
  246. migrate1To2SelectAllUsersIDsNoTx = `SELECT id FROM user`
  247. migrate1To2UpdateSyncTopicNoTx = `UPDATE user SET sync_topic = ? WHERE id = ?`
  248. )
  249. // Manager is an implementation of Manager. It stores users and access control list
  250. // in a SQLite database.
  251. type Manager struct {
  252. db *sql.DB
  253. defaultAccess Permission // Default permission if no ACL matches
  254. statsQueue map[string]*User // Username -> User, for "unimportant" user updates
  255. mu sync.Mutex
  256. }
  257. var _ Auther = (*Manager)(nil)
  258. // NewManager creates a new Manager instance
  259. func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
  260. return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
  261. }
  262. // NewManager creates a new Manager instance
  263. func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
  264. db, err := sql.Open("sqlite3", filename)
  265. if err != nil {
  266. return nil, err
  267. }
  268. if err := setupDB(db); err != nil {
  269. return nil, err
  270. }
  271. if err := runStartupQueries(db, startupQueries); err != nil {
  272. return nil, err
  273. }
  274. manager := &Manager{
  275. db: db,
  276. defaultAccess: defaultAccess,
  277. statsQueue: make(map[string]*User),
  278. }
  279. go manager.userStatsQueueWriter(statsWriterInterval)
  280. return manager, nil
  281. }
  282. // Authenticate checks username and password and returns a User if correct. The method
  283. // returns in constant-ish time, regardless of whether the user exists or the password is
  284. // correct or incorrect.
  285. func (a *Manager) Authenticate(username, password string) (*User, error) {
  286. if username == Everyone {
  287. return nil, ErrUnauthenticated
  288. }
  289. user, err := a.User(username)
  290. if err != nil {
  291. log.Trace("authentication of user %s failed (1): %s", username, err.Error())
  292. bcrypt.CompareHashAndPassword([]byte(intentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  293. return nil, ErrUnauthenticated
  294. }
  295. if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  296. log.Trace("authentication of user %s failed (2): %s", username, err.Error())
  297. return nil, ErrUnauthenticated
  298. }
  299. return user, nil
  300. }
  301. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  302. // The method sets the User.Token value to the token that was used for authentication.
  303. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  304. if len(token) != tokenLength {
  305. return nil, ErrUnauthenticated
  306. }
  307. user, err := a.userByToken(token)
  308. if err != nil {
  309. return nil, ErrUnauthenticated
  310. }
  311. user.Token = token
  312. return user, nil
  313. }
  314. // CreateToken generates a random token for the given user and returns it. The token expires
  315. // after a fixed duration unless ExtendToken is called. This function also prunes tokens for the
  316. // given user, if there are too many of them.
  317. func (a *Manager) CreateToken(user *User) (*Token, error) {
  318. token, expires := util.RandomString(tokenLength), time.Now().Add(tokenExpiryDuration)
  319. tx, err := a.db.Begin()
  320. if err != nil {
  321. return nil, err
  322. }
  323. defer tx.Rollback()
  324. if _, err := tx.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
  325. return nil, err
  326. }
  327. rows, err := tx.Query(selectTokenCountQuery, user.Name)
  328. if err != nil {
  329. return nil, err
  330. }
  331. defer rows.Close()
  332. if !rows.Next() {
  333. return nil, errNoRows
  334. }
  335. var tokenCount int
  336. if err := rows.Scan(&tokenCount); err != nil {
  337. return nil, err
  338. }
  339. if tokenCount >= tokenMaxCount {
  340. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  341. // on two indices, whereas the query below is a full table scan.
  342. if _, err := tx.Exec(deleteExcessTokensQuery, user.Name, tokenMaxCount); err != nil {
  343. return nil, err
  344. }
  345. }
  346. if err := tx.Commit(); err != nil {
  347. return nil, err
  348. }
  349. return &Token{
  350. Value: token,
  351. Expires: expires,
  352. }, nil
  353. }
  354. // ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
  355. func (a *Manager) ExtendToken(user *User) (*Token, error) {
  356. if user.Token == "" {
  357. return nil, errNoTokenProvided
  358. }
  359. newExpires := time.Now().Add(tokenExpiryDuration)
  360. if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
  361. return nil, err
  362. }
  363. return &Token{
  364. Value: user.Token,
  365. Expires: newExpires,
  366. }, nil
  367. }
  368. // RemoveToken deletes the token defined in User.Token
  369. func (a *Manager) RemoveToken(user *User) error {
  370. if user.Token == "" {
  371. return ErrUnauthorized
  372. }
  373. if _, err := a.db.Exec(deleteTokenQuery, user.Name, user.Token); err != nil {
  374. return err
  375. }
  376. return nil
  377. }
  378. // RemoveExpiredTokens deletes all expired tokens from the database
  379. func (a *Manager) RemoveExpiredTokens() error {
  380. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  381. return err
  382. }
  383. return nil
  384. }
  385. // ChangeSettings persists the user settings
  386. func (a *Manager) ChangeSettings(user *User) error {
  387. prefs, err := json.Marshal(user.Prefs)
  388. if err != nil {
  389. return err
  390. }
  391. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  392. return err
  393. }
  394. return nil
  395. }
  396. // ResetStats resets all user stats in the user database. This touches all users.
  397. func (a *Manager) ResetStats() error {
  398. a.mu.Lock()
  399. defer a.mu.Unlock()
  400. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  401. return err
  402. }
  403. a.statsQueue = make(map[string]*User)
  404. return nil
  405. }
  406. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  407. // batches at a regular interval
  408. func (a *Manager) EnqueueStats(user *User) {
  409. a.mu.Lock()
  410. defer a.mu.Unlock()
  411. a.statsQueue[user.Name] = user
  412. }
  413. func (a *Manager) userStatsQueueWriter(interval time.Duration) {
  414. ticker := time.NewTicker(interval)
  415. for range ticker.C {
  416. if err := a.writeUserStatsQueue(); err != nil {
  417. log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
  418. }
  419. }
  420. }
  421. func (a *Manager) writeUserStatsQueue() error {
  422. a.mu.Lock()
  423. if len(a.statsQueue) == 0 {
  424. a.mu.Unlock()
  425. log.Trace("User Manager: No user stats updates to commit")
  426. return nil
  427. }
  428. statsQueue := a.statsQueue
  429. a.statsQueue = make(map[string]*User)
  430. a.mu.Unlock()
  431. tx, err := a.db.Begin()
  432. if err != nil {
  433. return err
  434. }
  435. defer tx.Rollback()
  436. log.Debug("User Manager: Writing user stats queue for %d user(s)", len(statsQueue))
  437. for username, u := range statsQueue {
  438. log.Trace("User Manager: Updating stats for user %s: messages=%d, emails=%d", username, u.Stats.Messages, u.Stats.Emails)
  439. if _, err := tx.Exec(updateUserStatsQuery, u.Stats.Messages, u.Stats.Emails, username); err != nil {
  440. return err
  441. }
  442. }
  443. return tx.Commit()
  444. }
  445. // Authorize returns nil if the given user has access to the given topic using the desired
  446. // permission. The user param may be nil to signal an anonymous user.
  447. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  448. if user != nil && user.Role == RoleAdmin {
  449. return nil // Admin can do everything
  450. }
  451. username := Everyone
  452. if user != nil {
  453. username = user.Name
  454. }
  455. // Select the read/write permissions for this user/topic combo. The query may return two
  456. // rows (one for everyone, and one for the user), but prioritizes the user.
  457. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  458. if err != nil {
  459. return err
  460. }
  461. defer rows.Close()
  462. if !rows.Next() {
  463. return a.resolvePerms(a.defaultAccess, perm)
  464. }
  465. var read, write bool
  466. if err := rows.Scan(&read, &write); err != nil {
  467. return err
  468. } else if err := rows.Err(); err != nil {
  469. return err
  470. }
  471. return a.resolvePerms(NewPermission(read, write), perm)
  472. }
  473. func (a *Manager) resolvePerms(base, perm Permission) error {
  474. if perm == PermissionRead && base.IsRead() {
  475. return nil
  476. } else if perm == PermissionWrite && base.IsWrite() {
  477. return nil
  478. }
  479. return ErrUnauthorized
  480. }
  481. // AddUser adds a user with the given username, password and role
  482. func (a *Manager) AddUser(username, password string, role Role, createdBy string) error {
  483. if !AllowedUsername(username) || !AllowedRole(role) {
  484. return ErrInvalidArgument
  485. }
  486. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  487. if err != nil {
  488. return err
  489. }
  490. syncTopic, now := util.RandomString(syncTopicLength), time.Now().Unix()
  491. if _, err = a.db.Exec(insertUserQuery, username, hash, role, syncTopic, createdBy, now, now); err != nil {
  492. return err
  493. }
  494. return nil
  495. }
  496. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  497. // if the user did not exist in the first place.
  498. func (a *Manager) RemoveUser(username string) error {
  499. if !AllowedUsername(username) {
  500. return ErrInvalidArgument
  501. }
  502. // Rows in user_access, user_token, etc. are deleted via foreign keys
  503. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  504. return err
  505. }
  506. return nil
  507. }
  508. // Users returns a list of users. It always also returns the Everyone user ("*").
  509. func (a *Manager) Users() ([]*User, error) {
  510. rows, err := a.db.Query(selectUsernamesQuery)
  511. if err != nil {
  512. return nil, err
  513. }
  514. defer rows.Close()
  515. usernames := make([]string, 0)
  516. for rows.Next() {
  517. var username string
  518. if err := rows.Scan(&username); err != nil {
  519. return nil, err
  520. } else if err := rows.Err(); err != nil {
  521. return nil, err
  522. }
  523. usernames = append(usernames, username)
  524. }
  525. rows.Close()
  526. users := make([]*User, 0)
  527. for _, username := range usernames {
  528. user, err := a.User(username)
  529. if err != nil {
  530. return nil, err
  531. }
  532. users = append(users, user)
  533. }
  534. return users, nil
  535. }
  536. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  537. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  538. func (a *Manager) User(username string) (*User, error) {
  539. rows, err := a.db.Query(selectUserByNameQuery, username)
  540. if err != nil {
  541. return nil, err
  542. }
  543. return a.readUser(rows)
  544. }
  545. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  546. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  547. if err != nil {
  548. return nil, err
  549. }
  550. return a.readUser(rows)
  551. }
  552. func (a *Manager) userByToken(token string) (*User, error) {
  553. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  554. if err != nil {
  555. return nil, err
  556. }
  557. return a.readUser(rows)
  558. }
  559. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  560. defer rows.Close()
  561. var username, hash, role, prefs, syncTopic string
  562. var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
  563. var paid sql.NullBool
  564. var messages, emails int64
  565. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
  566. if !rows.Next() {
  567. return nil, ErrUserNotFound
  568. }
  569. if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  570. return nil, err
  571. } else if err := rows.Err(); err != nil {
  572. return nil, err
  573. }
  574. user := &User{
  575. Name: username,
  576. Hash: hash,
  577. Role: Role(role),
  578. Prefs: &Prefs{},
  579. SyncTopic: syncTopic,
  580. Stats: &Stats{
  581. Messages: messages,
  582. Emails: emails,
  583. },
  584. }
  585. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  586. return nil, err
  587. }
  588. if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
  589. user.Billing = &Billing{
  590. StripeCustomerID: stripeCustomerID.String,
  591. StripeSubscriptionID: stripeSubscriptionID.String,
  592. }
  593. }
  594. if tierCode.Valid {
  595. // See readTier() when this is changed!
  596. user.Tier = &Tier{
  597. Code: tierCode.String,
  598. Name: tierName.String,
  599. Paid: paid.Bool,
  600. MessagesLimit: messagesLimit.Int64,
  601. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  602. EmailsLimit: emailsLimit.Int64,
  603. ReservationsLimit: reservationsLimit.Int64,
  604. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  605. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  606. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  607. StripePriceID: stripePriceID.String,
  608. }
  609. }
  610. return user, nil
  611. }
  612. // Grants returns all user-specific access control entries
  613. func (a *Manager) Grants(username string) ([]Grant, error) {
  614. rows, err := a.db.Query(selectUserAccessQuery, username)
  615. if err != nil {
  616. return nil, err
  617. }
  618. defer rows.Close()
  619. grants := make([]Grant, 0)
  620. for rows.Next() {
  621. var topic string
  622. var read, write bool
  623. if err := rows.Scan(&topic, &read, &write); err != nil {
  624. return nil, err
  625. } else if err := rows.Err(); err != nil {
  626. return nil, err
  627. }
  628. grants = append(grants, Grant{
  629. TopicPattern: fromSQLWildcard(topic),
  630. Allow: NewPermission(read, write),
  631. })
  632. }
  633. return grants, nil
  634. }
  635. // Reservations returns all user-owned topics, and the associated everyone-access
  636. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  637. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  638. if err != nil {
  639. return nil, err
  640. }
  641. defer rows.Close()
  642. reservations := make([]Reservation, 0)
  643. for rows.Next() {
  644. var topic string
  645. var ownerRead, ownerWrite bool
  646. var everyoneRead, everyoneWrite sql.NullBool
  647. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  648. return nil, err
  649. } else if err := rows.Err(); err != nil {
  650. return nil, err
  651. }
  652. reservations = append(reservations, Reservation{
  653. Topic: topic,
  654. Owner: NewPermission(ownerRead, ownerWrite),
  655. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  656. })
  657. }
  658. return reservations, nil
  659. }
  660. // HasReservation returns true if the given topic access is owned by the user
  661. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  662. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  663. if err != nil {
  664. return false, err
  665. }
  666. defer rows.Close()
  667. if !rows.Next() {
  668. return false, errNoRows
  669. }
  670. var count int64
  671. if err := rows.Scan(&count); err != nil {
  672. return false, err
  673. }
  674. return count > 0, nil
  675. }
  676. // ReservationsCount returns the number of reservations owned by this user
  677. func (a *Manager) ReservationsCount(username string) (int64, error) {
  678. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  679. if err != nil {
  680. return 0, err
  681. }
  682. defer rows.Close()
  683. if !rows.Next() {
  684. return 0, errNoRows
  685. }
  686. var count int64
  687. if err := rows.Scan(&count); err != nil {
  688. return 0, err
  689. }
  690. return count, nil
  691. }
  692. // ChangePassword changes a user's password
  693. func (a *Manager) ChangePassword(username, password string) error {
  694. hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
  695. if err != nil {
  696. return err
  697. }
  698. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  699. return err
  700. }
  701. return nil
  702. }
  703. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  704. // all existing access control entries (Grant) are removed, since they are no longer needed.
  705. func (a *Manager) ChangeRole(username string, role Role) error {
  706. if !AllowedUsername(username) || !AllowedRole(role) {
  707. return ErrInvalidArgument
  708. }
  709. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  710. return err
  711. }
  712. if role == RoleAdmin {
  713. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  714. return err
  715. }
  716. }
  717. return nil
  718. }
  719. // ChangeTier changes a user's tier using the tier code
  720. func (a *Manager) ChangeTier(username, tier string) error {
  721. if !AllowedUsername(username) {
  722. return ErrInvalidArgument
  723. }
  724. rows, err := a.db.Query(selectTierIDQuery, tier)
  725. if err != nil {
  726. return err
  727. }
  728. defer rows.Close()
  729. if !rows.Next() {
  730. return ErrInvalidArgument
  731. }
  732. var tierID int64
  733. if err := rows.Scan(&tierID); err != nil {
  734. return err
  735. }
  736. rows.Close()
  737. if _, err := a.db.Exec(updateUserTierQuery, tierID, username); err != nil {
  738. return err
  739. }
  740. return nil
  741. }
  742. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  743. // If there are any ACL entries that are not owned by the user, an error is returned.
  744. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  745. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  746. return ErrInvalidArgument
  747. }
  748. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  749. if err != nil {
  750. return err
  751. }
  752. defer rows.Close()
  753. if !rows.Next() {
  754. return errNoRows
  755. }
  756. var otherCount int
  757. if err := rows.Scan(&otherCount); err != nil {
  758. return err
  759. }
  760. if otherCount > 0 {
  761. return errTopicOwnedByOthers
  762. }
  763. return nil
  764. }
  765. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  766. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  767. // owner may either be a user (username), or the system (empty).
  768. func (a *Manager) AllowAccess(owner, username string, topicPattern string, read bool, write bool) error {
  769. if !AllowedUsername(username) && username != Everyone {
  770. return ErrInvalidArgument
  771. } else if owner != "" && !AllowedUsername(owner) {
  772. return ErrInvalidArgument
  773. } else if !AllowedTopicPattern(topicPattern) {
  774. return ErrInvalidArgument
  775. }
  776. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), read, write, owner, owner); err != nil {
  777. return err
  778. }
  779. return nil
  780. }
  781. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  782. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  783. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  784. if !AllowedUsername(username) && username != Everyone && username != "" {
  785. return ErrInvalidArgument
  786. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  787. return ErrInvalidArgument
  788. }
  789. if username == "" && topicPattern == "" {
  790. _, err := a.db.Exec(deleteAllAccessQuery, username)
  791. return err
  792. } else if topicPattern == "" {
  793. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  794. return err
  795. }
  796. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  797. return err
  798. }
  799. // ResetTier removes the tier from the given user
  800. func (a *Manager) ResetTier(username string) error {
  801. if !AllowedUsername(username) && username != Everyone && username != "" {
  802. return ErrInvalidArgument
  803. }
  804. _, err := a.db.Exec(deleteUserTierQuery, username)
  805. return err
  806. }
  807. // DefaultAccess returns the default read/write access if no access control entry matches
  808. func (a *Manager) DefaultAccess() Permission {
  809. return a.defaultAccess
  810. }
  811. // CreateTier creates a new tier in the database
  812. func (a *Manager) CreateTier(tier *Tier) error {
  813. if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.Paid, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil {
  814. return err
  815. }
  816. return nil
  817. }
  818. func (a *Manager) ChangeBilling(user *User) error {
  819. if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
  820. return err
  821. }
  822. return nil
  823. }
  824. func (a *Manager) Tier(code string) (*Tier, error) {
  825. rows, err := a.db.Query(selectTierByCodeQuery, code)
  826. if err != nil {
  827. return nil, err
  828. }
  829. return a.readTier(rows)
  830. }
  831. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  832. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  833. if err != nil {
  834. return nil, err
  835. }
  836. return a.readTier(rows)
  837. }
  838. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  839. defer rows.Close()
  840. var code, name string
  841. var stripePriceID sql.NullString
  842. var paid bool
  843. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
  844. if !rows.Next() {
  845. return nil, ErrTierNotFound
  846. }
  847. if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
  848. return nil, err
  849. } else if err := rows.Err(); err != nil {
  850. return nil, err
  851. }
  852. // When changed, note readUser() as well
  853. return &Tier{
  854. Code: code,
  855. Name: name,
  856. Paid: paid,
  857. MessagesLimit: messagesLimit.Int64,
  858. MessagesExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  859. EmailsLimit: emailsLimit.Int64,
  860. ReservationsLimit: reservationsLimit.Int64,
  861. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  862. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  863. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  864. StripePriceID: stripePriceID.String, // May be empty!
  865. }, nil
  866. }
  867. func toSQLWildcard(s string) string {
  868. return strings.ReplaceAll(s, "*", "%")
  869. }
  870. func fromSQLWildcard(s string) string {
  871. return strings.ReplaceAll(s, "%", "*")
  872. }
  873. func runStartupQueries(db *sql.DB, startupQueries string) error {
  874. if _, err := db.Exec(startupQueries); err != nil {
  875. return err
  876. }
  877. if _, err := db.Exec(builtinStartupQueries); err != nil {
  878. return err
  879. }
  880. return nil
  881. }
  882. func setupDB(db *sql.DB) error {
  883. // If 'schemaVersion' table does not exist, this must be a new database
  884. rowsSV, err := db.Query(selectSchemaVersionQuery)
  885. if err != nil {
  886. return setupNewDB(db)
  887. }
  888. defer rowsSV.Close()
  889. // If 'schemaVersion' table exists, read version and potentially upgrade
  890. schemaVersion := 0
  891. if !rowsSV.Next() {
  892. return errors.New("cannot determine schema version: database file may be corrupt")
  893. }
  894. if err := rowsSV.Scan(&schemaVersion); err != nil {
  895. return err
  896. }
  897. rowsSV.Close()
  898. // Do migrations
  899. if schemaVersion == currentSchemaVersion {
  900. return nil
  901. } else if schemaVersion == 1 {
  902. return migrateFrom1(db)
  903. }
  904. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  905. }
  906. func setupNewDB(db *sql.DB) error {
  907. if _, err := db.Exec(createTablesQueries); err != nil {
  908. return err
  909. }
  910. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  911. return err
  912. }
  913. return nil
  914. }
  915. func migrateFrom1(db *sql.DB) error {
  916. log.Info("Migrating user database schema: from 1 to 2")
  917. tx, err := db.Begin()
  918. if err != nil {
  919. return err
  920. }
  921. defer tx.Rollback()
  922. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  923. return err
  924. }
  925. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  926. return err
  927. }
  928. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  929. return err
  930. }
  931. rows, err := tx.Query(migrate1To2SelectAllUsersIDsNoTx)
  932. if err != nil {
  933. return err
  934. }
  935. defer rows.Close()
  936. syncTopics := make(map[int]string)
  937. for rows.Next() {
  938. var userID int
  939. if err := rows.Scan(&userID); err != nil {
  940. return err
  941. }
  942. syncTopics[userID] = util.RandomString(syncTopicLength)
  943. }
  944. if err := rows.Close(); err != nil {
  945. return err
  946. }
  947. for userID, syncTopic := range syncTopics {
  948. if _, err := tx.Exec(migrate1To2UpdateSyncTopicNoTx, syncTopic, userID); err != nil {
  949. return err
  950. }
  951. }
  952. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  953. return err
  954. }
  955. if err := tx.Commit(); err != nil {
  956. return err
  957. }
  958. return nil // Update this when a new version is added
  959. }