Browse Source

Add tests, fix firebase

binwiederhier 1 month ago
parent
commit
44f20f6b4c

+ 10 - 0
server/server_firebase.go

@@ -143,6 +143,15 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
 			"poll_id": m.PollID,
 			"poll_id": m.PollID,
 		}
 		}
 		apnsConfig = createAPNSAlertConfig(m, data)
 		apnsConfig = createAPNSAlertConfig(m, data)
+	case messageDeleteEvent, messageClearEvent:
+		data = map[string]string{
+			"id":          m.ID,
+			"time":        fmt.Sprintf("%d", m.Time),
+			"event":       m.Event,
+			"topic":       m.Topic,
+			"sequence_id": m.SequenceID,
+		}
+		apnsConfig = createAPNSBackgroundConfig(data)
 	case messageEvent:
 	case messageEvent:
 		if auther != nil {
 		if auther != nil {
 			// If "anonymous read" for a topic is not allowed, we cannot send the message along
 			// If "anonymous read" for a topic is not allowed, we cannot send the message along
@@ -161,6 +170,7 @@ func toFirebaseMessage(m *message, auther user.Auther) (*messaging.Message, erro
 			"time":         fmt.Sprintf("%d", m.Time),
 			"time":         fmt.Sprintf("%d", m.Time),
 			"event":        m.Event,
 			"event":        m.Event,
 			"topic":        m.Topic,
 			"topic":        m.Topic,
+			"sequence_id":  m.SequenceID,
 			"priority":     fmt.Sprintf("%d", m.Priority),
 			"priority":     fmt.Sprintf("%d", m.Priority),
 			"tags":         strings.Join(m.Tags, ","),
 			"tags":         strings.Join(m.Tags, ","),
 			"click":        m.Click,
 			"click":        m.Click,

+ 3 - 0
server/server_firebase_test.go

@@ -177,6 +177,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
 				"time":               fmt.Sprintf("%d", m.Time),
 				"time":               fmt.Sprintf("%d", m.Time),
 				"event":              "message",
 				"event":              "message",
 				"topic":              "mytopic",
 				"topic":              "mytopic",
+				"sequence_id":        "",
 				"priority":           "4",
 				"priority":           "4",
 				"tags":               strings.Join(m.Tags, ","),
 				"tags":               strings.Join(m.Tags, ","),
 				"click":              "https://google.com",
 				"click":              "https://google.com",
@@ -199,6 +200,7 @@ func TestToFirebaseMessage_Message_Normal_Allowed(t *testing.T) {
 		"time":               fmt.Sprintf("%d", m.Time),
 		"time":               fmt.Sprintf("%d", m.Time),
 		"event":              "message",
 		"event":              "message",
 		"topic":              "mytopic",
 		"topic":              "mytopic",
+		"sequence_id":        "",
 		"priority":           "4",
 		"priority":           "4",
 		"tags":               strings.Join(m.Tags, ","),
 		"tags":               strings.Join(m.Tags, ","),
 		"click":              "https://google.com",
 		"click":              "https://google.com",
@@ -232,6 +234,7 @@ func TestToFirebaseMessage_Message_Normal_Not_Allowed(t *testing.T) {
 		"time":         fmt.Sprintf("%d", m.Time),
 		"time":         fmt.Sprintf("%d", m.Time),
 		"event":        "poll_request",
 		"event":        "poll_request",
 		"topic":        "mytopic",
 		"topic":        "mytopic",
+		"sequence_id":  "",
 		"message":      "New message",
 		"message":      "New message",
 		"title":        "",
 		"title":        "",
 		"tags":         "",
 		"tags":         "",

+ 208 - 2
server/server_test.go

@@ -8,8 +8,6 @@ import (
 	"encoding/base64"
 	"encoding/base64"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
-	"golang.org/x/crypto/bcrypt"
-	"heckel.io/ntfy/v2/user"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"net/http/httptest"
 	"net/http/httptest"
@@ -24,7 +22,9 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/require"
+	"golang.org/x/crypto/bcrypt"
 	"heckel.io/ntfy/v2/log"
 	"heckel.io/ntfy/v2/log"
+	"heckel.io/ntfy/v2/user"
 	"heckel.io/ntfy/v2/util"
 	"heckel.io/ntfy/v2/util"
 )
 )
 
 
@@ -3289,6 +3289,212 @@ func TestServer_MessageTemplate_Until100_000(t *testing.T) {
 	require.Contains(t, toHTTPError(t, response.Body.String()).Message, "too many iterations")
 	require.Contains(t, toHTTPError(t, response.Body.String()).Message, "too many iterations")
 }
 }
 
 
+func TestServer_DeleteMessage(t *testing.T) {
+	t.Parallel()
+	s := newTestServer(t, newTestConfig(t))
+
+	// Publish a message with a sequence ID
+	response := request(t, s, "PUT", "/mytopic/seq123", "original message", nil)
+	require.Equal(t, 200, response.Code)
+	msg := toMessage(t, response.Body.String())
+	require.Equal(t, "seq123", msg.SequenceID)
+	require.Equal(t, "message", msg.Event)
+
+	// Delete the message using DELETE method
+	response = request(t, s, "DELETE", "/mytopic/seq123", "", nil)
+	require.Equal(t, 200, response.Code)
+	deleteMsg := toMessage(t, response.Body.String())
+	require.Equal(t, "seq123", deleteMsg.SequenceID)
+	require.Equal(t, "message_delete", deleteMsg.Event)
+
+	// Poll and verify both messages are returned
+	response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
+	require.Equal(t, 200, response.Code)
+	lines := strings.Split(strings.TrimSpace(response.Body.String()), "\n")
+	require.Equal(t, 2, len(lines))
+
+	msg1 := toMessage(t, lines[0])
+	msg2 := toMessage(t, lines[1])
+	require.Equal(t, "message", msg1.Event)
+	require.Equal(t, "message_delete", msg2.Event)
+	require.Equal(t, "seq123", msg1.SequenceID)
+	require.Equal(t, "seq123", msg2.SequenceID)
+}
+
+func TestServer_ClearMessage(t *testing.T) {
+	t.Parallel()
+	s := newTestServer(t, newTestConfig(t))
+
+	// Publish a message with a sequence ID
+	response := request(t, s, "PUT", "/mytopic/seq456", "original message", nil)
+	require.Equal(t, 200, response.Code)
+	msg := toMessage(t, response.Body.String())
+	require.Equal(t, "seq456", msg.SequenceID)
+	require.Equal(t, "message", msg.Event)
+
+	// Clear the message using PUT /topic/seq/clear
+	response = request(t, s, "PUT", "/mytopic/seq456/clear", "", nil)
+	require.Equal(t, 200, response.Code)
+	clearMsg := toMessage(t, response.Body.String())
+	require.Equal(t, "seq456", clearMsg.SequenceID)
+	require.Equal(t, "message_clear", clearMsg.Event)
+
+	// Poll and verify both messages are returned
+	response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
+	require.Equal(t, 200, response.Code)
+	lines := strings.Split(strings.TrimSpace(response.Body.String()), "\n")
+	require.Equal(t, 2, len(lines))
+
+	msg1 := toMessage(t, lines[0])
+	msg2 := toMessage(t, lines[1])
+	require.Equal(t, "message", msg1.Event)
+	require.Equal(t, "message_clear", msg2.Event)
+	require.Equal(t, "seq456", msg1.SequenceID)
+	require.Equal(t, "seq456", msg2.SequenceID)
+}
+
+func TestServer_ClearMessage_ReadEndpoint(t *testing.T) {
+	// Test that /topic/seq/read also works
+	t.Parallel()
+	s := newTestServer(t, newTestConfig(t))
+
+	// Publish a message
+	response := request(t, s, "PUT", "/mytopic/seq789", "original message", nil)
+	require.Equal(t, 200, response.Code)
+
+	// Clear using /read endpoint
+	response = request(t, s, "PUT", "/mytopic/seq789/read", "", nil)
+	require.Equal(t, 200, response.Code)
+	clearMsg := toMessage(t, response.Body.String())
+	require.Equal(t, "seq789", clearMsg.SequenceID)
+	require.Equal(t, "message_clear", clearMsg.Event)
+}
+
+func TestServer_UpdateMessage(t *testing.T) {
+	t.Parallel()
+	s := newTestServer(t, newTestConfig(t))
+
+	// Publish original message
+	response := request(t, s, "PUT", "/mytopic/update-seq", "original message", nil)
+	require.Equal(t, 200, response.Code)
+	msg1 := toMessage(t, response.Body.String())
+	require.Equal(t, "update-seq", msg1.SequenceID)
+	require.Equal(t, "original message", msg1.Message)
+
+	// Update the message (same sequence ID, new content)
+	response = request(t, s, "PUT", "/mytopic/update-seq", "updated message", nil)
+	require.Equal(t, 200, response.Code)
+	msg2 := toMessage(t, response.Body.String())
+	require.Equal(t, "update-seq", msg2.SequenceID)
+	require.Equal(t, "updated message", msg2.Message)
+	require.NotEqual(t, msg1.ID, msg2.ID) // Different message IDs
+
+	// Poll and verify both versions are returned
+	response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
+	require.Equal(t, 200, response.Code)
+	lines := strings.Split(strings.TrimSpace(response.Body.String()), "\n")
+	require.Equal(t, 2, len(lines))
+
+	polledMsg1 := toMessage(t, lines[0])
+	polledMsg2 := toMessage(t, lines[1])
+	require.Equal(t, "original message", polledMsg1.Message)
+	require.Equal(t, "updated message", polledMsg2.Message)
+	require.Equal(t, "update-seq", polledMsg1.SequenceID)
+	require.Equal(t, "update-seq", polledMsg2.SequenceID)
+}
+
+func TestServer_UpdateMessage_UsingMessageID(t *testing.T) {
+	t.Parallel()
+	s := newTestServer(t, newTestConfig(t))
+
+	// Publish original message without a sequence ID
+	response := request(t, s, "PUT", "/mytopic", "original message", nil)
+	require.Equal(t, 200, response.Code)
+	msg1 := toMessage(t, response.Body.String())
+	require.NotEmpty(t, msg1.ID)
+	require.Empty(t, msg1.SequenceID) // No sequence ID provided
+	require.Equal(t, "original message", msg1.Message)
+
+	// Update the message using the message ID as the sequence ID
+	response = request(t, s, "PUT", "/mytopic/"+msg1.ID, "updated message", nil)
+	require.Equal(t, 200, response.Code)
+	msg2 := toMessage(t, response.Body.String())
+	require.Equal(t, msg1.ID, msg2.SequenceID) // Message ID is now used as sequence ID
+	require.Equal(t, "updated message", msg2.Message)
+	require.NotEqual(t, msg1.ID, msg2.ID) // Different message IDs
+
+	// Poll and verify both versions are returned
+	response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
+	require.Equal(t, 200, response.Code)
+	lines := strings.Split(strings.TrimSpace(response.Body.String()), "\n")
+	require.Equal(t, 2, len(lines))
+
+	polledMsg1 := toMessage(t, lines[0])
+	polledMsg2 := toMessage(t, lines[1])
+	require.Equal(t, "original message", polledMsg1.Message)
+	require.Equal(t, "updated message", polledMsg2.Message)
+	require.Empty(t, polledMsg1.SequenceID)          // Original has no sequence ID
+	require.Equal(t, msg1.ID, polledMsg2.SequenceID) // Update uses original message ID as sequence ID
+}
+
+func TestServer_DeleteAndClear_InvalidSequenceID(t *testing.T) {
+	t.Parallel()
+	s := newTestServer(t, newTestConfig(t))
+
+	// Test invalid sequence ID for delete (returns 404 because route doesn't match)
+	response := request(t, s, "DELETE", "/mytopic/invalid*seq", "", nil)
+	require.Equal(t, 404, response.Code)
+
+	// Test invalid sequence ID for clear (returns 404 because route doesn't match)
+	response = request(t, s, "PUT", "/mytopic/invalid*seq/clear", "", nil)
+	require.Equal(t, 404, response.Code)
+}
+
+func TestServer_DeleteMessage_WithFirebase(t *testing.T) {
+	sender := newTestFirebaseSender(10)
+	s := newTestServer(t, newTestConfig(t))
+	s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
+
+	// Publish a message
+	response := request(t, s, "PUT", "/mytopic/firebase-seq", "test message", nil)
+	require.Equal(t, 200, response.Code)
+
+	time.Sleep(100 * time.Millisecond) // Firebase publishing happens
+	require.Equal(t, 1, len(sender.Messages()))
+	require.Equal(t, "message", sender.Messages()[0].Data["event"])
+
+	// Delete the message
+	response = request(t, s, "DELETE", "/mytopic/firebase-seq", "", nil)
+	require.Equal(t, 200, response.Code)
+
+	time.Sleep(100 * time.Millisecond) // Firebase publishing happens
+	require.Equal(t, 2, len(sender.Messages()))
+	require.Equal(t, "message_delete", sender.Messages()[1].Data["event"])
+	require.Equal(t, "firebase-seq", sender.Messages()[1].Data["sequence_id"])
+}
+
+func TestServer_ClearMessage_WithFirebase(t *testing.T) {
+	sender := newTestFirebaseSender(10)
+	s := newTestServer(t, newTestConfig(t))
+	s.firebaseClient = newFirebaseClient(sender, &testAuther{Allow: true})
+
+	// Publish a message
+	response := request(t, s, "PUT", "/mytopic/firebase-clear-seq", "test message", nil)
+	require.Equal(t, 200, response.Code)
+
+	time.Sleep(100 * time.Millisecond)
+	require.Equal(t, 1, len(sender.Messages()))
+
+	// Clear the message
+	response = request(t, s, "PUT", "/mytopic/firebase-clear-seq/clear", "", nil)
+	require.Equal(t, 200, response.Code)
+
+	time.Sleep(100 * time.Millisecond)
+	require.Equal(t, 2, len(sender.Messages()))
+	require.Equal(t, "message_clear", sender.Messages()[1].Data["event"])
+	require.Equal(t, "firebase-clear-seq", sender.Messages()[1].Data["sequence_id"])
+}
+
 func newTestConfig(t *testing.T) *Config {
 func newTestConfig(t *testing.T) *Config {
 	conf := NewConfig()
 	conf := NewConfig()
 	conf.BaseURL = "http://127.0.0.1:12345"
 	conf.BaseURL = "http://127.0.0.1:12345"

+ 4 - 3
web/src/app/SubscriptionManager.js

@@ -202,9 +202,10 @@ class SubscriptionManager {
 
 
   /** Adds/replaces notifications, will not throw if they exist */
   /** Adds/replaces notifications, will not throw if they exist */
   async addNotifications(subscriptionId, notifications) {
   async addNotifications(subscriptionId, notifications) {
-    const notificationsWithSubscriptionId = notifications.map((notification) => {
-      return { ...messageWithSequenceId(notification), subscriptionId };
-    });
+    const notificationsWithSubscriptionId = notifications.map((notification) => ({
+      ...messageWithSequenceId(notification),
+      subscriptionId,
+    }));
     const lastNotificationId = notifications.at(-1).id;
     const lastNotificationId = notifications.at(-1).id;
     await this.db.notifications.bulkPut(notificationsWithSubscriptionId);
     await this.db.notifications.bulkPut(notificationsWithSubscriptionId);
     await this.db.subscriptions.update(subscriptionId, {
     await this.db.subscriptions.update(subscriptionId, {

+ 1 - 1
web/src/app/notificationUtils.js

@@ -50,7 +50,7 @@ export const isImage = (attachment) => {
 export const icon = "/static/images/ntfy.png";
 export const icon = "/static/images/ntfy.png";
 export const badge = "/static/images/mask-icon.svg";
 export const badge = "/static/images/mask-icon.svg";
 
 
-export const toNotificationParams = ({ subscriptionId, message, defaultTitle, topicRoute }) => {
+export const toNotificationParams = ({ message, defaultTitle, topicRoute }) => {
   const image = isImage(message.attachment) ? message.attachment.url : undefined;
   const image = isImage(message.attachment) ? message.attachment.url : undefined;
 
 
   // https://developer.mozilla.org/en-US/docs/Web/API/Notifications_API
   // https://developer.mozilla.org/en-US/docs/Web/API/Notifications_API

+ 3 - 3
web/src/app/utils.js

@@ -104,10 +104,10 @@ export const maybeActionErrors = (notification) => {
 };
 };
 
 
 export const messageWithSequenceId = (message) => {
 export const messageWithSequenceId = (message) => {
-  if (!message.sequenceId) {
-    message.sequenceId = message.sequence_id || message.id;
+  if (message.sequenceId) {
+    return message;
   }
   }
-  return message;
+  return { ...message, sequenceId: message.sequence_id || message.id };
 };
 };
 
 
 export const shuffle = (arr) => {
 export const shuffle = (arr) => {