Просмотр исходного кода

Limit number of webpush subscriptions per subscriber IP

binwiederhier 2 лет назад
Родитель
Сommit
341e84f643
5 измененных файлов с 95 добавлено и 16 удалено
  1. 2 4
      cmd/webpush.go
  2. 1 1
      server/server_webpush.go
  3. 2 1
      server/server_webpush_test.go
  4. 47 10
      server/webpush_store.go
  5. 43 0
      server/webpush_store_test.go

+ 2 - 4
cmd/webpush.go

@@ -35,8 +35,7 @@ func generateWebPushKeys(c *cli.Context) error {
 	if err != nil {
 		return err
 	}
-
-	fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file:
+	_, err = fmt.Fprintf(c.App.ErrWriter, `Web Push keys generated. Add the following lines to your config file:
 
 web-push-public-key: %s
 web-push-private-key: %s
@@ -45,6 +44,5 @@ web-push-email-address: <email address>
 
 See https://ntfy.sh/docs/config/#web-push for details.
 `, publicKey, privateKey)
-
-	return nil
+	return err
 }

+ 1 - 1
server/server_web_push.go → server/server_webpush.go

@@ -59,7 +59,7 @@ func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *
 			}
 		}
 	}
-	if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), req.Topics); err != nil {
+	if err := s.webPush.UpsertSubscription(req.Endpoint, req.Auth, req.P256dh, v.MaybeUserID(), v.IP(), req.Topics); err != nil {
 		return err
 	}
 	return s.writeJSON(w, newSuccessResponse())

+ 2 - 1
server/server_web_push_test.go → server/server_webpush_test.go

@@ -9,6 +9,7 @@ import (
 	"io"
 	"net/http"
 	"net/http/httptest"
+	"net/netip"
 	"strings"
 	"sync/atomic"
 	"testing"
@@ -225,7 +226,7 @@ func payloadForTopics(t *testing.T, topics []string, endpoint string) string {
 }
 
 func addSubscription(t *testing.T, s *Server, endpoint string, topics ...string) {
-	require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", topics)) // Test auth and p256dh
+	require.Nil(t, s.webPush.UpsertSubscription(endpoint, "kSC3T8aN1JCQxxPdrFLrZg", "BMKKbxdUU_xLS7G1Wh5AN8PvWOjCzkCuKZYb8apcqYrDxjOF_2piggBnoJLQYx9IeSD70fNuwawI3e9Y8m3S3PE", "u_123", netip.MustParseAddr("1.2.3.4"), topics)) // Test auth and p256dh
 }
 
 func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {

+ 47 - 10
server/webpush_store.go

@@ -2,15 +2,23 @@ package server
 
 import (
 	"database/sql"
+	"errors"
 	"heckel.io/ntfy/util"
+	"net/netip"
 	"time"
 
 	_ "github.com/mattn/go-sqlite3" // SQLite driver
 )
 
 const (
-	subscriptionIDPrefix = "wps_"
-	subscriptionIDLength = 10
+	subscriptionIDPrefix             = "wps_"
+	subscriptionIDLength             = 10
+	subscriptionLimitPerSubscriberIP = 10
+)
+
+var (
+	errWebPushNoRows               = errors.New("no rows found")
+	errWebPushTooManySubscriptions = errors.New("too many subscriptions")
 )
 
 const (
@@ -21,11 +29,13 @@ const (
 			endpoint TEXT NOT NULL,
 			key_auth TEXT NOT NULL,
 			key_p256dh TEXT NOT NULL,
-			user_id TEXT NOT NULL,
+			user_id TEXT NOT NULL,		
+			subscriber_ip TEXT NOT NULL,
 			updated_at INT NOT NULL,
 			warned_at INT NOT NULL DEFAULT 0
 		);
 		CREATE UNIQUE INDEX IF NOT EXISTS idx_endpoint ON subscription (endpoint);
+		CREATE INDEX IF NOT EXISTS idx_subscriber_ip ON subscription (subscriber_ip);
 		CREATE TABLE IF NOT EXISTS subscription_topic (
 			subscription_id TEXT NOT NULL,
 			topic TEXT NOT NULL,
@@ -43,8 +53,9 @@ const (
 		PRAGMA foreign_keys = ON;
 	`
 
-	selectWebPushSubscriptionIDByEndpoint   = `SELECT id FROM subscription WHERE endpoint = ?`
-	selectWebPushSubscriptionsForTopicQuery = `
+	selectWebPushSubscriptionIDByEndpoint        = `SELECT id FROM subscription WHERE endpoint = ?`
+	selectWebPushSubscriptionCountBySubscriberIP = `SELECT COUNT(*) FROM subscription WHERE subscriber_ip = ?`
+	selectWebPushSubscriptionsForTopicQuery      = `
 		SELECT id, endpoint, key_auth, key_p256dh, user_id
 		FROM subscription_topic st
 		JOIN subscription s ON s.id = st.subscription_id
@@ -52,10 +63,10 @@ const (
 	`
 	selectWebPushSubscriptionsExpiringSoonQuery = `SELECT id, endpoint, key_auth, key_p256dh, user_id FROM subscription WHERE warned_at = 0 AND updated_at <= ?`
 	insertWebPushSubscriptionQuery              = `
-		INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, updated_at, warned_at)
-		VALUES (?, ?, ?, ?, ?, ?, ?)
+		INSERT INTO subscription (id, endpoint, key_auth, key_p256dh, user_id, subscriber_ip, updated_at, warned_at)
+		VALUES (?, ?, ?, ?, ?, ?, ?, ?)
 		ON CONFLICT (endpoint) 
-		DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, updated_at = excluded.updated_at, warned_at = excluded.warned_at
+		DO UPDATE SET key_auth = excluded.key_auth, key_p256dh = excluded.key_p256dh, user_id = excluded.user_id, subscriber_ip = excluded.subscriber_ip, updated_at = excluded.updated_at, warned_at = excluded.warned_at
 	`
 	updateWebPushSubscriptionWarningSentQuery = `UPDATE subscription SET warned_at = ? WHERE id = ?`
 	deleteWebPushSubscriptionByEndpointQuery  = `DELETE FROM subscription WHERE endpoint = ?`
@@ -119,12 +130,28 @@ func runWebPushStartupQueries(db *sql.DB) error {
 
 // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. It always first deletes all
 // existing entries for a given endpoint.
-func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, topics []string) error {
+func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
 	tx, err := c.db.Begin()
 	if err != nil {
 		return err
 	}
 	defer tx.Rollback()
+	// Read number of subscriptions for subscriber IP address
+	rowsCount, err := tx.Query(selectWebPushSubscriptionCountBySubscriberIP, subscriberIP.String())
+	if err != nil {
+		return err
+	}
+	defer rowsCount.Close()
+	var subscriptionCount int
+	if !rowsCount.Next() {
+		return errWebPushNoRows
+	}
+	if err := rowsCount.Scan(&subscriptionCount); err != nil {
+		return err
+	}
+	if err := rowsCount.Close(); err != nil {
+		return err
+	}
 	// Read existing subscription ID for endpoint (or create new ID)
 	rows, err := tx.Query(selectWebPushSubscriptionIDByEndpoint, endpoint)
 	if err != nil {
@@ -137,6 +164,9 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
 			return err
 		}
 	} else {
+		if subscriptionCount >= subscriptionLimitPerSubscriberIP {
+			return errWebPushTooManySubscriptions
+		}
 		subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
 	}
 	if err := rows.Close(); err != nil {
@@ -144,7 +174,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
 	}
 	// Insert or update subscription
 	updatedAt, warnedAt := time.Now().Unix(), 0
-	if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, updatedAt, warnedAt); err != nil {
+	if _, err = tx.Exec(insertWebPushSubscriptionQuery, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
 		return err
 	}
 	// Replace all subscription topics
@@ -159,6 +189,7 @@ func (c *webPushStore) UpsertSubscription(endpoint string, auth, p256dh, userID
 	return tx.Commit()
 }
 
+// SubscriptionsForTopic returns all subscriptions for the given topic
 func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscription, error) {
 	rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
 	if err != nil {
@@ -168,6 +199,7 @@ func (c *webPushStore) SubscriptionsForTopic(topic string) ([]*webPushSubscripti
 	return c.subscriptionsFromRows(rows)
 }
 
+// SubscriptionsExpiring returns all subscriptions that have not been updated for a given time period
 func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPushSubscription, error) {
 	rows, err := c.db.Query(selectWebPushSubscriptionsExpiringSoonQuery, time.Now().Add(-warnAfter).Unix())
 	if err != nil {
@@ -177,6 +209,7 @@ func (c *webPushStore) SubscriptionsExpiring(warnAfter time.Duration) ([]*webPus
 	return c.subscriptionsFromRows(rows)
 }
 
+// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon
 func (c *webPushStore) MarkExpiryWarningSent(subscriptions []*webPushSubscription) error {
 	tx, err := c.db.Begin()
 	if err != nil {
@@ -209,21 +242,25 @@ func (c *webPushStore) subscriptionsFromRows(rows *sql.Rows) ([]*webPushSubscrip
 	return subscriptions, nil
 }
 
+// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint
 func (c *webPushStore) RemoveSubscriptionsByEndpoint(endpoint string) error {
 	_, err := c.db.Exec(deleteWebPushSubscriptionByEndpointQuery, endpoint)
 	return err
 }
 
+// RemoveSubscriptionsByUserID removes all subscriptions for the given user ID
 func (c *webPushStore) RemoveSubscriptionsByUserID(userID string) error {
 	_, err := c.db.Exec(deleteWebPushSubscriptionByUserIDQuery, userID)
 	return err
 }
 
+// RemoveExpiredSubscriptions removes all subscriptions that have not been updated for a given time period
 func (c *webPushStore) RemoveExpiredSubscriptions(expireAfter time.Duration) error {
 	_, err := c.db.Exec(deleteWebPushSubscriptionByAgeQuery, time.Now().Add(-expireAfter).Unix())
 	return err
 }
 
+// Close closes the underlying database connection
 func (c *webPushStore) Close() error {
 	return c.db.Close()
 }

+ 43 - 0
server/webpush_store_test.go

@@ -1,7 +1,10 @@
 package server
 
 import (
+	"fmt"
 	"github.com/stretchr/testify/require"
+	"net/netip"
+	"path/filepath"
 	"testing"
 )
 
@@ -10,3 +13,43 @@ func newTestWebPushStore(t *testing.T, filename string) *webPushStore {
 	require.Nil(t, err)
 	return webPush
 }
+
+func TestWebPushStore_UpsertSubscription_SubscriptionsForTopic(t *testing.T) {
+	webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db"))
+	defer webPush.Close()
+
+	require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
+
+	subs, err := webPush.SubscriptionsForTopic("test-topic")
+	require.Nil(t, err)
+	require.Len(t, subs, 1)
+	require.Equal(t, subs[0].Endpoint, testWebPushEndpoint)
+	require.Equal(t, subs[0].P256dh, "p256dh-key")
+	require.Equal(t, subs[0].Auth, "auth-key")
+	require.Equal(t, subs[0].UserID, "u_1234")
+
+	subs2, err := webPush.SubscriptionsForTopic("mytopic")
+	require.Nil(t, err)
+	require.Len(t, subs2, 1)
+	require.Equal(t, subs[0].Endpoint, subs2[0].Endpoint)
+}
+
+func TestWebPushStore_UpsertSubscription_SubscriberIPLimitReached(t *testing.T) {
+	webPush := newTestWebPushStore(t, filepath.Join(t.TempDir(), "webpush.db"))
+	defer webPush.Close()
+
+	// Insert 10 subscriptions with the same IP address
+	for i := 0; i < 10; i++ {
+		endpoint := fmt.Sprintf(testWebPushEndpoint+"%d", i)
+		require.Nil(t, webPush.UpsertSubscription(endpoint, "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
+	}
+
+	// Another one for the same endpoint should be fine
+	require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"0", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
+
+	// But with a different endpoint it should fail
+	require.Equal(t, errWebPushTooManySubscriptions, webPush.UpsertSubscription(testWebPushEndpoint+"11", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("1.2.3.4"), []string{"test-topic", "mytopic"}))
+
+	// But with a different IP address it should be fine again
+	require.Nil(t, webPush.UpsertSubscription(testWebPushEndpoint+"99", "auth-key", "p256dh-key", "u_1234", netip.MustParseAddr("9.9.9.9"), []string{"test-topic", "mytopic"}))
+}