webpush_store.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package server
  2. import (
  3. "database/sql"
  4. "errors"
  5. "heckel.io/ntfy/v2/util"
  6. "net/netip"
  7. "time"
  8. _ "github.com/mattn/go-sqlite3" // SQLite driver
  9. )
  10. const (
  11. subscriptionIDPrefix = "wps_"
  12. subscriptionIDLength = 10
  13. subscriptionEndpointLimitPerSubscriberIP = 10
  14. )
  15. var (
  16. errWebPushNoRows = errors.New("no rows found")
  17. errWebPushTooManySubscriptions = errors.New("too many subscriptions")
  18. errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
  19. )
  20. const (
  21. createWebPushSubscriptionsTableQuery = `
  22. BEGIN;
  23. CREATE TABLE IF NOT EXISTS subscription (
  24. id TEXT PRIMARY KEY,
  25. endpoint TEXT NOT NULL,
  26. key_auth TEXT NOT NULL,
  27. key_p256dh TEXT NOT NULL,
  28. user_id TEXT NOT NULL,
  29. subscriber_ip TEXT NOT NULL,
  30. updated_at INT NOT NULL,
  31. warned_at INT NOT NULL DEFAULT 0
  32. );
  33. CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
  34. CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
  35. CREATE TABLE IF NOT EXISTS subscription_topic (
  36. subscription_id TEXT NOT NULL,
  37. topic TEXT NOT NULL,
  38. PRIMARY KEY (subscription_id, topic),
  39. FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
  40. );
  41. CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
  42. CREATE TABLE IF NOT EXISTS schemaVersion (
  43. id INT PRIMARY KEY,
  44. version INT NOT NULL
  45. );
  46. COMMIT;
  47. `
  48. builtinStartupQueries = `
  49. PRAGMA foreign_keys = ON;
  50. `
  51. selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
  52. selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
  53. selectWebPushSubscriptionsForTopicQuery = `
  54. SELECT id, endpoint, key_auth, key_p256dh, user_id
  55. FROM subscription_topic st
  56. JOIN subscription s ON s.id = st.subscription_id
  57. WHERE st.topic = ?
  58. ORDER BY endpoint
  59. `
  60. selectWebPushSubscriptionsExpiringSoonQuery = `
  61. SELECT id, endpoint, key_auth, key_p256dh, user_id
  62. FROM subscription
  63. WHERE warned_at = 0 AND updated_at <= ?
  64. `
  65. insertWebPushSubscriptionQuery = `
  66. INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
  67. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  68. ON CONFLICT (endpoint)
  69. DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
  70. `
  71. updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
  72. deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
  73. deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
  74. deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
  75. insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
  76. deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
  77. deleteWebPushSubscriptionTopicWithoutSubscription = `DELETE FROM subscription_topic WHERE subscription_id NOT IN (SELECT id FROM subscription)`
  78. )
  79. // Schema management queries
  80. const (
  81. currentWebPushSchemaVersion = 1
  82. insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  83. selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  84. )
  85. type webPushStore struct {
  86. db *sql.DB
  87. }
  88. func newWebPushStore(filename, startupQueries string) (*webPushStore, error) {
  89. db, err := sql.Open("sqlite3", filename)
  90. if err != nil {
  91. return nil, err
  92. }
  93. if err := setupWebPushDB(db); err != nil {
  94. return nil, err
  95. }
  96. if err := runWebPushStartupQueries(db, startupQueries); err != nil {
  97. return nil, err
  98. }
  99. return &webPushStore{
  100. db: db,
  101. }, nil
  102. }
  103. func setupWebPushDB(db *sql.DB) error {
  104. // If 'schemaVersion' table does not exist, this must be a new database
  105. rows, err := db.Query(selectWebPushSchemaVersionQuery)
  106. if err != nil {
  107. return setupNewWebPushDB(db)
  108. }
  109. return rows.Close()
  110. }
  111. func setupNewWebPushDB(db *sql.DB) error {
  112. if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
  113. return err
  114. }
  115. if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
  116. return err
  117. }
  118. return nil
  119. }
  120. func runWebPushStartupQueries(db *sql.DB, startupQueries string) error {
  121. if _, err := db.Exec(startupQueries); err != nil {
  122. return err
  123. }
  124. if _, err := db.Exec(builtinStartupQueries); err != nil {
  125. return err
  126. }
  127. return nil
  128. }
  129. // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
  130. // existing entries for a given endpoint.
  131. func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
  132. tx, err := c.db.Begin()
  133. if err != nil {
  134. return err
  135. }
  136. defer tx.Rollback()
  137. // Read number of subscriptions for subscriber IP address
  138. rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
  139. if err != nil {
  140. return err
  141. }
  142. defer rowsCount.Close()
  143. var subscriptionCount int
  144. if !rowsCount.Next() {
  145. return errWebPushNoRows
  146. }
  147. if err := rowsCount.Scan(&subscriptionCount); err != nil {
  148. return err
  149. }
  150. if err := rowsCount.Close(); err != nil {
  151. return err
  152. }
  153. // Read existing subscription ID for endpoint (or create new ID)
  154. rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
  155. if err != nil {
  156. return err
  157. }
  158. defer rows.Close()
  159. var subscriptionID string
  160. if rows.Next() {
  161. if err := rows.Scan(&subscriptionID); err != nil {
  162. return err
  163. }
  164. } else {
  165. if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
  166. return errWebPushTooManySubscriptions
  167. }
  168. subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
  169. }
  170. if err := rows.Close(); err != nil {
  171. return err
  172. }
  173. // Insert or update subscription
  174. updatedAt, warnedAt := time.Now().Unix(), 0
  175. if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
  176. return err
  177. }
  178. // Replace all subscription topics
  179. if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
  180. return err
  181. }
  182. for _, topic := range topics {
  183. if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
  184. return err
  185. }
  186. }
  187. return tx.Commit()
  188. }
  189. // SubscriptionsForTopic returns all subscriptions for the given topic
  190. func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
  191. rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
  192. if err != nil {
  193. return nil, err
  194. }
  195. defer rows.Close()
  196. return c.subscriptionsFromRows(rows)
  197. }
  198. // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
  199. func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
  200. rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
  201. if err != nil {
  202. return nil, err
  203. }
  204. defer rows.Close()
  205. return c.subscriptionsFromRows(rows)
  206. }
  207. // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
  208. func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
  209. tx, err := c.db.Begin()
  210. if err != nil {
  211. return err
  212. }
  213. defer tx.Rollback()
  214. for _, subscription := range subscriptions {
  215. if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
  216. return err
  217. }
  218. }
  219. return tx.Commit()
  220. }
  221. func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
  222. subscriptions := make([]*webPushSubscription, 0)
  223. for rows.Next() {
  224. var id, endpoint, auth, p256dh, userID string
  225. if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
  226. return nil, err
  227. }
  228. subscriptions = append(subscriptions, &webPushSubscription{
  229. ID: id,
  230. Endpoint: endpoint,
  231. Auth: auth,
  232. P256dh: p256dh,
  233. UserID: userID,
  234. })
  235. }
  236. return subscriptions, nil
  237. }
  238. // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
  239. func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
  240. _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
  241. return err
  242. }
  243. // RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
  244. func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
  245. if userID == "" {
  246. return errWebPushUserIDCannotBeEmpty
  247. }
  248. _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
  249. return err
  250. }
  251. // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
  252. func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
  253. _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
  254. if err != nil {
  255. return err
  256. }
  257. _, err = c.db.Exec(deleteWebPushSubscriptionTopicWithoutSubscription)
  258. return err
  259. }
  260. // Close closes the underlying database connection
  261. func (c *webPushStore) Close() error {
  262. return c.db.Close()
  263. }