binwiederhier 2 лет назад
Родитель
Сommit
9e0687e142
7 измененных файлов с 32 добавлено и 63 удалено
  1. 1 1
      go.mod
  2. 1 3
      server/server_account.go
  3. 2 8
      server/server_test.go
  4. 7 27
      server/server_web_push.go
  5. 7 10
      server/server_web_push_test.go
  6. 2 2
      server/types.go
  7. 12 12
      server/web_push.go

+ 1 - 1
go.mod

@@ -27,6 +27,7 @@ require github.com/pkg/errors v0.9.1 // indirect
 
 require (
 	firebase.google.com/go/v4 v4.11.0
+	github.com/SherClockHolmes/webpush-go v1.2.0
 	github.com/prometheus/client_golang v1.15.1
 	github.com/stripe/stripe-go/v74 v74.21.0
 )
@@ -39,7 +40,6 @@ require (
 	cloud.google.com/go/longrunning v0.5.0 // indirect
 	github.com/AlekSi/pointer v1.2.0 // indirect
 	github.com/MicahParks/keyfunc v1.9.0 // indirect
-	github.com/SherClockHolmes/webpush-go v1.2.0 // indirect
 	github.com/beorn7/perks v1.0.1 // indirect
 	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect

+ 1 - 3
server/server_account.go

@@ -171,9 +171,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
 		return errHTTPBadRequestIncorrectPasswordConfirmation
 	}
 	if s.webPush != nil {
-		err := s.webPush.ExpireWebPushForUser(u.Name)
-
-		if err != nil {
+		if err := s.webPush.RemoveByUserID(u.ID); err != nil {
 			logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name)
 		}
 	}

+ 2 - 8
server/server_test.go

@@ -2620,12 +2620,8 @@ func newTestConfigWithAuthFile(t *testing.T) *Config {
 
 func newTestConfigWithWebPush(t *testing.T) *Config {
 	conf := newTestConfig(t)
-
 	privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
-	if err != nil {
-		t.Fatal(err)
-	}
-
+	require.Nil(t, err)
 	conf.WebPushEnabled = true
 	conf.WebPushSubscriptionsFile = filepath.Join(t.TempDir(), "subscriptions.db")
 	conf.WebPushEmailAddress = "testing@example.com"
@@ -2636,9 +2632,7 @@ func newTestConfigWithWebPush(t *testing.T) *Config {
 
 func newTestServer(t *testing.T, config *Config) *Server {
 	server, err := New(config)
-	if err != nil {
-		t.Fatal(err)
-	}
+	require.Nil(t, err)
 	return server
 }
 

+ 7 - 27
server/server_web_push.go

@@ -10,15 +10,8 @@ import (
 )
 
 func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	var username string
-	u := v.User()
-	if u != nil {
-		username = u.Name
-	}
-
 	var sub webPushSubscribePayload
 	err := json.NewDecoder(r.Body).Decode(&sub)
-
 	if err != nil || sub.BrowserSubscription.Endpoint == "" || sub.BrowserSubscription.Keys.P256dh == "" || sub.BrowserSubscription.Keys.Auth == "" {
 		return errHTTPBadRequestWebPushSubscriptionInvalid
 	}
@@ -27,12 +20,9 @@ func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Requ
 	if err != nil {
 		return err
 	}
-
-	err = s.webPush.AddSubscription(topic.ID, username, sub)
-	if err != nil {
+	if err = s.webPush.AddSubscription(topic.ID, v.MaybeUserID(), sub); err != nil {
 		return err
 	}
-
 	return s.writeJSON(w, newSuccessResponse())
 }
 
@@ -59,7 +49,7 @@ func (s *Server) handleTopicWebPushUnsubscribe(w http.ResponseWriter, r *http.Re
 }
 
 func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
-	subscriptions, err := s.webPush.GetSubscriptionsForTopic(m.Topic)
+	subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
 	if err != nil {
 		logvm(v, m).Err(err).Warn("Unable to publish web push messages")
 		return
@@ -69,21 +59,17 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
 
 	// Importing the emojis in the service worker would add unnecessary complexity,
 	// simply do it here for web push notifications instead
-	var titleWithDefault string
-	var formattedTitle string
-
+	var titleWithDefault, formattedTitle string
 	emojis, _, err := toEmojis(m.Tags)
 	if err != nil {
 		logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
 		return
 	}
-
 	if m.Title == "" {
 		titleWithDefault = m.Topic
 	} else {
 		titleWithDefault = m.Title
 	}
-
 	if len(emojis) > 0 {
 		formattedTitle = fmt.Sprintf("%s %s", strings.Join(emojis[:], " "), titleWithDefault)
 	} else {
@@ -92,7 +78,7 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
 
 	for i, xi := range subscriptions {
 		go func(i int, sub webPushSubscription) {
-			ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.Username, "topic": m.Topic, "message_id": m.ID}
+			ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.UserID, "topic": m.Topic, "message_id": m.ID}
 
 			payload := &webPushPayload{
 				SubscriptionID: fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic),
@@ -110,31 +96,25 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
 				Subscriber:      s.config.WebPushEmailAddress,
 				VAPIDPublicKey:  s.config.WebPushPublicKey,
 				VAPIDPrivateKey: s.config.WebPushPrivateKey,
-				// deliverability on iOS isn't great with lower urgency values,
+				// Deliverability on iOS isn't great with lower urgency values,
 				// and thus we can't really map lower ntfy priorities to lower urgency values
 				Urgency: webpush.UrgencyHigh,
 			})
 
 			if err != nil {
 				logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
-
-				err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
-				if err != nil {
+				if err := s.webPush.RemoveByEndpoint(sub.BrowserSubscription.Endpoint); err != nil {
 					logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
 				}
-
 				return
 			}
 
 			// May want to handle at least 429 differently, but for now treat all errors the same
 			if !(200 <= resp.StatusCode && resp.StatusCode <= 299) {
 				logvm(v, m).Fields(ctx).Field("response", resp).Debug("Unable to publish web push message")
-
-				err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
-				if err != nil {
+				if err := s.webPush.RemoveByEndpoint(sub.BrowserSubscription.Endpoint); err != nil {
 					logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
 				}
-
 				return
 			}
 		}(i, xi)

+ 7 - 10
server/server_web_push_test.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"net/http"
 	"net/http/httptest"
+	"strings"
 	"sync/atomic"
 	"testing"
 
@@ -41,7 +42,7 @@ func TestServer_WebPush_TopicSubscribe(t *testing.T) {
 	require.Equal(t, 200, response.Code)
 	require.Equal(t, `{"success":true}`+"\n", response.Body.String())
 
-	subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
+	subs, err := s.webPush.SubscriptionsForTopic("test-topic")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -50,7 +51,7 @@ func TestServer_WebPush_TopicSubscribe(t *testing.T) {
 	require.Equal(t, subs[0].BrowserSubscription.Endpoint, "https://example.com/webpush")
 	require.Equal(t, subs[0].BrowserSubscription.Keys.P256dh, "p256dh-key")
 	require.Equal(t, subs[0].BrowserSubscription.Keys.Auth, "auth-key")
-	require.Equal(t, subs[0].Username, "")
+	require.Equal(t, subs[0].UserID, "")
 }
 
 func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
@@ -64,17 +65,13 @@ func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
 	response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, map[string]string{
 		"Authorization": util.BasicAuth("ben", "ben"),
 	})
-
 	require.Equal(t, 200, response.Code)
 	require.Equal(t, `{"success":true}`+"\n", response.Body.String())
 
-	subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
-	if err != nil {
-		t.Fatal(err)
-	}
-
+	subs, err := s.webPush.SubscriptionsForTopic("test-topic")
+	require.Nil(t, err)
 	require.Len(t, subs, 1)
-	require.Equal(t, subs[0].Username, "ben")
+	require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
 }
 
 func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
@@ -203,7 +200,7 @@ func addSubscription(t *testing.T, s *Server, topic string, url string) {
 }
 
 func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
-	subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
+	subs, err := s.webPush.SubscriptionsForTopic("test-topic")
 	if err != nil {
 		t.Fatal(err)
 	}

+ 2 - 2
server/types.go

@@ -41,7 +41,7 @@ type message struct {
 	PollID     string      `json:"poll_id,omitempty"`
 	Encoding   string      `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
 	Sender     netip.Addr  `json:"-"`                  // IP address of uploader, used for rate limiting
-	User       string      `json:"-"`                  // Username of the uploader, used to associated attachments
+	User       string      `json:"-"`                  // UserID of the uploader, used to associated attachments
 }
 
 func (m *message) Context() log.Context {
@@ -476,7 +476,7 @@ type webPushPayload struct {
 
 type webPushSubscription struct {
 	BrowserSubscription webpush.Subscription
-	Username            string
+	UserID              string
 }
 
 type webPushSubscribePayload struct {

+ 12 - 12
server/web_push.go

@@ -12,7 +12,7 @@ const (
 		CREATE TABLE IF NOT EXISTS subscriptions (
 			id INTEGER PRIMARY KEY AUTOINCREMENT,
 			topic TEXT NOT NULL,
-			username TEXT,
+			user_id TEXT,
 			endpoint TEXT NOT NULL,
 			key_auth TEXT NOT NULL,
 			key_p256dh TEXT NOT NULL,
@@ -24,14 +24,14 @@ const (
 		COMMIT;
 	`
 	insertWebPushSubscriptionQuery = `
-		INSERT OR REPLACE INTO subscriptions (topic, username, endpoint, key_auth, key_p256dh)
+		INSERT OR REPLACE INTO subscriptions (topic, user_id, endpoint, key_auth, key_p256dh)
 		VALUES (?, ?, ?, ?, ?)
 	`
 	deleteWebPushSubscriptionByEndpointQuery         = `DELETE FROM subscriptions WHERE endpoint = ?`
-	deleteWebPushSubscriptionByUsernameQuery         = `DELETE FROM subscriptions WHERE username = ?`
+	deleteWebPushSubscriptionByUserIDQuery           = `DELETE FROM subscriptions WHERE user_id = ?`
 	deleteWebPushSubscriptionByTopicAndEndpointQuery = `DELETE FROM subscriptions WHERE topic = ? AND endpoint = ?`
 
-	selectWebPushSubscriptionsForTopicQuery = `SELECT endpoint, key_auth, key_p256dh, username FROM subscriptions WHERE topic = ?`
+	selectWebPushSubscriptionsForTopicQuery = `SELECT endpoint, key_auth, key_p256dh, user_id FROM subscriptions WHERE topic = ?`
 
 	selectWebPushSubscriptionsCountQuery = `SELECT COUNT(*) FROM subscriptions`
 )
@@ -69,11 +69,11 @@ func setupNewSubscriptionsDB(db *sql.DB) error {
 	return nil
 }
 
-func (c *webPushStore) AddSubscription(topic string, username string, subscription webPushSubscribePayload) error {
+func (c *webPushStore) AddSubscription(topic string, userID string, subscription webPushSubscribePayload) error {
 	_, err := c.db.Exec(
 		insertWebPushSubscriptionQuery,
 		topic,
-		username,
+		userID,
 		subscription.BrowserSubscription.Endpoint,
 		subscription.BrowserSubscription.Keys.Auth,
 		subscription.BrowserSubscription.Keys.P256dh,
@@ -90,7 +90,7 @@ func (c *webPushStore) RemoveSubscription(topic string, endpoint string) error {
 	return err
 }
 
-func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []webPushSubscription, err error) {
+func (c *webPushStore) SubscriptionsForTopic(topic string) (subscriptions []webPushSubscription, err error) {
 	rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
 	if err != nil {
 		return nil, err
@@ -100,7 +100,7 @@ func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []w
 	var data []webPushSubscription
 	for rows.Next() {
 		i := webPushSubscription{}
-		err = rows.Scan(&i.BrowserSubscription.Endpoint, &i.BrowserSubscription.Keys.Auth, &i.BrowserSubscription.Keys.P256dh, &i.Username)
+		err = rows.Scan(&i.BrowserSubscription.Endpoint, &i.BrowserSubscription.Keys.Auth, &i.BrowserSubscription.Keys.P256dh, &i.UserID)
 		if err != nil {
 			return nil, err
 		}
@@ -109,7 +109,7 @@ func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []w
 	return data, nil
 }
 
-func (c *webPushStore) ExpireWebPushEndpoint(endpoint string) error {
+func (c *webPushStore) RemoveByEndpoint(endpoint string) error {
 	_, err := c.db.Exec(
 		deleteWebPushSubscriptionByEndpointQuery,
 		endpoint,
@@ -117,10 +117,10 @@ func (c *webPushStore) ExpireWebPushEndpoint(endpoint string) error {
 	return err
 }
 
-func (c *webPushStore) ExpireWebPushForUser(username string) error {
+func (c *webPushStore) RemoveByUserID(userID string) error {
 	_, err := c.db.Exec(
-		deleteWebPushSubscriptionByUsernameQuery,
-		username,
+		deleteWebPushSubscriptionByUserIDQuery,
+		userID,
 	)
 	return err
 }