Jelajahi Sumber

Move stuff to server_web_push.go

binwiederhier 2 tahun lalu
induk
melakukan
7f3e4b5f47
4 mengubah file dengan 193 tambahan dan 190 penghapusan
  1. 45 184
      server/server.go
  2. 2 2
      server/server_account.go
  3. 142 0
      server/server_web_push.go
  4. 4 4
      server/server_web_push_test.go

+ 45 - 184
server/server.go

@@ -33,35 +33,33 @@ import (
 	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
-
-	"github.com/SherClockHolmes/webpush-go"
 )
 
 // Server is the main server, providing the UI and API for ntfy
 type Server struct {
-	config                   *Config
-	httpServer               *http.Server
-	httpsServer              *http.Server
-	httpMetricsServer        *http.Server
-	httpProfileServer        *http.Server
-	unixListener             net.Listener
-	smtpServer               *smtp.Server
-	smtpServerBackend        *smtpBackend
-	smtpSender               mailer
-	topics                   map[string]*topic
-	visitors                 map[string]*visitor // ip:<ip> or user:<user>
-	firebaseClient           *firebaseClient
-	messages                 int64                               // Total number of messages (persisted if messageCache enabled)
-	messagesHistory          []int64                             // Last n values of the messages counter, used to determine rate
-	userManager              *user.Manager                       // Might be nil!
-	messageCache             *messageCache                       // Database that stores the messages
-	webPushSubscriptionStore *webPushStore                       // Database that stores web push subscriptions
-	fileCache                *fileCache                          // File system based cache that stores attachments
-	stripe                   stripeAPI                           // Stripe API, can be replaced with a mock
-	priceCache               *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
-	metricsHandler           http.Handler                        // Handles /metrics if enable-metrics set, and listen-metrics-http not set
-	closeChan                chan bool
-	mu                       sync.RWMutex
+	config            *Config
+	httpServer        *http.Server
+	httpsServer       *http.Server
+	httpMetricsServer *http.Server
+	httpProfileServer *http.Server
+	unixListener      net.Listener
+	smtpServer        *smtp.Server
+	smtpServerBackend *smtpBackend
+	smtpSender        mailer
+	topics            map[string]*topic
+	visitors          map[string]*visitor // ip:<ip> or user:<user>
+	firebaseClient    *firebaseClient
+	messages          int64                               // Total number of messages (persisted if messageCache enabled)
+	messagesHistory   []int64                             // Last n values of the messages counter, used to determine rate
+	userManager       *user.Manager                       // Might be nil!
+	messageCache      *messageCache                       // Database that stores the messages
+	webPush           *webPushStore                       // Database that stores web push subscriptions
+	fileCache         *fileCache                          // File system based cache that stores attachments
+	stripe            stripeAPI                           // Stripe API, can be replaced with a mock
+	priceCache        *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
+	metricsHandler    http.Handler                        // Handles /metrics if enable-metrics set, and listen-metrics-http not set
+	closeChan         chan bool
+	mu                sync.RWMutex
 }
 
 // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@@ -160,9 +158,12 @@ func New(conf *Config) (*Server, error) {
 	if err != nil {
 		return nil, err
 	}
-	webPushSubscriptionStore, err := createWebPushSubscriptionStore(conf)
-	if err != nil {
-		return nil, err
+	var webPush *webPushStore
+	if conf.WebPushEnabled {
+		webPush, err = newWebPushStore(conf.WebPushSubscriptionsFile)
+		if err != nil {
+			return nil, err
+		}
 	}
 	topics, err := messageCache.Topics()
 	if err != nil {
@@ -201,18 +202,18 @@ func New(conf *Config) (*Server, error) {
 		firebaseClient = newFirebaseClient(sender, auther)
 	}
 	s := &Server{
-		config:                   conf,
-		messageCache:             messageCache,
-		webPushSubscriptionStore: webPushSubscriptionStore,
-		fileCache:                fileCache,
-		firebaseClient:           firebaseClient,
-		smtpSender:               mailer,
-		topics:                   topics,
-		userManager:              userManager,
-		messages:                 messages,
-		messagesHistory:          []int64{messages},
-		visitors:                 make(map[string]*visitor),
-		stripe:                   stripe,
+		config:          conf,
+		messageCache:    messageCache,
+		webPush:         webPush,
+		fileCache:       fileCache,
+		firebaseClient:  firebaseClient,
+		smtpSender:      mailer,
+		topics:          topics,
+		userManager:     userManager,
+		messages:        messages,
+		messagesHistory: []int64{messages},
+		visitors:        make(map[string]*visitor),
+		stripe:          stripe,
 	}
 	s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
 	return s, nil
@@ -227,14 +228,6 @@ func createMessageCache(conf *Config) (*messageCache, error) {
 	return newMemCache()
 }
 
-func createWebPushSubscriptionStore(conf *Config) (*webPushStore, error) {
-	if !conf.WebPushEnabled {
-		return nil, nil
-	}
-
-	return newWebPushStore(conf.WebPushSubscriptionsFile)
-}
-
 // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
 // a manager go routine to print stats and prune messages.
 func (s *Server) Run() error {
@@ -364,8 +357,8 @@ func (s *Server) closeDatabases() {
 		s.userManager.Close()
 	}
 	s.messageCache.Close()
-	if s.webPushSubscriptionStore != nil {
-		s.webPushSubscriptionStore.Close()
+	if s.webPush != nil {
+		s.webPush.Close()
 	}
 }
 
@@ -536,9 +529,9 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 	} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
 		return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
 	} else if r.Method == http.MethodPost && webPushSubscribePathRegex.MatchString(r.URL.Path) {
-		return s.limitRequestsWithTopic(s.authorizeTopicRead(s.ensureWebPushEnabled(s.handleTopicWebPushSubscribe)))(w, r, v)
+		return s.ensureWebPushEnabled(s.limitRequestsWithTopic(s.authorizeTopicRead(s.handleTopicWebPushSubscribe)))(w, r, v)
 	} else if r.Method == http.MethodPost && webPushUnsubscribePathRegex.MatchString(r.URL.Path) {
-		return s.limitRequestsWithTopic(s.authorizeTopicRead(s.ensureWebPushEnabled(s.handleTopicWebPushUnsubscribe)))(w, r, v)
+		return s.ensureWebPushEnabled(s.limitRequestsWithTopic(s.authorizeTopicRead(s.handleTopicWebPushUnsubscribe)))(w, r, v)
 	} else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
 		return s.ensureWebEnabled(s.handleTopic)(w, r, v)
 	}
@@ -578,55 +571,6 @@ func (s *Server) handleAPIWebPushConfig(w http.ResponseWriter, _ *http.Request,
 	return s.writeJSON(w, response)
 }
 
-func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	var username string
-	u := v.User()
-	if u != nil {
-		username = u.Name
-	}
-
-	var sub webPushSubscribePayload
-	err := json.NewDecoder(r.Body).Decode(&sub)
-
-	if err != nil || sub.BrowserSubscription.Endpoint == "" || sub.BrowserSubscription.Keys.P256dh == "" || sub.BrowserSubscription.Keys.Auth == "" {
-		return errHTTPBadRequestWebPushSubscriptionInvalid
-	}
-
-	topic, err := fromContext[*topic](r, contextTopic)
-	if err != nil {
-		return err
-	}
-
-	err = s.webPushSubscriptionStore.AddSubscription(topic.ID, username, sub)
-	if err != nil {
-		return err
-	}
-
-	return s.writeJSON(w, newSuccessResponse())
-}
-
-func (s *Server) handleTopicWebPushUnsubscribe(w http.ResponseWriter, r *http.Request, _ *visitor) error {
-	var payload webPushUnsubscribePayload
-
-	err := json.NewDecoder(r.Body).Decode(&payload)
-
-	if err != nil {
-		return errHTTPBadRequestWebPushSubscriptionInvalid
-	}
-
-	topic, err := fromContext[*topic](r, contextTopic)
-	if err != nil {
-		return err
-	}
-
-	err = s.webPushSubscriptionStore.RemoveSubscription(topic.ID, payload.Endpoint)
-	if err != nil {
-		return err
-	}
-
-	return s.writeJSON(w, newSuccessResponse())
-}
-
 func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
 	response := &apiHealthResponse{
 		Healthy: true,
@@ -977,89 +921,6 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
 	}
 }
 
-func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
-	subscriptions, err := s.webPushSubscriptionStore.GetSubscriptionsForTopic(m.Topic)
-	if err != nil {
-		logvm(v, m).Err(err).Warn("Unable to publish web push messages")
-		return
-	}
-	
-	ctx := log.Context{"topic": m.Topic, "message_id": m.ID, "total_count": len(subscriptions)}
-
-	// Importing the emojis in the service worker would add unnecessary complexity,
-	// simply do it here for web push notifications instead
-	var titleWithDefault string
-	var formattedTitle string
-
-	emojis, _, err := toEmojis(m.Tags)
-	if err != nil {
-		logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
-		return
-	}
-
-	if m.Title == "" {
-		titleWithDefault = m.Topic
-	} else {
-		titleWithDefault = m.Title
-	}
-
-	if len(emojis) > 0 {
-		formattedTitle = fmt.Sprintf("%s %s", strings.Join(emojis[:], " "), titleWithDefault)
-	} else {
-		formattedTitle = titleWithDefault
-	}
-
-	for i, xi := range subscriptions {
-		go func(i int, sub webPushSubscription) {
-			ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.Username, "topic": m.Topic, "message_id": m.ID}
-
-			payload := &webPushPayload{
-				SubscriptionID: fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic),
-				Message:        *m,
-				FormattedTitle: formattedTitle,
-			}
-			jsonPayload, err := json.Marshal(payload)
-
-			if err != nil {
-				logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
-				return
-			}
-
-			resp, err := webpush.SendNotification(jsonPayload, &sub.BrowserSubscription, &webpush.Options{
-				Subscriber:      s.config.WebPushEmailAddress,
-				VAPIDPublicKey:  s.config.WebPushPublicKey,
-				VAPIDPrivateKey: s.config.WebPushPrivateKey,
-				// deliverability on iOS isn't great with lower urgency values,
-				// and thus we can't really map lower ntfy priorities to lower urgency values
-				Urgency: webpush.UrgencyHigh,
-			})
-
-			if err != nil {
-				logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
-
-				err = s.webPushSubscriptionStore.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
-				if err != nil {
-					logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
-				}
-
-				return
-			}
-
-			// May want to handle at least 429 differently, but for now treat all errors the same
-			if !(200 <= resp.StatusCode && resp.StatusCode <= 299) {
-				logvm(v, m).Fields(ctx).Field("response", resp).Debug("Unable to publish web push message")
-
-				err = s.webPushSubscriptionStore.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
-				if err != nil {
-					logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
-				}
-
-				return
-			}
-		}(i, xi)
-	}
-}
-
 func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email, call string, unifiedpush bool, err *errHTTP) {
 	cache = readBoolParam(r, true, "x-cache", "cache")
 	firebase = readBoolParam(r, true, "x-firebase", "firebase")

+ 2 - 2
server/server_account.go

@@ -170,8 +170,8 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
 	if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
 		return errHTTPBadRequestIncorrectPasswordConfirmation
 	}
-	if s.webPushSubscriptionStore != nil {
-		err := s.webPushSubscriptionStore.ExpireWebPushForUser(u.Name)
+	if s.webPush != nil {
+		err := s.webPush.ExpireWebPushForUser(u.Name)
 
 		if err != nil {
 			logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name)

+ 142 - 0
server/server_web_push.go

@@ -0,0 +1,142 @@
+package server
+
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/SherClockHolmes/webpush-go"
+	"heckel.io/ntfy/log"
+	"net/http"
+	"strings"
+)
+
+func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	var username string
+	u := v.User()
+	if u != nil {
+		username = u.Name
+	}
+
+	var sub webPushSubscribePayload
+	err := json.NewDecoder(r.Body).Decode(&sub)
+
+	if err != nil || sub.BrowserSubscription.Endpoint == "" || sub.BrowserSubscription.Keys.P256dh == "" || sub.BrowserSubscription.Keys.Auth == "" {
+		return errHTTPBadRequestWebPushSubscriptionInvalid
+	}
+
+	topic, err := fromContext[*topic](r, contextTopic)
+	if err != nil {
+		return err
+	}
+
+	err = s.webPush.AddSubscription(topic.ID, username, sub)
+	if err != nil {
+		return err
+	}
+
+	return s.writeJSON(w, newSuccessResponse())
+}
+
+func (s *Server) handleTopicWebPushUnsubscribe(w http.ResponseWriter, r *http.Request, _ *visitor) error {
+	var payload webPushUnsubscribePayload
+
+	err := json.NewDecoder(r.Body).Decode(&payload)
+
+	if err != nil {
+		return errHTTPBadRequestWebPushSubscriptionInvalid
+	}
+
+	topic, err := fromContext[*topic](r, contextTopic)
+	if err != nil {
+		return err
+	}
+
+	err = s.webPush.RemoveSubscription(topic.ID, payload.Endpoint)
+	if err != nil {
+		return err
+	}
+
+	return s.writeJSON(w, newSuccessResponse())
+}
+
+func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
+	subscriptions, err := s.webPush.GetSubscriptionsForTopic(m.Topic)
+	if err != nil {
+		logvm(v, m).Err(err).Warn("Unable to publish web push messages")
+		return
+	}
+
+	ctx := log.Context{"topic": m.Topic, "message_id": m.ID, "total_count": len(subscriptions)}
+
+	// Importing the emojis in the service worker would add unnecessary complexity,
+	// simply do it here for web push notifications instead
+	var titleWithDefault string
+	var formattedTitle string
+
+	emojis, _, err := toEmojis(m.Tags)
+	if err != nil {
+		logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
+		return
+	}
+
+	if m.Title == "" {
+		titleWithDefault = m.Topic
+	} else {
+		titleWithDefault = m.Title
+	}
+
+	if len(emojis) > 0 {
+		formattedTitle = fmt.Sprintf("%s %s", strings.Join(emojis[:], " "), titleWithDefault)
+	} else {
+		formattedTitle = titleWithDefault
+	}
+
+	for i, xi := range subscriptions {
+		go func(i int, sub webPushSubscription) {
+			ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.Username, "topic": m.Topic, "message_id": m.ID}
+
+			payload := &webPushPayload{
+				SubscriptionID: fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic),
+				Message:        *m,
+				FormattedTitle: formattedTitle,
+			}
+			jsonPayload, err := json.Marshal(payload)
+
+			if err != nil {
+				logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
+				return
+			}
+
+			resp, err := webpush.SendNotification(jsonPayload, &sub.BrowserSubscription, &webpush.Options{
+				Subscriber:      s.config.WebPushEmailAddress,
+				VAPIDPublicKey:  s.config.WebPushPublicKey,
+				VAPIDPrivateKey: s.config.WebPushPrivateKey,
+				// deliverability on iOS isn't great with lower urgency values,
+				// and thus we can't really map lower ntfy priorities to lower urgency values
+				Urgency: webpush.UrgencyHigh,
+			})
+
+			if err != nil {
+				logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
+
+				err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
+				if err != nil {
+					logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
+				}
+
+				return
+			}
+
+			// May want to handle at least 429 differently, but for now treat all errors the same
+			if !(200 <= resp.StatusCode && resp.StatusCode <= 299) {
+				logvm(v, m).Fields(ctx).Field("response", resp).Debug("Unable to publish web push message")
+
+				err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
+				if err != nil {
+					logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
+				}
+
+				return
+			}
+		}(i, xi)
+	}
+}

+ 4 - 4
server/server_web_push_test.go

@@ -41,7 +41,7 @@ func TestServer_WebPush_TopicSubscribe(t *testing.T) {
 	require.Equal(t, 200, response.Code)
 	require.Equal(t, `{"success":true}`+"\n", response.Body.String())
 
-	subs, err := s.webPushSubscriptionStore.GetSubscriptionsForTopic("test-topic")
+	subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -68,7 +68,7 @@ func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
 	require.Equal(t, 200, response.Code)
 	require.Equal(t, `{"success":true}`+"\n", response.Body.String())
 
-	subs, err := s.webPushSubscriptionStore.GetSubscriptionsForTopic("test-topic")
+	subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -187,7 +187,7 @@ func TestServer_WebPush_PublishExpire(t *testing.T) {
 }
 
 func addSubscription(t *testing.T, s *Server, topic string, url string) {
-	err := s.webPushSubscriptionStore.AddSubscription("test-topic", "", webPushSubscribePayload{
+	err := s.webPush.AddSubscription("test-topic", "", webPushSubscribePayload{
 		BrowserSubscription: webpush.Subscription{
 			Endpoint: url,
 			Keys: webpush.Keys{
@@ -203,7 +203,7 @@ func addSubscription(t *testing.T, s *Server, topic string, url string) {
 }
 
 func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
-	subs, err := s.webPushSubscriptionStore.GetSubscriptionsForTopic("test-topic")
+	subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
 	if err != nil {
 		t.Fatal(err)
 	}