webpush_store.go 8.6 KB


  1. package server
  2. import (
  3. "database/sql"
  4. "errors"
  5. "heckel.io/ntfy/util"
  6. "net/netip"
  7. "time"
  8. _ "github.com/mattn/go-sqlite3" // SQLite driver
  9. )
  10. const (
  11. subscriptionIDPrefix = "wps_"
  12. subscriptionIDLength = 10
  13. subscriptionLimitPerSubscriberIP = 10
  14. )
  15. var (
  16. errWebPushNoRows = errors.New("no rows found")
  17. errWebPushTooManySubscriptions = errors.New("too many subscriptions")
  18. )
  19. const (
  20. createWebPushSubscriptionsTableQuery = `
  21. BEGIN;
  22. CREATE TABLE IF NOT EXISTS subscription (
  23. id TEXT PRIMARY KEY,
  24. endpoint TEXT NOT NULL,
  25. key_auth TEXT NOT NULL,
  26. key_p256dh TEXT NOT NULL,
  27. user_id TEXT NOT NULL,
  28. subscriber_ip TEXT NOT NULL,
  29. updated_at INT NOT NULL,
  30. warned_at INT NOT NULL DEFAULT 0
  31. );
  32. CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
  33. CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
  34. CREATE TABLE IF NOT EXISTS subscription_topic (
  35. subscription_id TEXT NOT NULL,
  36. topic TEXT NOT NULL,
  37. PRIMARY KEY (subscription_id, topic),
  38. FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
  39. );
  40. CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
  41. CREATE TABLE IF NOT EXISTS schemaVersion (
  42. id INT PRIMARY KEY,
  43. version INT NOT NULL
  44. );
  45. COMMIT;
  46. `
  47. builtinStartupQueries = `
  48. PRAGMA foreign_keys = ON;
  49. `
  50. selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
  51. selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
  52. selectWebPushSubscriptionsForTopicQuery = `
  53. SELECT id, endpoint, key_auth, key_p256dh, user_id
  54. FROM subscription_topic st
  55. JOIN subscription s ON s.id = st.subscription_id
  56. WHERE st.topic = ?
  57. `
  58. selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?`
  59. insertWebPushSubscriptionQuery = `
  60. INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
  61. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  62. ON CONFLICT (endpoint)
  63. DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
  64. `
  65. updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
  66. deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
  67. deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
  68. deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
  69. insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
  70. deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
  71. )
  72. // Schema management queries
  73. const (
  74. currentWebPushSchemaVersion = 1
  75. insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  76. selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  77. )
  78. type webPushStore struct {
  79. db *sql.DB
  80. }
  81. func newWebPushStore(filename string) (*webPushStore, error) {
  82. db, err := sql.Open("sqlite3", filename)
  83. if err != nil {
  84. return nil, err
  85. }
  86. if err := setupWebPushDB(db); err != nil {
  87. return nil, err
  88. }
  89. if err := runWebPushStartupQueries(db); err != nil {
  90. return nil, err
  91. }
  92. return &webPushStore{
  93. db: db,
  94. }, nil
  95. }
  96. func setupWebPushDB(db *sql.DB) error {
  97. // If 'schemaVersion' table does not exist, this must be a new database
  98. rows, err := db.Query(selectWebPushSchemaVersionQuery)
  99. if err != nil {
  100. return setupNewWebPushDB(db)
  101. }
  102. return rows.Close()
  103. }
  104. func setupNewWebPushDB(db *sql.DB) error {
  105. if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
  106. return err
  107. }
  108. if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
  109. return err
  110. }
  111. return nil
  112. }
  113. func runWebPushStartupQueries(db *sql.DB) error {
  114. _, err := db.Exec(builtinStartupQueries)
  115. return err
  116. }
  117. // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
  118. // existing entries for a given endpoint.
  119. func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
  120. tx, err := c.db.Begin()
  121. if err != nil {
  122. return err
  123. }
  124. defer tx.Rollback()
  125. // Read number of subscriptions for subscriber IP address
  126. rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
  127. if err != nil {
  128. return err
  129. }
  130. defer rowsCount.Close()
  131. var subscriptionCount int
  132. if !rowsCount.Next() {
  133. return errWebPushNoRows
  134. }
  135. if err := rowsCount.Scan(&subscriptionCount); err != nil {
  136. return err
  137. }
  138. if err := rowsCount.Close(); err != nil {
  139. return err
  140. }
  141. // Read existing subscription ID for endpoint (or create new ID)
  142. rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
  143. if err != nil {
  144. return err
  145. }
  146. defer rows.Close()
  147. var subscriptionID string
  148. if rows.Next() {
  149. if err := rows.Scan(&subscriptionID); err != nil {
  150. return err
  151. }
  152. } else {
  153. if subscriptionCount >= subscriptionLimitPerSubscriberIP {
  154. return errWebPushTooManySubscriptions
  155. }
  156. subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
  157. }
  158. if err := rows.Close(); err != nil {
  159. return err
  160. }
  161. // Insert or update subscription
  162. updatedAt, warnedAt := time.Now().Unix(), 0
  163. if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
  164. return err
  165. }
  166. // Replace all subscription topics
  167. if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
  168. return err
  169. }
  170. for _, topic := range topics {
  171. if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
  172. return err
  173. }
  174. }
  175. return tx.Commit()
  176. }
  177. // SubscriptionsForTopic returns all subscriptions for the given topic
  178. func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
  179. rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
  180. if err != nil {
  181. return nil, err
  182. }
  183. defer rows.Close()
  184. return c.subscriptionsFromRows(rows)
  185. }
  186. // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
  187. func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
  188. rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
  189. if err != nil {
  190. return nil, err
  191. }
  192. defer rows.Close()
  193. return c.subscriptionsFromRows(rows)
  194. }
  195. // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
  196. func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
  197. tx, err := c.db.Begin()
  198. if err != nil {
  199. return err
  200. }
  201. defer tx.Rollback()
  202. for _, subscription := range subscriptions {
  203. if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
  204. return err
  205. }
  206. }
  207. return tx.Commit()
  208. }
  209. func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
  210. subscriptions := make([]*webPushSubscription, 0)
  211. for rows.Next() {
  212. var id, endpoint, auth, p256dh, userID string
  213. if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
  214. return nil, err
  215. }
  216. subscriptions = append(subscriptions, &webPushSubscription{
  217. ID: id,
  218. Endpoint: endpoint,
  219. Auth: auth,
  220. P256dh: p256dh,
  221. UserID: userID,
  222. })
  223. }
  224. return subscriptions, nil
  225. }
  226. // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
  227. func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
  228. _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
  229. return err
  230. }
  231. // RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
  232. func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
  233. _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
  234. return err
  235. }
  236. // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
  237. func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
  238. _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
  239. return err
  240. }
  241. // Close closes the underlying database connection
  242. func (c *webPushStore) Close() error {
  243. return c.db.Close()
  244. }