Bläddra i källkod

Combine things, move stuff

Philipp Heckel 4 år sedan
förälder
incheckning
2b6363474e
5 ändrade filer med 231 tillägg och 187 borttagningar
  1. 40 129
      server/server.go
  2. 0 58
      server/server_test.go
  3. 70 0
      server/types.go
  4. 55 0
      server/util.go
  5. 66 0
      server/util_test.go

+ 40 - 129
server/server.go

@@ -32,9 +32,6 @@ import (
 	"unicode/utf8"
 )
 
-// TODO add "max messages in a topic" limit
-// TODO implement "since=<ID>"
-
 // Server is the main server, providing the UI and API for ntfy
 type Server struct {
 	config       *Config
@@ -59,25 +56,6 @@ type indexPage struct {
 	CacheDuration time.Duration
 }
 
-type sinceTime time.Time
-
-func (t sinceTime) IsAll() bool {
-	return t == sinceAllMessages
-}
-
-func (t sinceTime) IsNone() bool {
-	return t == sinceNoMessages
-}
-
-func (t sinceTime) Time() time.Time {
-	return time.Time(t)
-}
-
-var (
-	sinceAllMessages = sinceTime(time.Unix(0, 0))
-	sinceNoMessages  = sinceTime(time.Unix(1, 0))
-)
-
 var (
 	topicRegex       = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No /!
 	topicPathRegex   = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
@@ -117,7 +95,6 @@ const (
 	firebaseControlTopic     = "~control"                // See Android if changed
 	emptyMessageBody         = "triggered"               // Used if message body is empty
 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
-	fcmMessageLimit          = 4000                      // see maybeTruncateFCMMessage for details
 )
 
 // WebSocket constants
@@ -232,25 +209,6 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) {
 	}, nil
 }
 
-// maybeTruncateFCMMessage performs best-effort truncation of FCM messages.
-// The docs say the limit is 4000 characters, but during testing it wasn't quite clear
-// what fields matter; so we're just capping the serialized JSON to 4000 bytes.
-func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
-	s, err := json.Marshal(m)
-	if err != nil {
-		return m
-	}
-	if len(s) > fcmMessageLimit {
-		over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ...
-		message, ok := m.Data["message"]
-		if ok && len(message) > over {
-			m.Data["truncated"] = "1"
-			m.Data["message"] = message[:len(message)-over]
-		}
-	}
-	return m
-}
-
 // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
 // a manager go routine to print stats and prune messages.
 func (s *Server) Run() error {
@@ -391,7 +349,7 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
 }
 
 func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error {
-	unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see PUT/POST too!
+	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
 	if unifiedpush {
 		w.Header().Set("Content-Type", "application/json")
 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
@@ -497,13 +455,15 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 	if err := json.NewEncoder(w).Encode(m); err != nil {
 		return err
 	}
-	s.inc(&s.messages)
+	s.mu.Lock()
+	s.messages++
+	s.mu.Unlock()
 	return nil
 }
 
 func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, err error) {
-	cache = readParam(r, "x-cache", "cache") != "no"
-	firebase = readParam(r, "x-firebase", "firebase") != "no"
+	cache = readBoolParam(r, true, "x-cache", "cache")
+	firebase = readBoolParam(r, true, "x-firebase", "firebase")
 	m.Title = readParam(r, "x-title", "title", "t")
 	m.Click = readParam(r, "x-click", "click")
 	filename := readParam(r, "x-filename", "filename", "file", "f")
@@ -574,29 +534,13 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 		}
 		m.Time = delay.Unix()
 	}
-	unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see GET too!
+	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
 	if unifiedpush {
 		firebase = false
 	}
 	return cache, firebase, email, nil
 }
 
-func readParam(r *http.Request, names ...string) string {
-	for _, name := range names {
-		value := r.Header.Get(name)
-		if value != "" {
-			return strings.TrimSpace(value)
-		}
-	}
-	for _, name := range names {
-		value := r.URL.Query().Get(strings.ToLower(name))
-		if value != "" {
-			return strings.TrimSpace(value)
-		}
-	}
-	return ""
-}
-
 // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
 //
 // 1. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
@@ -680,7 +624,7 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
 		}
 		return buf.String(), nil
 	}
-	return s.handleSubscribe(w, r, v, "json", "application/x-ndjson", encoder)
+	return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
 }
 
 func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -694,7 +638,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
 		}
 		return fmt.Sprintf("data: %s\n", buf.String()), nil
 	}
-	return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder)
+	return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
 }
 
 func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -704,33 +648,25 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v
 		}
 		return "\n", nil // "keepalive" and "open" events just send an empty line
 	}
-	return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder)
+	return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
 }
 
-func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error {
+func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
 	if err := v.SubscriptionAllowed(); err != nil {
 		return errHTTPTooManyRequestsLimitSubscriptions
 	}
 	defer v.RemoveSubscription()
-	topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack
-	topicIDs := util.SplitNoEmpty(topicsStr, ",")
-	topics, err := s.topicsFromIDs(topicIDs...)
+	topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
 	if err != nil {
 		return err
 	}
-	poll := readParam(r, "x-poll", "poll", "po") == "1"
-	scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1"
-	since, err := parseSince(r, poll)
-	if err != nil {
-		return err
-	}
-	messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r)
+	poll, since, scheduled, filters, err := parseSubscribeParams(r)
 	if err != nil {
 		return err
 	}
 	var wlock sync.Mutex
 	sub := func(msg *message) error {
-		if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) {
+		if !filters.Pass(msg) {
 			return nil
 		}
 		m, err := encoder(msg)
@@ -785,19 +721,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		return errHTTPTooManyRequestsLimitSubscriptions
 	}
 	defer v.RemoveSubscription()
-	topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/ws") // Hack
-	topicIDs := util.SplitNoEmpty(topicsStr, ",")
-	topics, err := s.topicsFromIDs(topicIDs...)
+	topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
 	if err != nil {
 		return err
 	}
-	poll := readParam(r, "x-poll", "poll", "po") == "1"
-	scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1"
-	since, err := parseSince(r, poll)
-	if err != nil {
-		return err
-	}
-	messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r)
+	poll, since, scheduled, filters, err := parseSubscribeParams(r)
 	if err != nil {
 		return err
 	}
@@ -850,7 +778,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		}
 	})
 	sub := func(msg *message) error {
-		if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) {
+		if !filters.Pass(msg) {
 			return nil
 		}
 		if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
@@ -884,42 +812,18 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	return err
 }
 
-func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string, err error) {
-	messageFilter = readParam(r, "x-message", "message", "m")
-	titleFilter = readParam(r, "x-title", "title", "t")
-	tagsFilter = util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",")
-	priorityFilter = make([]int, 0)
-	for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") {
-		priority, err := util.ParsePriority(p)
-		if err != nil {
-			return "", "", nil, nil, err
-		}
-		priorityFilter = append(priorityFilter, priority)
-	}
-	return
-}
-
-func passesQueryFilter(msg *message, messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string) bool {
-	if msg.Event != messageEvent {
-		return true // filters only apply to messages
-	}
-	if messageFilter != "" && msg.Message != messageFilter {
-		return false
-	}
-	if titleFilter != "" && msg.Title != titleFilter {
-		return false
-	}
-	messagePriority := msg.Priority
-	if messagePriority == 0 {
-		messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0)
-	}
-	if len(priorityFilter) > 0 && !util.InIntList(priorityFilter, messagePriority) {
-		return false
+func parseSubscribeParams(r *http.Request) (poll bool, since sinceTime, scheduled bool, filters *queryFilter, err error) {
+	poll = readBoolParam(r, false, "x-poll", "poll", "po")
+	scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
+	since, err = parseSince(r, poll)
+	if err != nil {
+		return
 	}
-	if len(tagsFilter) > 0 && !util.InStringListAll(msg.Tags, tagsFilter) {
-		return false
+	filters, err = parseQueryFilters(r)
+	if err != nil {
+		return
 	}
-	return true
+	return
 }
 
 func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error {
@@ -980,6 +884,19 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
 	return topics[0], nil
 }
 
+func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
+	parts := strings.Split(path, "/")
+	if len(parts) < 2 {
+		return nil, "", errHTTPBadRequestTopicInvalid
+	}
+	topicIDs := util.SplitNoEmpty(parts[1], ",")
+	topics, err := s.topicsFromIDs(topicIDs...)
+	if err != nil {
+		return nil, "", errHTTPBadRequestTopicInvalid
+	}
+	return topics, parts[1], nil
+}
+
 func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
@@ -1180,9 +1097,3 @@ func (s *Server) visitor(r *http.Request) *visitor {
 	v.Keepalive()
 	return v
 }
-
-func (s *Server) inc(counter *int64) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	*counter++
-}

+ 0 - 58
server/server_test.go

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"context"
 	"encoding/json"
-	"firebase.google.com/go/messaging"
 	"fmt"
 	"github.com/stretchr/testify/require"
 	"heckel.io/ntfy/util"
@@ -624,63 +623,6 @@ func TestServer_UnifiedPushDiscovery(t *testing.T) {
 	require.Equal(t, `{"unifiedpush":{"version":1}}`+"\n", response.Body.String())
 }
 
-func TestServer_MaybeTruncateFCMMessage(t *testing.T) {
-	origMessage := strings.Repeat("this is a long string", 300)
-	origFCMMessage := &messaging.Message{
-		Topic: "mytopic",
-		Data: map[string]string{
-			"id":       "abcdefg",
-			"time":     "1641324761",
-			"event":    "message",
-			"topic":    "mytopic",
-			"priority": "0",
-			"tags":     "",
-			"title":    "",
-			"message":  origMessage,
-		},
-		Android: &messaging.AndroidConfig{
-			Priority: "high",
-		},
-	}
-	origMessageLength := len(origFCMMessage.Data["message"])
-	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
-	require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
-
-	truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
-	truncatedMessageLength := len(truncatedFCMMessage.Data["message"])
-	serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage)
-	require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage))
-	require.Equal(t, "1", truncatedFCMMessage.Data["truncated"])
-	require.NotEqual(t, origMessageLength, truncatedMessageLength)
-}
-
-func TestServer_MaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
-	origMessage := "not really a long string"
-	origFCMMessage := &messaging.Message{
-		Topic: "mytopic",
-		Data: map[string]string{
-			"id":       "abcdefg",
-			"time":     "1641324761",
-			"event":    "message",
-			"topic":    "mytopic",
-			"priority": "0",
-			"tags":     "",
-			"title":    "",
-			"message":  origMessage,
-		},
-	}
-	origMessageLength := len(origFCMMessage.Data["message"])
-	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
-	require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
-
-	notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
-	notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"])
-	serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage)
-	require.Equal(t, origMessageLength, notTruncatedMessageLength)
-	require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
-	require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
-}
-
 func TestServer_PublishAttachment(t *testing.T) {
 	content := util.RandomString(5000) // > 4096
 	s := newTestServer(t, newTestConfig(t))

+ 70 - 0
server/message.go → server/types.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"heckel.io/ntfy/util"
+	"net/http"
 	"time"
 )
 
@@ -70,3 +71,72 @@ func newKeepaliveMessage(topic string) *message {
 func newDefaultMessage(topic, msg string) *message {
 	return newMessage(messageEvent, topic, msg)
 }
+
+type sinceTime time.Time
+
+func (t sinceTime) IsAll() bool {
+	return t == sinceAllMessages
+}
+
+func (t sinceTime) IsNone() bool {
+	return t == sinceNoMessages
+}
+
+func (t sinceTime) Time() time.Time {
+	return time.Time(t)
+}
+
+var (
+	sinceAllMessages = sinceTime(time.Unix(0, 0))
+	sinceNoMessages  = sinceTime(time.Unix(1, 0))
+)
+
+type queryFilter struct {
+	Message  string
+	Title    string
+	Tags     []string
+	Priority []int
+}
+
+func parseQueryFilters(r *http.Request) (*queryFilter, error) {
+	messageFilter := readParam(r, "x-message", "message", "m")
+	titleFilter := readParam(r, "x-title", "title", "t")
+	tagsFilter := util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",")
+	priorityFilter := make([]int, 0)
+	for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") {
+		priority, err := util.ParsePriority(p)
+		if err != nil {
+			return nil, err
+		}
+		priorityFilter = append(priorityFilter, priority)
+	}
+	return &queryFilter{
+		Message:  messageFilter,
+		Title:    titleFilter,
+		Tags:     tagsFilter,
+		Priority: priorityFilter,
+	}, nil
+}
+
+func (q *queryFilter) Pass(msg *message) bool {
+	if msg.Event != messageEvent {
+		return true // filters only apply to messages
+	}
+	if q.Message != "" && msg.Message != q.Message {
+		return false
+	}
+	if q.Title != "" && msg.Title != q.Title {
+		return false
+	}
+	messagePriority := msg.Priority
+	if messagePriority == 0 {
+		messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0)
+	}
+	if len(q.Priority) > 0 && !util.InIntList(q.Priority, messagePriority) {
+		return false
+	}
+	if len(q.Tags) > 0 && !util.InStringListAll(msg.Tags, q.Tags) {
+		return false
+	}
+	return true
+}

+ 55 - 0
server/util.go

@@ -0,0 +1,55 @@
+package server
+
+import (
+	"encoding/json"
+	"firebase.google.com/go/messaging"
+	"net/http"
+	"strings"
+)
+
+const (
+	fcmMessageLimit = 4000
+)
+
+// maybeTruncateFCMMessage performs best-effort truncation of FCM messages.
+// The docs say the limit is 4000 characters, but during testing it wasn't quite clear
+// what fields matter; so we're just capping the serialized JSON to 4000 bytes.
+func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
+	s, err := json.Marshal(m)
+	if err != nil {
+		return m
+	}
+	if len(s) > fcmMessageLimit {
+		over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ...
+		message, ok := m.Data["message"]
+		if ok && len(message) > over {
+			m.Data["truncated"] = "1"
+			m.Data["message"] = message[:len(message)-over]
+		}
+	}
+	return m
+}
+
+func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
+	value := strings.ToLower(readParam(r, names...))
+	if value == "" {
+		return defaultValue
+	}
+	return value == "1" || value == "yes" || value == "true"
+}
+
+func readParam(r *http.Request, names ...string) string {
+	for _, name := range names {
+		value := r.Header.Get(name)
+		if value != "" {
+			return strings.TrimSpace(value)
+		}
+	}
+	for _, name := range names {
+		value := r.URL.Query().Get(strings.ToLower(name))
+		if value != "" {
+			return strings.TrimSpace(value)
+		}
+	}
+	return ""
+}

+ 66 - 0
server/util_test.go

@@ -0,0 +1,66 @@
+package server
+
+import (
+	"encoding/json"
+	"firebase.google.com/go/messaging"
+	"github.com/stretchr/testify/require"
+	"strings"
+	"testing"
+)
+
+func TestMaybeTruncateFCMMessage(t *testing.T) {
+	origMessage := strings.Repeat("this is a long string", 300)
+	origFCMMessage := &messaging.Message{
+		Topic: "mytopic",
+		Data: map[string]string{
+			"id":       "abcdefg",
+			"time":     "1641324761",
+			"event":    "message",
+			"topic":    "mytopic",
+			"priority": "0",
+			"tags":     "",
+			"title":    "",
+			"message":  origMessage,
+		},
+		Android: &messaging.AndroidConfig{
+			Priority: "high",
+		},
+	}
+	origMessageLength := len(origFCMMessage.Data["message"])
+	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
+	require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
+
+	truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
+	truncatedMessageLength := len(truncatedFCMMessage.Data["message"])
+	serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage)
+	require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage))
+	require.Equal(t, "1", truncatedFCMMessage.Data["truncated"])
+	require.NotEqual(t, origMessageLength, truncatedMessageLength)
+}
+
+func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
+	origMessage := "not really a long string"
+	origFCMMessage := &messaging.Message{
+		Topic: "mytopic",
+		Data: map[string]string{
+			"id":       "abcdefg",
+			"time":     "1641324761",
+			"event":    "message",
+			"topic":    "mytopic",
+			"priority": "0",
+			"tags":     "",
+			"title":    "",
+			"message":  origMessage,
+		},
+	}
+	origMessageLength := len(origFCMMessage.Data["message"])
+	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
+	require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
+
+	notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
+	notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"])
+	serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage)
+	require.Equal(t, origMessageLength, notTruncatedMessageLength)
+	require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
+	require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
+}