manager.go 73 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143
  1. // Package user deals with authentication and authorization against topics
  2. package user
  3. import (
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/mattn/go-sqlite3"
  9. "github.com/stripe/stripe-go/v74"
  10. "golang.org/x/crypto/bcrypt"
  11. "heckel.io/ntfy/v2/log"
  12. "heckel.io/ntfy/v2/util"
  13. "net/netip"
  14. "path/filepath"
  15. "slices"
  16. "strings"
  17. "sync"
  18. "time"
  19. )
  20. const (
  21. tierIDPrefix = "ti_"
  22. tierIDLength = 8
  23. syncTopicPrefix = "st_"
  24. syncTopicLength = 16
  25. userIDPrefix = "u_"
  26. userIDLength = 12
  27. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
  28. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  29. tokenPrefix = "tk_"
  30. tokenLength = 32
  31. tokenMaxCount = 60 // Only keep this many tokens in the table per user
  32. tag = "user_manager"
  33. )
  34. // Default constants that may be overridden by configs
  35. const (
  36. DefaultUserStatsQueueWriterInterval = 33 * time.Second
  37. DefaultUserPasswordBcryptCost = 10
  38. )
  39. var (
  40. errNoTokenProvided = errors.New("no token provided")
  41. errTopicOwnedByOthers = errors.New("topic owned by others")
  42. errNoRows = errors.New("no rows found")
  43. )
  44. // Manager-related queries
  45. const (
  46. createTablesQueries = `
  47. BEGIN;
  48. CREATE TABLE IF NOT EXISTS tier (
  49. id TEXT PRIMARY KEY,
  50. code TEXT NOT NULL,
  51. name TEXT NOT NULL,
  52. messages_limit INT NOT NULL,
  53. messages_expiry_duration INT NOT NULL,
  54. emails_limit INT NOT NULL,
  55. calls_limit INT NOT NULL,
  56. reservations_limit INT NOT NULL,
  57. attachment_file_size_limit INT NOT NULL,
  58. attachment_total_size_limit INT NOT NULL,
  59. attachment_expiry_duration INT NOT NULL,
  60. attachment_bandwidth_limit INT NOT NULL,
  61. stripe_monthly_price_id TEXT,
  62. stripe_yearly_price_id TEXT
  63. );
  64. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  65. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  66. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  67. CREATE TABLE IF NOT EXISTS user (
  68. id TEXT PRIMARY KEY,
  69. tier_id TEXT,
  70. user TEXT NOT NULL,
  71. pass TEXT NOT NULL,
  72. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  73. prefs JSON NOT NULL DEFAULT '{}',
  74. sync_topic TEXT NOT NULL,
  75. provisioned INT NOT NULL,
  76. stats_messages INT NOT NULL DEFAULT (0),
  77. stats_emails INT NOT NULL DEFAULT (0),
  78. stats_calls INT NOT NULL DEFAULT (0),
  79. stripe_customer_id TEXT,
  80. stripe_subscription_id TEXT,
  81. stripe_subscription_status TEXT,
  82. stripe_subscription_interval TEXT,
  83. stripe_subscription_paid_until INT,
  84. stripe_subscription_cancel_at INT,
  85. created INT NOT NULL,
  86. deleted INT,
  87. FOREIGN KEY (tier_id) REFERENCES tier (id)
  88. );
  89. CREATE UNIQUE INDEX idx_user ON user (user);
  90. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  91. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  92. CREATE TABLE IF NOT EXISTS user_access (
  93. user_id TEXT NOT NULL,
  94. topic TEXT NOT NULL,
  95. read INT NOT NULL,
  96. write INT NOT NULL,
  97. owner_user_id INT,
  98. provisioned INT NOT NULL,
  99. PRIMARY KEY (user_id, topic),
  100. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  101. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  102. );
  103. CREATE TABLE IF NOT EXISTS user_token (
  104. user_id TEXT NOT NULL,
  105. token TEXT NOT NULL,
  106. label TEXT NOT NULL,
  107. last_access INT NOT NULL,
  108. last_origin TEXT NOT NULL,
  109. expires INT NOT NULL,
  110. provisioned INT NOT NULL,
  111. PRIMARY KEY (user_id, token),
  112. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  113. );
  114. CREATE UNIQUE INDEX idx_user_token ON user_token (token);
  115. CREATE TABLE IF NOT EXISTS user_phone (
  116. user_id TEXT NOT NULL,
  117. phone_number TEXT NOT NULL,
  118. PRIMARY KEY (user_id, phone_number),
  119. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  120. );
  121. CREATE TABLE IF NOT EXISTS schemaVersion (
  122. id INT PRIMARY KEY,
  123. version INT NOT NULL
  124. );
  125. INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
  126. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
  127. ON CONFLICT (id) DO NOTHING;
  128. COMMIT;
  129. `
  130. builtinStartupQueries = `
  131. PRAGMA foreign_keys = ON;
  132. `
  133. selectUserByIDQuery = `
  134. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  135. FROM user u
  136. LEFT JOIN tier t on t.id = u.tier_id
  137. WHERE u.id = ?
  138. `
  139. selectUserByNameQuery = `
  140. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  141. FROM user u
  142. LEFT JOIN tier t on t.id = u.tier_id
  143. WHERE user = ?
  144. `
  145. selectUserByTokenQuery = `
  146. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  147. FROM user u
  148. JOIN user_token tk on u.id = tk.user_id
  149. LEFT JOIN tier t on t.id = u.tier_id
  150. WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
  151. `
  152. selectUserByStripeCustomerIDQuery = `
  153. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  154. FROM user u
  155. LEFT JOIN tier t on t.id = u.tier_id
  156. WHERE u.stripe_customer_id = ?
  157. `
  158. selectTopicPermsQuery = `
  159. SELECT read, write
  160. FROM user_access a
  161. JOIN user u ON u.id = a.user_id
  162. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
  163. ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
  164. `
  165. insertUserQuery = `
  166. INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
  167. VALUES (?, ?, ?, ?, ?, ?, ?)
  168. `
  169. selectUsernamesQuery = `
  170. SELECT user
  171. FROM user
  172. ORDER BY
  173. CASE role
  174. WHEN 'admin' THEN 1
  175. WHEN 'anonymous' THEN 3
  176. ELSE 2
  177. END, user
  178. `
  179. selectUserCountQuery = `SELECT COUNT(*) FROM user`
  180. selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
  181. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  182. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  183. updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
  184. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
  185. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
  186. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
  187. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  188. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  189. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  190. upsertUserAccessQuery = `
  191. INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
  192. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
  193. ON CONFLICT (user_id, topic)
  194. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
  195. `
  196. selectUserAllAccessQuery = `
  197. SELECT user_id, topic, read, write, provisioned
  198. FROM user_access
  199. ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
  200. `
  201. selectUserAccessQuery = `
  202. SELECT topic, read, write, provisioned
  203. FROM user_access
  204. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  205. ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
  206. `
  207. selectUserReservationsQuery = `
  208. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  209. FROM user_access a_user
  210. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  211. WHERE a_user.user_id = a_user.owner_user_id
  212. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  213. ORDER BY a_user.topic
  214. `
  215. selectUserReservationsCountQuery = `
  216. SELECT COUNT(*)
  217. FROM user_access
  218. WHERE user_id = owner_user_id
  219. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  220. `
  221. selectUserReservationsOwnerQuery = `
  222. SELECT owner_user_id
  223. FROM user_access
  224. WHERE topic = ?
  225. AND user_id = owner_user_id
  226. `
  227. selectUserHasReservationQuery = `
  228. SELECT COUNT(*)
  229. FROM user_access
  230. WHERE user_id = owner_user_id
  231. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  232. AND topic = ?
  233. `
  234. selectOtherAccessCountQuery = `
  235. SELECT COUNT(*)
  236. FROM user_access
  237. WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
  238. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  239. `
  240. deleteAllAccessQuery = `DELETE FROM user_access`
  241. deleteUserAccessQuery = `
  242. DELETE FROM user_access
  243. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  244. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  245. `
  246. deleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
  247. deleteTopicAccessQuery = `
  248. DELETE FROM user_access
  249. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  250. AND topic = ?
  251. `
  252. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  253. selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
  254. selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
  255. selectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
  256. upsertTokenQuery = `
  257. INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
  258. VALUES (?, ?, ?, ?, ?, ?, ?)
  259. ON CONFLICT (user_id, token)
  260. DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
  261. `
  262. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  263. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  264. updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
  265. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  266. deleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
  267. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  268. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  269. deleteExcessTokensQuery = `
  270. DELETE FROM user_token
  271. WHERE user_id = ?
  272. AND (user_id, token) NOT IN (
  273. SELECT user_id, token
  274. FROM user_token
  275. WHERE user_id = ?
  276. ORDER BY expires DESC
  277. LIMIT ?
  278. )
  279. `
  280. selectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?`
  281. insertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
  282. deletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
  283. insertTierQuery = `
  284. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
  285. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  286. `
  287. updateTierQuery = `
  288. UPDATE tier
  289. SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
  290. WHERE code = ?
  291. `
  292. selectTiersQuery = `
  293. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  294. FROM tier
  295. `
  296. selectTierByCodeQuery = `
  297. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  298. FROM tier
  299. WHERE code = ?
  300. `
  301. selectTierByPriceIDQuery = `
  302. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  303. FROM tier
  304. WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
  305. `
  306. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  307. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  308. deleteTierQuery = `DELETE FROM tier WHERE code = ?`
  309. updateBillingQuery = `
  310. UPDATE user
  311. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  312. WHERE user = ?
  313. `
  314. )
  315. // Schema management queries
  316. const (
  317. currentSchemaVersion = 6
  318. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  319. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  320. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  321. // 1 -> 2 (complex migration!)
  322. migrate1To2CreateTablesQueries = `
  323. ALTER TABLE user RENAME TO user_old;
  324. CREATE TABLE IF NOT EXISTS tier (
  325. id TEXT PRIMARY KEY,
  326. code TEXT NOT NULL,
  327. name TEXT NOT NULL,
  328. messages_limit INT NOT NULL,
  329. messages_expiry_duration INT NOT NULL,
  330. emails_limit INT NOT NULL,
  331. reservations_limit INT NOT NULL,
  332. attachment_file_size_limit INT NOT NULL,
  333. attachment_total_size_limit INT NOT NULL,
  334. attachment_expiry_duration INT NOT NULL,
  335. attachment_bandwidth_limit INT NOT NULL,
  336. stripe_price_id TEXT
  337. );
  338. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  339. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  340. CREATE TABLE IF NOT EXISTS user (
  341. id TEXT PRIMARY KEY,
  342. tier_id TEXT,
  343. user TEXT NOT NULL,
  344. pass TEXT NOT NULL,
  345. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  346. prefs JSON NOT NULL DEFAULT '{}',
  347. sync_topic TEXT NOT NULL,
  348. stats_messages INT NOT NULL DEFAULT (0),
  349. stats_emails INT NOT NULL DEFAULT (0),
  350. stripe_customer_id TEXT,
  351. stripe_subscription_id TEXT,
  352. stripe_subscription_status TEXT,
  353. stripe_subscription_paid_until INT,
  354. stripe_subscription_cancel_at INT,
  355. created INT NOT NULL,
  356. deleted INT,
  357. FOREIGN KEY (tier_id) REFERENCES tier (id)
  358. );
  359. CREATE UNIQUE INDEX idx_user ON user (user);
  360. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  361. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  362. CREATE TABLE IF NOT EXISTS user_access (
  363. user_id TEXT NOT NULL,
  364. topic TEXT NOT NULL,
  365. read INT NOT NULL,
  366. write INT NOT NULL,
  367. owner_user_id INT,
  368. PRIMARY KEY (user_id, topic),
  369. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  370. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  371. );
  372. CREATE TABLE IF NOT EXISTS user_token (
  373. user_id TEXT NOT NULL,
  374. token TEXT NOT NULL,
  375. label TEXT NOT NULL,
  376. last_access INT NOT NULL,
  377. last_origin TEXT NOT NULL,
  378. expires INT NOT NULL,
  379. PRIMARY KEY (user_id, token),
  380. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  381. );
  382. CREATE TABLE IF NOT EXISTS schemaVersion (
  383. id INT PRIMARY KEY,
  384. version INT NOT NULL
  385. );
  386. INSERT INTO user (id, user, pass, role, sync_topic, created)
  387. VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
  388. ON CONFLICT (id) DO NOTHING;
  389. `
  390. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  391. migrate1To2InsertUserNoTx = `
  392. INSERT INTO user (id, user, pass, role, sync_topic, created)
  393. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  394. `
  395. migrate1To2InsertFromOldTablesAndDropNoTx = `
  396. INSERT INTO user_access (user_id, topic, read, write)
  397. SELECT u.id, a.topic, a.read, a.write
  398. FROM user u
  399. JOIN access a ON u.user = a.user;
  400. DROP TABLE access;
  401. DROP TABLE user_old;
  402. `
  403. // 2 -> 3
  404. migrate2To3UpdateQueries = `
  405. ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
  406. ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
  407. ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
  408. DROP INDEX IF EXISTS idx_tier_price_id;
  409. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  410. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  411. `
  412. // 3 -> 4
  413. migrate3To4UpdateQueries = `
  414. ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
  415. ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
  416. CREATE TABLE IF NOT EXISTS user_phone (
  417. user_id TEXT NOT NULL,
  418. phone_number TEXT NOT NULL,
  419. PRIMARY KEY (user_id, phone_number),
  420. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  421. );
  422. `
  423. // 4 -> 5
  424. migrate4To5UpdateQueries = `
  425. UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
  426. `
  427. // 5 -> 6
  428. migrate5To6UpdateQueries = `
  429. PRAGMA foreign_keys=off;
  430. -- Alter user table: Add provisioned column
  431. ALTER TABLE user RENAME TO user_old;
  432. CREATE TABLE IF NOT EXISTS user (
  433. id TEXT PRIMARY KEY,
  434. tier_id TEXT,
  435. user TEXT NOT NULL,
  436. pass TEXT NOT NULL,
  437. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  438. prefs JSON NOT NULL DEFAULT '{}',
  439. sync_topic TEXT NOT NULL,
  440. provisioned INT NOT NULL,
  441. stats_messages INT NOT NULL DEFAULT (0),
  442. stats_emails INT NOT NULL DEFAULT (0),
  443. stats_calls INT NOT NULL DEFAULT (0),
  444. stripe_customer_id TEXT,
  445. stripe_subscription_id TEXT,
  446. stripe_subscription_status TEXT,
  447. stripe_subscription_interval TEXT,
  448. stripe_subscription_paid_until INT,
  449. stripe_subscription_cancel_at INT,
  450. created INT NOT NULL,
  451. deleted INT,
  452. FOREIGN KEY (tier_id) REFERENCES tier (id)
  453. );
  454. INSERT INTO user
  455. SELECT
  456. id,
  457. tier_id,
  458. user,
  459. pass,
  460. role,
  461. prefs,
  462. sync_topic,
  463. 0, -- provisioned
  464. stats_messages,
  465. stats_emails,
  466. stats_calls,
  467. stripe_customer_id,
  468. stripe_subscription_id,
  469. stripe_subscription_status,
  470. stripe_subscription_interval,
  471. stripe_subscription_paid_until,
  472. stripe_subscription_cancel_at,
  473. created,
  474. deleted
  475. FROM user_old;
  476. DROP TABLE user_old;
  477. -- Alter user_access table: Add provisioned column
  478. ALTER TABLE user_access RENAME TO user_access_old;
  479. CREATE TABLE user_access (
  480. user_id TEXT NOT NULL,
  481. topic TEXT NOT NULL,
  482. read INT NOT NULL,
  483. write INT NOT NULL,
  484. owner_user_id INT,
  485. provisioned INTEGER NOT NULL,
  486. PRIMARY KEY (user_id, topic),
  487. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  488. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  489. );
  490. INSERT INTO user_access SELECT *, 0 FROM user_access_old;
  491. DROP TABLE user_access_old;
  492. -- Alter user_token table: Add provisioned column
  493. ALTER TABLE user_token RENAME TO user_token_old;
  494. CREATE TABLE IF NOT EXISTS user_token (
  495. user_id TEXT NOT NULL,
  496. token TEXT NOT NULL,
  497. label TEXT NOT NULL,
  498. last_access INT NOT NULL,
  499. last_origin TEXT NOT NULL,
  500. expires INT NOT NULL,
  501. provisioned INT NOT NULL,
  502. PRIMARY KEY (user_id, token),
  503. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  504. );
  505. INSERT INTO user_token SELECT *, 0 FROM user_token_old;
  506. DROP TABLE user_token_old;
  507. -- Recreate indices
  508. CREATE UNIQUE INDEX idx_user ON user (user);
  509. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  510. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  511. CREATE UNIQUE INDEX idx_user_token ON user_token (token);
  512. -- Re-enable foreign keys
  513. PRAGMA foreign_keys=on;
  514. `
  515. )
  516. var (
  517. migrations = map[int]func(db *sql.DB) error{
  518. 1: migrateFrom1,
  519. 2: migrateFrom2,
  520. 3: migrateFrom3,
  521. 4: migrateFrom4,
  522. 5: migrateFrom5,
  523. }
  524. )
  525. // Manager is an implementation of Manager. It stores users and access control list
  526. // in a SQLite database.
  527. type Manager struct {
  528. config *Config
  529. db *sql.DB
  530. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  531. tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
  532. mu sync.Mutex
  533. }
  534. // Config holds the configuration for the user Manager
  535. type Config struct {
  536. Filename string // Database filename, e.g. "/var/lib/ntfy/user.db"
  537. StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers
  538. DefaultAccess Permission // Default permission if no ACL matches
  539. ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
  540. Users []*User // Predefined users to create on startup
  541. Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
  542. Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
  543. QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
  544. BcryptCost int // Cost of generated passwords; lowering makes testing faster
  545. }
  546. var _ Auther = (*Manager)(nil)
  547. // NewManager creates a new Manager instance
  548. func NewManager(config *Config) (*Manager, error) {
  549. // Set defaults
  550. if config.BcryptCost <= 0 {
  551. config.BcryptCost = DefaultUserPasswordBcryptCost
  552. }
  553. if config.QueueWriterInterval.Seconds() <= 0 {
  554. config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
  555. }
  556. // Check the parent directory of the database file (makes for friendly error messages)
  557. parentDir := filepath.Dir(config.Filename)
  558. if !util.FileExists(parentDir) {
  559. return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
  560. }
  561. // Open DB and run setup queries
  562. db, err := sql.Open("sqlite3", config.Filename)
  563. if err != nil {
  564. return nil, err
  565. }
  566. if err := setupDB(db); err != nil {
  567. return nil, err
  568. }
  569. if err := runStartupQueries(db, config.StartupQueries); err != nil {
  570. return nil, err
  571. }
  572. manager := &Manager{
  573. db: db,
  574. config: config,
  575. statsQueue: make(map[string]*Stats),
  576. tokenQueue: make(map[string]*TokenUpdate),
  577. }
  578. if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil {
  579. return nil, err
  580. }
  581. go manager.asyncQueueWriter(config.QueueWriterInterval)
  582. return manager, nil
  583. }
  584. // Authenticate checks username and password and returns a User if correct, and the user has not been
  585. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  586. // the password is correct or incorrect.
  587. func (a *Manager) Authenticate(username, password string) (*User, error) {
  588. if username == Everyone {
  589. return nil, ErrUnauthenticated
  590. }
  591. user, err := a.User(username)
  592. if err != nil {
  593. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
  594. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  595. return nil, ErrUnauthenticated
  596. } else if user.Deleted {
  597. log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
  598. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  599. return nil, ErrUnauthenticated
  600. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  601. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
  602. return nil, ErrUnauthenticated
  603. }
  604. return user, nil
  605. }
  606. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  607. // The method sets the User.Token value to the token that was used for authentication.
  608. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  609. if len(token) != tokenLength {
  610. return nil, ErrUnauthenticated
  611. }
  612. user, err := a.userByToken(token)
  613. if err != nil {
  614. log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
  615. return nil, ErrUnauthenticated
  616. }
  617. user.Token = token
  618. return user, nil
  619. }
  620. // CreateToken generates a random token for the given user and returns it. The token expires
  621. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  622. // given user, if there are too many of them.
  623. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
  624. return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
  625. return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned)
  626. })
  627. }
  628. func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
  629. access := time.Now()
  630. if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil {
  631. return nil, err
  632. }
  633. rows, err := tx.Query(selectTokenCountQuery, userID)
  634. if err != nil {
  635. return nil, err
  636. }
  637. defer rows.Close()
  638. if !rows.Next() {
  639. return nil, errNoRows
  640. }
  641. var tokenCount int
  642. if err := rows.Scan(&tokenCount); err != nil {
  643. return nil, err
  644. }
  645. if tokenCount >= tokenMaxCount {
  646. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  647. // on two indices, whereas the query below is a full table scan.
  648. if _, err := tx.Exec(deleteExcessTokensQuery, userID, userID, tokenMaxCount); err != nil {
  649. return nil, err
  650. }
  651. }
  652. return &Token{
  653. Value: token,
  654. Label: label,
  655. LastAccess: access,
  656. LastOrigin: origin,
  657. Expires: expires,
  658. Provisioned: provisioned,
  659. }, nil
  660. }
  661. // Tokens returns all existing tokens for the user with the given user ID
  662. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  663. rows, err := a.db.Query(selectTokensQuery, userID)
  664. if err != nil {
  665. return nil, err
  666. }
  667. defer rows.Close()
  668. tokens := make([]*Token, 0)
  669. for {
  670. token, err := a.readToken(rows)
  671. if errors.Is(err, ErrTokenNotFound) {
  672. break
  673. } else if err != nil {
  674. return nil, err
  675. }
  676. tokens = append(tokens, token)
  677. }
  678. return tokens, nil
  679. }
  680. func (a *Manager) allProvisionedTokens() ([]*Token, error) {
  681. rows, err := a.db.Query(selectAllProvisionedTokensQuery)
  682. if err != nil {
  683. return nil, err
  684. }
  685. defer rows.Close()
  686. tokens := make([]*Token, 0)
  687. for {
  688. token, err := a.readToken(rows)
  689. if errors.Is(err, ErrTokenNotFound) {
  690. break
  691. } else if err != nil {
  692. return nil, err
  693. }
  694. tokens = append(tokens, token)
  695. }
  696. return tokens, nil
  697. }
  698. // Token returns a specific token for a user
  699. func (a *Manager) Token(userID, token string) (*Token, error) {
  700. rows, err := a.db.Query(selectTokenQuery, userID, token)
  701. if err != nil {
  702. return nil, err
  703. }
  704. defer rows.Close()
  705. return a.readToken(rows)
  706. }
  707. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  708. var token, label, lastOrigin string
  709. var lastAccess, expires int64
  710. var provisioned bool
  711. if !rows.Next() {
  712. return nil, ErrTokenNotFound
  713. }
  714. if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
  715. return nil, err
  716. } else if err := rows.Err(); err != nil {
  717. return nil, err
  718. }
  719. lastOriginIP, err := netip.ParseAddr(lastOrigin)
  720. if err != nil {
  721. lastOriginIP = netip.IPv4Unspecified()
  722. }
  723. return &Token{
  724. Value: token,
  725. Label: label,
  726. LastAccess: time.Unix(lastAccess, 0),
  727. LastOrigin: lastOriginIP,
  728. Expires: time.Unix(expires, 0),
  729. Provisioned: provisioned,
  730. }, nil
  731. }
  732. // ChangeToken updates a token's label and/or expiry date
  733. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  734. if token == "" {
  735. return nil, errNoTokenProvided
  736. }
  737. if err := a.CanChangeToken(userID, token); err != nil {
  738. return nil, err
  739. }
  740. tx, err := a.db.Begin()
  741. if err != nil {
  742. return nil, err
  743. }
  744. defer tx.Rollback()
  745. if label != nil {
  746. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  747. return nil, err
  748. }
  749. }
  750. if expires != nil {
  751. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  752. return nil, err
  753. }
  754. }
  755. if err := tx.Commit(); err != nil {
  756. return nil, err
  757. }
  758. return a.Token(userID, token)
  759. }
  760. // RemoveToken deletes the token defined in User.Token
  761. func (a *Manager) RemoveToken(userID, token string) error {
  762. if err := a.CanChangeToken(userID, token); err != nil {
  763. return err
  764. }
  765. return execTx(a.db, func(tx *sql.Tx) error {
  766. return a.removeTokenTx(tx, userID, token)
  767. })
  768. }
  769. func (a *Manager) removeTokenTx(tx *sql.Tx, userID, token string) error {
  770. if token == "" {
  771. return errNoTokenProvided
  772. }
  773. if _, err := tx.Exec(deleteTokenQuery, userID, token); err != nil {
  774. return err
  775. }
  776. return nil
  777. }
  778. // CanChangeToken checks if the token can be changed. If the token is provisioned, it cannot be changed.
  779. func (a *Manager) CanChangeToken(userID, token string) error {
  780. t, err := a.Token(userID, token)
  781. if err != nil {
  782. return err
  783. } else if t.Provisioned {
  784. return ErrProvisionedTokenChange
  785. }
  786. return nil
  787. }
  788. // RemoveExpiredTokens deletes all expired tokens from the database
  789. func (a *Manager) RemoveExpiredTokens() error {
  790. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  791. return err
  792. }
  793. return nil
  794. }
  795. // PhoneNumbers returns all phone numbers for the user with the given user ID
  796. func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
  797. rows, err := a.db.Query(selectPhoneNumbersQuery, userID)
  798. if err != nil {
  799. return nil, err
  800. }
  801. defer rows.Close()
  802. phoneNumbers := make([]string, 0)
  803. for {
  804. phoneNumber, err := a.readPhoneNumber(rows)
  805. if errors.Is(err, ErrPhoneNumberNotFound) {
  806. break
  807. } else if err != nil {
  808. return nil, err
  809. }
  810. phoneNumbers = append(phoneNumbers, phoneNumber)
  811. }
  812. return phoneNumbers, nil
  813. }
  814. func (a *Manager) readPhoneNumber(rows *sql.Rows) (string, error) {
  815. var phoneNumber string
  816. if !rows.Next() {
  817. return "", ErrPhoneNumberNotFound
  818. }
  819. if err := rows.Scan(&phoneNumber); err != nil {
  820. return "", err
  821. } else if err := rows.Err(); err != nil {
  822. return "", err
  823. }
  824. return phoneNumber, nil
  825. }
  826. // AddPhoneNumber adds a phone number to the user with the given user ID
  827. func (a *Manager) AddPhoneNumber(userID string, phoneNumber string) error {
  828. if _, err := a.db.Exec(insertPhoneNumberQuery, userID, phoneNumber); err != nil {
  829. if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
  830. return ErrPhoneNumberExists
  831. }
  832. return err
  833. }
  834. return nil
  835. }
  836. // RemovePhoneNumber deletes a phone number from the user with the given user ID
  837. func (a *Manager) RemovePhoneNumber(userID string, phoneNumber string) error {
  838. _, err := a.db.Exec(deletePhoneNumberQuery, userID, phoneNumber)
  839. return err
  840. }
  841. // RemoveDeletedUsers deletes all users that have been marked deleted for
  842. func (a *Manager) RemoveDeletedUsers() error {
  843. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  844. return err
  845. }
  846. return nil
  847. }
  848. // ChangeSettings persists the user settings
  849. func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
  850. b, err := json.Marshal(prefs)
  851. if err != nil {
  852. return err
  853. }
  854. if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil {
  855. return err
  856. }
  857. return nil
  858. }
  859. // ResetStats resets all user stats in the user database. This touches all users.
  860. func (a *Manager) ResetStats() error {
  861. a.mu.Lock() // Includes database query to avoid races!
  862. defer a.mu.Unlock()
  863. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  864. return err
  865. }
  866. a.statsQueue = make(map[string]*Stats)
  867. return nil
  868. }
  869. // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  870. // batches at a regular interval
  871. func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
  872. a.mu.Lock()
  873. defer a.mu.Unlock()
  874. a.statsQueue[userID] = stats
  875. }
  876. // EnqueueTokenUpdate adds the token update to a queue which writes out token access times
  877. // in batches at a regular interval
  878. func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
  879. a.mu.Lock()
  880. defer a.mu.Unlock()
  881. a.tokenQueue[tokenID] = update
  882. }
  883. func (a *Manager) asyncQueueWriter(interval time.Duration) {
  884. ticker := time.NewTicker(interval)
  885. for range ticker.C {
  886. if err := a.writeUserStatsQueue(); err != nil {
  887. log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
  888. }
  889. if err := a.writeTokenUpdateQueue(); err != nil {
  890. log.Tag(tag).Err(err).Warn("Writing token update queue failed")
  891. }
  892. }
  893. }
  894. func (a *Manager) writeUserStatsQueue() error {
  895. a.mu.Lock()
  896. if len(a.statsQueue) == 0 {
  897. a.mu.Unlock()
  898. log.Tag(tag).Trace("No user stats updates to commit")
  899. return nil
  900. }
  901. statsQueue := a.statsQueue
  902. a.statsQueue = make(map[string]*Stats)
  903. a.mu.Unlock()
  904. tx, err := a.db.Begin()
  905. if err != nil {
  906. return err
  907. }
  908. defer tx.Rollback()
  909. log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
  910. for userID, update := range statsQueue {
  911. log.
  912. Tag(tag).
  913. Fields(log.Context{
  914. "user_id": userID,
  915. "messages_count": update.Messages,
  916. "emails_count": update.Emails,
  917. "calls_count": update.Calls,
  918. }).
  919. Trace("Updating stats for user %s", userID)
  920. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, update.Calls, userID); err != nil {
  921. return err
  922. }
  923. }
  924. return tx.Commit()
  925. }
  926. func (a *Manager) writeTokenUpdateQueue() error {
  927. a.mu.Lock()
  928. if len(a.tokenQueue) == 0 {
  929. a.mu.Unlock()
  930. log.Tag(tag).Trace("No token updates to commit")
  931. return nil
  932. }
  933. tokenQueue := a.tokenQueue
  934. a.tokenQueue = make(map[string]*TokenUpdate)
  935. a.mu.Unlock()
  936. tx, err := a.db.Begin()
  937. if err != nil {
  938. return err
  939. }
  940. defer tx.Rollback()
  941. log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
  942. for tokenID, update := range tokenQueue {
  943. log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
  944. if err := a.updateTokenLastAccessTx(tx, tokenID, update.LastAccess.Unix(), update.LastOrigin.String()); err != nil {
  945. return err
  946. }
  947. }
  948. return tx.Commit()
  949. }
  950. func (a *Manager) updateTokenLastAccessTx(tx *sql.Tx, token string, lastAccess int64, lastOrigin string) error {
  951. if _, err := tx.Exec(updateTokenLastAccessQuery, lastAccess, lastOrigin, token); err != nil {
  952. return err
  953. }
  954. return nil
  955. }
  956. // Authorize returns nil if the given user has access to the given topic using the desired
  957. // permission. The user param may be nil to signal an anonymous user.
  958. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  959. if user != nil && user.Role == RoleAdmin {
  960. return nil // Admin can do everything
  961. }
  962. username := Everyone
  963. if user != nil {
  964. username = user.Name
  965. }
  966. // Select the read/write permissions for this user/topic combo.
  967. // - The query may return two rows (one for everyone, and one for the user), but prioritizes the user.
  968. // - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
  969. // - It also prioritizes write permissions over read permissions
  970. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  971. if err != nil {
  972. return err
  973. }
  974. defer rows.Close()
  975. if !rows.Next() {
  976. return a.resolvePerms(a.config.DefaultAccess, perm)
  977. }
  978. var read, write bool
  979. if err := rows.Scan(&read, &write); err != nil {
  980. return err
  981. } else if err := rows.Err(); err != nil {
  982. return err
  983. }
  984. return a.resolvePerms(NewPermission(read, write), perm)
  985. }
  986. func (a *Manager) resolvePerms(base, perm Permission) error {
  987. if perm == PermissionRead && base.IsRead() {
  988. return nil
  989. } else if perm == PermissionWrite && base.IsWrite() {
  990. return nil
  991. }
  992. return ErrUnauthorized
  993. }
  994. // AddUser adds a user with the given username, password and role
  995. func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
  996. return execTx(a.db, func(tx *sql.Tx) error {
  997. return a.addUserTx(tx, username, password, role, hashed, false)
  998. })
  999. }
  1000. // AddUser adds a user with the given username, password and role
  1001. func (a *Manager) addUserTx(tx *sql.Tx, username, password string, role Role, hashed, provisioned bool) error {
  1002. if !AllowedUsername(username) || !AllowedRole(role) {
  1003. return ErrInvalidArgument
  1004. }
  1005. var hash string
  1006. var err error = nil
  1007. if hashed {
  1008. hash = password
  1009. if err := ValidPasswordHash(hash); err != nil {
  1010. return err
  1011. }
  1012. } else {
  1013. hash, err = hashPassword(password, a.config.BcryptCost)
  1014. if err != nil {
  1015. return err
  1016. }
  1017. }
  1018. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1019. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  1020. if _, err = tx.Exec(insertUserQuery, userID, username, hash, role, syncTopic, provisioned, now); err != nil {
  1021. if errors.Is(err, sqlite3.ErrConstraintUnique) {
  1022. return ErrUserExists
  1023. }
  1024. return err
  1025. }
  1026. return nil
  1027. }
  1028. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  1029. // if the user did not exist in the first place.
  1030. func (a *Manager) RemoveUser(username string) error {
  1031. if err := a.CanChangeUser(username); err != nil {
  1032. return err
  1033. }
  1034. return execTx(a.db, func(tx *sql.Tx) error {
  1035. return a.removeUserTx(tx, username)
  1036. })
  1037. }
  1038. func (a *Manager) removeUserTx(tx *sql.Tx, username string) error {
  1039. if !AllowedUsername(username) {
  1040. return ErrInvalidArgument
  1041. }
  1042. // Rows in user_access, user_token, etc. are deleted via foreign keys
  1043. if _, err := tx.Exec(deleteUserQuery, username); err != nil {
  1044. return err
  1045. }
  1046. return nil
  1047. }
  1048. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  1049. // successful auth via Authenticate. A background process will delete the user at a later date.
  1050. func (a *Manager) MarkUserRemoved(user *User) error {
  1051. if !AllowedUsername(user.Name) {
  1052. return ErrInvalidArgument
  1053. }
  1054. tx, err := a.db.Begin()
  1055. if err != nil {
  1056. return err
  1057. }
  1058. defer tx.Rollback()
  1059. if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  1060. return err
  1061. }
  1062. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  1063. return err
  1064. }
  1065. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  1066. return err
  1067. }
  1068. return tx.Commit()
  1069. }
  1070. // Users returns a list of users. It always also returns the Everyone user ("*").
  1071. func (a *Manager) Users() ([]*User, error) {
  1072. rows, err := a.db.Query(selectUsernamesQuery)
  1073. if err != nil {
  1074. return nil, err
  1075. }
  1076. defer rows.Close()
  1077. usernames := make([]string, 0)
  1078. for rows.Next() {
  1079. var username string
  1080. if err := rows.Scan(&username); err != nil {
  1081. return nil, err
  1082. } else if err := rows.Err(); err != nil {
  1083. return nil, err
  1084. }
  1085. usernames = append(usernames, username)
  1086. }
  1087. rows.Close()
  1088. users := make([]*User, 0)
  1089. for _, username := range usernames {
  1090. user, err := a.User(username)
  1091. if err != nil {
  1092. return nil, err
  1093. }
  1094. users = append(users, user)
  1095. }
  1096. return users, nil
  1097. }
  1098. // UsersCount returns the number of users in the databsae
  1099. func (a *Manager) UsersCount() (int64, error) {
  1100. rows, err := a.db.Query(selectUserCountQuery)
  1101. if err != nil {
  1102. return 0, err
  1103. }
  1104. defer rows.Close()
  1105. if !rows.Next() {
  1106. return 0, errNoRows
  1107. }
  1108. var count int64
  1109. if err := rows.Scan(&count); err != nil {
  1110. return 0, err
  1111. }
  1112. return count, nil
  1113. }
  1114. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  1115. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  1116. func (a *Manager) User(username string) (*User, error) {
  1117. rows, err := a.db.Query(selectUserByNameQuery, username)
  1118. if err != nil {
  1119. return nil, err
  1120. }
  1121. return a.readUser(rows)
  1122. }
  1123. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  1124. func (a *Manager) UserByID(id string) (*User, error) {
  1125. rows, err := a.db.Query(selectUserByIDQuery, id)
  1126. if err != nil {
  1127. return nil, err
  1128. }
  1129. return a.readUser(rows)
  1130. }
  1131. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  1132. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  1133. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  1134. if err != nil {
  1135. return nil, err
  1136. }
  1137. return a.readUser(rows)
  1138. }
  1139. func (a *Manager) userByToken(token string) (*User, error) {
  1140. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  1141. if err != nil {
  1142. return nil, err
  1143. }
  1144. return a.readUser(rows)
  1145. }
  1146. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  1147. defer rows.Close()
  1148. var id, username, hash, role, prefs, syncTopic string
  1149. var provisioned bool
  1150. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
  1151. var messages, emails, calls int64
  1152. var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  1153. if !rows.Next() {
  1154. return nil, ErrUserNotFound
  1155. }
  1156. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  1157. return nil, err
  1158. } else if err := rows.Err(); err != nil {
  1159. return nil, err
  1160. }
  1161. user := &User{
  1162. ID: id,
  1163. Name: username,
  1164. Hash: hash,
  1165. Role: Role(role),
  1166. Prefs: &Prefs{},
  1167. SyncTopic: syncTopic,
  1168. Provisioned: provisioned,
  1169. Stats: &Stats{
  1170. Messages: messages,
  1171. Emails: emails,
  1172. Calls: calls,
  1173. },
  1174. Billing: &Billing{
  1175. StripeCustomerID: stripeCustomerID.String, // May be empty
  1176. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  1177. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  1178. StripeSubscriptionInterval: stripe.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
  1179. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  1180. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  1181. },
  1182. Deleted: deleted.Valid,
  1183. }
  1184. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  1185. return nil, err
  1186. }
  1187. if tierCode.Valid {
  1188. // See readTier() when this is changed!
  1189. user.Tier = &Tier{
  1190. ID: tierID.String,
  1191. Code: tierCode.String,
  1192. Name: tierName.String,
  1193. MessageLimit: messagesLimit.Int64,
  1194. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1195. EmailLimit: emailsLimit.Int64,
  1196. CallLimit: callsLimit.Int64,
  1197. ReservationLimit: reservationsLimit.Int64,
  1198. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1199. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1200. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1201. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1202. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  1203. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  1204. }
  1205. }
  1206. return user, nil
  1207. }
  1208. // AllGrants returns all user-specific access control entries, mapped to their respective user IDs
  1209. func (a *Manager) AllGrants() (map[string][]Grant, error) {
  1210. rows, err := a.db.Query(selectUserAllAccessQuery)
  1211. if err != nil {
  1212. return nil, err
  1213. }
  1214. defer rows.Close()
  1215. grants := make(map[string][]Grant, 0)
  1216. for rows.Next() {
  1217. var userID, topic string
  1218. var read, write, provisioned bool
  1219. if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
  1220. return nil, err
  1221. } else if err := rows.Err(); err != nil {
  1222. return nil, err
  1223. }
  1224. if _, ok := grants[userID]; !ok {
  1225. grants[userID] = make([]Grant, 0)
  1226. }
  1227. grants[userID] = append(grants[userID], Grant{
  1228. TopicPattern: fromSQLWildcard(topic),
  1229. Permission: NewPermission(read, write),
  1230. Provisioned: provisioned,
  1231. })
  1232. }
  1233. return grants, nil
  1234. }
  1235. // Grants returns all user-specific access control entries
  1236. func (a *Manager) Grants(username string) ([]Grant, error) {
  1237. rows, err := a.db.Query(selectUserAccessQuery, username)
  1238. if err != nil {
  1239. return nil, err
  1240. }
  1241. defer rows.Close()
  1242. grants := make([]Grant, 0)
  1243. for rows.Next() {
  1244. var topic string
  1245. var read, write, provisioned bool
  1246. if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
  1247. return nil, err
  1248. } else if err := rows.Err(); err != nil {
  1249. return nil, err
  1250. }
  1251. grants = append(grants, Grant{
  1252. TopicPattern: fromSQLWildcard(topic),
  1253. Permission: NewPermission(read, write),
  1254. Provisioned: provisioned,
  1255. })
  1256. }
  1257. return grants, nil
  1258. }
  1259. // Reservations returns all user-owned topics, and the associated everyone-access
  1260. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  1261. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  1262. if err != nil {
  1263. return nil, err
  1264. }
  1265. defer rows.Close()
  1266. reservations := make([]Reservation, 0)
  1267. for rows.Next() {
  1268. var topic string
  1269. var ownerRead, ownerWrite bool
  1270. var everyoneRead, everyoneWrite sql.NullBool
  1271. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  1272. return nil, err
  1273. } else if err := rows.Err(); err != nil {
  1274. return nil, err
  1275. }
  1276. reservations = append(reservations, Reservation{
  1277. Topic: unescapeUnderscore(topic),
  1278. Owner: NewPermission(ownerRead, ownerWrite),
  1279. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  1280. })
  1281. }
  1282. return reservations, nil
  1283. }
  1284. // HasReservation returns true if the given topic access is owned by the user
  1285. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  1286. rows, err := a.db.Query(selectUserHasReservationQuery, username, escapeUnderscore(topic))
  1287. if err != nil {
  1288. return false, err
  1289. }
  1290. defer rows.Close()
  1291. if !rows.Next() {
  1292. return false, errNoRows
  1293. }
  1294. var count int64
  1295. if err := rows.Scan(&count); err != nil {
  1296. return false, err
  1297. }
  1298. return count > 0, nil
  1299. }
  1300. // ReservationsCount returns the number of reservations owned by this user
  1301. func (a *Manager) ReservationsCount(username string) (int64, error) {
  1302. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  1303. if err != nil {
  1304. return 0, err
  1305. }
  1306. defer rows.Close()
  1307. if !rows.Next() {
  1308. return 0, errNoRows
  1309. }
  1310. var count int64
  1311. if err := rows.Scan(&count); err != nil {
  1312. return 0, err
  1313. }
  1314. return count, nil
  1315. }
  1316. // ReservationOwner returns user ID of the user that owns this topic, or an
  1317. // empty string if it's not owned by anyone
  1318. func (a *Manager) ReservationOwner(topic string) (string, error) {
  1319. rows, err := a.db.Query(selectUserReservationsOwnerQuery, escapeUnderscore(topic))
  1320. if err != nil {
  1321. return "", err
  1322. }
  1323. defer rows.Close()
  1324. if !rows.Next() {
  1325. return "", nil
  1326. }
  1327. var ownerUserID string
  1328. if err := rows.Scan(&ownerUserID); err != nil {
  1329. return "", err
  1330. }
  1331. return ownerUserID, nil
  1332. }
  1333. // ChangePassword changes a user's password
  1334. func (a *Manager) ChangePassword(username, password string, hashed bool) error {
  1335. if err := a.CanChangeUser(username); err != nil {
  1336. return err
  1337. }
  1338. return execTx(a.db, func(tx *sql.Tx) error {
  1339. return a.changePasswordTx(tx, username, password, hashed)
  1340. })
  1341. }
  1342. // CanChangeUser checks if the user with the given username can be changed.
  1343. // This is used to prevent changes to provisioned users, which are defined in the config file.
  1344. func (a *Manager) CanChangeUser(username string) error {
  1345. user, err := a.User(username)
  1346. if err != nil {
  1347. return err
  1348. } else if user.Provisioned {
  1349. return ErrProvisionedUserChange
  1350. }
  1351. return nil
  1352. }
  1353. func (a *Manager) changePasswordTx(tx *sql.Tx, username, password string, hashed bool) error {
  1354. var hash string
  1355. var err error
  1356. if hashed {
  1357. hash = password
  1358. if err := ValidPasswordHash(hash); err != nil {
  1359. return err
  1360. }
  1361. } else {
  1362. hash, err = hashPassword(password, a.config.BcryptCost)
  1363. if err != nil {
  1364. return err
  1365. }
  1366. }
  1367. if _, err := tx.Exec(updateUserPassQuery, hash, username); err != nil {
  1368. return err
  1369. }
  1370. return nil
  1371. }
  1372. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  1373. // all existing access control entries (Grant) are removed, since they are no longer needed.
  1374. func (a *Manager) ChangeRole(username string, role Role) error {
  1375. if err := a.CanChangeUser(username); err != nil {
  1376. return err
  1377. }
  1378. return execTx(a.db, func(tx *sql.Tx) error {
  1379. return a.changeRoleTx(tx, username, role)
  1380. })
  1381. }
  1382. func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error {
  1383. if !AllowedUsername(username) || !AllowedRole(role) {
  1384. return ErrInvalidArgument
  1385. }
  1386. if _, err := tx.Exec(updateUserRoleQuery, string(role), username); err != nil {
  1387. return err
  1388. }
  1389. if role == RoleAdmin {
  1390. if _, err := tx.Exec(deleteUserAccessQuery, username, username); err != nil {
  1391. return err
  1392. }
  1393. }
  1394. return nil
  1395. }
  1396. // changeProvisionedTx changes the provisioned status of a user. This is used to mark users as
  1397. // provisioned. A provisioned user is a user defined in the config file.
  1398. func (a *Manager) changeProvisionedTx(tx *sql.Tx, username string, provisioned bool) error {
  1399. if _, err := tx.Exec(updateUserProvisionedQuery, provisioned, username); err != nil {
  1400. return err
  1401. }
  1402. return nil
  1403. }
  1404. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  1405. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  1406. func (a *Manager) ChangeTier(username, tier string) error {
  1407. if !AllowedUsername(username) {
  1408. return ErrInvalidArgument
  1409. }
  1410. t, err := a.Tier(tier)
  1411. if err != nil {
  1412. return err
  1413. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  1414. return err
  1415. }
  1416. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  1417. return err
  1418. }
  1419. return nil
  1420. }
  1421. // ResetTier removes the tier from the given user
  1422. func (a *Manager) ResetTier(username string) error {
  1423. if !AllowedUsername(username) && username != Everyone && username != "" {
  1424. return ErrInvalidArgument
  1425. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  1426. return err
  1427. }
  1428. _, err := a.db.Exec(deleteUserTierQuery, username)
  1429. return err
  1430. }
  1431. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  1432. u, err := a.User(username)
  1433. if err != nil {
  1434. return err
  1435. }
  1436. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  1437. reservations, err := a.Reservations(username)
  1438. if err != nil {
  1439. return err
  1440. } else if int64(len(reservations)) > reservationsLimit {
  1441. return ErrTooManyReservations
  1442. }
  1443. }
  1444. return nil
  1445. }
  1446. // AllowReservation tests if a user may create an access control entry for the given topic.
  1447. // If there are any ACL entries that are not owned by the user, an error is returned.
  1448. func (a *Manager) AllowReservation(username string, topic string) error {
  1449. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  1450. return ErrInvalidArgument
  1451. }
  1452. rows, err := a.db.Query(selectOtherAccessCountQuery, escapeUnderscore(topic), escapeUnderscore(topic), username)
  1453. if err != nil {
  1454. return err
  1455. }
  1456. defer rows.Close()
  1457. if !rows.Next() {
  1458. return errNoRows
  1459. }
  1460. var otherCount int
  1461. if err := rows.Scan(&otherCount); err != nil {
  1462. return err
  1463. }
  1464. if otherCount > 0 {
  1465. return errTopicOwnedByOthers
  1466. }
  1467. return nil
  1468. }
  1469. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  1470. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  1471. // owner may either be a user (username), or the system (empty).
  1472. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  1473. return execTx(a.db, func(tx *sql.Tx) error {
  1474. return a.allowAccessTx(tx, username, topicPattern, permission, false)
  1475. })
  1476. }
  1477. func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error {
  1478. if !AllowedUsername(username) && username != Everyone {
  1479. return ErrInvalidArgument
  1480. } else if !AllowedTopicPattern(topicPattern) {
  1481. return ErrInvalidArgument
  1482. }
  1483. owner := ""
  1484. if _, err := tx.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner, provisioned); err != nil {
  1485. return err
  1486. }
  1487. return nil
  1488. }
  1489. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  1490. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  1491. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  1492. return execTx(a.db, func(tx *sql.Tx) error {
  1493. return a.resetAccessTx(tx, username, topicPattern)
  1494. })
  1495. }
  1496. func (a *Manager) resetAccessTx(tx *sql.Tx, username string, topicPattern string) error {
  1497. if !AllowedUsername(username) && username != Everyone && username != "" {
  1498. return ErrInvalidArgument
  1499. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  1500. return ErrInvalidArgument
  1501. }
  1502. if username == "" && topicPattern == "" {
  1503. _, err := tx.Exec(deleteAllAccessQuery, username)
  1504. return err
  1505. } else if topicPattern == "" {
  1506. _, err := tx.Exec(deleteUserAccessQuery, username, username)
  1507. return err
  1508. }
  1509. _, err := tx.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  1510. return err
  1511. }
  1512. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  1513. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  1514. // can modify or delete them.
  1515. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  1516. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  1517. return ErrInvalidArgument
  1518. }
  1519. tx, err := a.db.Begin()
  1520. if err != nil {
  1521. return err
  1522. }
  1523. defer tx.Rollback()
  1524. if _, err := tx.Exec(upsertUserAccessQuery, username, escapeUnderscore(topic), true, true, username, username, false); err != nil {
  1525. return err
  1526. }
  1527. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, escapeUnderscore(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil {
  1528. return err
  1529. }
  1530. return tx.Commit()
  1531. }
  1532. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  1533. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  1534. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  1535. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  1536. return ErrInvalidArgument
  1537. }
  1538. for _, topic := range topics {
  1539. if !AllowedTopic(topic) {
  1540. return ErrInvalidArgument
  1541. }
  1542. }
  1543. tx, err := a.db.Begin()
  1544. if err != nil {
  1545. return err
  1546. }
  1547. defer tx.Rollback()
  1548. for _, topic := range topics {
  1549. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, escapeUnderscore(topic)); err != nil {
  1550. return err
  1551. }
  1552. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, escapeUnderscore(topic)); err != nil {
  1553. return err
  1554. }
  1555. }
  1556. return tx.Commit()
  1557. }
  1558. // DefaultAccess returns the default read/write access if no access control entry matches
  1559. func (a *Manager) DefaultAccess() Permission {
  1560. return a.config.DefaultAccess
  1561. }
  1562. // AddTier creates a new tier in the database
  1563. func (a *Manager) AddTier(tier *Tier) error {
  1564. if tier.ID == "" {
  1565. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1566. }
  1567. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
  1568. return err
  1569. }
  1570. return nil
  1571. }
  1572. // UpdateTier updates a tier's properties in the database
  1573. func (a *Manager) UpdateTier(tier *Tier) error {
  1574. if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
  1575. return err
  1576. }
  1577. return nil
  1578. }
  1579. // RemoveTier deletes the tier with the given code
  1580. func (a *Manager) RemoveTier(code string) error {
  1581. if !AllowedTier(code) {
  1582. return ErrInvalidArgument
  1583. }
  1584. // This fails if any user has this tier
  1585. if _, err := a.db.Exec(deleteTierQuery, code); err != nil {
  1586. return err
  1587. }
  1588. return nil
  1589. }
  1590. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1591. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1592. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1593. return err
  1594. }
  1595. return nil
  1596. }
  1597. // Tiers returns a list of all Tier structs
  1598. func (a *Manager) Tiers() ([]*Tier, error) {
  1599. rows, err := a.db.Query(selectTiersQuery)
  1600. if err != nil {
  1601. return nil, err
  1602. }
  1603. defer rows.Close()
  1604. tiers := make([]*Tier, 0)
  1605. for {
  1606. tier, err := a.readTier(rows)
  1607. if errors.Is(err, ErrTierNotFound) {
  1608. break
  1609. } else if err != nil {
  1610. return nil, err
  1611. }
  1612. tiers = append(tiers, tier)
  1613. }
  1614. return tiers, nil
  1615. }
  1616. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1617. func (a *Manager) Tier(code string) (*Tier, error) {
  1618. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1619. if err != nil {
  1620. return nil, err
  1621. }
  1622. defer rows.Close()
  1623. return a.readTier(rows)
  1624. }
  1625. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1626. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1627. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID)
  1628. if err != nil {
  1629. return nil, err
  1630. }
  1631. defer rows.Close()
  1632. return a.readTier(rows)
  1633. }
  1634. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1635. var id, code, name string
  1636. var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
  1637. var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1638. if !rows.Next() {
  1639. return nil, ErrTierNotFound
  1640. }
  1641. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  1642. return nil, err
  1643. } else if err := rows.Err(); err != nil {
  1644. return nil, err
  1645. }
  1646. // When changed, note readUser() as well
  1647. return &Tier{
  1648. ID: id,
  1649. Code: code,
  1650. Name: name,
  1651. MessageLimit: messagesLimit.Int64,
  1652. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1653. EmailLimit: emailsLimit.Int64,
  1654. CallLimit: callsLimit.Int64,
  1655. ReservationLimit: reservationsLimit.Int64,
  1656. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1657. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1658. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1659. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1660. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  1661. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  1662. }, nil
  1663. }
  1664. // Close closes the underlying database
  1665. func (a *Manager) Close() error {
  1666. return a.db.Close()
  1667. }
  1668. // maybeProvisionUsersAccessAndTokens provisions users, access control entries, and tokens based on the config.
  1669. func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
  1670. if !a.config.ProvisionEnabled {
  1671. return nil
  1672. }
  1673. existingUsers, err := a.Users()
  1674. if err != nil {
  1675. return err
  1676. }
  1677. provisionUsernames := util.Map(a.config.Users, func(u *User) string {
  1678. return u.Name
  1679. })
  1680. return execTx(a.db, func(tx *sql.Tx) error {
  1681. if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
  1682. return fmt.Errorf("failed to provision users: %v", err)
  1683. }
  1684. if err := a.maybeProvisionGrants(tx); err != nil {
  1685. return fmt.Errorf("failed to provision grants: %v", err)
  1686. }
  1687. if err := a.maybeProvisionTokens(tx, provisionUsernames); err != nil {
  1688. return fmt.Errorf("failed to provision tokens: %v", err)
  1689. }
  1690. return nil
  1691. })
  1692. }
  1693. // maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
  1694. // It also removes users that are provisioned, but not in the config anymore.
  1695. func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
  1696. // Remove users that are provisioned, but not in the config anymore
  1697. for _, user := range existingUsers {
  1698. if user.Name == Everyone {
  1699. continue
  1700. } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
  1701. if err := a.removeUserTx(tx, user.Name); err != nil {
  1702. return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
  1703. }
  1704. }
  1705. }
  1706. // Add or update provisioned users
  1707. for _, user := range a.config.Users {
  1708. if user.Name == Everyone {
  1709. continue
  1710. }
  1711. existingUser, exists := util.Find(existingUsers, func(u *User) bool {
  1712. return u.Name == user.Name
  1713. })
  1714. if !exists {
  1715. if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
  1716. return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
  1717. }
  1718. } else {
  1719. if !existingUser.Provisioned {
  1720. if err := a.changeProvisionedTx(tx, user.Name, true); err != nil {
  1721. return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
  1722. }
  1723. }
  1724. if existingUser.Hash != user.Hash {
  1725. if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil {
  1726. return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
  1727. }
  1728. }
  1729. if existingUser.Role != user.Role {
  1730. if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
  1731. return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
  1732. }
  1733. }
  1734. }
  1735. }
  1736. return nil
  1737. }
  1738. // maybyProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config.
  1739. //
  1740. // Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
  1741. // access time) or do not have dependent resources (such as grants or tokens).
  1742. func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error {
  1743. // Remove all provisioned grants
  1744. if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil {
  1745. return err
  1746. }
  1747. // (Re-)add provisioned grants
  1748. for username, grants := range a.config.Access {
  1749. user, exists := util.Find(a.config.Users, func(u *User) bool {
  1750. return u.Name == username
  1751. })
  1752. if !exists && username != Everyone {
  1753. return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username)
  1754. } else if user != nil && user.Role == RoleAdmin {
  1755. return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
  1756. }
  1757. for _, grant := range grants {
  1758. if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil {
  1759. return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
  1760. }
  1761. if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
  1762. return err
  1763. }
  1764. }
  1765. }
  1766. return nil
  1767. }
  1768. func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) error {
  1769. // Remove tokens that are provisioned, but not in the config anymore
  1770. existingTokens, err := a.allProvisionedTokens()
  1771. if err != nil {
  1772. return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err)
  1773. }
  1774. var provisionTokens []string
  1775. for _, userTokens := range a.config.Tokens {
  1776. for _, token := range userTokens {
  1777. provisionTokens = append(provisionTokens, token.Value)
  1778. }
  1779. }
  1780. for _, existingToken := range existingTokens {
  1781. if !slices.Contains(provisionTokens, existingToken.Value) {
  1782. if _, err := tx.Exec(deleteProvisionedTokenQuery, existingToken.Value); err != nil {
  1783. return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
  1784. }
  1785. }
  1786. }
  1787. // (Re-)add provisioned tokens
  1788. for username, tokens := range a.config.Tokens {
  1789. if !slices.Contains(provisionUsernames, username) && username != Everyone {
  1790. return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
  1791. }
  1792. var userID string
  1793. row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
  1794. if err := row.Scan(&userID); err != nil {
  1795. return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
  1796. }
  1797. for _, token := range tokens {
  1798. if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
  1799. return err
  1800. }
  1801. }
  1802. }
  1803. return nil
  1804. }
  1805. // toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
  1806. // and escapes '_', assuming '\' as escape character.
  1807. func toSQLWildcard(s string) string {
  1808. return escapeUnderscore(strings.ReplaceAll(s, "*", "%"))
  1809. }
  1810. // fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*',
  1811. // and removes the '\_' escape character.
  1812. func fromSQLWildcard(s string) string {
  1813. return strings.ReplaceAll(unescapeUnderscore(s), "%", "*")
  1814. }
  1815. func escapeUnderscore(s string) string {
  1816. return strings.ReplaceAll(s, "_", "\\_")
  1817. }
  1818. func unescapeUnderscore(s string) string {
  1819. return strings.ReplaceAll(s, "\\_", "_")
  1820. }
  1821. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1822. if _, err := db.Exec(startupQueries); err != nil {
  1823. return err
  1824. }
  1825. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1826. return err
  1827. }
  1828. return nil
  1829. }
  1830. func setupDB(db *sql.DB) error {
  1831. // If 'schemaVersion' table does not exist, this must be a new database
  1832. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1833. if err != nil {
  1834. return setupNewDB(db)
  1835. }
  1836. defer rowsSV.Close()
  1837. // If 'schemaVersion' table exists, read version and potentially upgrade
  1838. schemaVersion := 0
  1839. if !rowsSV.Next() {
  1840. return errors.New("cannot determine schema version: database file may be corrupt")
  1841. }
  1842. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1843. return err
  1844. }
  1845. rowsSV.Close()
  1846. // Do migrations
  1847. if schemaVersion == currentSchemaVersion {
  1848. return nil
  1849. } else if schemaVersion > currentSchemaVersion {
  1850. return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
  1851. }
  1852. for i := schemaVersion; i < currentSchemaVersion; i++ {
  1853. fn, ok := migrations[i]
  1854. if !ok {
  1855. return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
  1856. } else if err := fn(db); err != nil {
  1857. return err
  1858. }
  1859. }
  1860. return nil
  1861. }
  1862. func setupNewDB(db *sql.DB) error {
  1863. if _, err := db.Exec(createTablesQueries); err != nil {
  1864. return err
  1865. }
  1866. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1867. return err
  1868. }
  1869. return nil
  1870. }
  1871. func migrateFrom1(db *sql.DB) error {
  1872. log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
  1873. tx, err := db.Begin()
  1874. if err != nil {
  1875. return err
  1876. }
  1877. defer tx.Rollback()
  1878. // Rename user -> user_old, and create new tables
  1879. if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
  1880. return err
  1881. }
  1882. // Insert users from user_old into new user table, with ID and sync_topic
  1883. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1884. if err != nil {
  1885. return err
  1886. }
  1887. defer rows.Close()
  1888. usernames := make([]string, 0)
  1889. for rows.Next() {
  1890. var username string
  1891. if err := rows.Scan(&username); err != nil {
  1892. return err
  1893. }
  1894. usernames = append(usernames, username)
  1895. }
  1896. if err := rows.Close(); err != nil {
  1897. return err
  1898. }
  1899. for _, username := range usernames {
  1900. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1901. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1902. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1903. return err
  1904. }
  1905. }
  1906. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1907. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1908. return err
  1909. }
  1910. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1911. return err
  1912. }
  1913. if err := tx.Commit(); err != nil {
  1914. return err
  1915. }
  1916. return nil
  1917. }
  1918. func migrateFrom2(db *sql.DB) error {
  1919. log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
  1920. tx, err := db.Begin()
  1921. if err != nil {
  1922. return err
  1923. }
  1924. defer tx.Rollback()
  1925. if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
  1926. return err
  1927. }
  1928. if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
  1929. return err
  1930. }
  1931. return tx.Commit()
  1932. }
  1933. func migrateFrom3(db *sql.DB) error {
  1934. log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
  1935. tx, err := db.Begin()
  1936. if err != nil {
  1937. return err
  1938. }
  1939. defer tx.Rollback()
  1940. if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
  1941. return err
  1942. }
  1943. if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
  1944. return err
  1945. }
  1946. return tx.Commit()
  1947. }
  1948. func migrateFrom4(db *sql.DB) error {
  1949. log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
  1950. tx, err := db.Begin()
  1951. if err != nil {
  1952. return err
  1953. }
  1954. defer tx.Rollback()
  1955. if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
  1956. return err
  1957. }
  1958. if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
  1959. return err
  1960. }
  1961. return tx.Commit()
  1962. }
  1963. func migrateFrom5(db *sql.DB) error {
  1964. log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
  1965. tx, err := db.Begin()
  1966. if err != nil {
  1967. return err
  1968. }
  1969. defer tx.Rollback()
  1970. if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
  1971. return err
  1972. }
  1973. if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
  1974. return err
  1975. }
  1976. return tx.Commit()
  1977. }
  1978. func nullString(s string) sql.NullString {
  1979. if s == "" {
  1980. return sql.NullString{}
  1981. }
  1982. return sql.NullString{String: s, Valid: true}
  1983. }
  1984. func nullInt64(v int64) sql.NullInt64 {
  1985. if v == 0 {
  1986. return sql.NullInt64{}
  1987. }
  1988. return sql.NullInt64{Int64: v, Valid: true}
  1989. }
  1990. // execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
  1991. func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
  1992. tx, err := db.Begin()
  1993. if err != nil {
  1994. return err
  1995. }
  1996. defer tx.Rollback()
  1997. if err := f(tx); err != nil {
  1998. return err
  1999. }
  2000. return tx.Commit()
  2001. }
  2002. // queryTx executes a function in a transaction and returns the result. If the function
  2003. // returns an error, the transaction is rolled back.
  2004. func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
  2005. tx, err := db.Begin()
  2006. if err != nil {
  2007. var zero T
  2008. return zero, err
  2009. }
  2010. defer tx.Rollback()
  2011. t, err := f(tx)
  2012. if err != nil {
  2013. return t, err
  2014. }
  2015. if err := tx.Commit(); err != nil {
  2016. return t, err
  2017. }
  2018. return t, nil
  2019. }