manager.go 72 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111
  1. // Package user deals with authentication and authorization against topics
  2. package user
  3. import (
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/mattn/go-sqlite3"
  9. "github.com/stripe/stripe-go/v74"
  10. "golang.org/x/crypto/bcrypt"
  11. "heckel.io/ntfy/v2/log"
  12. "heckel.io/ntfy/v2/util"
  13. "net/netip"
  14. "path/filepath"
  15. "slices"
  16. "strings"
  17. "sync"
  18. "time"
  19. )
  20. const (
  21. tierIDPrefix = "ti_"
  22. tierIDLength = 8
  23. syncTopicPrefix = "st_"
  24. syncTopicLength = 16
  25. userIDPrefix = "u_"
  26. userIDLength = 12
  27. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
  28. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  29. tokenPrefix = "tk_"
  30. tokenLength = 32
  31. tokenMaxCount = 60 // Only keep this many tokens in the table per user
  32. tag = "user_manager"
  33. )
  34. // Default constants that may be overridden by configs
  35. const (
  36. DefaultUserStatsQueueWriterInterval = 33 * time.Second
  37. DefaultUserPasswordBcryptCost = 10
  38. )
  39. var (
  40. errNoTokenProvided = errors.New("no token provided")
  41. errTopicOwnedByOthers = errors.New("topic owned by others")
  42. errNoRows = errors.New("no rows found")
  43. )
  44. // Manager-related queries
  45. const (
  46. createTablesQueries = `
  47. BEGIN;
  48. CREATE TABLE IF NOT EXISTS tier (
  49. id TEXT PRIMARY KEY,
  50. code TEXT NOT NULL,
  51. name TEXT NOT NULL,
  52. messages_limit INT NOT NULL,
  53. messages_expiry_duration INT NOT NULL,
  54. emails_limit INT NOT NULL,
  55. calls_limit INT NOT NULL,
  56. reservations_limit INT NOT NULL,
  57. attachment_file_size_limit INT NOT NULL,
  58. attachment_total_size_limit INT NOT NULL,
  59. attachment_expiry_duration INT NOT NULL,
  60. attachment_bandwidth_limit INT NOT NULL,
  61. stripe_monthly_price_id TEXT,
  62. stripe_yearly_price_id TEXT
  63. );
  64. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  65. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  66. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  67. CREATE TABLE IF NOT EXISTS user (
  68. id TEXT PRIMARY KEY,
  69. tier_id TEXT,
  70. user TEXT NOT NULL,
  71. pass TEXT NOT NULL,
  72. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  73. prefs JSON NOT NULL DEFAULT '{}',
  74. sync_topic TEXT NOT NULL,
  75. provisioned INT NOT NULL,
  76. stats_messages INT NOT NULL DEFAULT (0),
  77. stats_emails INT NOT NULL DEFAULT (0),
  78. stats_calls INT NOT NULL DEFAULT (0),
  79. stripe_customer_id TEXT,
  80. stripe_subscription_id TEXT,
  81. stripe_subscription_status TEXT,
  82. stripe_subscription_interval TEXT,
  83. stripe_subscription_paid_until INT,
  84. stripe_subscription_cancel_at INT,
  85. created INT NOT NULL,
  86. deleted INT,
  87. FOREIGN KEY (tier_id) REFERENCES tier (id)
  88. );
  89. CREATE UNIQUE INDEX idx_user ON user (user);
  90. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  91. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  92. CREATE TABLE IF NOT EXISTS user_access (
  93. user_id TEXT NOT NULL,
  94. topic TEXT NOT NULL,
  95. read INT NOT NULL,
  96. write INT NOT NULL,
  97. owner_user_id INT,
  98. provisioned INT NOT NULL,
  99. PRIMARY KEY (user_id, topic),
  100. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  101. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  102. );
  103. CREATE TABLE IF NOT EXISTS user_token (
  104. user_id TEXT NOT NULL,
  105. token TEXT NOT NULL,
  106. label TEXT NOT NULL,
  107. last_access INT NOT NULL,
  108. last_origin TEXT NOT NULL,
  109. expires INT NOT NULL,
  110. provisioned INT NOT NULL,
  111. PRIMARY KEY (user_id, token),
  112. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  113. );
  114. CREATE UNIQUE INDEX idx_user_token ON user_token (token);
  115. CREATE TABLE IF NOT EXISTS user_phone (
  116. user_id TEXT NOT NULL,
  117. phone_number TEXT NOT NULL,
  118. PRIMARY KEY (user_id, phone_number),
  119. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  120. );
  121. CREATE TABLE IF NOT EXISTS schemaVersion (
  122. id INT PRIMARY KEY,
  123. version INT NOT NULL
  124. );
  125. INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
  126. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', false, UNIXEPOCH())
  127. ON CONFLICT (id) DO NOTHING;
  128. COMMIT;
  129. `
  130. builtinStartupQueries = `
  131. PRAGMA foreign_keys = ON;
  132. `
  133. selectUserByIDQuery = `
  134. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  135. FROM user u
  136. LEFT JOIN tier t on t.id = u.tier_id
  137. WHERE u.id = ?
  138. `
  139. selectUserByNameQuery = `
  140. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  141. FROM user u
  142. LEFT JOIN tier t on t.id = u.tier_id
  143. WHERE user = ?
  144. `
  145. selectUserByTokenQuery = `
  146. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  147. FROM user u
  148. JOIN user_token tk on u.id = tk.user_id
  149. LEFT JOIN tier t on t.id = u.tier_id
  150. WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
  151. `
  152. selectUserByStripeCustomerIDQuery = `
  153. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.provisioned, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  154. FROM user u
  155. LEFT JOIN tier t on t.id = u.tier_id
  156. WHERE u.stripe_customer_id = ?
  157. `
  158. selectTopicPermsQuery = `
  159. SELECT read, write
  160. FROM user_access a
  161. JOIN user u ON u.id = a.user_id
  162. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
  163. ORDER BY u.user DESC, LENGTH(a.topic) DESC, a.write DESC
  164. `
  165. insertUserQuery = `
  166. INSERT INTO user (id, user, pass, role, sync_topic, provisioned, created)
  167. VALUES (?, ?, ?, ?, ?, ?, ?)
  168. `
  169. selectUsernamesQuery = `
  170. SELECT user
  171. FROM user
  172. ORDER BY
  173. CASE role
  174. WHEN 'admin' THEN 1
  175. WHEN 'anonymous' THEN 3
  176. ELSE 2
  177. END, user
  178. `
  179. selectUserCountQuery = `SELECT COUNT(*) FROM user`
  180. selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
  181. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  182. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  183. updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
  184. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
  185. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
  186. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
  187. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  188. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  189. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  190. upsertUserAccessQuery = `
  191. INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
  192. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))), ?)
  193. ON CONFLICT (user_id, topic)
  194. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id, provisioned=excluded.provisioned
  195. `
  196. selectUserAllAccessQuery = `
  197. SELECT user_id, topic, read, write, provisioned
  198. FROM user_access
  199. ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
  200. `
  201. selectUserAccessQuery = `
  202. SELECT topic, read, write, provisioned
  203. FROM user_access
  204. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  205. ORDER BY LENGTH(topic) DESC, write DESC, read DESC, topic
  206. `
  207. selectUserReservationsQuery = `
  208. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  209. FROM user_access a_user
  210. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  211. WHERE a_user.user_id = a_user.owner_user_id
  212. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  213. ORDER BY a_user.topic
  214. `
  215. selectUserReservationsCountQuery = `
  216. SELECT COUNT(*)
  217. FROM user_access
  218. WHERE user_id = owner_user_id
  219. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  220. `
  221. selectUserReservationsOwnerQuery = `
  222. SELECT owner_user_id
  223. FROM user_access
  224. WHERE topic = ?
  225. AND user_id = owner_user_id
  226. `
  227. selectUserHasReservationQuery = `
  228. SELECT COUNT(*)
  229. FROM user_access
  230. WHERE user_id = owner_user_id
  231. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  232. AND topic = ?
  233. `
  234. selectOtherAccessCountQuery = `
  235. SELECT COUNT(*)
  236. FROM user_access
  237. WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
  238. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  239. `
  240. deleteAllAccessQuery = `DELETE FROM user_access`
  241. deleteUserAccessQuery = `
  242. DELETE FROM user_access
  243. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  244. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  245. `
  246. deleteUserAccessProvisionedQuery = `DELETE FROM user_access WHERE provisioned = 1`
  247. deleteTopicAccessQuery = `
  248. DELETE FROM user_access
  249. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  250. AND topic = ?
  251. `
  252. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  253. selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
  254. selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
  255. selectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
  256. upsertTokenQuery = `
  257. INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
  258. VALUES (?, ?, ?, ?, ?, ?, ?)
  259. ON CONFLICT (user_id, token)
  260. DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
  261. `
  262. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  263. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  264. updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
  265. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  266. deleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
  267. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  268. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  269. deleteExcessTokensQuery = `
  270. DELETE FROM user_token
  271. WHERE user_id = ?
  272. AND (user_id, token) NOT IN (
  273. SELECT user_id, token
  274. FROM user_token
  275. WHERE user_id = ?
  276. ORDER BY expires DESC
  277. LIMIT ?
  278. )
  279. `
  280. selectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?`
  281. insertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
  282. deletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
  283. insertTierQuery = `
  284. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
  285. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  286. `
  287. updateTierQuery = `
  288. UPDATE tier
  289. SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
  290. WHERE code = ?
  291. `
  292. selectTiersQuery = `
  293. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  294. FROM tier
  295. `
  296. selectTierByCodeQuery = `
  297. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  298. FROM tier
  299. WHERE code = ?
  300. `
  301. selectTierByPriceIDQuery = `
  302. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  303. FROM tier
  304. WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
  305. `
  306. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  307. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  308. deleteTierQuery = `DELETE FROM tier WHERE code = ?`
  309. updateBillingQuery = `
  310. UPDATE user
  311. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  312. WHERE user = ?
  313. `
  314. )
  315. // Schema management queries
  316. const (
  317. currentSchemaVersion = 6
  318. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  319. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  320. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  321. // 1 -> 2 (complex migration!)
  322. migrate1To2CreateTablesQueries = `
  323. ALTER TABLE user RENAME TO user_old;
  324. CREATE TABLE IF NOT EXISTS tier (
  325. id TEXT PRIMARY KEY,
  326. code TEXT NOT NULL,
  327. name TEXT NOT NULL,
  328. messages_limit INT NOT NULL,
  329. messages_expiry_duration INT NOT NULL,
  330. emails_limit INT NOT NULL,
  331. reservations_limit INT NOT NULL,
  332. attachment_file_size_limit INT NOT NULL,
  333. attachment_total_size_limit INT NOT NULL,
  334. attachment_expiry_duration INT NOT NULL,
  335. attachment_bandwidth_limit INT NOT NULL,
  336. stripe_price_id TEXT
  337. );
  338. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  339. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  340. CREATE TABLE IF NOT EXISTS user (
  341. id TEXT PRIMARY KEY,
  342. tier_id TEXT,
  343. user TEXT NOT NULL,
  344. pass TEXT NOT NULL,
  345. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  346. prefs JSON NOT NULL DEFAULT '{}',
  347. sync_topic TEXT NOT NULL,
  348. stats_messages INT NOT NULL DEFAULT (0),
  349. stats_emails INT NOT NULL DEFAULT (0),
  350. stripe_customer_id TEXT,
  351. stripe_subscription_id TEXT,
  352. stripe_subscription_status TEXT,
  353. stripe_subscription_paid_until INT,
  354. stripe_subscription_cancel_at INT,
  355. created INT NOT NULL,
  356. deleted INT,
  357. FOREIGN KEY (tier_id) REFERENCES tier (id)
  358. );
  359. CREATE UNIQUE INDEX idx_user ON user (user);
  360. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  361. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  362. CREATE TABLE IF NOT EXISTS user_access (
  363. user_id TEXT NOT NULL,
  364. topic TEXT NOT NULL,
  365. read INT NOT NULL,
  366. write INT NOT NULL,
  367. owner_user_id INT,
  368. PRIMARY KEY (user_id, topic),
  369. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  370. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  371. );
  372. CREATE TABLE IF NOT EXISTS user_token (
  373. user_id TEXT NOT NULL,
  374. token TEXT NOT NULL,
  375. label TEXT NOT NULL,
  376. last_access INT NOT NULL,
  377. last_origin TEXT NOT NULL,
  378. expires INT NOT NULL,
  379. PRIMARY KEY (user_id, token),
  380. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  381. );
  382. CREATE TABLE IF NOT EXISTS schemaVersion (
  383. id INT PRIMARY KEY,
  384. version INT NOT NULL
  385. );
  386. INSERT INTO user (id, user, pass, role, sync_topic, created)
  387. VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
  388. ON CONFLICT (id) DO NOTHING;
  389. `
  390. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  391. migrate1To2InsertUserNoTx = `
  392. INSERT INTO user (id, user, pass, role, sync_topic, created)
  393. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  394. `
  395. migrate1To2InsertFromOldTablesAndDropNoTx = `
  396. INSERT INTO user_access (user_id, topic, read, write)
  397. SELECT u.id, a.topic, a.read, a.write
  398. FROM user u
  399. JOIN access a ON u.user = a.user;
  400. DROP TABLE access;
  401. DROP TABLE user_old;
  402. `
  403. // 2 -> 3
  404. migrate2To3UpdateQueries = `
  405. ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
  406. ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
  407. ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
  408. DROP INDEX IF EXISTS idx_tier_price_id;
  409. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  410. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  411. `
  412. // 3 -> 4
  413. migrate3To4UpdateQueries = `
  414. ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
  415. ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
  416. CREATE TABLE IF NOT EXISTS user_phone (
  417. user_id TEXT NOT NULL,
  418. phone_number TEXT NOT NULL,
  419. PRIMARY KEY (user_id, phone_number),
  420. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  421. );
  422. `
  423. // 4 -> 5
  424. migrate4To5UpdateQueries = `
  425. UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
  426. `
  427. // 5 -> 6
  428. migrate5To6UpdateQueries = `
  429. PRAGMA foreign_keys=off;
  430. -- Alter user table: Add provisioned column
  431. ALTER TABLE user RENAME TO user_old;
  432. CREATE TABLE IF NOT EXISTS user (
  433. id TEXT PRIMARY KEY,
  434. tier_id TEXT,
  435. user TEXT NOT NULL,
  436. pass TEXT NOT NULL,
  437. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  438. prefs JSON NOT NULL DEFAULT '{}',
  439. sync_topic TEXT NOT NULL,
  440. provisioned INT NOT NULL,
  441. stats_messages INT NOT NULL DEFAULT (0),
  442. stats_emails INT NOT NULL DEFAULT (0),
  443. stats_calls INT NOT NULL DEFAULT (0),
  444. stripe_customer_id TEXT,
  445. stripe_subscription_id TEXT,
  446. stripe_subscription_status TEXT,
  447. stripe_subscription_interval TEXT,
  448. stripe_subscription_paid_until INT,
  449. stripe_subscription_cancel_at INT,
  450. created INT NOT NULL,
  451. deleted INT,
  452. FOREIGN KEY (tier_id) REFERENCES tier (id)
  453. );
  454. INSERT INTO user
  455. SELECT
  456. id,
  457. tier_id,
  458. user,
  459. pass,
  460. role,
  461. prefs,
  462. sync_topic,
  463. 0, -- provisioned
  464. stats_messages,
  465. stats_emails,
  466. stats_calls,
  467. stripe_customer_id,
  468. stripe_subscription_id,
  469. stripe_subscription_status,
  470. stripe_subscription_interval,
  471. stripe_subscription_paid_until,
  472. stripe_subscription_cancel_at,
  473. created,
  474. deleted
  475. FROM user_old;
  476. DROP TABLE user_old;
  477. -- Alter user_access table: Add provisioned column
  478. ALTER TABLE user_access RENAME TO user_access_old;
  479. CREATE TABLE user_access (
  480. user_id TEXT NOT NULL,
  481. topic TEXT NOT NULL,
  482. read INT NOT NULL,
  483. write INT NOT NULL,
  484. owner_user_id INT,
  485. provisioned INTEGER NOT NULL,
  486. PRIMARY KEY (user_id, topic),
  487. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  488. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  489. );
  490. INSERT INTO user_access SELECT *, 0 FROM user_access_old;
  491. DROP TABLE user_access_old;
  492. -- Alter user_token table: Add provisioned column
  493. ALTER TABLE user_token RENAME TO user_token_old;
  494. CREATE TABLE IF NOT EXISTS user_token (
  495. user_id TEXT NOT NULL,
  496. token TEXT NOT NULL,
  497. label TEXT NOT NULL,
  498. last_access INT NOT NULL,
  499. last_origin TEXT NOT NULL,
  500. expires INT NOT NULL,
  501. provisioned INT NOT NULL,
  502. PRIMARY KEY (user_id, token),
  503. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  504. );
  505. INSERT INTO user_token SELECT *, 0 FROM user_token_old;
  506. DROP TABLE user_token_old;
  507. -- Recreate indices
  508. CREATE UNIQUE INDEX idx_user ON user (user);
  509. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  510. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  511. CREATE UNIQUE INDEX idx_user_token ON user_token (token);
  512. -- Re-enable foreign keys
  513. PRAGMA foreign_keys=on;
  514. `
  515. )
  516. var (
  517. migrations = map[int]func(db *sql.DB) error{
  518. 1: migrateFrom1,
  519. 2: migrateFrom2,
  520. 3: migrateFrom3,
  521. 4: migrateFrom4,
  522. 5: migrateFrom5,
  523. }
  524. )
  525. // Manager is an implementation of Manager. It stores users and access control list
  526. // in a SQLite database.
  527. type Manager struct {
  528. config *Config
  529. db *sql.DB
  530. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  531. tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
  532. mu sync.Mutex
  533. }
  534. // Config holds the configuration for the user Manager
  535. type Config struct {
  536. Filename string // Database filename, e.g. "/var/lib/ntfy/user.db"
  537. StartupQueries string // Queries to run on startup, e.g. to create initial users or tiers
  538. DefaultAccess Permission // Default permission if no ACL matches
  539. ProvisionEnabled bool // Hack: Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
  540. Users []*User // Predefined users to create on startup
  541. Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
  542. Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
  543. QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
  544. BcryptCost int // Cost of generated passwords; lowering makes testing faster
  545. }
  546. var _ Auther = (*Manager)(nil)
  547. // NewManager creates a new Manager instance
  548. func NewManager(config *Config) (*Manager, error) {
  549. // Set defaults
  550. if config.BcryptCost <= 0 {
  551. config.BcryptCost = DefaultUserPasswordBcryptCost
  552. }
  553. if config.QueueWriterInterval.Seconds() <= 0 {
  554. config.QueueWriterInterval = DefaultUserStatsQueueWriterInterval
  555. }
  556. // Check the parent directory of the database file (makes for friendly error messages)
  557. parentDir := filepath.Dir(config.Filename)
  558. if !util.FileExists(parentDir) {
  559. return nil, fmt.Errorf("user database directory %s does not exist or is not accessible", parentDir)
  560. }
  561. // Open DB and run setup queries
  562. db, err := sql.Open("sqlite3", config.Filename)
  563. if err != nil {
  564. return nil, err
  565. }
  566. if err := setupDB(db); err != nil {
  567. return nil, err
  568. }
  569. if err := runStartupQueries(db, config.StartupQueries); err != nil {
  570. return nil, err
  571. }
  572. manager := &Manager{
  573. db: db,
  574. config: config,
  575. statsQueue: make(map[string]*Stats),
  576. tokenQueue: make(map[string]*TokenUpdate),
  577. }
  578. if err := manager.maybeProvisionUsersAccessAndTokens(); err != nil {
  579. return nil, err
  580. }
  581. go manager.asyncQueueWriter(config.QueueWriterInterval)
  582. return manager, nil
  583. }
  584. // Authenticate checks username and password and returns a User if correct, and the user has not been
  585. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  586. // the password is correct or incorrect.
  587. func (a *Manager) Authenticate(username, password string) (*User, error) {
  588. if username == Everyone {
  589. return nil, ErrUnauthenticated
  590. }
  591. user, err := a.User(username)
  592. if err != nil {
  593. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
  594. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  595. return nil, ErrUnauthenticated
  596. } else if user.Deleted {
  597. log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
  598. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  599. return nil, ErrUnauthenticated
  600. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  601. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
  602. return nil, ErrUnauthenticated
  603. }
  604. return user, nil
  605. }
  606. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  607. // The method sets the User.Token value to the token that was used for authentication.
  608. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  609. if len(token) != tokenLength {
  610. return nil, ErrUnauthenticated
  611. }
  612. user, err := a.userByToken(token)
  613. if err != nil {
  614. log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
  615. return nil, ErrUnauthenticated
  616. }
  617. user.Token = token
  618. return user, nil
  619. }
  620. // CreateToken generates a random token for the given user and returns it. The token expires
  621. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  622. // given user, if there are too many of them.
  623. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
  624. return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
  625. return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned)
  626. })
  627. }
  628. func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
  629. access := time.Now()
  630. if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil {
  631. return nil, err
  632. }
  633. rows, err := tx.Query(selectTokenCountQuery, userID)
  634. if err != nil {
  635. return nil, err
  636. }
  637. defer rows.Close()
  638. if !rows.Next() {
  639. return nil, errNoRows
  640. }
  641. var tokenCount int
  642. if err := rows.Scan(&tokenCount); err != nil {
  643. return nil, err
  644. }
  645. if tokenCount >= tokenMaxCount {
  646. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  647. // on two indices, whereas the query below is a full table scan.
  648. if _, err := tx.Exec(deleteExcessTokensQuery, userID, userID, tokenMaxCount); err != nil {
  649. return nil, err
  650. }
  651. }
  652. return &Token{
  653. Value: token,
  654. Label: label,
  655. LastAccess: access,
  656. LastOrigin: origin,
  657. Expires: expires,
  658. Provisioned: provisioned,
  659. }, nil
  660. }
  661. // Tokens returns all existing tokens for the user with the given user ID
  662. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  663. rows, err := a.db.Query(selectTokensQuery, userID)
  664. if err != nil {
  665. return nil, err
  666. }
  667. defer rows.Close()
  668. tokens := make([]*Token, 0)
  669. for {
  670. token, err := a.readToken(rows)
  671. if errors.Is(err, ErrTokenNotFound) {
  672. break
  673. } else if err != nil {
  674. return nil, err
  675. }
  676. tokens = append(tokens, token)
  677. }
  678. return tokens, nil
  679. }
  680. func (a *Manager) allProvisionedTokens() ([]*Token, error) {
  681. rows, err := a.db.Query(selectAllProvisionedTokensQuery)
  682. if err != nil {
  683. return nil, err
  684. }
  685. defer rows.Close()
  686. tokens := make([]*Token, 0)
  687. for {
  688. token, err := a.readToken(rows)
  689. if errors.Is(err, ErrTokenNotFound) {
  690. break
  691. } else if err != nil {
  692. return nil, err
  693. }
  694. tokens = append(tokens, token)
  695. }
  696. return tokens, nil
  697. }
  698. // Token returns a specific token for a user
  699. func (a *Manager) Token(userID, token string) (*Token, error) {
  700. rows, err := a.db.Query(selectTokenQuery, userID, token)
  701. if err != nil {
  702. return nil, err
  703. }
  704. defer rows.Close()
  705. return a.readToken(rows)
  706. }
  707. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  708. var token, label, lastOrigin string
  709. var lastAccess, expires int64
  710. var provisioned bool
  711. if !rows.Next() {
  712. return nil, ErrTokenNotFound
  713. }
  714. if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
  715. return nil, err
  716. } else if err := rows.Err(); err != nil {
  717. return nil, err
  718. }
  719. lastOriginIP, err := netip.ParseAddr(lastOrigin)
  720. if err != nil {
  721. lastOriginIP = netip.IPv4Unspecified()
  722. }
  723. return &Token{
  724. Value: token,
  725. Label: label,
  726. LastAccess: time.Unix(lastAccess, 0),
  727. LastOrigin: lastOriginIP,
  728. Expires: time.Unix(expires, 0),
  729. Provisioned: provisioned,
  730. }, nil
  731. }
  732. // ChangeToken updates a token's label and/or expiry date
  733. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  734. if token == "" {
  735. return nil, errNoTokenProvided
  736. }
  737. tx, err := a.db.Begin()
  738. if err != nil {
  739. return nil, err
  740. }
  741. defer tx.Rollback()
  742. if label != nil {
  743. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  744. return nil, err
  745. }
  746. }
  747. if expires != nil {
  748. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  749. return nil, err
  750. }
  751. }
  752. if err := tx.Commit(); err != nil {
  753. return nil, err
  754. }
  755. return a.Token(userID, token)
  756. }
  757. // RemoveToken deletes the token defined in User.Token
  758. func (a *Manager) RemoveToken(userID, token string) error {
  759. return execTx(a.db, func(tx *sql.Tx) error {
  760. return a.removeTokenTx(tx, userID, token)
  761. })
  762. }
  763. func (a *Manager) removeTokenTx(tx *sql.Tx, userID, token string) error {
  764. if token == "" {
  765. return errNoTokenProvided
  766. }
  767. if _, err := tx.Exec(deleteTokenQuery, userID, token); err != nil {
  768. return err
  769. }
  770. return nil
  771. }
  772. // RemoveExpiredTokens deletes all expired tokens from the database
  773. func (a *Manager) RemoveExpiredTokens() error {
  774. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  775. return err
  776. }
  777. return nil
  778. }
  779. // PhoneNumbers returns all phone numbers for the user with the given user ID
  780. func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
  781. rows, err := a.db.Query(selectPhoneNumbersQuery, userID)
  782. if err != nil {
  783. return nil, err
  784. }
  785. defer rows.Close()
  786. phoneNumbers := make([]string, 0)
  787. for {
  788. phoneNumber, err := a.readPhoneNumber(rows)
  789. if errors.Is(err, ErrPhoneNumberNotFound) {
  790. break
  791. } else if err != nil {
  792. return nil, err
  793. }
  794. phoneNumbers = append(phoneNumbers, phoneNumber)
  795. }
  796. return phoneNumbers, nil
  797. }
  798. func (a *Manager) readPhoneNumber(rows *sql.Rows) (string, error) {
  799. var phoneNumber string
  800. if !rows.Next() {
  801. return "", ErrPhoneNumberNotFound
  802. }
  803. if err := rows.Scan(&phoneNumber); err != nil {
  804. return "", err
  805. } else if err := rows.Err(); err != nil {
  806. return "", err
  807. }
  808. return phoneNumber, nil
  809. }
  810. // AddPhoneNumber adds a phone number to the user with the given user ID
  811. func (a *Manager) AddPhoneNumber(userID string, phoneNumber string) error {
  812. if _, err := a.db.Exec(insertPhoneNumberQuery, userID, phoneNumber); err != nil {
  813. if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
  814. return ErrPhoneNumberExists
  815. }
  816. return err
  817. }
  818. return nil
  819. }
  820. // RemovePhoneNumber deletes a phone number from the user with the given user ID
  821. func (a *Manager) RemovePhoneNumber(userID string, phoneNumber string) error {
  822. _, err := a.db.Exec(deletePhoneNumberQuery, userID, phoneNumber)
  823. return err
  824. }
  825. // RemoveDeletedUsers deletes all users that have been marked deleted for
  826. func (a *Manager) RemoveDeletedUsers() error {
  827. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  828. return err
  829. }
  830. return nil
  831. }
  832. // ChangeSettings persists the user settings
  833. func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
  834. b, err := json.Marshal(prefs)
  835. if err != nil {
  836. return err
  837. }
  838. if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil {
  839. return err
  840. }
  841. return nil
  842. }
  843. // ResetStats resets all user stats in the user database. This touches all users.
  844. func (a *Manager) ResetStats() error {
  845. a.mu.Lock() // Includes database query to avoid races!
  846. defer a.mu.Unlock()
  847. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  848. return err
  849. }
  850. a.statsQueue = make(map[string]*Stats)
  851. return nil
  852. }
  853. // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  854. // batches at a regular interval
  855. func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
  856. a.mu.Lock()
  857. defer a.mu.Unlock()
  858. a.statsQueue[userID] = stats
  859. }
  860. // EnqueueTokenUpdate adds the token update to a queue which writes out token access times
  861. // in batches at a regular interval
  862. func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
  863. a.mu.Lock()
  864. defer a.mu.Unlock()
  865. a.tokenQueue[tokenID] = update
  866. }
  867. func (a *Manager) asyncQueueWriter(interval time.Duration) {
  868. ticker := time.NewTicker(interval)
  869. for range ticker.C {
  870. if err := a.writeUserStatsQueue(); err != nil {
  871. log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
  872. }
  873. if err := a.writeTokenUpdateQueue(); err != nil {
  874. log.Tag(tag).Err(err).Warn("Writing token update queue failed")
  875. }
  876. }
  877. }
  878. func (a *Manager) writeUserStatsQueue() error {
  879. a.mu.Lock()
  880. if len(a.statsQueue) == 0 {
  881. a.mu.Unlock()
  882. log.Tag(tag).Trace("No user stats updates to commit")
  883. return nil
  884. }
  885. statsQueue := a.statsQueue
  886. a.statsQueue = make(map[string]*Stats)
  887. a.mu.Unlock()
  888. tx, err := a.db.Begin()
  889. if err != nil {
  890. return err
  891. }
  892. defer tx.Rollback()
  893. log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
  894. for userID, update := range statsQueue {
  895. log.
  896. Tag(tag).
  897. Fields(log.Context{
  898. "user_id": userID,
  899. "messages_count": update.Messages,
  900. "emails_count": update.Emails,
  901. "calls_count": update.Calls,
  902. }).
  903. Trace("Updating stats for user %s", userID)
  904. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, update.Calls, userID); err != nil {
  905. return err
  906. }
  907. }
  908. return tx.Commit()
  909. }
  910. func (a *Manager) writeTokenUpdateQueue() error {
  911. a.mu.Lock()
  912. if len(a.tokenQueue) == 0 {
  913. a.mu.Unlock()
  914. log.Tag(tag).Trace("No token updates to commit")
  915. return nil
  916. }
  917. tokenQueue := a.tokenQueue
  918. a.tokenQueue = make(map[string]*TokenUpdate)
  919. a.mu.Unlock()
  920. tx, err := a.db.Begin()
  921. if err != nil {
  922. return err
  923. }
  924. defer tx.Rollback()
  925. log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
  926. for tokenID, update := range tokenQueue {
  927. log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
  928. if err := a.updateTokenLastAccessTx(tx, tokenID, update.LastAccess.Unix(), update.LastOrigin.String()); err != nil {
  929. return err
  930. }
  931. }
  932. return tx.Commit()
  933. }
  934. func (a *Manager) updateTokenLastAccessTx(tx *sql.Tx, token string, lastAccess int64, lastOrigin string) error {
  935. if _, err := tx.Exec(updateTokenLastAccessQuery, lastAccess, lastOrigin, token); err != nil {
  936. return err
  937. }
  938. return nil
  939. }
  940. // Authorize returns nil if the given user has access to the given topic using the desired
  941. // permission. The user param may be nil to signal an anonymous user.
  942. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  943. if user != nil && user.Role == RoleAdmin {
  944. return nil // Admin can do everything
  945. }
  946. username := Everyone
  947. if user != nil {
  948. username = user.Name
  949. }
  950. // Select the read/write permissions for this user/topic combo.
  951. // - The query may return two rows (one for everyone, and one for the user), but prioritizes the user.
  952. // - Furthermore, the query prioritizes more specific permissions (longer!) over more generic ones, e.g. "test*" > "*"
  953. // - It also prioritizes write permissions over read permissions
  954. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  955. if err != nil {
  956. return err
  957. }
  958. defer rows.Close()
  959. if !rows.Next() {
  960. return a.resolvePerms(a.config.DefaultAccess, perm)
  961. }
  962. var read, write bool
  963. if err := rows.Scan(&read, &write); err != nil {
  964. return err
  965. } else if err := rows.Err(); err != nil {
  966. return err
  967. }
  968. return a.resolvePerms(NewPermission(read, write), perm)
  969. }
  970. func (a *Manager) resolvePerms(base, perm Permission) error {
  971. if perm == PermissionRead && base.IsRead() {
  972. return nil
  973. } else if perm == PermissionWrite && base.IsWrite() {
  974. return nil
  975. }
  976. return ErrUnauthorized
  977. }
  978. // AddUser adds a user with the given username, password and role
  979. func (a *Manager) AddUser(username, password string, role Role, hashed bool) error {
  980. return execTx(a.db, func(tx *sql.Tx) error {
  981. return a.addUserTx(tx, username, password, role, hashed, false)
  982. })
  983. }
  984. // AddUser adds a user with the given username, password and role
  985. func (a *Manager) addUserTx(tx *sql.Tx, username, password string, role Role, hashed, provisioned bool) error {
  986. if !AllowedUsername(username) || !AllowedRole(role) {
  987. return ErrInvalidArgument
  988. }
  989. var hash string
  990. var err error = nil
  991. if hashed {
  992. hash = password
  993. if err := ValidPasswordHash(hash); err != nil {
  994. return err
  995. }
  996. } else {
  997. hash, err = hashPassword(password, a.config.BcryptCost)
  998. if err != nil {
  999. return err
  1000. }
  1001. }
  1002. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1003. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  1004. if _, err = tx.Exec(insertUserQuery, userID, username, hash, role, syncTopic, provisioned, now); err != nil {
  1005. if errors.Is(err, sqlite3.ErrConstraintUnique) {
  1006. return ErrUserExists
  1007. }
  1008. return err
  1009. }
  1010. return nil
  1011. }
  1012. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  1013. // if the user did not exist in the first place.
  1014. func (a *Manager) RemoveUser(username string) error {
  1015. return execTx(a.db, func(tx *sql.Tx) error {
  1016. return a.removeUserTx(tx, username)
  1017. })
  1018. }
  1019. func (a *Manager) removeUserTx(tx *sql.Tx, username string) error {
  1020. if !AllowedUsername(username) {
  1021. return ErrInvalidArgument
  1022. }
  1023. // Rows in user_access, user_token, etc. are deleted via foreign keys
  1024. if _, err := tx.Exec(deleteUserQuery, username); err != nil {
  1025. return err
  1026. }
  1027. return nil
  1028. }
  1029. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  1030. // successful auth via Authenticate. A background process will delete the user at a later date.
  1031. func (a *Manager) MarkUserRemoved(user *User) error {
  1032. if !AllowedUsername(user.Name) {
  1033. return ErrInvalidArgument
  1034. }
  1035. tx, err := a.db.Begin()
  1036. if err != nil {
  1037. return err
  1038. }
  1039. defer tx.Rollback()
  1040. if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  1041. return err
  1042. }
  1043. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  1044. return err
  1045. }
  1046. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  1047. return err
  1048. }
  1049. return tx.Commit()
  1050. }
  1051. // Users returns a list of users. It always also returns the Everyone user ("*").
  1052. func (a *Manager) Users() ([]*User, error) {
  1053. rows, err := a.db.Query(selectUsernamesQuery)
  1054. if err != nil {
  1055. return nil, err
  1056. }
  1057. defer rows.Close()
  1058. usernames := make([]string, 0)
  1059. for rows.Next() {
  1060. var username string
  1061. if err := rows.Scan(&username); err != nil {
  1062. return nil, err
  1063. } else if err := rows.Err(); err != nil {
  1064. return nil, err
  1065. }
  1066. usernames = append(usernames, username)
  1067. }
  1068. rows.Close()
  1069. users := make([]*User, 0)
  1070. for _, username := range usernames {
  1071. user, err := a.User(username)
  1072. if err != nil {
  1073. return nil, err
  1074. }
  1075. users = append(users, user)
  1076. }
  1077. return users, nil
  1078. }
  1079. // UsersCount returns the number of users in the databsae
  1080. func (a *Manager) UsersCount() (int64, error) {
  1081. rows, err := a.db.Query(selectUserCountQuery)
  1082. if err != nil {
  1083. return 0, err
  1084. }
  1085. defer rows.Close()
  1086. if !rows.Next() {
  1087. return 0, errNoRows
  1088. }
  1089. var count int64
  1090. if err := rows.Scan(&count); err != nil {
  1091. return 0, err
  1092. }
  1093. return count, nil
  1094. }
  1095. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  1096. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  1097. func (a *Manager) User(username string) (*User, error) {
  1098. rows, err := a.db.Query(selectUserByNameQuery, username)
  1099. if err != nil {
  1100. return nil, err
  1101. }
  1102. return a.readUser(rows)
  1103. }
  1104. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  1105. func (a *Manager) UserByID(id string) (*User, error) {
  1106. rows, err := a.db.Query(selectUserByIDQuery, id)
  1107. if err != nil {
  1108. return nil, err
  1109. }
  1110. return a.readUser(rows)
  1111. }
  1112. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  1113. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  1114. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  1115. if err != nil {
  1116. return nil, err
  1117. }
  1118. return a.readUser(rows)
  1119. }
  1120. func (a *Manager) userByToken(token string) (*User, error) {
  1121. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  1122. if err != nil {
  1123. return nil, err
  1124. }
  1125. return a.readUser(rows)
  1126. }
  1127. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  1128. defer rows.Close()
  1129. var id, username, hash, role, prefs, syncTopic string
  1130. var provisioned bool
  1131. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
  1132. var messages, emails, calls int64
  1133. var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  1134. if !rows.Next() {
  1135. return nil, ErrUserNotFound
  1136. }
  1137. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &provisioned, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  1138. return nil, err
  1139. } else if err := rows.Err(); err != nil {
  1140. return nil, err
  1141. }
  1142. user := &User{
  1143. ID: id,
  1144. Name: username,
  1145. Hash: hash,
  1146. Role: Role(role),
  1147. Prefs: &Prefs{},
  1148. SyncTopic: syncTopic,
  1149. Provisioned: provisioned,
  1150. Stats: &Stats{
  1151. Messages: messages,
  1152. Emails: emails,
  1153. Calls: calls,
  1154. },
  1155. Billing: &Billing{
  1156. StripeCustomerID: stripeCustomerID.String, // May be empty
  1157. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  1158. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  1159. StripeSubscriptionInterval: stripe.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
  1160. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  1161. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  1162. },
  1163. Deleted: deleted.Valid,
  1164. }
  1165. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  1166. return nil, err
  1167. }
  1168. if tierCode.Valid {
  1169. // See readTier() when this is changed!
  1170. user.Tier = &Tier{
  1171. ID: tierID.String,
  1172. Code: tierCode.String,
  1173. Name: tierName.String,
  1174. MessageLimit: messagesLimit.Int64,
  1175. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1176. EmailLimit: emailsLimit.Int64,
  1177. CallLimit: callsLimit.Int64,
  1178. ReservationLimit: reservationsLimit.Int64,
  1179. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1180. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1181. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1182. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1183. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  1184. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  1185. }
  1186. }
  1187. return user, nil
  1188. }
  1189. // AllGrants returns all user-specific access control entries, mapped to their respective user IDs
  1190. func (a *Manager) AllGrants() (map[string][]Grant, error) {
  1191. rows, err := a.db.Query(selectUserAllAccessQuery)
  1192. if err != nil {
  1193. return nil, err
  1194. }
  1195. defer rows.Close()
  1196. grants := make(map[string][]Grant, 0)
  1197. for rows.Next() {
  1198. var userID, topic string
  1199. var read, write, provisioned bool
  1200. if err := rows.Scan(&userID, &topic, &read, &write, &provisioned); err != nil {
  1201. return nil, err
  1202. } else if err := rows.Err(); err != nil {
  1203. return nil, err
  1204. }
  1205. if _, ok := grants[userID]; !ok {
  1206. grants[userID] = make([]Grant, 0)
  1207. }
  1208. grants[userID] = append(grants[userID], Grant{
  1209. TopicPattern: fromSQLWildcard(topic),
  1210. Permission: NewPermission(read, write),
  1211. Provisioned: provisioned,
  1212. })
  1213. }
  1214. return grants, nil
  1215. }
  1216. // Grants returns all user-specific access control entries
  1217. func (a *Manager) Grants(username string) ([]Grant, error) {
  1218. rows, err := a.db.Query(selectUserAccessQuery, username)
  1219. if err != nil {
  1220. return nil, err
  1221. }
  1222. defer rows.Close()
  1223. grants := make([]Grant, 0)
  1224. for rows.Next() {
  1225. var topic string
  1226. var read, write, provisioned bool
  1227. if err := rows.Scan(&topic, &read, &write, &provisioned); err != nil {
  1228. return nil, err
  1229. } else if err := rows.Err(); err != nil {
  1230. return nil, err
  1231. }
  1232. grants = append(grants, Grant{
  1233. TopicPattern: fromSQLWildcard(topic),
  1234. Permission: NewPermission(read, write),
  1235. Provisioned: provisioned,
  1236. })
  1237. }
  1238. return grants, nil
  1239. }
  1240. // Reservations returns all user-owned topics, and the associated everyone-access
  1241. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  1242. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  1243. if err != nil {
  1244. return nil, err
  1245. }
  1246. defer rows.Close()
  1247. reservations := make([]Reservation, 0)
  1248. for rows.Next() {
  1249. var topic string
  1250. var ownerRead, ownerWrite bool
  1251. var everyoneRead, everyoneWrite sql.NullBool
  1252. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  1253. return nil, err
  1254. } else if err := rows.Err(); err != nil {
  1255. return nil, err
  1256. }
  1257. reservations = append(reservations, Reservation{
  1258. Topic: unescapeUnderscore(topic),
  1259. Owner: NewPermission(ownerRead, ownerWrite),
  1260. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  1261. })
  1262. }
  1263. return reservations, nil
  1264. }
  1265. // HasReservation returns true if the given topic access is owned by the user
  1266. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  1267. rows, err := a.db.Query(selectUserHasReservationQuery, username, escapeUnderscore(topic))
  1268. if err != nil {
  1269. return false, err
  1270. }
  1271. defer rows.Close()
  1272. if !rows.Next() {
  1273. return false, errNoRows
  1274. }
  1275. var count int64
  1276. if err := rows.Scan(&count); err != nil {
  1277. return false, err
  1278. }
  1279. return count > 0, nil
  1280. }
  1281. // ReservationsCount returns the number of reservations owned by this user
  1282. func (a *Manager) ReservationsCount(username string) (int64, error) {
  1283. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  1284. if err != nil {
  1285. return 0, err
  1286. }
  1287. defer rows.Close()
  1288. if !rows.Next() {
  1289. return 0, errNoRows
  1290. }
  1291. var count int64
  1292. if err := rows.Scan(&count); err != nil {
  1293. return 0, err
  1294. }
  1295. return count, nil
  1296. }
  1297. // ReservationOwner returns user ID of the user that owns this topic, or an
  1298. // empty string if it's not owned by anyone
  1299. func (a *Manager) ReservationOwner(topic string) (string, error) {
  1300. rows, err := a.db.Query(selectUserReservationsOwnerQuery, escapeUnderscore(topic))
  1301. if err != nil {
  1302. return "", err
  1303. }
  1304. defer rows.Close()
  1305. if !rows.Next() {
  1306. return "", nil
  1307. }
  1308. var ownerUserID string
  1309. if err := rows.Scan(&ownerUserID); err != nil {
  1310. return "", err
  1311. }
  1312. return ownerUserID, nil
  1313. }
  1314. // ChangePassword changes a user's password
  1315. func (a *Manager) ChangePassword(username, password string, hashed bool) error {
  1316. return execTx(a.db, func(tx *sql.Tx) error {
  1317. return a.changePasswordTx(tx, username, password, hashed)
  1318. })
  1319. }
  1320. func (a *Manager) changePasswordTx(tx *sql.Tx, username, password string, hashed bool) error {
  1321. var hash string
  1322. var err error
  1323. if hashed {
  1324. hash = password
  1325. if err := ValidPasswordHash(hash); err != nil {
  1326. return err
  1327. }
  1328. } else {
  1329. hash, err = hashPassword(password, a.config.BcryptCost)
  1330. if err != nil {
  1331. return err
  1332. }
  1333. }
  1334. if _, err := tx.Exec(updateUserPassQuery, hash, username); err != nil {
  1335. return err
  1336. }
  1337. return nil
  1338. }
  1339. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  1340. // all existing access control entries (Grant) are removed, since they are no longer needed.
  1341. func (a *Manager) ChangeRole(username string, role Role) error {
  1342. return execTx(a.db, func(tx *sql.Tx) error {
  1343. return a.changeRoleTx(tx, username, role)
  1344. })
  1345. }
  1346. func (a *Manager) changeRoleTx(tx *sql.Tx, username string, role Role) error {
  1347. if !AllowedUsername(username) || !AllowedRole(role) {
  1348. return ErrInvalidArgument
  1349. }
  1350. if _, err := tx.Exec(updateUserRoleQuery, string(role), username); err != nil {
  1351. return err
  1352. }
  1353. if role == RoleAdmin {
  1354. if _, err := tx.Exec(deleteUserAccessQuery, username, username); err != nil {
  1355. return err
  1356. }
  1357. }
  1358. return nil
  1359. }
  1360. // ChangeProvisioned changes the provisioned status of a user. This is used to mark users as
  1361. // provisioned. A provisioned user is a user defined in the config file.
  1362. func (a *Manager) ChangeProvisioned(username string, provisioned bool) error {
  1363. return execTx(a.db, func(tx *sql.Tx) error {
  1364. return a.changeProvisionedTx(tx, username, provisioned)
  1365. })
  1366. }
  1367. func (a *Manager) changeProvisionedTx(tx *sql.Tx, username string, provisioned bool) error {
  1368. if _, err := tx.Exec(updateUserProvisionedQuery, provisioned, username); err != nil {
  1369. return err
  1370. }
  1371. return nil
  1372. }
  1373. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  1374. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  1375. func (a *Manager) ChangeTier(username, tier string) error {
  1376. if !AllowedUsername(username) {
  1377. return ErrInvalidArgument
  1378. }
  1379. t, err := a.Tier(tier)
  1380. if err != nil {
  1381. return err
  1382. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  1383. return err
  1384. }
  1385. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  1386. return err
  1387. }
  1388. return nil
  1389. }
  1390. // ResetTier removes the tier from the given user
  1391. func (a *Manager) ResetTier(username string) error {
  1392. if !AllowedUsername(username) && username != Everyone && username != "" {
  1393. return ErrInvalidArgument
  1394. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  1395. return err
  1396. }
  1397. _, err := a.db.Exec(deleteUserTierQuery, username)
  1398. return err
  1399. }
  1400. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  1401. u, err := a.User(username)
  1402. if err != nil {
  1403. return err
  1404. }
  1405. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  1406. reservations, err := a.Reservations(username)
  1407. if err != nil {
  1408. return err
  1409. } else if int64(len(reservations)) > reservationsLimit {
  1410. return ErrTooManyReservations
  1411. }
  1412. }
  1413. return nil
  1414. }
  1415. // AllowReservation tests if a user may create an access control entry for the given topic.
  1416. // If there are any ACL entries that are not owned by the user, an error is returned.
  1417. func (a *Manager) AllowReservation(username string, topic string) error {
  1418. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  1419. return ErrInvalidArgument
  1420. }
  1421. rows, err := a.db.Query(selectOtherAccessCountQuery, escapeUnderscore(topic), escapeUnderscore(topic), username)
  1422. if err != nil {
  1423. return err
  1424. }
  1425. defer rows.Close()
  1426. if !rows.Next() {
  1427. return errNoRows
  1428. }
  1429. var otherCount int
  1430. if err := rows.Scan(&otherCount); err != nil {
  1431. return err
  1432. }
  1433. if otherCount > 0 {
  1434. return errTopicOwnedByOthers
  1435. }
  1436. return nil
  1437. }
  1438. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  1439. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  1440. // owner may either be a user (username), or the system (empty).
  1441. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  1442. return execTx(a.db, func(tx *sql.Tx) error {
  1443. return a.allowAccessTx(tx, username, topicPattern, permission, false)
  1444. })
  1445. }
  1446. func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string, permission Permission, provisioned bool) error {
  1447. if !AllowedUsername(username) && username != Everyone {
  1448. return ErrInvalidArgument
  1449. } else if !AllowedTopicPattern(topicPattern) {
  1450. return ErrInvalidArgument
  1451. }
  1452. owner := ""
  1453. if _, err := tx.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner, provisioned); err != nil {
  1454. return err
  1455. }
  1456. return nil
  1457. }
  1458. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  1459. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  1460. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  1461. return execTx(a.db, func(tx *sql.Tx) error {
  1462. return a.resetAccessTx(tx, username, topicPattern)
  1463. })
  1464. }
  1465. func (a *Manager) resetAccessTx(tx *sql.Tx, username string, topicPattern string) error {
  1466. if !AllowedUsername(username) && username != Everyone && username != "" {
  1467. return ErrInvalidArgument
  1468. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  1469. return ErrInvalidArgument
  1470. }
  1471. if username == "" && topicPattern == "" {
  1472. _, err := tx.Exec(deleteAllAccessQuery, username)
  1473. return err
  1474. } else if topicPattern == "" {
  1475. _, err := tx.Exec(deleteUserAccessQuery, username, username)
  1476. return err
  1477. }
  1478. _, err := tx.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  1479. return err
  1480. }
  1481. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  1482. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  1483. // can modify or delete them.
  1484. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  1485. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  1486. return ErrInvalidArgument
  1487. }
  1488. tx, err := a.db.Begin()
  1489. if err != nil {
  1490. return err
  1491. }
  1492. defer tx.Rollback()
  1493. if _, err := tx.Exec(upsertUserAccessQuery, username, escapeUnderscore(topic), true, true, username, username, false); err != nil {
  1494. return err
  1495. }
  1496. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, escapeUnderscore(topic), everyone.IsRead(), everyone.IsWrite(), username, username, false); err != nil {
  1497. return err
  1498. }
  1499. return tx.Commit()
  1500. }
  1501. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  1502. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  1503. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  1504. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  1505. return ErrInvalidArgument
  1506. }
  1507. for _, topic := range topics {
  1508. if !AllowedTopic(topic) {
  1509. return ErrInvalidArgument
  1510. }
  1511. }
  1512. tx, err := a.db.Begin()
  1513. if err != nil {
  1514. return err
  1515. }
  1516. defer tx.Rollback()
  1517. for _, topic := range topics {
  1518. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, escapeUnderscore(topic)); err != nil {
  1519. return err
  1520. }
  1521. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, escapeUnderscore(topic)); err != nil {
  1522. return err
  1523. }
  1524. }
  1525. return tx.Commit()
  1526. }
  1527. // DefaultAccess returns the default read/write access if no access control entry matches
  1528. func (a *Manager) DefaultAccess() Permission {
  1529. return a.config.DefaultAccess
  1530. }
  1531. // AddTier creates a new tier in the database
  1532. func (a *Manager) AddTier(tier *Tier) error {
  1533. if tier.ID == "" {
  1534. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1535. }
  1536. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
  1537. return err
  1538. }
  1539. return nil
  1540. }
  1541. // UpdateTier updates a tier's properties in the database
  1542. func (a *Manager) UpdateTier(tier *Tier) error {
  1543. if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
  1544. return err
  1545. }
  1546. return nil
  1547. }
  1548. // RemoveTier deletes the tier with the given code
  1549. func (a *Manager) RemoveTier(code string) error {
  1550. if !AllowedTier(code) {
  1551. return ErrInvalidArgument
  1552. }
  1553. // This fails if any user has this tier
  1554. if _, err := a.db.Exec(deleteTierQuery, code); err != nil {
  1555. return err
  1556. }
  1557. return nil
  1558. }
  1559. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1560. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1561. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1562. return err
  1563. }
  1564. return nil
  1565. }
  1566. // Tiers returns a list of all Tier structs
  1567. func (a *Manager) Tiers() ([]*Tier, error) {
  1568. rows, err := a.db.Query(selectTiersQuery)
  1569. if err != nil {
  1570. return nil, err
  1571. }
  1572. defer rows.Close()
  1573. tiers := make([]*Tier, 0)
  1574. for {
  1575. tier, err := a.readTier(rows)
  1576. if err == ErrTierNotFound {
  1577. break
  1578. } else if err != nil {
  1579. return nil, err
  1580. }
  1581. tiers = append(tiers, tier)
  1582. }
  1583. return tiers, nil
  1584. }
  1585. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1586. func (a *Manager) Tier(code string) (*Tier, error) {
  1587. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1588. if err != nil {
  1589. return nil, err
  1590. }
  1591. defer rows.Close()
  1592. return a.readTier(rows)
  1593. }
  1594. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1595. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1596. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID)
  1597. if err != nil {
  1598. return nil, err
  1599. }
  1600. defer rows.Close()
  1601. return a.readTier(rows)
  1602. }
  1603. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1604. var id, code, name string
  1605. var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
  1606. var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1607. if !rows.Next() {
  1608. return nil, ErrTierNotFound
  1609. }
  1610. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  1611. return nil, err
  1612. } else if err := rows.Err(); err != nil {
  1613. return nil, err
  1614. }
  1615. // When changed, note readUser() as well
  1616. return &Tier{
  1617. ID: id,
  1618. Code: code,
  1619. Name: name,
  1620. MessageLimit: messagesLimit.Int64,
  1621. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1622. EmailLimit: emailsLimit.Int64,
  1623. CallLimit: callsLimit.Int64,
  1624. ReservationLimit: reservationsLimit.Int64,
  1625. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1626. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1627. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1628. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1629. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  1630. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  1631. }, nil
  1632. }
  1633. // Close closes the underlying database
  1634. func (a *Manager) Close() error {
  1635. return a.db.Close()
  1636. }
  1637. // maybeProvisionUsersAccessAndTokens provisions users, access control entries, and tokens based on the config.
  1638. func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
  1639. if !a.config.ProvisionEnabled {
  1640. return nil
  1641. }
  1642. existingUsers, err := a.Users()
  1643. if err != nil {
  1644. return err
  1645. }
  1646. provisionUsernames := util.Map(a.config.Users, func(u *User) string {
  1647. return u.Name
  1648. })
  1649. return execTx(a.db, func(tx *sql.Tx) error {
  1650. if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
  1651. return fmt.Errorf("failed to provision users: %v", err)
  1652. }
  1653. if err := a.maybeProvisionGrants(tx); err != nil {
  1654. return fmt.Errorf("failed to provision grants: %v", err)
  1655. }
  1656. if err := a.maybeProvisionTokens(tx, provisionUsernames); err != nil {
  1657. return fmt.Errorf("failed to provision tokens: %v", err)
  1658. }
  1659. return nil
  1660. })
  1661. }
  1662. // maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
  1663. // It also removes users that are provisioned, but not in the config anymore.
  1664. func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
  1665. // Remove users that are provisioned, but not in the config anymore
  1666. for _, user := range existingUsers {
  1667. if user.Name == Everyone {
  1668. continue
  1669. } else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
  1670. if err := a.removeUserTx(tx, user.Name); err != nil {
  1671. return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
  1672. }
  1673. }
  1674. }
  1675. // Add or update provisioned users
  1676. for _, user := range a.config.Users {
  1677. if user.Name == Everyone {
  1678. continue
  1679. }
  1680. existingUser, exists := util.Find(existingUsers, func(u *User) bool {
  1681. return u.Name == user.Name
  1682. })
  1683. if !exists {
  1684. if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
  1685. return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
  1686. }
  1687. } else {
  1688. if !existingUser.Provisioned {
  1689. if err := a.changeProvisionedTx(tx, user.Name, true); err != nil {
  1690. return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
  1691. }
  1692. }
  1693. if existingUser.Hash != user.Hash {
  1694. if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil {
  1695. return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
  1696. }
  1697. }
  1698. if existingUser.Role != user.Role {
  1699. if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
  1700. return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
  1701. }
  1702. }
  1703. }
  1704. }
  1705. return nil
  1706. }
  1707. // maybyProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config.
  1708. //
  1709. // Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
  1710. // access time) or do not have dependent resources (such as grants or tokens).
  1711. func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error {
  1712. // Remove all provisioned grants
  1713. if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil {
  1714. return err
  1715. }
  1716. // (Re-)add provisioned grants
  1717. for username, grants := range a.config.Access {
  1718. user, exists := util.Find(a.config.Users, func(u *User) bool {
  1719. return u.Name == username
  1720. })
  1721. if !exists && username != Everyone {
  1722. return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username)
  1723. } else if user != nil && user.Role == RoleAdmin {
  1724. return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
  1725. }
  1726. for _, grant := range grants {
  1727. if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil {
  1728. return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
  1729. }
  1730. if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
  1731. return err
  1732. }
  1733. }
  1734. }
  1735. return nil
  1736. }
  1737. func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) error {
  1738. // Remove tokens that are provisioned, but not in the config anymore
  1739. existingTokens, err := a.allProvisionedTokens()
  1740. if err != nil {
  1741. return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err)
  1742. }
  1743. var provisionTokens []string
  1744. for _, userTokens := range a.config.Tokens {
  1745. for _, token := range userTokens {
  1746. provisionTokens = append(provisionTokens, token.Value)
  1747. }
  1748. }
  1749. for _, existingToken := range existingTokens {
  1750. if !slices.Contains(provisionTokens, existingToken.Value) {
  1751. if _, err := tx.Exec(deleteProvisionedTokenQuery, existingToken.Value); err != nil {
  1752. return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
  1753. }
  1754. }
  1755. }
  1756. // (Re-)add provisioned tokens
  1757. for username, tokens := range a.config.Tokens {
  1758. if !slices.Contains(provisionUsernames, username) && username != Everyone {
  1759. return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
  1760. }
  1761. var userID string
  1762. row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
  1763. if err := row.Scan(&userID); err != nil {
  1764. return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
  1765. }
  1766. for _, token := range tokens {
  1767. if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
  1768. return err
  1769. }
  1770. }
  1771. }
  1772. return nil
  1773. }
  1774. // toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
  1775. // and escapes '_', assuming '\' as escape character.
  1776. func toSQLWildcard(s string) string {
  1777. return escapeUnderscore(strings.ReplaceAll(s, "*", "%"))
  1778. }
  1779. // fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*',
  1780. // and removes the '\_' escape character.
  1781. func fromSQLWildcard(s string) string {
  1782. return strings.ReplaceAll(unescapeUnderscore(s), "%", "*")
  1783. }
  1784. func escapeUnderscore(s string) string {
  1785. return strings.ReplaceAll(s, "_", "\\_")
  1786. }
  1787. func unescapeUnderscore(s string) string {
  1788. return strings.ReplaceAll(s, "\\_", "_")
  1789. }
  1790. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1791. if _, err := db.Exec(startupQueries); err != nil {
  1792. return err
  1793. }
  1794. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1795. return err
  1796. }
  1797. return nil
  1798. }
  1799. func setupDB(db *sql.DB) error {
  1800. // If 'schemaVersion' table does not exist, this must be a new database
  1801. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1802. if err != nil {
  1803. return setupNewDB(db)
  1804. }
  1805. defer rowsSV.Close()
  1806. // If 'schemaVersion' table exists, read version and potentially upgrade
  1807. schemaVersion := 0
  1808. if !rowsSV.Next() {
  1809. return errors.New("cannot determine schema version: database file may be corrupt")
  1810. }
  1811. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1812. return err
  1813. }
  1814. rowsSV.Close()
  1815. // Do migrations
  1816. if schemaVersion == currentSchemaVersion {
  1817. return nil
  1818. } else if schemaVersion > currentSchemaVersion {
  1819. return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
  1820. }
  1821. for i := schemaVersion; i < currentSchemaVersion; i++ {
  1822. fn, ok := migrations[i]
  1823. if !ok {
  1824. return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
  1825. } else if err := fn(db); err != nil {
  1826. return err
  1827. }
  1828. }
  1829. return nil
  1830. }
  1831. func setupNewDB(db *sql.DB) error {
  1832. if _, err := db.Exec(createTablesQueries); err != nil {
  1833. return err
  1834. }
  1835. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1836. return err
  1837. }
  1838. return nil
  1839. }
  1840. func migrateFrom1(db *sql.DB) error {
  1841. log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
  1842. tx, err := db.Begin()
  1843. if err != nil {
  1844. return err
  1845. }
  1846. defer tx.Rollback()
  1847. // Rename user -> user_old, and create new tables
  1848. if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
  1849. return err
  1850. }
  1851. // Insert users from user_old into new user table, with ID and sync_topic
  1852. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1853. if err != nil {
  1854. return err
  1855. }
  1856. defer rows.Close()
  1857. usernames := make([]string, 0)
  1858. for rows.Next() {
  1859. var username string
  1860. if err := rows.Scan(&username); err != nil {
  1861. return err
  1862. }
  1863. usernames = append(usernames, username)
  1864. }
  1865. if err := rows.Close(); err != nil {
  1866. return err
  1867. }
  1868. for _, username := range usernames {
  1869. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1870. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1871. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1872. return err
  1873. }
  1874. }
  1875. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1876. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1877. return err
  1878. }
  1879. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1880. return err
  1881. }
  1882. if err := tx.Commit(); err != nil {
  1883. return err
  1884. }
  1885. return nil
  1886. }
  1887. func migrateFrom2(db *sql.DB) error {
  1888. log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
  1889. tx, err := db.Begin()
  1890. if err != nil {
  1891. return err
  1892. }
  1893. defer tx.Rollback()
  1894. if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
  1895. return err
  1896. }
  1897. if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
  1898. return err
  1899. }
  1900. return tx.Commit()
  1901. }
  1902. func migrateFrom3(db *sql.DB) error {
  1903. log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
  1904. tx, err := db.Begin()
  1905. if err != nil {
  1906. return err
  1907. }
  1908. defer tx.Rollback()
  1909. if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
  1910. return err
  1911. }
  1912. if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
  1913. return err
  1914. }
  1915. return tx.Commit()
  1916. }
  1917. func migrateFrom4(db *sql.DB) error {
  1918. log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
  1919. tx, err := db.Begin()
  1920. if err != nil {
  1921. return err
  1922. }
  1923. defer tx.Rollback()
  1924. if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
  1925. return err
  1926. }
  1927. if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
  1928. return err
  1929. }
  1930. return tx.Commit()
  1931. }
  1932. func migrateFrom5(db *sql.DB) error {
  1933. log.Tag(tag).Info("Migrating user database schema: from 5 to 6")
  1934. tx, err := db.Begin()
  1935. if err != nil {
  1936. return err
  1937. }
  1938. defer tx.Rollback()
  1939. if _, err := tx.Exec(migrate5To6UpdateQueries); err != nil {
  1940. return err
  1941. }
  1942. if _, err := tx.Exec(updateSchemaVersion, 6); err != nil {
  1943. return err
  1944. }
  1945. return tx.Commit()
  1946. }
  1947. func nullString(s string) sql.NullString {
  1948. if s == "" {
  1949. return sql.NullString{}
  1950. }
  1951. return sql.NullString{String: s, Valid: true}
  1952. }
  1953. func nullInt64(v int64) sql.NullInt64 {
  1954. if v == 0 {
  1955. return sql.NullInt64{}
  1956. }
  1957. return sql.NullInt64{Int64: v, Valid: true}
  1958. }
  1959. // execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
  1960. func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
  1961. tx, err := db.Begin()
  1962. if err != nil {
  1963. return err
  1964. }
  1965. defer tx.Rollback()
  1966. if err := f(tx); err != nil {
  1967. return err
  1968. }
  1969. return tx.Commit()
  1970. }
  1971. // queryTx executes a function in a transaction and returns the result. If the function
  1972. // returns an error, the transaction is rolled back.
  1973. func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
  1974. tx, err := db.Begin()
  1975. if err != nil {
  1976. var zero T
  1977. return zero, err
  1978. }
  1979. defer tx.Rollback()
  1980. t, err := f(tx)
  1981. if err != nil {
  1982. return t, err
  1983. }
  1984. if err := tx.Commit(); err != nil {
  1985. return t, err
  1986. }
  1987. return t, nil
  1988. }