webpush_store.go 8.8 KB


  1. package server
  2. import (
  3. "database/sql"
  4. "errors"
  5. "heckel.io/ntfy/util"
  6. "net/netip"
  7. "time"
  8. _ "github.com/mattn/go-sqlite3" // SQLite driver
  9. )
  10. const (
  11. subscriptionIDPrefix = "wps_"
  12. subscriptionIDLength = 10
  13. subscriptionEndpointLimitPerSubscriberIP = 10
  14. )
  15. var (
  16. errWebPushNoRows = errors.New("no rows found")
  17. errWebPushTooManySubscriptions = errors.New("too many subscriptions")
  18. errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
  19. )
  20. const (
  21. createWebPushSubscriptionsTableQuery = `
  22. BEGIN;
  23. CREATE TABLE IF NOT EXISTS subscription (
  24. id TEXT PRIMARY KEY,
  25. endpoint TEXT NOT NULL,
  26. key_auth TEXT NOT NULL,
  27. key_p256dh TEXT NOT NULL,
  28. user_id TEXT NOT NULL,
  29. subscriber_ip TEXT NOT NULL,
  30. updated_at INT NOT NULL,
  31. warned_at INT NOT NULL DEFAULT 0
  32. );
  33. CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
  34. CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
  35. CREATE TABLE IF NOT EXISTS subscription_topic (
  36. subscription_id TEXT NOT NULL,
  37. topic TEXT NOT NULL,
  38. PRIMARY KEY (subscription_id, topic),
  39. FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
  40. );
  41. CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
  42. CREATE TABLE IF NOT EXISTS schemaVersion (
  43. id INT PRIMARY KEY,
  44. version INT NOT NULL
  45. );
  46. COMMIT;
  47. `
  48. builtinStartupQueries = `
  49. PRAGMA foreign_keys = ON;
  50. `
  51. selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
  52. selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
  53. selectWebPushSubscriptionsForTopicQuery = `
  54. SELECT id, endpoint, key_auth, key_p256dh, user_id
  55. FROM subscription_topic st
  56. JOIN subscription s ON s.id = st.subscription_id
  57. WHERE st.topic = ?
  58. ORDER BY endpoint
  59. `
  60. selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?`
  61. insertWebPushSubscriptionQuery = `
  62. INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
  63. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  64. ON CONFLICT (endpoint)
  65. DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
  66. `
  67. updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
  68. deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
  69. deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
  70. deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
  71. insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
  72. deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
  73. )
  74. // Schema management queries
  75. const (
  76. currentWebPushSchemaVersion = 1
  77. insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  78. selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  79. )
  80. type webPushStore struct {
  81. db *sql.DB
  82. }
  83. func newWebPushStore(filename string) (*webPushStore, error) {
  84. db, err := sql.Open("sqlite3", filename)
  85. if err != nil {
  86. return nil, err
  87. }
  88. if err := setupWebPushDB(db); err != nil {
  89. return nil, err
  90. }
  91. if err := runWebPushStartupQueries(db); err != nil {
  92. return nil, err
  93. }
  94. return &webPushStore{
  95. db: db,
  96. }, nil
  97. }
  98. func setupWebPushDB(db *sql.DB) error {
  99. // If 'schemaVersion' table does not exist, this must be a new database
  100. rows, err := db.Query(selectWebPushSchemaVersionQuery)
  101. if err != nil {
  102. return setupNewWebPushDB(db)
  103. }
  104. return rows.Close()
  105. }
  106. func setupNewWebPushDB(db *sql.DB) error {
  107. if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
  108. return err
  109. }
  110. if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
  111. return err
  112. }
  113. return nil
  114. }
  115. func runWebPushStartupQueries(db *sql.DB) error {
  116. _, err := db.Exec(builtinStartupQueries)
  117. return err
  118. }
  119. // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
  120. // existing entries for a given endpoint.
  121. func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
  122. tx, err := c.db.Begin()
  123. if err != nil {
  124. return err
  125. }
  126. defer tx.Rollback()
  127. // Read number of subscriptions for subscriber IP address
  128. rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
  129. if err != nil {
  130. return err
  131. }
  132. defer rowsCount.Close()
  133. var subscriptionCount int
  134. if !rowsCount.Next() {
  135. return errWebPushNoRows
  136. }
  137. if err := rowsCount.Scan(&subscriptionCount); err != nil {
  138. return err
  139. }
  140. if err := rowsCount.Close(); err != nil {
  141. return err
  142. }
  143. // Read existing subscription ID for endpoint (or create new ID)
  144. rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
  145. if err != nil {
  146. return err
  147. }
  148. defer rows.Close()
  149. var subscriptionID string
  150. if rows.Next() {
  151. if err := rows.Scan(&subscriptionID); err != nil {
  152. return err
  153. }
  154. } else {
  155. if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
  156. return errWebPushTooManySubscriptions
  157. }
  158. subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
  159. }
  160. if err := rows.Close(); err != nil {
  161. return err
  162. }
  163. // Insert or update subscription
  164. updatedAt, warnedAt := time.Now().Unix(), 0
  165. if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
  166. return err
  167. }
  168. // Replace all subscription topics
  169. if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
  170. return err
  171. }
  172. for _, topic := range topics {
  173. if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
  174. return err
  175. }
  176. }
  177. return tx.Commit()
  178. }
  179. // SubscriptionsForTopic returns all subscriptions for the given topic
  180. func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
  181. rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
  182. if err != nil {
  183. return nil, err
  184. }
  185. defer rows.Close()
  186. return c.subscriptionsFromRows(rows)
  187. }
  188. // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
  189. func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
  190. rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
  191. if err != nil {
  192. return nil, err
  193. }
  194. defer rows.Close()
  195. return c.subscriptionsFromRows(rows)
  196. }
  197. // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
  198. func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
  199. tx, err := c.db.Begin()
  200. if err != nil {
  201. return err
  202. }
  203. defer tx.Rollback()
  204. for _, subscription := range subscriptions {
  205. if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
  206. return err
  207. }
  208. }
  209. return tx.Commit()
  210. }
  211. func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
  212. subscriptions := make([]*webPushSubscription, 0)
  213. for rows.Next() {
  214. var id, endpoint, auth, p256dh, userID string
  215. if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
  216. return nil, err
  217. }
  218. subscriptions = append(subscriptions, &webPushSubscription{
  219. ID: id,
  220. Endpoint: endpoint,
  221. Auth: auth,
  222. P256dh: p256dh,
  223. UserID: userID,
  224. })
  225. }
  226. return subscriptions, nil
  227. }
  228. // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
  229. func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
  230. _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
  231. return err
  232. }
  233. // RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
  234. func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
  235. if userID == "" {
  236. return errWebPushUserIDCannotBeEmpty
  237. }
  238. _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
  239. return err
  240. }
  241. // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
  242. func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
  243. _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
  244. return err
  245. }
  246. // Close closes the underlying database connection
  247. func (c *webPushStore) Close() error {
  248. return c.db.Close()
  249. }