manager.go 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357
  1. package user
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. _ "github.com/mattn/go-sqlite3" // SQLite driver
  8. "github.com/stripe/stripe-go/v74"
  9. "golang.org/x/crypto/bcrypt"
  10. "heckel.io/ntfy/log"
  11. "heckel.io/ntfy/util"
  12. "net/netip"
  13. "strings"
  14. "sync"
  15. "time"
  16. )
  17. const (
  18. tierIDPrefix = "ti_"
  19. tierIDLength = 8
  20. syncTopicPrefix = "st_"
  21. syncTopicLength = 16
  22. userIDPrefix = "u_"
  23. userIDLength = 12
  24. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
  25. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  26. tokenPrefix = "tk_"
  27. tokenLength = 32
  28. tokenMaxCount = 20 // Only keep this many tokens in the table per user
  29. tagManager = "user_manager"
  30. )
  31. // Default constants that may be overridden by configs
  32. const (
  33. DefaultUserStatsQueueWriterInterval = 33 * time.Second
  34. DefaultUserPasswordBcryptCost = 10
  35. )
  36. var (
  37. errNoTokenProvided = errors.New("no token provided")
  38. errTopicOwnedByOthers = errors.New("topic owned by others")
  39. errNoRows = errors.New("no rows found")
  40. )
  41. // Manager-related queries
  42. const (
  43. createTablesQueriesNoTx = `
  44. CREATE TABLE IF NOT EXISTS tier (
  45. id TEXT PRIMARY KEY,
  46. code TEXT NOT NULL,
  47. name TEXT NOT NULL,
  48. messages_limit INT NOT NULL,
  49. messages_expiry_duration INT NOT NULL,
  50. emails_limit INT NOT NULL,
  51. reservations_limit INT NOT NULL,
  52. attachment_file_size_limit INT NOT NULL,
  53. attachment_total_size_limit INT NOT NULL,
  54. attachment_expiry_duration INT NOT NULL,
  55. attachment_bandwidth_limit INT NOT NULL,
  56. stripe_price_id TEXT
  57. );
  58. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  59. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  60. CREATE TABLE IF NOT EXISTS user (
  61. id TEXT PRIMARY KEY,
  62. tier_id TEXT,
  63. user TEXT NOT NULL,
  64. pass TEXT NOT NULL,
  65. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  66. prefs JSON NOT NULL DEFAULT '{}',
  67. sync_topic TEXT NOT NULL,
  68. stats_messages INT NOT NULL DEFAULT (0),
  69. stats_emails INT NOT NULL DEFAULT (0),
  70. stripe_customer_id TEXT,
  71. stripe_subscription_id TEXT,
  72. stripe_subscription_status TEXT,
  73. stripe_subscription_paid_until INT,
  74. stripe_subscription_cancel_at INT,
  75. created INT NOT NULL,
  76. deleted INT,
  77. FOREIGN KEY (tier_id) REFERENCES tier (id)
  78. );
  79. CREATE UNIQUE INDEX idx_user ON user (user);
  80. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  81. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  82. CREATE TABLE IF NOT EXISTS user_access (
  83. user_id TEXT NOT NULL,
  84. topic TEXT NOT NULL,
  85. read INT NOT NULL,
  86. write INT NOT NULL,
  87. owner_user_id INT,
  88. PRIMARY KEY (user_id, topic),
  89. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  90. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  91. );
  92. CREATE TABLE IF NOT EXISTS user_token (
  93. user_id TEXT NOT NULL,
  94. token TEXT NOT NULL,
  95. label TEXT NOT NULL,
  96. last_access INT NOT NULL,
  97. last_origin TEXT NOT NULL,
  98. expires INT NOT NULL,
  99. PRIMARY KEY (user_id, token),
  100. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  101. );
  102. CREATE TABLE IF NOT EXISTS schemaVersion (
  103. id INT PRIMARY KEY,
  104. version INT NOT NULL
  105. );
  106. INSERT INTO user (id, user, pass, role, sync_topic, created)
  107. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
  108. ON CONFLICT (id) DO NOTHING;
  109. `
  110. createTablesQueries = `BEGIN; ` + createTablesQueriesNoTx + ` COMMIT;`
  111. builtinStartupQueries = `
  112. PRAGMA foreign_keys = ON;
  113. `
  114. selectUserByIDQuery = `
  115. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  116. FROM user u
  117. LEFT JOIN tier t on t.id = u.tier_id
  118. WHERE u.id = ?
  119. `
  120. selectUserByNameQuery = `
  121. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  122. FROM user u
  123. LEFT JOIN tier t on t.id = u.tier_id
  124. WHERE user = ?
  125. `
  126. selectUserByTokenQuery = `
  127. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  128. FROM user u
  129. JOIN user_token tk on u.id = tk.user_id
  130. LEFT JOIN tier t on t.id = u.tier_id
  131. WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
  132. `
  133. selectUserByStripeCustomerIDQuery = `
  134. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
  135. FROM user u
  136. LEFT JOIN tier t on t.id = u.tier_id
  137. WHERE u.stripe_customer_id = ?
  138. `
  139. selectTopicPermsQuery = `
  140. SELECT read, write
  141. FROM user_access a
  142. JOIN user u ON u.id = a.user_id
  143. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  144. ORDER BY u.user DESC
  145. `
  146. insertUserQuery = `
  147. INSERT INTO user (id, user, pass, role, sync_topic, created)
  148. VALUES (?, ?, ?, ?, ?, ?)
  149. `
  150. selectUsernamesQuery = `
  151. SELECT user
  152. FROM user
  153. ORDER BY
  154. CASE role
  155. WHEN 'admin' THEN 1
  156. WHEN 'anonymous' THEN 3
  157. ELSE 2
  158. END, user
  159. `
  160. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  161. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  162. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE user = ?`
  163. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ? WHERE id = ?`
  164. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0`
  165. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  166. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  167. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  168. upsertUserAccessQuery = `
  169. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  170. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  171. ON CONFLICT (user_id, topic)
  172. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  173. `
  174. selectUserAccessQuery = `
  175. SELECT topic, read, write
  176. FROM user_access
  177. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  178. ORDER BY write DESC, read DESC, topic
  179. `
  180. selectUserReservationsQuery = `
  181. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  182. FROM user_access a_user
  183. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  184. WHERE a_user.user_id = a_user.owner_user_id
  185. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  186. ORDER BY a_user.topic
  187. `
  188. selectUserReservationsCountQuery = `
  189. SELECT COUNT(*)
  190. FROM user_access
  191. WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  192. `
  193. selectUserHasReservationQuery = `
  194. SELECT COUNT(*)
  195. FROM user_access
  196. WHERE user_id = owner_user_id
  197. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  198. AND topic = ?
  199. `
  200. selectOtherAccessCountQuery = `
  201. SELECT COUNT(*)
  202. FROM user_access
  203. WHERE (topic = ? OR ? LIKE topic)
  204. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  205. `
  206. deleteAllAccessQuery = `DELETE FROM user_access`
  207. deleteUserAccessQuery = `
  208. DELETE FROM user_access
  209. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  210. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  211. `
  212. deleteTopicAccessQuery = `
  213. DELETE FROM user_access
  214. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  215. AND topic = ?
  216. `
  217. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  218. selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
  219. selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
  220. insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
  221. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  222. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  223. updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
  224. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  225. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  226. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  227. deleteExcessTokensQuery = `
  228. DELETE FROM user_token
  229. WHERE (user_id, token) NOT IN (
  230. SELECT user_id, token
  231. FROM user_token
  232. WHERE user_id = ?
  233. ORDER BY expires DESC
  234. LIMIT ?
  235. )
  236. `
  237. insertTierQuery = `
  238. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id)
  239. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  240. `
  241. updateTierQuery = `
  242. UPDATE tier
  243. SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_price_id = ?
  244. WHERE code = ?
  245. `
  246. selectTiersQuery = `
  247. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  248. FROM tier
  249. `
  250. selectTierByCodeQuery = `
  251. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  252. FROM tier
  253. WHERE code = ?
  254. `
  255. selectTierByPriceIDQuery = `
  256. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_price_id
  257. FROM tier
  258. WHERE stripe_price_id = ?
  259. `
  260. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  261. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  262. deleteTierQuery = `DELETE FROM tier WHERE code = ?`
  263. updateBillingQuery = `
  264. UPDATE user
  265. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  266. WHERE user = ?
  267. `
  268. )
  269. // Schema management queries
  270. const (
  271. currentSchemaVersion = 2
  272. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  273. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  274. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  275. // 1 -> 2 (complex migration!)
  276. migrate1To2RenameUserTableQueryNoTx = `
  277. ALTER TABLE user RENAME TO user_old;
  278. `
  279. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  280. migrate1To2InsertUserNoTx = `
  281. INSERT INTO user (id, user, pass, role, sync_topic, created)
  282. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  283. `
  284. migrate1To2InsertFromOldTablesAndDropNoTx = `
  285. INSERT INTO user_access (user_id, topic, read, write)
  286. SELECT u.id, a.topic, a.read, a.write
  287. FROM user u
  288. JOIN access a ON u.user = a.user;
  289. DROP TABLE access;
  290. DROP TABLE user_old;
  291. `
  292. )
  293. // Manager is an implementation of Manager. It stores users and access control list
  294. // in a SQLite database.
  295. type Manager struct {
  296. db *sql.DB
  297. defaultAccess Permission // Default permission if no ACL matches
  298. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  299. tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
  300. bcryptCost int // Makes testing easier
  301. mu sync.Mutex
  302. }
  303. var _ Auther = (*Manager)(nil)
  304. // NewManager creates a new Manager instance
  305. func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
  306. db, err := sql.Open("sqlite3", filename)
  307. if err != nil {
  308. return nil, err
  309. }
  310. if err := setupDB(db); err != nil {
  311. return nil, err
  312. }
  313. if err := runStartupQueries(db, startupQueries); err != nil {
  314. return nil, err
  315. }
  316. manager := &Manager{
  317. db: db,
  318. defaultAccess: defaultAccess,
  319. statsQueue: make(map[string]*Stats),
  320. tokenQueue: make(map[string]*TokenUpdate),
  321. bcryptCost: bcryptCost,
  322. }
  323. go manager.asyncQueueWriter(queueWriterInterval)
  324. return manager, nil
  325. }
  326. // Authenticate checks username and password and returns a User if correct, and the user has not been
  327. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  328. // the password is correct or incorrect.
  329. func (a *Manager) Authenticate(username, password string) (*User, error) {
  330. if username == Everyone {
  331. return nil, ErrUnauthenticated
  332. }
  333. user, err := a.User(username)
  334. if err != nil {
  335. log.Tag(tagManager).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
  336. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  337. return nil, ErrUnauthenticated
  338. } else if user.Deleted {
  339. log.Tag(tagManager).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
  340. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  341. return nil, ErrUnauthenticated
  342. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  343. log.Tag(tagManager).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
  344. return nil, ErrUnauthenticated
  345. }
  346. return user, nil
  347. }
  348. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  349. // The method sets the User.Token value to the token that was used for authentication.
  350. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  351. if len(token) != tokenLength {
  352. return nil, ErrUnauthenticated
  353. }
  354. user, err := a.userByToken(token)
  355. if err != nil {
  356. log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed")
  357. return nil, ErrUnauthenticated
  358. }
  359. user.Token = token
  360. return user, nil
  361. }
  362. // CreateToken generates a random token for the given user and returns it. The token expires
  363. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  364. // given user, if there are too many of them.
  365. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
  366. token := util.RandomStringPrefix(tokenPrefix, tokenLength)
  367. tx, err := a.db.Begin()
  368. if err != nil {
  369. return nil, err
  370. }
  371. defer tx.Rollback()
  372. access := time.Now()
  373. if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
  374. return nil, err
  375. }
  376. rows, err := tx.Query(selectTokenCountQuery, userID)
  377. if err != nil {
  378. return nil, err
  379. }
  380. defer rows.Close()
  381. if !rows.Next() {
  382. return nil, errNoRows
  383. }
  384. var tokenCount int
  385. if err := rows.Scan(&tokenCount); err != nil {
  386. return nil, err
  387. }
  388. if tokenCount >= tokenMaxCount {
  389. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  390. // on two indices, whereas the query below is a full table scan.
  391. if _, err := tx.Exec(deleteExcessTokensQuery, userID, tokenMaxCount); err != nil {
  392. return nil, err
  393. }
  394. }
  395. if err := tx.Commit(); err != nil {
  396. return nil, err
  397. }
  398. return &Token{
  399. Value: token,
  400. Label: label,
  401. LastAccess: access,
  402. LastOrigin: origin,
  403. Expires: expires,
  404. }, nil
  405. }
  406. // Tokens returns all existing tokens for the user with the given user ID
  407. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  408. rows, err := a.db.Query(selectTokensQuery, userID)
  409. if err != nil {
  410. return nil, err
  411. }
  412. defer rows.Close()
  413. tokens := make([]*Token, 0)
  414. for {
  415. token, err := a.readToken(rows)
  416. if err == ErrTokenNotFound {
  417. break
  418. } else if err != nil {
  419. return nil, err
  420. }
  421. tokens = append(tokens, token)
  422. }
  423. return tokens, nil
  424. }
  425. // Token returns a specific token for a user
  426. func (a *Manager) Token(userID, token string) (*Token, error) {
  427. rows, err := a.db.Query(selectTokenQuery, userID, token)
  428. if err != nil {
  429. return nil, err
  430. }
  431. defer rows.Close()
  432. return a.readToken(rows)
  433. }
  434. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  435. var token, label, lastOrigin string
  436. var lastAccess, expires int64
  437. if !rows.Next() {
  438. return nil, ErrTokenNotFound
  439. }
  440. if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
  441. return nil, err
  442. } else if err := rows.Err(); err != nil {
  443. return nil, err
  444. }
  445. lastOriginIP, err := netip.ParseAddr(lastOrigin)
  446. if err != nil {
  447. lastOriginIP = netip.IPv4Unspecified()
  448. }
  449. return &Token{
  450. Value: token,
  451. Label: label,
  452. LastAccess: time.Unix(lastAccess, 0),
  453. LastOrigin: lastOriginIP,
  454. Expires: time.Unix(expires, 0),
  455. }, nil
  456. }
  457. // ChangeToken updates a token's label and/or expiry date
  458. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  459. if token == "" {
  460. return nil, errNoTokenProvided
  461. }
  462. tx, err := a.db.Begin()
  463. if err != nil {
  464. return nil, err
  465. }
  466. defer tx.Rollback()
  467. if label != nil {
  468. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  469. return nil, err
  470. }
  471. }
  472. if expires != nil {
  473. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  474. return nil, err
  475. }
  476. }
  477. if err := tx.Commit(); err != nil {
  478. return nil, err
  479. }
  480. return a.Token(userID, token)
  481. }
  482. // RemoveToken deletes the token defined in User.Token
  483. func (a *Manager) RemoveToken(userID, token string) error {
  484. if token == "" {
  485. return errNoTokenProvided
  486. }
  487. if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil {
  488. return err
  489. }
  490. return nil
  491. }
  492. // RemoveExpiredTokens deletes all expired tokens from the database
  493. func (a *Manager) RemoveExpiredTokens() error {
  494. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  495. return err
  496. }
  497. return nil
  498. }
  499. // RemoveDeletedUsers deletes all users that have been marked deleted for
  500. func (a *Manager) RemoveDeletedUsers() error {
  501. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  502. return err
  503. }
  504. return nil
  505. }
  506. // ChangeSettings persists the user settings
  507. func (a *Manager) ChangeSettings(user *User) error {
  508. prefs, err := json.Marshal(user.Prefs)
  509. if err != nil {
  510. return err
  511. }
  512. if _, err := a.db.Exec(updateUserPrefsQuery, string(prefs), user.Name); err != nil {
  513. return err
  514. }
  515. return nil
  516. }
  517. // ResetStats resets all user stats in the user database. This touches all users.
  518. func (a *Manager) ResetStats() error {
  519. a.mu.Lock() // Includes database query to avoid races!
  520. defer a.mu.Unlock()
  521. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  522. return err
  523. }
  524. a.statsQueue = make(map[string]*Stats)
  525. return nil
  526. }
  527. // EnqueueStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  528. // batches at a regular interval
  529. func (a *Manager) EnqueueStats(userID string, stats *Stats) {
  530. a.mu.Lock()
  531. defer a.mu.Unlock()
  532. a.statsQueue[userID] = stats
  533. }
  534. // EnqueueTokenUpdate adds the token update to a queue which writes out token access times
  535. // in batches at a regular interval
  536. func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
  537. a.mu.Lock()
  538. defer a.mu.Unlock()
  539. a.tokenQueue[tokenID] = update
  540. }
  541. func (a *Manager) asyncQueueWriter(interval time.Duration) {
  542. ticker := time.NewTicker(interval)
  543. for range ticker.C {
  544. if err := a.writeUserStatsQueue(); err != nil {
  545. log.Tag(tagManager).Err(err).Warn("Writing user stats queue failed")
  546. }
  547. if err := a.writeTokenUpdateQueue(); err != nil {
  548. log.Tag(tagManager).Err(err).Warn("Writing token update queue failed")
  549. }
  550. }
  551. }
  552. func (a *Manager) writeUserStatsQueue() error {
  553. a.mu.Lock()
  554. if len(a.statsQueue) == 0 {
  555. a.mu.Unlock()
  556. log.Tag(tagManager).Trace("No user stats updates to commit")
  557. return nil
  558. }
  559. statsQueue := a.statsQueue
  560. a.statsQueue = make(map[string]*Stats)
  561. a.mu.Unlock()
  562. tx, err := a.db.Begin()
  563. if err != nil {
  564. return err
  565. }
  566. defer tx.Rollback()
  567. log.Tag(tagManager).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
  568. for userID, update := range statsQueue {
  569. log.
  570. Tag(tagManager).
  571. Fields(log.Context{
  572. "user_id": userID,
  573. "messages_count": update.Messages,
  574. "emails_count": update.Emails,
  575. }).
  576. Trace("Updating stats for user %s", userID)
  577. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, userID); err != nil {
  578. return err
  579. }
  580. }
  581. return tx.Commit()
  582. }
  583. func (a *Manager) writeTokenUpdateQueue() error {
  584. a.mu.Lock()
  585. if len(a.tokenQueue) == 0 {
  586. a.mu.Unlock()
  587. log.Tag(tagManager).Trace("No token updates to commit")
  588. return nil
  589. }
  590. tokenQueue := a.tokenQueue
  591. a.tokenQueue = make(map[string]*TokenUpdate)
  592. a.mu.Unlock()
  593. tx, err := a.db.Begin()
  594. if err != nil {
  595. return err
  596. }
  597. defer tx.Rollback()
  598. log.Tag(tagManager).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
  599. for tokenID, update := range tokenQueue {
  600. log.Tag(tagManager).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
  601. if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
  602. return err
  603. }
  604. }
  605. return tx.Commit()
  606. }
  607. // Authorize returns nil if the given user has access to the given topic using the desired
  608. // permission. The user param may be nil to signal an anonymous user.
  609. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  610. if user != nil && user.Role == RoleAdmin {
  611. return nil // Admin can do everything
  612. }
  613. username := Everyone
  614. if user != nil {
  615. username = user.Name
  616. }
  617. // Select the read/write permissions for this user/topic combo. The query may return two
  618. // rows (one for everyone, and one for the user), but prioritizes the user.
  619. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  620. if err != nil {
  621. return err
  622. }
  623. defer rows.Close()
  624. if !rows.Next() {
  625. return a.resolvePerms(a.defaultAccess, perm)
  626. }
  627. var read, write bool
  628. if err := rows.Scan(&read, &write); err != nil {
  629. return err
  630. } else if err := rows.Err(); err != nil {
  631. return err
  632. }
  633. return a.resolvePerms(NewPermission(read, write), perm)
  634. }
  635. func (a *Manager) resolvePerms(base, perm Permission) error {
  636. if perm == PermissionRead && base.IsRead() {
  637. return nil
  638. } else if perm == PermissionWrite && base.IsWrite() {
  639. return nil
  640. }
  641. return ErrUnauthorized
  642. }
  643. // AddUser adds a user with the given username, password and role
  644. func (a *Manager) AddUser(username, password string, role Role) error {
  645. if !AllowedUsername(username) || !AllowedRole(role) {
  646. return ErrInvalidArgument
  647. }
  648. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  649. if err != nil {
  650. return err
  651. }
  652. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  653. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  654. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
  655. return err
  656. }
  657. return nil
  658. }
  659. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  660. // if the user did not exist in the first place.
  661. func (a *Manager) RemoveUser(username string) error {
  662. if !AllowedUsername(username) {
  663. return ErrInvalidArgument
  664. }
  665. // Rows in user_access, user_token, etc. are deleted via foreign keys
  666. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  667. return err
  668. }
  669. return nil
  670. }
  671. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  672. // successful auth via Authenticate. A background process will delete the user at a later date.
  673. func (a *Manager) MarkUserRemoved(user *User) error {
  674. if !AllowedUsername(user.Name) {
  675. return ErrInvalidArgument
  676. }
  677. tx, err := a.db.Begin()
  678. if err != nil {
  679. return err
  680. }
  681. defer tx.Rollback()
  682. if _, err := a.db.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  683. return err
  684. }
  685. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  686. return err
  687. }
  688. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  689. return err
  690. }
  691. return tx.Commit()
  692. }
  693. // Users returns a list of users. It always also returns the Everyone user ("*").
  694. func (a *Manager) Users() ([]*User, error) {
  695. rows, err := a.db.Query(selectUsernamesQuery)
  696. if err != nil {
  697. return nil, err
  698. }
  699. defer rows.Close()
  700. usernames := make([]string, 0)
  701. for rows.Next() {
  702. var username string
  703. if err := rows.Scan(&username); err != nil {
  704. return nil, err
  705. } else if err := rows.Err(); err != nil {
  706. return nil, err
  707. }
  708. usernames = append(usernames, username)
  709. }
  710. rows.Close()
  711. users := make([]*User, 0)
  712. for _, username := range usernames {
  713. user, err := a.User(username)
  714. if err != nil {
  715. return nil, err
  716. }
  717. users = append(users, user)
  718. }
  719. return users, nil
  720. }
  721. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  722. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  723. func (a *Manager) User(username string) (*User, error) {
  724. rows, err := a.db.Query(selectUserByNameQuery, username)
  725. if err != nil {
  726. return nil, err
  727. }
  728. return a.readUser(rows)
  729. }
  730. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  731. func (a *Manager) UserByID(id string) (*User, error) {
  732. rows, err := a.db.Query(selectUserByIDQuery, id)
  733. if err != nil {
  734. return nil, err
  735. }
  736. return a.readUser(rows)
  737. }
  738. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  739. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  740. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  741. if err != nil {
  742. return nil, err
  743. }
  744. return a.readUser(rows)
  745. }
  746. func (a *Manager) userByToken(token string) (*User, error) {
  747. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  748. if err != nil {
  749. return nil, err
  750. }
  751. return a.readUser(rows)
  752. }
  753. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  754. defer rows.Close()
  755. var id, username, hash, role, prefs, syncTopic string
  756. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
  757. var messages, emails int64
  758. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  759. if !rows.Next() {
  760. return nil, ErrUserNotFound
  761. }
  762. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  763. return nil, err
  764. } else if err := rows.Err(); err != nil {
  765. return nil, err
  766. }
  767. user := &User{
  768. ID: id,
  769. Name: username,
  770. Hash: hash,
  771. Role: Role(role),
  772. Prefs: &Prefs{},
  773. SyncTopic: syncTopic,
  774. Stats: &Stats{
  775. Messages: messages,
  776. Emails: emails,
  777. },
  778. Billing: &Billing{
  779. StripeCustomerID: stripeCustomerID.String, // May be empty
  780. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  781. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  782. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  783. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  784. },
  785. Deleted: deleted.Valid,
  786. }
  787. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  788. return nil, err
  789. }
  790. if tierCode.Valid {
  791. // See readTier() when this is changed!
  792. user.Tier = &Tier{
  793. ID: tierID.String,
  794. Code: tierCode.String,
  795. Name: tierName.String,
  796. MessageLimit: messagesLimit.Int64,
  797. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  798. EmailLimit: emailsLimit.Int64,
  799. ReservationLimit: reservationsLimit.Int64,
  800. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  801. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  802. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  803. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  804. StripePriceID: stripePriceID.String, // May be empty
  805. }
  806. }
  807. return user, nil
  808. }
  809. // Grants returns all user-specific access control entries
  810. func (a *Manager) Grants(username string) ([]Grant, error) {
  811. rows, err := a.db.Query(selectUserAccessQuery, username)
  812. if err != nil {
  813. return nil, err
  814. }
  815. defer rows.Close()
  816. grants := make([]Grant, 0)
  817. for rows.Next() {
  818. var topic string
  819. var read, write bool
  820. if err := rows.Scan(&topic, &read, &write); err != nil {
  821. return nil, err
  822. } else if err := rows.Err(); err != nil {
  823. return nil, err
  824. }
  825. grants = append(grants, Grant{
  826. TopicPattern: fromSQLWildcard(topic),
  827. Allow: NewPermission(read, write),
  828. })
  829. }
  830. return grants, nil
  831. }
  832. // Reservations returns all user-owned topics, and the associated everyone-access
  833. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  834. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  835. if err != nil {
  836. return nil, err
  837. }
  838. defer rows.Close()
  839. reservations := make([]Reservation, 0)
  840. for rows.Next() {
  841. var topic string
  842. var ownerRead, ownerWrite bool
  843. var everyoneRead, everyoneWrite sql.NullBool
  844. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  845. return nil, err
  846. } else if err := rows.Err(); err != nil {
  847. return nil, err
  848. }
  849. reservations = append(reservations, Reservation{
  850. Topic: topic,
  851. Owner: NewPermission(ownerRead, ownerWrite),
  852. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  853. })
  854. }
  855. return reservations, nil
  856. }
  857. // HasReservation returns true if the given topic access is owned by the user
  858. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  859. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  860. if err != nil {
  861. return false, err
  862. }
  863. defer rows.Close()
  864. if !rows.Next() {
  865. return false, errNoRows
  866. }
  867. var count int64
  868. if err := rows.Scan(&count); err != nil {
  869. return false, err
  870. }
  871. return count > 0, nil
  872. }
  873. // ReservationsCount returns the number of reservations owned by this user
  874. func (a *Manager) ReservationsCount(username string) (int64, error) {
  875. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  876. if err != nil {
  877. return 0, err
  878. }
  879. defer rows.Close()
  880. if !rows.Next() {
  881. return 0, errNoRows
  882. }
  883. var count int64
  884. if err := rows.Scan(&count); err != nil {
  885. return 0, err
  886. }
  887. return count, nil
  888. }
  889. // ChangePassword changes a user's password
  890. func (a *Manager) ChangePassword(username, password string) error {
  891. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  892. if err != nil {
  893. return err
  894. }
  895. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  896. return err
  897. }
  898. return nil
  899. }
  900. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  901. // all existing access control entries (Grant) are removed, since they are no longer needed.
  902. func (a *Manager) ChangeRole(username string, role Role) error {
  903. if !AllowedUsername(username) || !AllowedRole(role) {
  904. return ErrInvalidArgument
  905. }
  906. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  907. return err
  908. }
  909. if role == RoleAdmin {
  910. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  911. return err
  912. }
  913. }
  914. return nil
  915. }
  916. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  917. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  918. func (a *Manager) ChangeTier(username, tier string) error {
  919. if !AllowedUsername(username) {
  920. return ErrInvalidArgument
  921. }
  922. t, err := a.Tier(tier)
  923. if err != nil {
  924. return err
  925. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  926. return err
  927. }
  928. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  929. return err
  930. }
  931. return nil
  932. }
  933. // ResetTier removes the tier from the given user
  934. func (a *Manager) ResetTier(username string) error {
  935. if !AllowedUsername(username) && username != Everyone && username != "" {
  936. return ErrInvalidArgument
  937. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  938. return err
  939. }
  940. _, err := a.db.Exec(deleteUserTierQuery, username)
  941. return err
  942. }
  943. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  944. u, err := a.User(username)
  945. if err != nil {
  946. return err
  947. }
  948. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  949. reservations, err := a.Reservations(username)
  950. if err != nil {
  951. return err
  952. } else if int64(len(reservations)) > reservationsLimit {
  953. return ErrTooManyReservations
  954. }
  955. }
  956. return nil
  957. }
  958. // CheckAllowAccess tests if a user may create an access control entry for the given topic.
  959. // If there are any ACL entries that are not owned by the user, an error is returned.
  960. // FIXME is this the same as HasReservation?
  961. func (a *Manager) CheckAllowAccess(username string, topic string) error {
  962. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  963. return ErrInvalidArgument
  964. }
  965. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  966. if err != nil {
  967. return err
  968. }
  969. defer rows.Close()
  970. if !rows.Next() {
  971. return errNoRows
  972. }
  973. var otherCount int
  974. if err := rows.Scan(&otherCount); err != nil {
  975. return err
  976. }
  977. if otherCount > 0 {
  978. return errTopicOwnedByOthers
  979. }
  980. return nil
  981. }
  982. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  983. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  984. // owner may either be a user (username), or the system (empty).
  985. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  986. if !AllowedUsername(username) && username != Everyone {
  987. return ErrInvalidArgument
  988. } else if !AllowedTopicPattern(topicPattern) {
  989. return ErrInvalidArgument
  990. }
  991. owner := ""
  992. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  993. return err
  994. }
  995. return nil
  996. }
  997. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  998. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  999. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  1000. if !AllowedUsername(username) && username != Everyone && username != "" {
  1001. return ErrInvalidArgument
  1002. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  1003. return ErrInvalidArgument
  1004. }
  1005. if username == "" && topicPattern == "" {
  1006. _, err := a.db.Exec(deleteAllAccessQuery, username)
  1007. return err
  1008. } else if topicPattern == "" {
  1009. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  1010. return err
  1011. }
  1012. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  1013. return err
  1014. }
  1015. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  1016. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  1017. // can modify or delete them.
  1018. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  1019. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  1020. return ErrInvalidArgument
  1021. }
  1022. tx, err := a.db.Begin()
  1023. if err != nil {
  1024. return err
  1025. }
  1026. defer tx.Rollback()
  1027. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  1028. return err
  1029. }
  1030. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  1031. return err
  1032. }
  1033. return tx.Commit()
  1034. }
  1035. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  1036. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  1037. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  1038. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  1039. return ErrInvalidArgument
  1040. }
  1041. for _, topic := range topics {
  1042. if !AllowedTopic(topic) {
  1043. return ErrInvalidArgument
  1044. }
  1045. }
  1046. tx, err := a.db.Begin()
  1047. if err != nil {
  1048. return err
  1049. }
  1050. defer tx.Rollback()
  1051. for _, topic := range topics {
  1052. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  1053. return err
  1054. }
  1055. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  1056. return err
  1057. }
  1058. }
  1059. return tx.Commit()
  1060. }
  1061. // DefaultAccess returns the default read/write access if no access control entry matches
  1062. func (a *Manager) DefaultAccess() Permission {
  1063. return a.defaultAccess
  1064. }
  1065. // AddTier creates a new tier in the database
  1066. func (a *Manager) AddTier(tier *Tier) error {
  1067. if tier.ID == "" {
  1068. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1069. }
  1070. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID)); err != nil {
  1071. return err
  1072. }
  1073. return nil
  1074. }
  1075. // UpdateTier updates a tier's properties in the database
  1076. func (a *Manager) UpdateTier(tier *Tier) error {
  1077. if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripePriceID), tier.Code); err != nil {
  1078. return err
  1079. }
  1080. return nil
  1081. }
  1082. // RemoveTier deletes the tier with the given code
  1083. func (a *Manager) RemoveTier(code string) error {
  1084. if !AllowedTier(code) {
  1085. return ErrInvalidArgument
  1086. }
  1087. // This fails if any user has this tier
  1088. if _, err := a.db.Exec(deleteTierQuery, code); err != nil {
  1089. return err
  1090. }
  1091. return nil
  1092. }
  1093. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1094. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1095. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1096. return err
  1097. }
  1098. return nil
  1099. }
  1100. // Tiers returns a list of all Tier structs
  1101. func (a *Manager) Tiers() ([]*Tier, error) {
  1102. rows, err := a.db.Query(selectTiersQuery)
  1103. if err != nil {
  1104. return nil, err
  1105. }
  1106. defer rows.Close()
  1107. tiers := make([]*Tier, 0)
  1108. for {
  1109. tier, err := a.readTier(rows)
  1110. if err == ErrTierNotFound {
  1111. break
  1112. } else if err != nil {
  1113. return nil, err
  1114. }
  1115. tiers = append(tiers, tier)
  1116. }
  1117. return tiers, nil
  1118. }
  1119. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1120. func (a *Manager) Tier(code string) (*Tier, error) {
  1121. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1122. if err != nil {
  1123. return nil, err
  1124. }
  1125. defer rows.Close()
  1126. return a.readTier(rows)
  1127. }
  1128. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1129. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1130. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
  1131. if err != nil {
  1132. return nil, err
  1133. }
  1134. defer rows.Close()
  1135. return a.readTier(rows)
  1136. }
  1137. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1138. var id, code, name string
  1139. var stripePriceID sql.NullString
  1140. var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1141. if !rows.Next() {
  1142. return nil, ErrTierNotFound
  1143. }
  1144. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
  1145. return nil, err
  1146. } else if err := rows.Err(); err != nil {
  1147. return nil, err
  1148. }
  1149. // When changed, note readUser() as well
  1150. return &Tier{
  1151. ID: id,
  1152. Code: code,
  1153. Name: name,
  1154. MessageLimit: messagesLimit.Int64,
  1155. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1156. EmailLimit: emailsLimit.Int64,
  1157. ReservationLimit: reservationsLimit.Int64,
  1158. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1159. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1160. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1161. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1162. StripePriceID: stripePriceID.String, // May be empty
  1163. }, nil
  1164. }
  1165. // Close closes the underlying database
  1166. func (a *Manager) Close() error {
  1167. return a.db.Close()
  1168. }
  1169. func toSQLWildcard(s string) string {
  1170. return strings.ReplaceAll(s, "*", "%")
  1171. }
  1172. func fromSQLWildcard(s string) string {
  1173. return strings.ReplaceAll(s, "%", "*")
  1174. }
  1175. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1176. if _, err := db.Exec(startupQueries); err != nil {
  1177. return err
  1178. }
  1179. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1180. return err
  1181. }
  1182. return nil
  1183. }
  1184. func setupDB(db *sql.DB) error {
  1185. // If 'schemaVersion' table does not exist, this must be a new database
  1186. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1187. if err != nil {
  1188. return setupNewDB(db)
  1189. }
  1190. defer rowsSV.Close()
  1191. // If 'schemaVersion' table exists, read version and potentially upgrade
  1192. schemaVersion := 0
  1193. if !rowsSV.Next() {
  1194. return errors.New("cannot determine schema version: database file may be corrupt")
  1195. }
  1196. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1197. return err
  1198. }
  1199. rowsSV.Close()
  1200. // Do migrations
  1201. if schemaVersion == currentSchemaVersion {
  1202. return nil
  1203. } else if schemaVersion == 1 {
  1204. return migrateFrom1(db)
  1205. }
  1206. return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
  1207. }
  1208. func setupNewDB(db *sql.DB) error {
  1209. if _, err := db.Exec(createTablesQueries); err != nil {
  1210. return err
  1211. }
  1212. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1213. return err
  1214. }
  1215. return nil
  1216. }
  1217. func migrateFrom1(db *sql.DB) error {
  1218. log.Tag(tagManager).Info("Migrating user database schema: from 1 to 2")
  1219. tx, err := db.Begin()
  1220. if err != nil {
  1221. return err
  1222. }
  1223. defer tx.Rollback()
  1224. // Rename user -> user_old, and create new tables
  1225. if _, err := tx.Exec(migrate1To2RenameUserTableQueryNoTx); err != nil {
  1226. return err
  1227. }
  1228. if _, err := tx.Exec(createTablesQueriesNoTx); err != nil {
  1229. return err
  1230. }
  1231. // Insert users from user_old into new user table, with ID and sync_topic
  1232. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1233. if err != nil {
  1234. return err
  1235. }
  1236. defer rows.Close()
  1237. usernames := make([]string, 0)
  1238. for rows.Next() {
  1239. var username string
  1240. if err := rows.Scan(&username); err != nil {
  1241. return err
  1242. }
  1243. usernames = append(usernames, username)
  1244. }
  1245. if err := rows.Close(); err != nil {
  1246. return err
  1247. }
  1248. for _, username := range usernames {
  1249. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1250. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1251. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1252. return err
  1253. }
  1254. }
  1255. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1256. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1257. return err
  1258. }
  1259. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1260. return err
  1261. }
  1262. if err := tx.Commit(); err != nil {
  1263. return err
  1264. }
  1265. return nil // Update this when a new version is added
  1266. }
  1267. func nullString(s string) sql.NullString {
  1268. if s == "" {
  1269. return sql.NullString{}
  1270. }
  1271. return sql.NullString{String: s, Valid: true}
  1272. }
  1273. func nullInt64(v int64) sql.NullInt64 {
  1274. if v == 0 {
  1275. return sql.NullInt64{}
  1276. }
  1277. return sql.NullInt64{Int64: v, Valid: true}
  1278. }