binwiederhier 1 kuukausi sitten
vanhempi
sitoutus
37d71051de
3 muutettua tiedostoa jossa 36 lisäystä ja 52 poistoa
  1. 24 51
      server/server.go
  2. 1 1
      server/server_webpush.go
  3. 11 0
      server/types.go

+ 24 - 51
server/server.go

@@ -879,7 +879,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 		return err
 	}
 	minc(metricMessagesPublishedSuccess)
-	return s.writeJSON(w, m)
+	return s.writeJSON(w, m.forJSON())
 }
 
 func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -908,50 +908,14 @@ func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *
 }
 
 func (s *Server) handleDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	t, err := fromContext[*topic](r, contextTopic)
-	if err != nil {
-		return err
-	}
-	vrate, err := fromContext[*visitor](r, contextRateVisitor)
-	if err != nil {
-		return err
-	}
-	if !util.ContainsIP(s.config.VisitorRequestExemptPrefixes, v.ip) && !vrate.MessageAllowed() {
-		return errHTTPTooManyRequestsLimitMessages.With(t)
-	}
-	sequenceID, e := s.sequenceIDFromPath(r.URL.Path)
-	if e != nil {
-		return e.With(t)
-	}
-	// Create a delete message with event type message_delete
-	m := newActionMessage(messageDeleteEvent, t.ID, sequenceID)
-	m.Sender = v.IP()
-	m.User = v.MaybeUserID()
-	m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
-	// Publish to subscribers
-	if err := t.Publish(v, m); err != nil {
-		return err
-	}
-	// Send to Firebase for Android clients
-	if s.firebaseClient != nil {
-		go s.sendToFirebase(v, m)
-	}
-	// Send to web push endpoints
-	if s.config.WebPushPublicKey != "" {
-		go s.publishToWebPushEndpoints(v, m)
-	}
-	// Add to message cache
-	if err := s.messageCache.AddMessage(m); err != nil {
-		return err
-	}
-	logvrm(v, r, m).Tag(tagPublish).Debug("Deleted message with sequence ID %s", sequenceID)
-	s.mu.Lock()
-	s.messages++
-	s.mu.Unlock()
-	return s.writeJSON(w, m)
+	return s.handleActionMessage(w, r, v, messageDeleteEvent, s.sequenceIDFromPath)
 }
 
 func (s *Server) handleMarkRead(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	return s.handleActionMessage(w, r, v, messageReadEvent, s.sequenceIDFromMarkReadPath)
+}
+
+func (s *Server) handleActionMessage(w http.ResponseWriter, r *http.Request, v *visitor, event string, extractSequenceID func(string) (string, *errHTTP)) error {
 	t, err := fromContext[*topic](r, contextTopic)
 	if err != nil {
 		return err
@@ -963,12 +927,12 @@ func (s *Server) handleMarkRead(w http.ResponseWriter, r *http.Request, v *visit
 	if !util.ContainsIP(s.config.VisitorRequestExemptPrefixes, v.ip) && !vrate.MessageAllowed() {
 		return errHTTPTooManyRequestsLimitMessages.With(t)
 	}
-	sequenceID, e := s.sequenceIDFromPath(r.URL.Path)
+	sequenceID, e := extractSequenceID(r.URL.Path)
 	if e != nil {
 		return e.With(t)
 	}
-	// Create a read message with event type message_read
-	m := newActionMessage(messageReadEvent, t.ID, sequenceID)
+	// Create an action message with the given event type
+	m := newActionMessage(event, t.ID, sequenceID)
 	m.Sender = v.IP()
 	m.User = v.MaybeUserID()
 	m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
@@ -988,11 +952,11 @@ func (s *Server) handleMarkRead(w http.ResponseWriter, r *http.Request, v *visit
 	if err := s.messageCache.AddMessage(m); err != nil {
 		return err
 	}
-	logvrm(v, r, m).Tag(tagPublish).Debug("Marked message as read with sequence ID %s", sequenceID)
+	logvrm(v, r, m).Tag(tagPublish).Debug("Published %s for sequence ID %s", event, sequenceID)
 	s.mu.Lock()
 	s.messages++
 	s.mu.Unlock()
-	return s.writeJSON(w, m)
+	return s.writeJSON(w, m.forJSON())
 }
 
 func (s *Server) sendToFirebase(v *visitor, m *message) {
@@ -1384,7 +1348,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
 func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		var buf bytes.Buffer
-		if err := json.NewEncoder(&buf).Encode(msg); err != nil {
+		if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
 			return "", err
 		}
 		return buf.String(), nil
@@ -1395,10 +1359,10 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
 func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		var buf bytes.Buffer
-		if err := json.NewEncoder(&buf).Encode(msg); err != nil {
+		if err := json.NewEncoder(&buf).Encode(msg.forJSON()); err != nil {
 			return "", err
 		}
-		if msg.Event != messageEvent {
+		if msg.Event != messageEvent && msg.Event != messageDeleteEvent && msg.Event != messageReadEvent {
 			return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
 		}
 		return fmt.Sprintf("data: %s\n", buf.String()), nil
@@ -1808,7 +1772,7 @@ func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
 	return topics, parts[1], nil
 }
 
-// sequenceIDFromPath returns the sequence ID from a POST path like /mytopic/sequenceIdHere
+// sequenceIDFromPath returns the sequence ID from a path like /mytopic/sequenceIdHere
 func (s *Server) sequenceIDFromPath(path string) (string, *errHTTP) {
 	parts := strings.Split(path, "/")
 	if len(parts) < 3 {
@@ -1817,6 +1781,15 @@ func (s *Server) sequenceIDFromPath(path string) (string, *errHTTP) {
 	return parts[2], nil
 }
 
+// sequenceIDFromMarkReadPath returns the sequence ID from a path like /mytopic/sequenceIdHere/read
+func (s *Server) sequenceIDFromMarkReadPath(path string) (string, *errHTTP) {
+	parts := strings.Split(path, "/")
+	if len(parts) < 4 || parts[3] != "read" {
+		return "", errHTTPBadRequestSequenceIDInvalid
+	}
+	return parts[2], nil
+}
+
 // topicsFromIDs returns the topics with the given IDs, creating them if they don't exist.
 func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
 	s.mu.Lock()

+ 1 - 1
server/server_webpush.go

@@ -89,7 +89,7 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
 		return
 	}
 	log.Tag(tagWebPush).With(v, m).Debug("Publishing web push message to %d subscribers", len(subscriptions))
-	payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m))
+	payload, err := json.Marshal(newWebPushPayload(fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic), m.forJSON()))
 	if err != nil {
 		log.Tag(tagWebPush).Err(err).With(v, m).Warn("Unable to marshal expiring payload")
 		return

+ 11 - 0
server/types.go

@@ -65,6 +65,17 @@ func (m *message) Context() log.Context {
 	return fields
 }
 
+// forJSON returns a copy of the message suitable for JSON output.
+// It clears the SequenceID if it equals the ID to reduce redundancy.
+func (m *message) forJSON() *message {
+	if m.SequenceID == m.ID {
+		clone := *m
+		clone.SequenceID = ""
+		return &clone
+	}
+	return m
+}
+
 type attachment struct {
 	Name    string `json:"name"`
 	Type    string `json:"type,omitempty"`