webpush_store.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package server
  2. import (
  3. "database/sql"
  4. "errors"
  5. "heckel.io/ntfy/util"
  6. "net/netip"
  7. "time"
  8. _ "github.com/mattn/go-sqlite3" // SQLite driver
  9. )
  10. const (
  11. subscriptionIDPrefix = "wps_"
  12. subscriptionIDLength = 10
  13. subscriptionEndpointLimitPerSubscriberIP = 10
  14. )
  15. var (
  16. errWebPushNoRows = errors.New("no rows found")
  17. errWebPushTooManySubscriptions = errors.New("too many subscriptions")
  18. errWebPushUserIDCannotBeEmpty = errors.New("user ID cannot be empty")
  19. )
  20. const (
  21. createWebPushSubscriptionsTableQuery = `
  22. BEGIN;
  23. CREATE TABLE IF NOT EXISTS subscription (
  24. id TEXT PRIMARY KEY,
  25. endpoint TEXT NOT NULL,
  26. key_auth TEXT NOT NULL,
  27. key_p256dh TEXT NOT NULL,
  28. user_id TEXT NOT NULL,
  29. subscriber_ip TEXT NOT NULL,
  30. updated_at INT NOT NULL,
  31. warned_at INT NOT NULL DEFAULT 0
  32. );
  33. CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
  34. CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
  35. CREATE TABLE IF NOT EXISTS subscription_topic (
  36. subscription_id TEXT NOT NULL,
  37. topic TEXT NOT NULL,
  38. PRIMARY KEY (subscription_id, topic),
  39. FOREIGN KEY (subscription_id) REFERENCES subscription (id) ON DELETE CASCADE
  40. );
  41. CREATE INDEX IF NOT EXISTS idx_topic ON subscription_topic (topic);
  42. CREATE TABLE IF NOT EXISTS schemaVersion (
  43. id INT PRIMARY KEY,
  44. version INT NOT NULL
  45. );
  46. COMMIT;
  47. `
  48. builtinStartupQueries = `
  49. PRAGMA foreign_keys = ON;
  50. `
  51. selectWebPushSubscriptionIDByEndpoint = `SELECT id FROM subscription WHERE endpoint = ?`
  52. selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
  53. selectWebPushSubscriptionsForTopicQuery = `
  54. SELECT id, endpoint, key_auth, key_p256dh, user_id
  55. FROM subscription_topic st
  56. JOIN subscription s ON s.id = st.subscription_id
  57. WHERE st.topic = ?
  58. ORDER BY endpoint
  59. `
  60. selectWebPushSubscriptionsExpiringSoonQuery = `
  61. SELECT id, endpoint, key_auth, key_p256dh, user_id
  62. FROM subscription
  63. WHERE warned_at = 0 AND updated_at <= ?
  64. `
  65. insertWebPushSubscriptionQuery = `
  66. INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
  67. VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  68. ON CONFLICT (endpoint)
  69. DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
  70. `
  71. updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
  72. deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscription WHERE endpoint = ?`
  73. deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscription WHERE user_id = ?`
  74. deleteWebPushSubscriptionByAgeQuery = `DELETE FROM subscription WHERE updated_at <= ?` // Full table scan!
  75. insertWebPushSubscriptionTopicQuery = `INSERT INTO subscription_topic (subscription_id, topic) VALUES (?, ?)`
  76. deleteWebPushSubscriptionTopicAllQuery = `DELETE FROM subscription_topic WHERE subscription_id = ?`
  77. )
  78. // Schema management queries
  79. const (
  80. currentWebPushSchemaVersion = 1
  81. insertWebPushSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
  82. selectWebPushSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
  83. )
  84. type webPushStore struct {
  85. db *sql.DB
  86. }
  87. func newWebPushStore(filename string) (*webPushStore, error) {
  88. db, err := sql.Open("sqlite3", filename)
  89. if err != nil {
  90. return nil, err
  91. }
  92. if err := setupWebPushDB(db); err != nil {
  93. return nil, err
  94. }
  95. if err := runWebPushStartupQueries(db); err != nil {
  96. return nil, err
  97. }
  98. return &webPushStore{
  99. db: db,
  100. }, nil
  101. }
  102. func setupWebPushDB(db *sql.DB) error {
  103. // If 'schemaVersion' table does not exist, this must be a new database
  104. rows, err := db.Query(selectWebPushSchemaVersionQuery)
  105. if err != nil {
  106. return setupNewWebPushDB(db)
  107. }
  108. return rows.Close()
  109. }
  110. func setupNewWebPushDB(db *sql.DB) error {
  111. if _, err := db.Exec(createWebPushSubscriptionsTableQuery); err != nil {
  112. return err
  113. }
  114. if _, err := db.Exec(insertWebPushSchemaVersion, currentWebPushSchemaVersion); err != nil {
  115. return err
  116. }
  117. return nil
  118. }
  119. func runWebPushStartupQueries(db *sql.DB) error {
  120. _, err := db.Exec(builtinStartupQueries)
  121. return err
  122. }
  123. // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
  124. // existing entries for a given endpoint.
  125. func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
  126. tx, err := c.db.Begin()
  127. if err != nil {
  128. return err
  129. }
  130. defer tx.Rollback()
  131. // Read number of subscriptions for subscriber IP address
  132. rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
  133. if err != nil {
  134. return err
  135. }
  136. defer rowsCount.Close()
  137. var subscriptionCount int
  138. if !rowsCount.Next() {
  139. return errWebPushNoRows
  140. }
  141. if err := rowsCount.Scan(&subscriptionCount); err != nil {
  142. return err
  143. }
  144. if err := rowsCount.Close(); err != nil {
  145. return err
  146. }
  147. // Read existing subscription ID for endpoint (or create new ID)
  148. rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
  149. if err != nil {
  150. return err
  151. }
  152. defer rows.Close()
  153. var subscriptionID string
  154. if rows.Next() {
  155. if err := rows.Scan(&subscriptionID); err != nil {
  156. return err
  157. }
  158. } else {
  159. if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
  160. return errWebPushTooManySubscriptions
  161. }
  162. subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
  163. }
  164. if err := rows.Close(); err != nil {
  165. return err
  166. }
  167. // Insert or update subscription
  168. updatedAt, warnedAt := time.Now().Unix(), 0
  169. if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
  170. return err
  171. }
  172. // Replace all subscription topics
  173. if _, err := tx.Exec(deleteWebPushSubscriptionTopicAllQuery, subscriptionID); err != nil {
  174. return err
  175. }
  176. for _, topic := range topics {
  177. if _, err = tx.Exec(insertWebPushSubscriptionTopicQuery, subscriptionID, topic); err != nil {
  178. return err
  179. }
  180. }
  181. return tx.Commit()
  182. }
  183. // SubscriptionsForTopic returns all subscriptions for the given topic
  184. func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
  185. rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
  186. if err != nil {
  187. return nil, err
  188. }
  189. defer rows.Close()
  190. return c.subscriptionsFromRows(rows)
  191. }
  192. // SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
  193. func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
  194. rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
  195. if err != nil {
  196. return nil, err
  197. }
  198. defer rows.Close()
  199. return c.subscriptionsFromRows(rows)
  200. }
  201. // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
  202. func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
  203. tx, err := c.db.Begin()
  204. if err != nil {
  205. return err
  206. }
  207. defer tx.Rollback()
  208. for _, subscription := range subscriptions {
  209. if _, err := tx.Exec(updateWebPushSubscriptionWarningSentQuery, time.Now().Unix(), subscription.ID); err != nil {
  210. return err
  211. }
  212. }
  213. return tx.Commit()
  214. }
  215. func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscription, error) {
  216. subscriptions := make([]*webPushSubscription, 0)
  217. for rows.Next() {
  218. var id, endpoint, auth, p256dh, userID string
  219. if err := rows.Scan(&id, &endpoint, &auth, &p256dh, &userID); err != nil {
  220. return nil, err
  221. }
  222. subscriptions = append(subscriptions, &webPushSubscription{
  223. ID: id,
  224. Endpoint: endpoint,
  225. Auth: auth,
  226. P256dh: p256dh,
  227. UserID: userID,
  228. })
  229. }
  230. return subscriptions, nil
  231. }
  232. // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
  233. func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
  234. _, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
  235. return err
  236. }
  237. // RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
  238. func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
  239. if userID == "" {
  240. return errWebPushUserIDCannotBeEmpty
  241. }
  242. _, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
  243. return err
  244. }
  245. // RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
  246. func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
  247. _, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
  248. return err
  249. }
  250. // Close closes the underlying database connection
  251. func (c *webPushStore) Close() error {
  252. return c.db.Close()
  253. }