Pārlūkot izejas kodu

Fix `since=<id>` implementation for multiple topics, closes #336

Philipp Heckel 3 gadi atpakaļ
vecāks
revīzija
d05211648d
4 mainītis faili ar 63 papildinājumiem un 7 dzēšanām
  1. 1 0
      docs/releases.md
  2. 2 2
      server/message_cache.go
  3. 13 5
      server/server.go
  4. 47 0
      server/server_test.go

+ 1 - 0
docs/releases.md

@@ -14,6 +14,7 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
 
 
 * Return HTTP 500 for GET /_matrix/push/v1/notify when base-url is not configured (no ticket)
 * Return HTTP 500 for GET /_matrix/push/v1/notify when base-url is not configured (no ticket)
 * Disallow setting `upstream-base-url` to the same value as `base-url` ([#334](https://github.com/binwiederhier/ntfy/issues/334), thanks to [@oester](https://github.com/oester) for reporting)
 * Disallow setting `upstream-base-url` to the same value as `base-url` ([#334](https://github.com/binwiederhier/ntfy/issues/334), thanks to [@oester](https://github.com/oester) for reporting)
+* Fix `since=<id>` implementation for multiple topics ([#336](https://github.com/binwiederhier/ntfy/issues/336), thanks to [@karmanyaahm](https://github.com/karmanyaahm) for reporting)
 
 
 ## ntfy Android app v1.14.0 (UNRELEASED)
 ## ntfy Android app v1.14.0 (UNRELEASED)
 
 

+ 2 - 2
server/message_cache.go

@@ -49,7 +49,7 @@ const (
 		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
 		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
 	`
 	`
 	pruneMessagesQuery           = `DELETE FROM messages WHERE time < ? AND published = 1`
 	pruneMessagesQuery           = `DELETE FROM messages WHERE time < ? AND published = 1`
-	selectRowIDFromMessageID     = `SELECT id FROM messages WHERE topic = ? AND mid = ?`
+	selectRowIDFromMessageID     = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
 	selectMessagesSinceTimeQuery = `
 	selectMessagesSinceTimeQuery = `
 		SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
 		SELECT mid, time, topic, message, title, priority, tags, click, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
 		FROM messages 
 		FROM messages 
@@ -294,7 +294,7 @@ func (c *messageCache) messagesSinceTime(topic string, since sinceMarker, schedu
 }
 }
 
 
 func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
 func (c *messageCache) messagesSinceID(topic string, since sinceMarker, scheduled bool) ([]*message, error) {
-	idrows, err := c.db.Query(selectRowIDFromMessageID, topic, since.ID())
+	idrows, err := c.db.Query(selectRowIDFromMessageID, since.ID())
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 13 - 5
server/server.go

@@ -16,6 +16,7 @@ import (
 	"path"
 	"path"
 	"path/filepath"
 	"path/filepath"
 	"regexp"
 	"regexp"
+	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -972,19 +973,26 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
 	return
 	return
 }
 }
 
 
+// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
+// marker, returning only messages that are newer than the marker.
 func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
 func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
 	if since.IsNone() {
 	if since.IsNone() {
 		return nil
 		return nil
 	}
 	}
+	messages := make([]*message, 0)
 	for _, t := range topics {
 	for _, t := range topics {
-		messages, err := s.messageCache.Messages(t.ID, since, scheduled)
+		topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		for _, m := range messages {
-			if err := sub(v, m); err != nil {
-				return err
-			}
+		messages = append(messages, topicMessages...)
+	}
+	sort.Slice(messages, func(i, j int) bool {
+		return messages[i].Time < messages[j].Time
+	})
+	for _, m := range messages {
+		if err := sub(v, m); err != nil {
+			return err
 		}
 		}
 	}
 	}
 	return nil
 	return nil

+ 47 - 0
server/server_test.go

@@ -437,6 +437,53 @@ func TestServer_PublishAndPollSince(t *testing.T) {
 	require.Equal(t, 40008, toHTTPError(t, response.Body.String()).Code)
 	require.Equal(t, 40008, toHTTPError(t, response.Body.String()).Code)
 }
 }
 
 
+func newMessageWithTimestamp(topic, message string, timestamp int64) *message {
+	m := newDefaultMessage(topic, message)
+	m.Time = timestamp
+	return m
+}
+
+func TestServer_PollSinceID_MultipleTopics(t *testing.T) {
+	s := newTestServer(t, newTestConfig(t))
+
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 1", 1655740277)))
+	markerMessage := newMessageWithTimestamp("mytopic2", "test 2", 1655740283)
+	require.Nil(t, s.messageCache.AddMessage(markerMessage))
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 3", 1655740289)))
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 4", 1655740293)))
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 5", 1655740297)))
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 6", 1655740303)))
+
+	response := request(t, s, "GET", fmt.Sprintf("/mytopic1,mytopic2/json?poll=1&since=%s", markerMessage.ID), "", nil)
+	messages := toMessages(t, response.Body.String())
+	require.Equal(t, 4, len(messages))
+	require.Equal(t, "test 3", messages[0].Message)
+	require.Equal(t, "mytopic1", messages[0].Topic)
+	require.Equal(t, "test 4", messages[1].Message)
+	require.Equal(t, "mytopic2", messages[1].Topic)
+	require.Equal(t, "test 5", messages[2].Message)
+	require.Equal(t, "mytopic1", messages[2].Topic)
+	require.Equal(t, "test 6", messages[3].Message)
+	require.Equal(t, "mytopic2", messages[3].Topic)
+}
+
+func TestServer_PollSinceID_MultipleTopics_IDDoesNotMatch(t *testing.T) {
+	s := newTestServer(t, newTestConfig(t))
+
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 3", 1655740289)))
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 4", 1655740293)))
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic1", "test 5", 1655740297)))
+	require.Nil(t, s.messageCache.AddMessage(newMessageWithTimestamp("mytopic2", "test 6", 1655740303)))
+
+	response := request(t, s, "GET", "/mytopic1,mytopic2/json?poll=1&since=NoMatchForID", "", nil)
+	messages := toMessages(t, response.Body.String())
+	require.Equal(t, 4, len(messages))
+	require.Equal(t, "test 3", messages[0].Message)
+	require.Equal(t, "test 4", messages[1].Message)
+	require.Equal(t, "test 5", messages[2].Message)
+	require.Equal(t, "test 6", messages[3].Message)
+}
+
 func TestServer_PublishViaGET(t *testing.T) {
 func TestServer_PublishViaGET(t *testing.T) {
 	s := newTestServer(t, newTestConfig(t))
 	s := newTestServer(t, newTestConfig(t))