manager.go 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643
  1. // Package user deals with authentication and authorization against topics
  2. package user
  3. import (
  4. "database/sql"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/mattn/go-sqlite3"
  9. _ "github.com/mattn/go-sqlite3" // SQLite driver
  10. "github.com/stripe/stripe-go/v74"
  11. "golang.org/x/crypto/bcrypt"
  12. "heckel.io/ntfy/log"
  13. "heckel.io/ntfy/util"
  14. "net/netip"
  15. "strings"
  16. "sync"
  17. "time"
  18. )
  19. const (
  20. tierIDPrefix = "ti_"
  21. tierIDLength = 8
  22. syncTopicPrefix = "st_"
  23. syncTopicLength = 16
  24. userIDPrefix = "u_"
  25. userIDLength = 12
  26. userAuthIntentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match DefaultUserPasswordBcryptCost
  27. userHardDeleteAfterDuration = 7 * 24 * time.Hour
  28. tokenPrefix = "tk_"
  29. tokenLength = 32
  30. tokenMaxCount = 20 // Only keep this many tokens in the table per user
  31. tag = "user_manager"
  32. )
  33. // Default constants that may be overridden by configs
  34. const (
  35. DefaultUserStatsQueueWriterInterval = 33 * time.Second
  36. DefaultUserPasswordBcryptCost = 10
  37. )
  38. var (
  39. errNoTokenProvided = errors.New("no token provided")
  40. errTopicOwnedByOthers = errors.New("topic owned by others")
  41. errNoRows = errors.New("no rows found")
  42. )
  43. // Manager-related queries
  44. const (
  45. createTablesQueries = `
  46. BEGIN;
  47. CREATE TABLE IF NOT EXISTS tier (
  48. id TEXT PRIMARY KEY,
  49. code TEXT NOT NULL,
  50. name TEXT NOT NULL,
  51. messages_limit INT NOT NULL,
  52. messages_expiry_duration INT NOT NULL,
  53. emails_limit INT NOT NULL,
  54. calls_limit INT NOT NULL,
  55. reservations_limit INT NOT NULL,
  56. attachment_file_size_limit INT NOT NULL,
  57. attachment_total_size_limit INT NOT NULL,
  58. attachment_expiry_duration INT NOT NULL,
  59. attachment_bandwidth_limit INT NOT NULL,
  60. stripe_monthly_price_id TEXT,
  61. stripe_yearly_price_id TEXT
  62. );
  63. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  64. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  65. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  66. CREATE TABLE IF NOT EXISTS user (
  67. id TEXT PRIMARY KEY,
  68. tier_id TEXT,
  69. user TEXT NOT NULL,
  70. pass TEXT NOT NULL,
  71. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  72. prefs JSON NOT NULL DEFAULT '{}',
  73. sync_topic TEXT NOT NULL,
  74. stats_messages INT NOT NULL DEFAULT (0),
  75. stats_emails INT NOT NULL DEFAULT (0),
  76. stats_calls INT NOT NULL DEFAULT (0),
  77. stripe_customer_id TEXT,
  78. stripe_subscription_id TEXT,
  79. stripe_subscription_status TEXT,
  80. stripe_subscription_interval TEXT,
  81. stripe_subscription_paid_until INT,
  82. stripe_subscription_cancel_at INT,
  83. created INT NOT NULL,
  84. deleted INT,
  85. FOREIGN KEY (tier_id) REFERENCES tier (id)
  86. );
  87. CREATE UNIQUE INDEX idx_user ON user (user);
  88. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  89. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  90. CREATE TABLE IF NOT EXISTS user_access (
  91. user_id TEXT NOT NULL,
  92. topic TEXT NOT NULL,
  93. read INT NOT NULL,
  94. write INT NOT NULL,
  95. owner_user_id INT,
  96. PRIMARY KEY (user_id, topic),
  97. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  98. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  99. );
  100. CREATE TABLE IF NOT EXISTS user_token (
  101. user_id TEXT NOT NULL,
  102. token TEXT NOT NULL,
  103. label TEXT NOT NULL,
  104. last_access INT NOT NULL,
  105. last_origin TEXT NOT NULL,
  106. expires INT NOT NULL,
  107. PRIMARY KEY (user_id, token),
  108. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  109. );
  110. CREATE TABLE IF NOT EXISTS user_phone (
  111. user_id TEXT NOT NULL,
  112. phone_number TEXT NOT NULL,
  113. PRIMARY KEY (user_id, phone_number),
  114. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  115. );
  116. CREATE UNIQUE INDEX idx_user_phone_number ON user_phone (phone_number);
  117. CREATE TABLE IF NOT EXISTS schemaVersion (
  118. id INT PRIMARY KEY,
  119. version INT NOT NULL
  120. );
  121. INSERT INTO user (id, user, pass, role, sync_topic, created)
  122. VALUES ('` + everyoneID + `', '*', '', 'anonymous', '', UNIXEPOCH())
  123. ON CONFLICT (id) DO NOTHING;
  124. COMMIT;
  125. `
  126. builtinStartupQueries = `
  127. PRAGMA foreign_keys = ON;
  128. `
  129. selectUserByIDQuery = `
  130. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  131. FROM user u
  132. LEFT JOIN tier t on t.id = u.tier_id
  133. WHERE u.id = ?
  134. `
  135. selectUserByNameQuery = `
  136. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  137. FROM user u
  138. LEFT JOIN tier t on t.id = u.tier_id
  139. WHERE user = ?
  140. `
  141. selectUserByTokenQuery = `
  142. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  143. FROM user u
  144. JOIN user_token tk on u.id = tk.user_id
  145. LEFT JOIN tier t on t.id = u.tier_id
  146. WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
  147. `
  148. selectUserByStripeCustomerIDQuery = `
  149. SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stats_calls, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_interval, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.calls_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_monthly_price_id, t.stripe_yearly_price_id
  150. FROM user u
  151. LEFT JOIN tier t on t.id = u.tier_id
  152. WHERE u.stripe_customer_id = ?
  153. `
  154. selectTopicPermsQuery = `
  155. SELECT read, write
  156. FROM user_access a
  157. JOIN user u ON u.id = a.user_id
  158. WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
  159. ORDER BY u.user DESC
  160. `
  161. insertUserQuery = `
  162. INSERT INTO user (id, user, pass, role, sync_topic, created)
  163. VALUES (?, ?, ?, ?, ?, ?)
  164. `
  165. selectUsernamesQuery = `
  166. SELECT user
  167. FROM user
  168. ORDER BY
  169. CASE role
  170. WHEN 'admin' THEN 1
  171. WHEN 'anonymous' THEN 3
  172. ELSE 2
  173. END, user
  174. `
  175. selectUserCountQuery = `SELECT COUNT(*) FROM user`
  176. updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
  177. updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
  178. updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
  179. updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
  180. updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
  181. updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
  182. deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
  183. deleteUserQuery = `DELETE FROM user WHERE user = ?`
  184. upsertUserAccessQuery = `
  185. INSERT INTO user_access (user_id, topic, read, write, owner_user_id)
  186. VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?, (SELECT IIF(?='',NULL,(SELECT id FROM user WHERE user=?))))
  187. ON CONFLICT (user_id, topic)
  188. DO UPDATE SET read=excluded.read, write=excluded.write, owner_user_id=excluded.owner_user_id
  189. `
  190. selectUserAllAccessQuery = `
  191. SELECT user_id, topic, read, write
  192. FROM user_access
  193. ORDER BY write DESC, read DESC, topic
  194. `
  195. selectUserAccessQuery = `
  196. SELECT topic, read, write
  197. FROM user_access
  198. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  199. ORDER BY write DESC, read DESC, topic
  200. `
  201. selectUserReservationsQuery = `
  202. SELECT a_user.topic, a_user.read, a_user.write, a_everyone.read AS everyone_read, a_everyone.write AS everyone_write
  203. FROM user_access a_user
  204. LEFT JOIN user_access a_everyone ON a_user.topic = a_everyone.topic AND a_everyone.user_id = (SELECT id FROM user WHERE user = ?)
  205. WHERE a_user.user_id = a_user.owner_user_id
  206. AND a_user.owner_user_id = (SELECT id FROM user WHERE user = ?)
  207. ORDER BY a_user.topic
  208. `
  209. selectUserReservationsCountQuery = `
  210. SELECT COUNT(*)
  211. FROM user_access
  212. WHERE user_id = owner_user_id
  213. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  214. `
  215. selectUserReservationsOwnerQuery = `
  216. SELECT owner_user_id
  217. FROM user_access
  218. WHERE topic = ?
  219. AND user_id = owner_user_id
  220. `
  221. selectUserHasReservationQuery = `
  222. SELECT COUNT(*)
  223. FROM user_access
  224. WHERE user_id = owner_user_id
  225. AND owner_user_id = (SELECT id FROM user WHERE user = ?)
  226. AND topic = ?
  227. `
  228. selectOtherAccessCountQuery = `
  229. SELECT COUNT(*)
  230. FROM user_access
  231. WHERE (topic = ? OR ? LIKE topic)
  232. AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
  233. `
  234. deleteAllAccessQuery = `DELETE FROM user_access`
  235. deleteUserAccessQuery = `
  236. DELETE FROM user_access
  237. WHERE user_id = (SELECT id FROM user WHERE user = ?)
  238. OR owner_user_id = (SELECT id FROM user WHERE user = ?)
  239. `
  240. deleteTopicAccessQuery = `
  241. DELETE FROM user_access
  242. WHERE (user_id = (SELECT id FROM user WHERE user = ?) OR owner_user_id = (SELECT id FROM user WHERE user = ?))
  243. AND topic = ?
  244. `
  245. selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
  246. selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
  247. selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
  248. insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
  249. updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
  250. updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
  251. updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
  252. deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
  253. deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
  254. deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
  255. deleteExcessTokensQuery = `
  256. DELETE FROM user_token
  257. WHERE (user_id, token) NOT IN (
  258. SELECT user_id, token
  259. FROM user_token
  260. WHERE user_id = ?
  261. ORDER BY expires DESC
  262. LIMIT ?
  263. )
  264. `
  265. selectPhoneNumbersQuery = `SELECT phone_number FROM user_phone WHERE user_id = ?`
  266. insertPhoneNumberQuery = `INSERT INTO user_phone (user_id, phone_number) VALUES (?, ?)`
  267. deletePhoneNumberQuery = `DELETE FROM user_phone WHERE user_id = ? AND phone_number = ?`
  268. insertTierQuery = `
  269. INSERT INTO tier (id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id)
  270. VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
  271. `
  272. updateTierQuery = `
  273. UPDATE tier
  274. SET name = ?, messages_limit = ?, messages_expiry_duration = ?, emails_limit = ?, calls_limit = ?, reservations_limit = ?, attachment_file_size_limit = ?, attachment_total_size_limit = ?, attachment_expiry_duration = ?, attachment_bandwidth_limit = ?, stripe_monthly_price_id = ?, stripe_yearly_price_id = ?
  275. WHERE code = ?
  276. `
  277. selectTiersQuery = `
  278. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  279. FROM tier
  280. `
  281. selectTierByCodeQuery = `
  282. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  283. FROM tier
  284. WHERE code = ?
  285. `
  286. selectTierByPriceIDQuery = `
  287. SELECT id, code, name, messages_limit, messages_expiry_duration, emails_limit, calls_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, attachment_bandwidth_limit, stripe_monthly_price_id, stripe_yearly_price_id
  288. FROM tier
  289. WHERE (stripe_monthly_price_id = ? OR stripe_yearly_price_id = ?)
  290. `
  291. updateUserTierQuery = `UPDATE user SET tier_id = (SELECT id FROM tier WHERE code = ?) WHERE user = ?`
  292. deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
  293. deleteTierQuery = `DELETE FROM tier WHERE code = ?`
  294. updateBillingQuery = `
  295. UPDATE user
  296. SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_interval = ?, stripe_subscription_paid_until = ?, stripe_subscription_cancel_at = ?
  297. WHERE user = ?
  298. `
  299. )
  300. // Schema management queries
  301. const (
  302. currentSchemaVersion = 4
  303. insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  304. updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
  305. selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  306. // 1 -> 2 (complex migration!)
  307. migrate1To2CreateTablesQueries = `
  308. ALTER TABLE user RENAME TO user_old;
  309. CREATE TABLE IF NOT EXISTS tier (
  310. id TEXT PRIMARY KEY,
  311. code TEXT NOT NULL,
  312. name TEXT NOT NULL,
  313. messages_limit INT NOT NULL,
  314. messages_expiry_duration INT NOT NULL,
  315. emails_limit INT NOT NULL,
  316. reservations_limit INT NOT NULL,
  317. attachment_file_size_limit INT NOT NULL,
  318. attachment_total_size_limit INT NOT NULL,
  319. attachment_expiry_duration INT NOT NULL,
  320. attachment_bandwidth_limit INT NOT NULL,
  321. stripe_price_id TEXT
  322. );
  323. CREATE UNIQUE INDEX idx_tier_code ON tier (code);
  324. CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
  325. CREATE TABLE IF NOT EXISTS user (
  326. id TEXT PRIMARY KEY,
  327. tier_id TEXT,
  328. user TEXT NOT NULL,
  329. pass TEXT NOT NULL,
  330. role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
  331. prefs JSON NOT NULL DEFAULT '{}',
  332. sync_topic TEXT NOT NULL,
  333. stats_messages INT NOT NULL DEFAULT (0),
  334. stats_emails INT NOT NULL DEFAULT (0),
  335. stripe_customer_id TEXT,
  336. stripe_subscription_id TEXT,
  337. stripe_subscription_status TEXT,
  338. stripe_subscription_paid_until INT,
  339. stripe_subscription_cancel_at INT,
  340. created INT NOT NULL,
  341. deleted INT,
  342. FOREIGN KEY (tier_id) REFERENCES tier (id)
  343. );
  344. CREATE UNIQUE INDEX idx_user ON user (user);
  345. CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
  346. CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
  347. CREATE TABLE IF NOT EXISTS user_access (
  348. user_id TEXT NOT NULL,
  349. topic TEXT NOT NULL,
  350. read INT NOT NULL,
  351. write INT NOT NULL,
  352. owner_user_id INT,
  353. PRIMARY KEY (user_id, topic),
  354. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
  355. FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
  356. );
  357. CREATE TABLE IF NOT EXISTS user_token (
  358. user_id TEXT NOT NULL,
  359. token TEXT NOT NULL,
  360. label TEXT NOT NULL,
  361. last_access INT NOT NULL,
  362. last_origin TEXT NOT NULL,
  363. expires INT NOT NULL,
  364. PRIMARY KEY (user_id, token),
  365. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  366. );
  367. CREATE TABLE IF NOT EXISTS schemaVersion (
  368. id INT PRIMARY KEY,
  369. version INT NOT NULL
  370. );
  371. INSERT INTO user (id, user, pass, role, sync_topic, created)
  372. VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
  373. ON CONFLICT (id) DO NOTHING;
  374. `
  375. migrate1To2SelectAllOldUsernamesNoTx = `SELECT user FROM user_old`
  376. migrate1To2InsertUserNoTx = `
  377. INSERT INTO user (id, user, pass, role, sync_topic, created)
  378. SELECT ?, user, pass, role, ?, UNIXEPOCH() FROM user_old WHERE user = ?
  379. `
  380. migrate1To2InsertFromOldTablesAndDropNoTx = `
  381. INSERT INTO user_access (user_id, topic, read, write)
  382. SELECT u.id, a.topic, a.read, a.write
  383. FROM user u
  384. JOIN access a ON u.user = a.user;
  385. DROP TABLE access;
  386. DROP TABLE user_old;
  387. `
  388. // 2 -> 3
  389. migrate2To3UpdateQueries = `
  390. ALTER TABLE user ADD COLUMN stripe_subscription_interval TEXT;
  391. ALTER TABLE tier RENAME COLUMN stripe_price_id TO stripe_monthly_price_id;
  392. ALTER TABLE tier ADD COLUMN stripe_yearly_price_id TEXT;
  393. DROP INDEX IF EXISTS idx_tier_price_id;
  394. CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
  395. CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
  396. `
  397. // 3 -> 4
  398. migrate3To4UpdateQueries = `
  399. ALTER TABLE tier ADD COLUMN calls_limit INT NOT NULL DEFAULT (0);
  400. ALTER TABLE user ADD COLUMN stats_calls INT NOT NULL DEFAULT (0);
  401. CREATE TABLE IF NOT EXISTS user_phone (
  402. user_id TEXT NOT NULL,
  403. phone_number TEXT NOT NULL,
  404. PRIMARY KEY (user_id, phone_number),
  405. FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
  406. );
  407. CREATE UNIQUE INDEX idx_user_phone_number ON user_phone (phone_number);
  408. `
  409. )
  410. var (
  411. migrations = map[int]func(db *sql.DB) error{
  412. 1: migrateFrom1,
  413. 2: migrateFrom2,
  414. 3: migrateFrom3,
  415. }
  416. )
  417. // Manager is an implementation of Manager. It stores users and access control list
  418. // in a SQLite database.
  419. type Manager struct {
  420. db *sql.DB
  421. defaultAccess Permission // Default permission if no ACL matches
  422. statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
  423. tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
  424. bcryptCost int // Makes testing easier
  425. mu sync.Mutex
  426. }
  427. var _ Auther = (*Manager)(nil)
  428. // NewManager creates a new Manager instance
  429. func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
  430. db, err := sql.Open("sqlite3", filename)
  431. if err != nil {
  432. return nil, err
  433. }
  434. if err := setupDB(db); err != nil {
  435. return nil, err
  436. }
  437. if err := runStartupQueries(db, startupQueries); err != nil {
  438. return nil, err
  439. }
  440. manager := &Manager{
  441. db: db,
  442. defaultAccess: defaultAccess,
  443. statsQueue: make(map[string]*Stats),
  444. tokenQueue: make(map[string]*TokenUpdate),
  445. bcryptCost: bcryptCost,
  446. }
  447. go manager.asyncQueueWriter(queueWriterInterval)
  448. return manager, nil
  449. }
  450. // Authenticate checks username and password and returns a User if correct, and the user has not been
  451. // marked as deleted. The method returns in constant-ish time, regardless of whether the user exists or
  452. // the password is correct or incorrect.
  453. func (a *Manager) Authenticate(username, password string) (*User, error) {
  454. if username == Everyone {
  455. return nil, ErrUnauthenticated
  456. }
  457. user, err := a.User(username)
  458. if err != nil {
  459. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (1)")
  460. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  461. return nil, ErrUnauthenticated
  462. } else if user.Deleted {
  463. log.Tag(tag).Field("user_name", username).Trace("Authentication of user failed (2): user marked deleted")
  464. bcrypt.CompareHashAndPassword([]byte(userAuthIntentionalSlowDownHash), []byte("intentional slow-down to avoid timing attacks"))
  465. return nil, ErrUnauthenticated
  466. } else if err := bcrypt.CompareHashAndPassword([]byte(user.Hash), []byte(password)); err != nil {
  467. log.Tag(tag).Field("user_name", username).Err(err).Trace("Authentication of user failed (3)")
  468. return nil, ErrUnauthenticated
  469. }
  470. return user, nil
  471. }
  472. // AuthenticateToken checks if the token exists and returns the associated User if it does.
  473. // The method sets the User.Token value to the token that was used for authentication.
  474. func (a *Manager) AuthenticateToken(token string) (*User, error) {
  475. if len(token) != tokenLength {
  476. return nil, ErrUnauthenticated
  477. }
  478. user, err := a.userByToken(token)
  479. if err != nil {
  480. log.Tag(tag).Field("token", token).Err(err).Trace("Authentication of token failed")
  481. return nil, ErrUnauthenticated
  482. }
  483. user.Token = token
  484. return user, nil
  485. }
  486. // CreateToken generates a random token for the given user and returns it. The token expires
  487. // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
  488. // given user, if there are too many of them.
  489. func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
  490. token := util.RandomStringPrefix(tokenPrefix, tokenLength)
  491. tx, err := a.db.Begin()
  492. if err != nil {
  493. return nil, err
  494. }
  495. defer tx.Rollback()
  496. access := time.Now()
  497. if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
  498. return nil, err
  499. }
  500. rows, err := tx.Query(selectTokenCountQuery, userID)
  501. if err != nil {
  502. return nil, err
  503. }
  504. defer rows.Close()
  505. if !rows.Next() {
  506. return nil, errNoRows
  507. }
  508. var tokenCount int
  509. if err := rows.Scan(&tokenCount); err != nil {
  510. return nil, err
  511. }
  512. if tokenCount >= tokenMaxCount {
  513. // This pruning logic is done in two queries for efficiency. The SELECT above is a lookup
  514. // on two indices, whereas the query below is a full table scan.
  515. if _, err := tx.Exec(deleteExcessTokensQuery, userID, tokenMaxCount); err != nil {
  516. return nil, err
  517. }
  518. }
  519. if err := tx.Commit(); err != nil {
  520. return nil, err
  521. }
  522. return &Token{
  523. Value: token,
  524. Label: label,
  525. LastAccess: access,
  526. LastOrigin: origin,
  527. Expires: expires,
  528. }, nil
  529. }
  530. // Tokens returns all existing tokens for the user with the given user ID
  531. func (a *Manager) Tokens(userID string) ([]*Token, error) {
  532. rows, err := a.db.Query(selectTokensQuery, userID)
  533. if err != nil {
  534. return nil, err
  535. }
  536. defer rows.Close()
  537. tokens := make([]*Token, 0)
  538. for {
  539. token, err := a.readToken(rows)
  540. if err == ErrTokenNotFound {
  541. break
  542. } else if err != nil {
  543. return nil, err
  544. }
  545. tokens = append(tokens, token)
  546. }
  547. return tokens, nil
  548. }
  549. // Token returns a specific token for a user
  550. func (a *Manager) Token(userID, token string) (*Token, error) {
  551. rows, err := a.db.Query(selectTokenQuery, userID, token)
  552. if err != nil {
  553. return nil, err
  554. }
  555. defer rows.Close()
  556. return a.readToken(rows)
  557. }
  558. func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
  559. var token, label, lastOrigin string
  560. var lastAccess, expires int64
  561. if !rows.Next() {
  562. return nil, ErrTokenNotFound
  563. }
  564. if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
  565. return nil, err
  566. } else if err := rows.Err(); err != nil {
  567. return nil, err
  568. }
  569. lastOriginIP, err := netip.ParseAddr(lastOrigin)
  570. if err != nil {
  571. lastOriginIP = netip.IPv4Unspecified()
  572. }
  573. return &Token{
  574. Value: token,
  575. Label: label,
  576. LastAccess: time.Unix(lastAccess, 0),
  577. LastOrigin: lastOriginIP,
  578. Expires: time.Unix(expires, 0),
  579. }, nil
  580. }
  581. // ChangeToken updates a token's label and/or expiry date
  582. func (a *Manager) ChangeToken(userID, token string, label *string, expires *time.Time) (*Token, error) {
  583. if token == "" {
  584. return nil, errNoTokenProvided
  585. }
  586. tx, err := a.db.Begin()
  587. if err != nil {
  588. return nil, err
  589. }
  590. defer tx.Rollback()
  591. if label != nil {
  592. if _, err := tx.Exec(updateTokenLabelQuery, *label, userID, token); err != nil {
  593. return nil, err
  594. }
  595. }
  596. if expires != nil {
  597. if _, err := tx.Exec(updateTokenExpiryQuery, expires.Unix(), userID, token); err != nil {
  598. return nil, err
  599. }
  600. }
  601. if err := tx.Commit(); err != nil {
  602. return nil, err
  603. }
  604. return a.Token(userID, token)
  605. }
  606. // RemoveToken deletes the token defined in User.Token
  607. func (a *Manager) RemoveToken(userID, token string) error {
  608. if token == "" {
  609. return errNoTokenProvided
  610. }
  611. if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil {
  612. return err
  613. }
  614. return nil
  615. }
  616. // RemoveExpiredTokens deletes all expired tokens from the database
  617. func (a *Manager) RemoveExpiredTokens() error {
  618. if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
  619. return err
  620. }
  621. return nil
  622. }
  623. // PhoneNumbers returns all phone numbers for the user with the given user ID
  624. func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
  625. rows, err := a.db.Query(selectPhoneNumbersQuery, userID)
  626. if err != nil {
  627. return nil, err
  628. }
  629. defer rows.Close()
  630. phoneNumbers := make([]string, 0)
  631. for {
  632. phoneNumber, err := a.readPhoneNumber(rows)
  633. if err == ErrPhoneNumberNotFound {
  634. break
  635. } else if err != nil {
  636. return nil, err
  637. }
  638. phoneNumbers = append(phoneNumbers, phoneNumber)
  639. }
  640. return phoneNumbers, nil
  641. }
  642. func (a *Manager) readPhoneNumber(rows *sql.Rows) (string, error) {
  643. var phoneNumber string
  644. if !rows.Next() {
  645. return "", ErrPhoneNumberNotFound
  646. }
  647. if err := rows.Scan(&phoneNumber); err != nil {
  648. return "", err
  649. } else if err := rows.Err(); err != nil {
  650. return "", err
  651. }
  652. return phoneNumber, nil
  653. }
  654. // AddPhoneNumber adds a phone number to the user with the given user ID
  655. func (a *Manager) AddPhoneNumber(userID string, phoneNumber string) error {
  656. if _, err := a.db.Exec(insertPhoneNumberQuery, userID, phoneNumber); err != nil {
  657. if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
  658. return ErrPhoneNumberExists
  659. }
  660. return err
  661. }
  662. return nil
  663. }
  664. // DeletePhoneNumber deletes a phone number from the user with the given user ID
  665. func (a *Manager) DeletePhoneNumber(userID string, phoneNumber string) error {
  666. _, err := a.db.Exec(deletePhoneNumberQuery, userID, phoneNumber)
  667. return err
  668. }
  669. // RemoveDeletedUsers deletes all users that have been marked deleted for
  670. func (a *Manager) RemoveDeletedUsers() error {
  671. if _, err := a.db.Exec(deleteUsersMarkedQuery, time.Now().Unix()); err != nil {
  672. return err
  673. }
  674. return nil
  675. }
  676. // ChangeSettings persists the user settings
  677. func (a *Manager) ChangeSettings(userID string, prefs *Prefs) error {
  678. b, err := json.Marshal(prefs)
  679. if err != nil {
  680. return err
  681. }
  682. if _, err := a.db.Exec(updateUserPrefsQuery, string(b), userID); err != nil {
  683. return err
  684. }
  685. return nil
  686. }
  687. // ResetStats resets all user stats in the user database. This touches all users.
  688. func (a *Manager) ResetStats() error {
  689. a.mu.Lock() // Includes database query to avoid races!
  690. defer a.mu.Unlock()
  691. if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
  692. return err
  693. }
  694. a.statsQueue = make(map[string]*Stats)
  695. return nil
  696. }
  697. // EnqueueUserStats adds the user to a queue which writes out user stats (messages, emails, ..) in
  698. // batches at a regular interval
  699. func (a *Manager) EnqueueUserStats(userID string, stats *Stats) {
  700. a.mu.Lock()
  701. defer a.mu.Unlock()
  702. a.statsQueue[userID] = stats
  703. }
  704. // EnqueueTokenUpdate adds the token update to a queue which writes out token access times
  705. // in batches at a regular interval
  706. func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
  707. a.mu.Lock()
  708. defer a.mu.Unlock()
  709. a.tokenQueue[tokenID] = update
  710. }
  711. func (a *Manager) asyncQueueWriter(interval time.Duration) {
  712. ticker := time.NewTicker(interval)
  713. for range ticker.C {
  714. if err := a.writeUserStatsQueue(); err != nil {
  715. log.Tag(tag).Err(err).Warn("Writing user stats queue failed")
  716. }
  717. if err := a.writeTokenUpdateQueue(); err != nil {
  718. log.Tag(tag).Err(err).Warn("Writing token update queue failed")
  719. }
  720. }
  721. }
  722. func (a *Manager) writeUserStatsQueue() error {
  723. a.mu.Lock()
  724. if len(a.statsQueue) == 0 {
  725. a.mu.Unlock()
  726. log.Tag(tag).Trace("No user stats updates to commit")
  727. return nil
  728. }
  729. statsQueue := a.statsQueue
  730. a.statsQueue = make(map[string]*Stats)
  731. a.mu.Unlock()
  732. tx, err := a.db.Begin()
  733. if err != nil {
  734. return err
  735. }
  736. defer tx.Rollback()
  737. log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
  738. for userID, update := range statsQueue {
  739. log.
  740. Tag(tag).
  741. Fields(log.Context{
  742. "user_id": userID,
  743. "messages_count": update.Messages,
  744. "emails_count": update.Emails,
  745. "calls_count": update.Calls,
  746. }).
  747. Trace("Updating stats for user %s", userID)
  748. if _, err := tx.Exec(updateUserStatsQuery, update.Messages, update.Emails, update.Calls, userID); err != nil {
  749. return err
  750. }
  751. }
  752. return tx.Commit()
  753. }
  754. func (a *Manager) writeTokenUpdateQueue() error {
  755. a.mu.Lock()
  756. if len(a.tokenQueue) == 0 {
  757. a.mu.Unlock()
  758. log.Tag(tag).Trace("No token updates to commit")
  759. return nil
  760. }
  761. tokenQueue := a.tokenQueue
  762. a.tokenQueue = make(map[string]*TokenUpdate)
  763. a.mu.Unlock()
  764. tx, err := a.db.Begin()
  765. if err != nil {
  766. return err
  767. }
  768. defer tx.Rollback()
  769. log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
  770. for tokenID, update := range tokenQueue {
  771. log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
  772. if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
  773. return err
  774. }
  775. }
  776. return tx.Commit()
  777. }
  778. // Authorize returns nil if the given user has access to the given topic using the desired
  779. // permission. The user param may be nil to signal an anonymous user.
  780. func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
  781. if user != nil && user.Role == RoleAdmin {
  782. return nil // Admin can do everything
  783. }
  784. username := Everyone
  785. if user != nil {
  786. username = user.Name
  787. }
  788. // Select the read/write permissions for this user/topic combo. The query may return two
  789. // rows (one for everyone, and one for the user), but prioritizes the user.
  790. rows, err := a.db.Query(selectTopicPermsQuery, Everyone, username, topic)
  791. if err != nil {
  792. return err
  793. }
  794. defer rows.Close()
  795. if !rows.Next() {
  796. return a.resolvePerms(a.defaultAccess, perm)
  797. }
  798. var read, write bool
  799. if err := rows.Scan(&read, &write); err != nil {
  800. return err
  801. } else if err := rows.Err(); err != nil {
  802. return err
  803. }
  804. return a.resolvePerms(NewPermission(read, write), perm)
  805. }
  806. func (a *Manager) resolvePerms(base, perm Permission) error {
  807. if perm == PermissionRead && base.IsRead() {
  808. return nil
  809. } else if perm == PermissionWrite && base.IsWrite() {
  810. return nil
  811. }
  812. return ErrUnauthorized
  813. }
  814. // AddUser adds a user with the given username, password and role
  815. func (a *Manager) AddUser(username, password string, role Role) error {
  816. if !AllowedUsername(username) || !AllowedRole(role) {
  817. return ErrInvalidArgument
  818. }
  819. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  820. if err != nil {
  821. return err
  822. }
  823. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  824. syncTopic, now := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength), time.Now().Unix()
  825. if _, err = a.db.Exec(insertUserQuery, userID, username, hash, role, syncTopic, now); err != nil {
  826. if sqliteErr, ok := err.(sqlite3.Error); ok && sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
  827. return ErrUserExists
  828. }
  829. return err
  830. }
  831. return nil
  832. }
  833. // RemoveUser deletes the user with the given username. The function returns nil on success, even
  834. // if the user did not exist in the first place.
  835. func (a *Manager) RemoveUser(username string) error {
  836. if !AllowedUsername(username) {
  837. return ErrInvalidArgument
  838. }
  839. // Rows in user_access, user_token, etc. are deleted via foreign keys
  840. if _, err := a.db.Exec(deleteUserQuery, username); err != nil {
  841. return err
  842. }
  843. return nil
  844. }
  845. // MarkUserRemoved sets the deleted flag on the user, and deletes all access tokens. This prevents
  846. // successful auth via Authenticate. A background process will delete the user at a later date.
  847. func (a *Manager) MarkUserRemoved(user *User) error {
  848. if !AllowedUsername(user.Name) {
  849. return ErrInvalidArgument
  850. }
  851. tx, err := a.db.Begin()
  852. if err != nil {
  853. return err
  854. }
  855. defer tx.Rollback()
  856. if _, err := tx.Exec(deleteUserAccessQuery, user.Name, user.Name); err != nil {
  857. return err
  858. }
  859. if _, err := tx.Exec(deleteAllTokenQuery, user.ID); err != nil {
  860. return err
  861. }
  862. if _, err := tx.Exec(updateUserDeletedQuery, time.Now().Add(userHardDeleteAfterDuration).Unix(), user.ID); err != nil {
  863. return err
  864. }
  865. return tx.Commit()
  866. }
  867. // Users returns a list of users. It always also returns the Everyone user ("*").
  868. func (a *Manager) Users() ([]*User, error) {
  869. rows, err := a.db.Query(selectUsernamesQuery)
  870. if err != nil {
  871. return nil, err
  872. }
  873. defer rows.Close()
  874. usernames := make([]string, 0)
  875. for rows.Next() {
  876. var username string
  877. if err := rows.Scan(&username); err != nil {
  878. return nil, err
  879. } else if err := rows.Err(); err != nil {
  880. return nil, err
  881. }
  882. usernames = append(usernames, username)
  883. }
  884. rows.Close()
  885. users := make([]*User, 0)
  886. for _, username := range usernames {
  887. user, err := a.User(username)
  888. if err != nil {
  889. return nil, err
  890. }
  891. users = append(users, user)
  892. }
  893. return users, nil
  894. }
  895. // UsersCount returns the number of users in the databsae
  896. func (a *Manager) UsersCount() (int64, error) {
  897. rows, err := a.db.Query(selectUserCountQuery)
  898. if err != nil {
  899. return 0, err
  900. }
  901. defer rows.Close()
  902. if !rows.Next() {
  903. return 0, errNoRows
  904. }
  905. var count int64
  906. if err := rows.Scan(&count); err != nil {
  907. return 0, err
  908. }
  909. return count, nil
  910. }
  911. // User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
  912. // You may also pass Everyone to retrieve the anonymous user and its Grant list.
  913. func (a *Manager) User(username string) (*User, error) {
  914. rows, err := a.db.Query(selectUserByNameQuery, username)
  915. if err != nil {
  916. return nil, err
  917. }
  918. return a.readUser(rows)
  919. }
  920. // UserByID returns the user with the given ID if it exists, or ErrUserNotFound otherwise
  921. func (a *Manager) UserByID(id string) (*User, error) {
  922. rows, err := a.db.Query(selectUserByIDQuery, id)
  923. if err != nil {
  924. return nil, err
  925. }
  926. return a.readUser(rows)
  927. }
  928. // UserByStripeCustomer returns the user with the given Stripe customer ID if it exists, or ErrUserNotFound otherwise.
  929. func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
  930. rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
  931. if err != nil {
  932. return nil, err
  933. }
  934. return a.readUser(rows)
  935. }
  936. func (a *Manager) userByToken(token string) (*User, error) {
  937. rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
  938. if err != nil {
  939. return nil, err
  940. }
  941. return a.readUser(rows)
  942. }
  943. func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
  944. defer rows.Close()
  945. var id, username, hash, role, prefs, syncTopic string
  946. var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripeSubscriptionInterval, stripeMonthlyPriceID, stripeYearlyPriceID, tierID, tierCode, tierName sql.NullString
  947. var messages, emails, calls int64
  948. var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
  949. if !rows.Next() {
  950. return nil, ErrUserNotFound
  951. }
  952. if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &calls, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionInterval, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  953. return nil, err
  954. } else if err := rows.Err(); err != nil {
  955. return nil, err
  956. }
  957. user := &User{
  958. ID: id,
  959. Name: username,
  960. Hash: hash,
  961. Role: Role(role),
  962. Prefs: &Prefs{},
  963. SyncTopic: syncTopic,
  964. Stats: &Stats{
  965. Messages: messages,
  966. Emails: emails,
  967. Calls: calls,
  968. },
  969. Billing: &Billing{
  970. StripeCustomerID: stripeCustomerID.String, // May be empty
  971. StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
  972. StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
  973. StripeSubscriptionInterval: stripe.PriceRecurringInterval(stripeSubscriptionInterval.String), // May be empty
  974. StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
  975. StripeSubscriptionCancelAt: time.Unix(stripeSubscriptionCancelAt.Int64, 0), // May be zero
  976. },
  977. Deleted: deleted.Valid,
  978. }
  979. if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
  980. return nil, err
  981. }
  982. if tierCode.Valid {
  983. // See readTier() when this is changed!
  984. user.Tier = &Tier{
  985. ID: tierID.String,
  986. Code: tierCode.String,
  987. Name: tierName.String,
  988. MessageLimit: messagesLimit.Int64,
  989. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  990. EmailLimit: emailsLimit.Int64,
  991. CallLimit: callsLimit.Int64,
  992. ReservationLimit: reservationsLimit.Int64,
  993. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  994. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  995. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  996. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  997. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  998. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  999. }
  1000. }
  1001. return user, nil
  1002. }
  1003. // AllGrants returns all user-specific access control entries, mapped to their respective user IDs
  1004. func (a *Manager) AllGrants() (map[string][]Grant, error) {
  1005. rows, err := a.db.Query(selectUserAllAccessQuery)
  1006. if err != nil {
  1007. return nil, err
  1008. }
  1009. defer rows.Close()
  1010. grants := make(map[string][]Grant, 0)
  1011. for rows.Next() {
  1012. var userID, topic string
  1013. var read, write bool
  1014. if err := rows.Scan(&userID, &topic, &read, &write); err != nil {
  1015. return nil, err
  1016. } else if err := rows.Err(); err != nil {
  1017. return nil, err
  1018. }
  1019. if _, ok := grants[userID]; !ok {
  1020. grants[userID] = make([]Grant, 0)
  1021. }
  1022. grants[userID] = append(grants[userID], Grant{
  1023. TopicPattern: fromSQLWildcard(topic),
  1024. Allow: NewPermission(read, write),
  1025. })
  1026. }
  1027. return grants, nil
  1028. }
  1029. // Grants returns all user-specific access control entries
  1030. func (a *Manager) Grants(username string) ([]Grant, error) {
  1031. rows, err := a.db.Query(selectUserAccessQuery, username)
  1032. if err != nil {
  1033. return nil, err
  1034. }
  1035. defer rows.Close()
  1036. grants := make([]Grant, 0)
  1037. for rows.Next() {
  1038. var topic string
  1039. var read, write bool
  1040. if err := rows.Scan(&topic, &read, &write); err != nil {
  1041. return nil, err
  1042. } else if err := rows.Err(); err != nil {
  1043. return nil, err
  1044. }
  1045. grants = append(grants, Grant{
  1046. TopicPattern: fromSQLWildcard(topic),
  1047. Allow: NewPermission(read, write),
  1048. })
  1049. }
  1050. return grants, nil
  1051. }
  1052. // Reservations returns all user-owned topics, and the associated everyone-access
  1053. func (a *Manager) Reservations(username string) ([]Reservation, error) {
  1054. rows, err := a.db.Query(selectUserReservationsQuery, Everyone, username)
  1055. if err != nil {
  1056. return nil, err
  1057. }
  1058. defer rows.Close()
  1059. reservations := make([]Reservation, 0)
  1060. for rows.Next() {
  1061. var topic string
  1062. var ownerRead, ownerWrite bool
  1063. var everyoneRead, everyoneWrite sql.NullBool
  1064. if err := rows.Scan(&topic, &ownerRead, &ownerWrite, &everyoneRead, &everyoneWrite); err != nil {
  1065. return nil, err
  1066. } else if err := rows.Err(); err != nil {
  1067. return nil, err
  1068. }
  1069. reservations = append(reservations, Reservation{
  1070. Topic: topic,
  1071. Owner: NewPermission(ownerRead, ownerWrite),
  1072. Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
  1073. })
  1074. }
  1075. return reservations, nil
  1076. }
  1077. // HasReservation returns true if the given topic access is owned by the user
  1078. func (a *Manager) HasReservation(username, topic string) (bool, error) {
  1079. rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
  1080. if err != nil {
  1081. return false, err
  1082. }
  1083. defer rows.Close()
  1084. if !rows.Next() {
  1085. return false, errNoRows
  1086. }
  1087. var count int64
  1088. if err := rows.Scan(&count); err != nil {
  1089. return false, err
  1090. }
  1091. return count > 0, nil
  1092. }
  1093. // ReservationsCount returns the number of reservations owned by this user
  1094. func (a *Manager) ReservationsCount(username string) (int64, error) {
  1095. rows, err := a.db.Query(selectUserReservationsCountQuery, username)
  1096. if err != nil {
  1097. return 0, err
  1098. }
  1099. defer rows.Close()
  1100. if !rows.Next() {
  1101. return 0, errNoRows
  1102. }
  1103. var count int64
  1104. if err := rows.Scan(&count); err != nil {
  1105. return 0, err
  1106. }
  1107. return count, nil
  1108. }
  1109. // ReservationOwner returns user ID of the user that owns this topic, or an
  1110. // empty string if it's not owned by anyone
  1111. func (a *Manager) ReservationOwner(topic string) (string, error) {
  1112. rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic)
  1113. if err != nil {
  1114. return "", err
  1115. }
  1116. defer rows.Close()
  1117. if !rows.Next() {
  1118. return "", nil
  1119. }
  1120. var ownerUserID string
  1121. if err := rows.Scan(&ownerUserID); err != nil {
  1122. return "", err
  1123. }
  1124. return ownerUserID, nil
  1125. }
  1126. // ChangePassword changes a user's password
  1127. func (a *Manager) ChangePassword(username, password string) error {
  1128. hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
  1129. if err != nil {
  1130. return err
  1131. }
  1132. if _, err := a.db.Exec(updateUserPassQuery, hash, username); err != nil {
  1133. return err
  1134. }
  1135. return nil
  1136. }
  1137. // ChangeRole changes a user's role. When a role is changed from RoleUser to RoleAdmin,
  1138. // all existing access control entries (Grant) are removed, since they are no longer needed.
  1139. func (a *Manager) ChangeRole(username string, role Role) error {
  1140. if !AllowedUsername(username) || !AllowedRole(role) {
  1141. return ErrInvalidArgument
  1142. }
  1143. if _, err := a.db.Exec(updateUserRoleQuery, string(role), username); err != nil {
  1144. return err
  1145. }
  1146. if role == RoleAdmin {
  1147. if _, err := a.db.Exec(deleteUserAccessQuery, username, username); err != nil {
  1148. return err
  1149. }
  1150. }
  1151. return nil
  1152. }
  1153. // ChangeTier changes a user's tier using the tier code. This function does not delete reservations, messages,
  1154. // or attachments, even if the new tier has lower limits in this regard. That has to be done elsewhere.
  1155. func (a *Manager) ChangeTier(username, tier string) error {
  1156. if !AllowedUsername(username) {
  1157. return ErrInvalidArgument
  1158. }
  1159. t, err := a.Tier(tier)
  1160. if err != nil {
  1161. return err
  1162. } else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
  1163. return err
  1164. }
  1165. if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
  1166. return err
  1167. }
  1168. return nil
  1169. }
  1170. // ResetTier removes the tier from the given user
  1171. func (a *Manager) ResetTier(username string) error {
  1172. if !AllowedUsername(username) && username != Everyone && username != "" {
  1173. return ErrInvalidArgument
  1174. } else if err := a.checkReservationsLimit(username, 0); err != nil {
  1175. return err
  1176. }
  1177. _, err := a.db.Exec(deleteUserTierQuery, username)
  1178. return err
  1179. }
  1180. func (a *Manager) checkReservationsLimit(username string, reservationsLimit int64) error {
  1181. u, err := a.User(username)
  1182. if err != nil {
  1183. return err
  1184. }
  1185. if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
  1186. reservations, err := a.Reservations(username)
  1187. if err != nil {
  1188. return err
  1189. } else if int64(len(reservations)) > reservationsLimit {
  1190. return ErrTooManyReservations
  1191. }
  1192. }
  1193. return nil
  1194. }
  1195. // AllowReservation tests if a user may create an access control entry for the given topic.
  1196. // If there are any ACL entries that are not owned by the user, an error is returned.
  1197. func (a *Manager) AllowReservation(username string, topic string) error {
  1198. if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
  1199. return ErrInvalidArgument
  1200. }
  1201. rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
  1202. if err != nil {
  1203. return err
  1204. }
  1205. defer rows.Close()
  1206. if !rows.Next() {
  1207. return errNoRows
  1208. }
  1209. var otherCount int
  1210. if err := rows.Scan(&otherCount); err != nil {
  1211. return err
  1212. }
  1213. if otherCount > 0 {
  1214. return errTopicOwnedByOthers
  1215. }
  1216. return nil
  1217. }
  1218. // AllowAccess adds or updates an entry in th access control list for a specific user. It controls
  1219. // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
  1220. // owner may either be a user (username), or the system (empty).
  1221. func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
  1222. if !AllowedUsername(username) && username != Everyone {
  1223. return ErrInvalidArgument
  1224. } else if !AllowedTopicPattern(topicPattern) {
  1225. return ErrInvalidArgument
  1226. }
  1227. owner := ""
  1228. if _, err := a.db.Exec(upsertUserAccessQuery, username, toSQLWildcard(topicPattern), permission.IsRead(), permission.IsWrite(), owner, owner); err != nil {
  1229. return err
  1230. }
  1231. return nil
  1232. }
  1233. // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
  1234. // empty) for an entire user. The parameter topicPattern may include wildcards (*).
  1235. func (a *Manager) ResetAccess(username string, topicPattern string) error {
  1236. if !AllowedUsername(username) && username != Everyone && username != "" {
  1237. return ErrInvalidArgument
  1238. } else if !AllowedTopicPattern(topicPattern) && topicPattern != "" {
  1239. return ErrInvalidArgument
  1240. }
  1241. if username == "" && topicPattern == "" {
  1242. _, err := a.db.Exec(deleteAllAccessQuery, username)
  1243. return err
  1244. } else if topicPattern == "" {
  1245. _, err := a.db.Exec(deleteUserAccessQuery, username, username)
  1246. return err
  1247. }
  1248. _, err := a.db.Exec(deleteTopicAccessQuery, username, username, toSQLWildcard(topicPattern))
  1249. return err
  1250. }
  1251. // AddReservation creates two access control entries for the given topic: one with full read/write access for the
  1252. // given user, and one for Everyone with the permission passed as everyone. The user also owns the entries, and
  1253. // can modify or delete them.
  1254. func (a *Manager) AddReservation(username string, topic string, everyone Permission) error {
  1255. if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
  1256. return ErrInvalidArgument
  1257. }
  1258. tx, err := a.db.Begin()
  1259. if err != nil {
  1260. return err
  1261. }
  1262. defer tx.Rollback()
  1263. if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
  1264. return err
  1265. }
  1266. if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
  1267. return err
  1268. }
  1269. return tx.Commit()
  1270. }
  1271. // RemoveReservations deletes the access control entries associated with the given username/topic, as
  1272. // well as all entries with Everyone/topic. This is the counterpart for AddReservation.
  1273. func (a *Manager) RemoveReservations(username string, topics ...string) error {
  1274. if !AllowedUsername(username) || username == Everyone || len(topics) == 0 {
  1275. return ErrInvalidArgument
  1276. }
  1277. for _, topic := range topics {
  1278. if !AllowedTopic(topic) {
  1279. return ErrInvalidArgument
  1280. }
  1281. }
  1282. tx, err := a.db.Begin()
  1283. if err != nil {
  1284. return err
  1285. }
  1286. defer tx.Rollback()
  1287. for _, topic := range topics {
  1288. if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
  1289. return err
  1290. }
  1291. if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
  1292. return err
  1293. }
  1294. }
  1295. return tx.Commit()
  1296. }
  1297. // DefaultAccess returns the default read/write access if no access control entry matches
  1298. func (a *Manager) DefaultAccess() Permission {
  1299. return a.defaultAccess
  1300. }
  1301. // AddTier creates a new tier in the database
  1302. func (a *Manager) AddTier(tier *Tier) error {
  1303. if tier.ID == "" {
  1304. tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
  1305. }
  1306. if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID)); err != nil {
  1307. return err
  1308. }
  1309. return nil
  1310. }
  1311. // UpdateTier updates a tier's properties in the database
  1312. func (a *Manager) UpdateTier(tier *Tier) error {
  1313. if _, err := a.db.Exec(updateTierQuery, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.CallLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, nullString(tier.StripeMonthlyPriceID), nullString(tier.StripeYearlyPriceID), tier.Code); err != nil {
  1314. return err
  1315. }
  1316. return nil
  1317. }
  1318. // RemoveTier deletes the tier with the given code
  1319. func (a *Manager) RemoveTier(code string) error {
  1320. if !AllowedTier(code) {
  1321. return ErrInvalidArgument
  1322. }
  1323. // This fails if any user has this tier
  1324. if _, err := a.db.Exec(deleteTierQuery, code); err != nil {
  1325. return err
  1326. }
  1327. return nil
  1328. }
  1329. // ChangeBilling updates a user's billing fields, namely the Stripe customer ID, and subscription information
  1330. func (a *Manager) ChangeBilling(username string, billing *Billing) error {
  1331. if _, err := a.db.Exec(updateBillingQuery, nullString(billing.StripeCustomerID), nullString(billing.StripeSubscriptionID), nullString(string(billing.StripeSubscriptionStatus)), nullString(string(billing.StripeSubscriptionInterval)), nullInt64(billing.StripeSubscriptionPaidUntil.Unix()), nullInt64(billing.StripeSubscriptionCancelAt.Unix()), username); err != nil {
  1332. return err
  1333. }
  1334. return nil
  1335. }
  1336. // Tiers returns a list of all Tier structs
  1337. func (a *Manager) Tiers() ([]*Tier, error) {
  1338. rows, err := a.db.Query(selectTiersQuery)
  1339. if err != nil {
  1340. return nil, err
  1341. }
  1342. defer rows.Close()
  1343. tiers := make([]*Tier, 0)
  1344. for {
  1345. tier, err := a.readTier(rows)
  1346. if err == ErrTierNotFound {
  1347. break
  1348. } else if err != nil {
  1349. return nil, err
  1350. }
  1351. tiers = append(tiers, tier)
  1352. }
  1353. return tiers, nil
  1354. }
  1355. // Tier returns a Tier based on the code, or ErrTierNotFound if it does not exist
  1356. func (a *Manager) Tier(code string) (*Tier, error) {
  1357. rows, err := a.db.Query(selectTierByCodeQuery, code)
  1358. if err != nil {
  1359. return nil, err
  1360. }
  1361. defer rows.Close()
  1362. return a.readTier(rows)
  1363. }
  1364. // TierByStripePrice returns a Tier based on the Stripe price ID, or ErrTierNotFound if it does not exist
  1365. func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
  1366. rows, err := a.db.Query(selectTierByPriceIDQuery, priceID, priceID)
  1367. if err != nil {
  1368. return nil, err
  1369. }
  1370. defer rows.Close()
  1371. return a.readTier(rows)
  1372. }
  1373. func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
  1374. var id, code, name string
  1375. var stripeMonthlyPriceID, stripeYearlyPriceID sql.NullString
  1376. var messagesLimit, messagesExpiryDuration, emailsLimit, callsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit sql.NullInt64
  1377. if !rows.Next() {
  1378. return nil, ErrTierNotFound
  1379. }
  1380. if err := rows.Scan(&id, &code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &callsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripeMonthlyPriceID, &stripeYearlyPriceID); err != nil {
  1381. return nil, err
  1382. } else if err := rows.Err(); err != nil {
  1383. return nil, err
  1384. }
  1385. // When changed, note readUser() as well
  1386. return &Tier{
  1387. ID: id,
  1388. Code: code,
  1389. Name: name,
  1390. MessageLimit: messagesLimit.Int64,
  1391. MessageExpiryDuration: time.Duration(messagesExpiryDuration.Int64) * time.Second,
  1392. EmailLimit: emailsLimit.Int64,
  1393. CallLimit: callsLimit.Int64,
  1394. ReservationLimit: reservationsLimit.Int64,
  1395. AttachmentFileSizeLimit: attachmentFileSizeLimit.Int64,
  1396. AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
  1397. AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
  1398. AttachmentBandwidthLimit: attachmentBandwidthLimit.Int64,
  1399. StripeMonthlyPriceID: stripeMonthlyPriceID.String, // May be empty
  1400. StripeYearlyPriceID: stripeYearlyPriceID.String, // May be empty
  1401. }, nil
  1402. }
  1403. // Close closes the underlying database
  1404. func (a *Manager) Close() error {
  1405. return a.db.Close()
  1406. }
  1407. func toSQLWildcard(s string) string {
  1408. return strings.ReplaceAll(s, "*", "%")
  1409. }
  1410. func fromSQLWildcard(s string) string {
  1411. return strings.ReplaceAll(s, "%", "*")
  1412. }
  1413. func runStartupQueries(db *sql.DB, startupQueries string) error {
  1414. if _, err := db.Exec(startupQueries); err != nil {
  1415. return err
  1416. }
  1417. if _, err := db.Exec(builtinStartupQueries); err != nil {
  1418. return err
  1419. }
  1420. return nil
  1421. }
  1422. func setupDB(db *sql.DB) error {
  1423. // If 'schemaVersion' table does not exist, this must be a new database
  1424. rowsSV, err := db.Query(selectSchemaVersionQuery)
  1425. if err != nil {
  1426. return setupNewDB(db)
  1427. }
  1428. defer rowsSV.Close()
  1429. // If 'schemaVersion' table exists, read version and potentially upgrade
  1430. schemaVersion := 0
  1431. if !rowsSV.Next() {
  1432. return errors.New("cannot determine schema version: database file may be corrupt")
  1433. }
  1434. if err := rowsSV.Scan(&schemaVersion); err != nil {
  1435. return err
  1436. }
  1437. rowsSV.Close()
  1438. // Do migrations
  1439. if schemaVersion == currentSchemaVersion {
  1440. return nil
  1441. } else if schemaVersion > currentSchemaVersion {
  1442. return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, currentSchemaVersion)
  1443. }
  1444. for i := schemaVersion; i < currentSchemaVersion; i++ {
  1445. fn, ok := migrations[i]
  1446. if !ok {
  1447. return fmt.Errorf("cannot find migration step from schema version %d to %d", i, i+1)
  1448. } else if err := fn(db); err != nil {
  1449. return err
  1450. }
  1451. }
  1452. return nil
  1453. }
  1454. func setupNewDB(db *sql.DB) error {
  1455. if _, err := db.Exec(createTablesQueries); err != nil {
  1456. return err
  1457. }
  1458. if _, err := db.Exec(insertSchemaVersion, currentSchemaVersion); err != nil {
  1459. return err
  1460. }
  1461. return nil
  1462. }
  1463. func migrateFrom1(db *sql.DB) error {
  1464. log.Tag(tag).Info("Migrating user database schema: from 1 to 2")
  1465. tx, err := db.Begin()
  1466. if err != nil {
  1467. return err
  1468. }
  1469. defer tx.Rollback()
  1470. // Rename user -> user_old, and create new tables
  1471. if _, err := tx.Exec(migrate1To2CreateTablesQueries); err != nil {
  1472. return err
  1473. }
  1474. // Insert users from user_old into new user table, with ID and sync_topic
  1475. rows, err := tx.Query(migrate1To2SelectAllOldUsernamesNoTx)
  1476. if err != nil {
  1477. return err
  1478. }
  1479. defer rows.Close()
  1480. usernames := make([]string, 0)
  1481. for rows.Next() {
  1482. var username string
  1483. if err := rows.Scan(&username); err != nil {
  1484. return err
  1485. }
  1486. usernames = append(usernames, username)
  1487. }
  1488. if err := rows.Close(); err != nil {
  1489. return err
  1490. }
  1491. for _, username := range usernames {
  1492. userID := util.RandomStringPrefix(userIDPrefix, userIDLength)
  1493. syncTopic := util.RandomStringPrefix(syncTopicPrefix, syncTopicLength)
  1494. if _, err := tx.Exec(migrate1To2InsertUserNoTx, userID, syncTopic, username); err != nil {
  1495. return err
  1496. }
  1497. }
  1498. // Migrate old "access" table to "user_access" and drop "access" and "user_old"
  1499. if _, err := tx.Exec(migrate1To2InsertFromOldTablesAndDropNoTx); err != nil {
  1500. return err
  1501. }
  1502. if _, err := tx.Exec(updateSchemaVersion, 2); err != nil {
  1503. return err
  1504. }
  1505. if err := tx.Commit(); err != nil {
  1506. return err
  1507. }
  1508. return nil
  1509. }
  1510. func migrateFrom2(db *sql.DB) error {
  1511. log.Tag(tag).Info("Migrating user database schema: from 2 to 3")
  1512. tx, err := db.Begin()
  1513. if err != nil {
  1514. return err
  1515. }
  1516. defer tx.Rollback()
  1517. if _, err := tx.Exec(migrate2To3UpdateQueries); err != nil {
  1518. return err
  1519. }
  1520. if _, err := tx.Exec(updateSchemaVersion, 3); err != nil {
  1521. return err
  1522. }
  1523. return tx.Commit()
  1524. }
  1525. func migrateFrom3(db *sql.DB) error {
  1526. log.Tag(tag).Info("Migrating user database schema: from 3 to 4")
  1527. tx, err := db.Begin()
  1528. if err != nil {
  1529. return err
  1530. }
  1531. defer tx.Rollback()
  1532. if _, err := tx.Exec(migrate3To4UpdateQueries); err != nil {
  1533. return err
  1534. }
  1535. if _, err := tx.Exec(updateSchemaVersion, 4); err != nil {
  1536. return err
  1537. }
  1538. return tx.Commit()
  1539. }
  1540. func nullString(s string) sql.NullString {
  1541. if s == "" {
  1542. return sql.NullString{}
  1543. }
  1544. return sql.NullString{String: s, Valid: true}
  1545. }
  1546. func nullInt64(v int64) sql.NullInt64 {
  1547. if v == 0 {
  1548. return sql.NullInt64{}
  1549. }
  1550. return sql.NullInt64{Int64: v, Valid: true}
  1551. }