Sfoglia il codice sorgente

Simplify(?) templating cases

binwiederhier 1 anno fa
parent
commit
de65d07518

+ 2 - 0
cmd/access_test.go

@@ -10,6 +10,7 @@ import (
 )
 
 func TestCLI_Access_Show(t *testing.T) {
+	t.Parallel()
 	s, conf, port := newTestServerWithAuth(t)
 	defer test.StopServer(t, s, port)
 
@@ -19,6 +20,7 @@ func TestCLI_Access_Show(t *testing.T) {
 }
 
 func TestCLI_Access_Grant_And_Publish(t *testing.T) {
+	t.Parallel()
 	s, conf, port := newTestServerWithAuth(t)
 	defer test.StopServer(t, s, port)
 

+ 1 - 0
cmd/config_loader_test.go

@@ -8,6 +8,7 @@ import (
 )
 
 func TestNewYamlSourceFromFile(t *testing.T) {
+	t.Parallel()
 	filename := filepath.Join(t.TempDir(), "server.yml")
 	contents := `
 # Normal options

+ 3 - 0
cmd/publish_test.go

@@ -17,6 +17,7 @@ import (
 )
 
 func TestCLI_Publish_Subscribe_Poll_Real_Server(t *testing.T) {
+	t.Parallel()
 	testMessage := util.RandomString(10)
 	app, _, _, _ := newTestApp()
 	require.Nil(t, app.Run([]string{"ntfy", "publish", "ntfytest", "ntfy unit test " + testMessage}))
@@ -35,6 +36,7 @@ func TestCLI_Publish_Subscribe_Poll_Real_Server(t *testing.T) {
 }
 
 func TestCLI_Publish_Subscribe_Poll(t *testing.T) {
+	t.Parallel()
 	s, port := test.StartServer(t)
 	defer test.StopServer(t, s, port)
 	topic := fmt.Sprintf("http://127.0.0.1:%d/mytopic", port)
@@ -51,6 +53,7 @@ func TestCLI_Publish_Subscribe_Poll(t *testing.T) {
 }
 
 func TestCLI_Publish_All_The_Things(t *testing.T) {
+	t.Parallel()
 	s, port := test.StartServer(t)
 	defer test.StopServer(t, s, port)
 	topic := fmt.Sprintf("http://127.0.0.1:%d/mytopic", port)

+ 2 - 1
server/errors.go

@@ -117,7 +117,8 @@ var (
 	errHTTPBadRequestWebPushSubscriptionInvalid      = &errHTTP{40038, http.StatusBadRequest, "invalid request: web push payload malformed", "", nil}
 	errHTTPBadRequestWebPushEndpointUnknown          = &errHTTP{40039, http.StatusBadRequest, "invalid request: web push endpoint unknown", "", nil}
 	errHTTPBadRequestWebPushTopicCountTooHigh        = &errHTTP{40040, http.StatusBadRequest, "invalid request: too many web push topic subscriptions", "", nil}
-	errHTTPBadRequestTemplatedMessageTooLarge        = &errHTTP{40041, http.StatusBadRequest, "invalid request: message is too large after replacing template", "", nil}
+	errHTTPBadRequestTemplatedMessageTooLarge        = &errHTTP{40041, http.StatusBadRequest, "invalid request: message or title is too large after replacing template", "", nil}
+	errHTTPBadRequestTemplatedMessageNotJSON         = &errHTTP{40042, http.StatusBadRequest, "invalid request: message body must be JSON if templating is enabled", "", nil}
 	errHTTPNotFound                                  = &errHTTP{40401, http.StatusNotFound, "page not found", "", nil}
 	errHTTPUnauthorized                              = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication", nil}
 	errHTTPForbidden                                 = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication", nil}

+ 43 - 41
server/server.go

@@ -111,6 +111,7 @@ var (
 	urlRegex                                             = regexp.MustCompile(`^https?://`)
 	phoneNumberRegex                                     = regexp.MustCompile(`^\+\d{1,100}$`)
 	templateVarRegex                                     = regexp.MustCompile(`\${([^}]+)}`)
+	templateVarFormat                                    = "${%s}"
 
 	//go:embed site
 	webFs       embed.FS
@@ -125,12 +126,12 @@ var (
 
 const (
 	firebaseControlTopic     = "~control"                // See Android if changed
-	firebasePollTopic        = "~poll"                   // See iOS if changed
+	firebasePollTopic        = "~poll"                   // See iOS if changed (DISABLED for now)
 	emptyMessageBody         = "triggered"               // Used if message body is empty
 	newMessageBody           = "New message"             // Used in poll requests as generic message
 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
 	encodingBase64           = "base64"                  // Used mainly for binary UnifiedPush messages
-	jsonBodyBytesLimit       = 16384                     // Max number of bytes for a JSON request body
+	httpBodyBytesLimit       = 32768                     // Max number of bytes for a request bodys (unless MessageLimit is higher)
 	unifiedPushTopicPrefix   = "up"                      // Temporarily, we rate limit all "up*" topics based on the subscriber
 	unifiedPushTopicLength   = 14                        // Length of UnifiedPush topics, including the "up" part
 	messagesHistoryMax       = 10                        // Number of message count values to keep in memory
@@ -675,7 +676,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
 	//   - avoid abuse (e.g. 1 uploader, 1k downloaders)
 	//   - and also uses the higher bandwidth limits of a paying user
 	m, err := s.messageCache.Message(messageID)
-	if err == errMessageNotFound {
+	if errors.Is(err, errMessageNotFound) {
 		if s.config.CacheBatchTimeout > 0 {
 			// Strange edge case: If we immediately after upload request the file (the web app does this for images),
 			// and messages are persisted asynchronously, retry fetching from the database
@@ -874,7 +875,7 @@ func (s *Server) sendToFirebase(v *visitor, m *message) {
 	logvm(v, m).Tag(tagFirebase).Debug("Publishing to Firebase")
 	if err := s.firebaseClient.Send(v, m); err != nil {
 		minc(metricFirebasePublishedFailure)
-		if err == errFirebaseTemporarilyBanned {
+		if errors.Is(err, errFirebaseTemporarilyBanned) {
 			logvm(v, m).Tag(tagFirebase).Err(err).Debug("Unable to publish to Firebase: %v", err.Error())
 		} else {
 			logvm(v, m).Tag(tagFirebase).Err(err).Warn("Unable to publish to Firebase: %v", err.Error())
@@ -1036,37 +1037,30 @@ func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, fi
 //  1. curl -X POST -H "Poll: 1234" ntfy.sh/...
 //     If a message is flagged as poll request, the body does not matter and is discarded
 //  2. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
-//     If body is binary, encode as base64, if not do not encode
+//     If UnifiedPush is enabled, encode as base64 if body is binary, and do not trim
 //  3. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
 //     Body must be a message, because we attached an external URL
 //  4. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
 //     Body must be attachment, because we passed a filename
-//  5. curl -T file.txt ntfy.sh/mytopic
+//  5. curl -H "Template: yes" -T file.txt ntfy.sh/mytopic
+//     If templating is enabled, read up to 32k and treat message body as JSON
+//  6. curl -T file.txt ntfy.sh/mytopic
 //     If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
-//  6. curl -H "Template: yes" -T file.txt ntfy.sh/mytopic
-//     If file.txt is < 4096*2 (message limit*2) and a template is used, try parsing under the assumption
-//     that the message generated by the template will be less than 4096
 //  7. curl -T file.txt ntfy.sh/mytopic
 //     If file.txt is > message limit or template && file.txt > message limit*2, treat it as an attachment
-func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template bool, unifiedpush bool) error {
+func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, template, unifiedpush bool) error {
 	if m.Event == pollRequestEvent { // Case 1
 		return s.handleBodyDiscard(body)
 	} else if unifiedpush {
 		return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
 	} else if m.Attachment != nil && m.Attachment.URL != "" {
-		return s.handleBodyAsTextMessage(m, body, template) // Case 3
+		return s.handleBodyAsTextMessage(m, body) // Case 3
 	} else if m.Attachment != nil && m.Attachment.Name != "" {
 		return s.handleBodyAsAttachment(r, v, m, body) // Case 4
-	} else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
-		return s.handleBodyAsTextMessage(m, body, template) // Case 5
 	} else if template {
-		templateBody, err := util.Peek(body, s.config.MessageSizeLimit*2)
-		if err != nil {
-			return err
-		}
-		if !templateBody.LimitReached {
-			return s.handleBodyAsTextMessage(m, templateBody, template) // Case 6
-		}
+		return s.handleBodyAsTemplatedTextMessage(m, body) // Case 5
+	} else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
+		return s.handleBodyAsTextMessage(m, body) // Case 6
 	}
 	return s.handleBodyAsAttachment(r, v, m, body) // Case 7
 }
@@ -1087,34 +1081,32 @@ func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedRead
 	return nil
 }
 
-func replaceGJSONTemplate(template string, source string) string {
-	matches := templateVarRegex.FindAllStringSubmatch(template, -1)
-	for _, v := range matches {
-		query := v[1]
-		if result := gjson.Get(source, query); result.Exists() {
-			template = strings.ReplaceAll(template, fmt.Sprintf("${%s}", query), result.String())
-		}
-	}
-	return template
-}
-
-func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser, template bool) error {
+func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
 	if !utf8.Valid(body.PeekedBytes) {
 		return errHTTPBadRequestMessageNotUTF8.With(m)
 	}
 	if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
-		peekedBody := strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
-		if template && gjson.Valid(peekedBody) {
-			m.Message = replaceGJSONTemplate(m.Message, peekedBody)
-			m.Title = replaceGJSONTemplate(m.Title, peekedBody)
-		} else {
-			m.Message = peekedBody
-		}
+		m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
 	}
 	if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
 		m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
 	}
-	// Ensure message is less than message limit after templating
+	return nil
+}
+
+func (s *Server) handleBodyAsTemplatedTextMessage(m *message, body *util.PeekedReadCloser) error {
+	body, err := util.Peek(body, httpBodyBytesLimit)
+	if err != nil {
+		return err
+	} else if body.LimitReached {
+		return errHTTPEntityTooLargeJSONBody
+	}
+	peekedBody := strings.TrimSpace(string(body.PeekedBytes))
+	if !gjson.Valid(peekedBody) {
+		return errHTTPBadRequestTemplatedMessageNotJSON
+	}
+	m.Message = replaceGJSONTemplate(m.Message, peekedBody)
+	m.Title = replaceGJSONTemplate(m.Title, peekedBody)
 	if len(m.Message) > s.config.MessageSizeLimit {
 		return errHTTPBadRequestTemplatedMessageTooLarge
 	}
@@ -1163,7 +1155,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
 		util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
 	}
 	m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
-	if err == util.ErrLimitReached {
+	if errors.Is(err, util.ErrLimitReached) {
 		return errHTTPEntityTooLargeAttachment.With(m)
 	} else if err != nil {
 		return err
@@ -1171,6 +1163,16 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
 	return nil
 }
 
+func replaceGJSONTemplate(template string, source string) string {
+	matches := templateVarRegex.FindAllStringSubmatch(template, -1)
+	for _, m := range matches {
+		if result := gjson.Get(source, m[1]); result.Exists() {
+			template = strings.ReplaceAll(template, fmt.Sprintf(templateVarFormat, m[1]), result.String())
+		}
+	}
+	return template
+}
+
 func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		var buf bytes.Buffer

+ 12 - 12
server/server_account.go

@@ -28,7 +28,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
 			return errHTTPTooManyRequestsLimitAccountCreation
 		}
 	}
-	newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit, false)
+	newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -160,7 +160,7 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
 }
 
 func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiAccountDeleteRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountDeleteRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	} else if req.Password == "" {
@@ -192,7 +192,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
 }
 
 func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiAccountPasswordChangeRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountPasswordChangeRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	} else if req.Password == "" || req.NewPassword == "" {
@@ -210,7 +210,7 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
 }
 
 func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiAccountTokenIssueRequest](r.Body, jsonBodyBytesLimit, true) // Allow empty body!
+	req, err := readJSONWithLimit[apiAccountTokenIssueRequest](r.Body, httpBodyBytesLimit, true) // Allow empty body!
 	if err != nil {
 		return err
 	}
@@ -246,7 +246,7 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request
 
 func (s *Server) handleAccountTokenUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	u := v.User()
-	req, err := readJSONWithLimit[apiAccountTokenUpdateRequest](r.Body, jsonBodyBytesLimit, true) // Allow empty body!
+	req, err := readJSONWithLimit[apiAccountTokenUpdateRequest](r.Body, httpBodyBytesLimit, true) // Allow empty body!
 	if err != nil {
 		return err
 	} else if req.Token == "" {
@@ -302,7 +302,7 @@ func (s *Server) handleAccountTokenDelete(w http.ResponseWriter, r *http.Request
 }
 
 func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	newPrefs, err := readJSONWithLimit[user.Prefs](r.Body, jsonBodyBytesLimit, false)
+	newPrefs, err := readJSONWithLimit[user.Prefs](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -336,7 +336,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
 }
 
 func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	newSubscription, err := readJSONWithLimit[user.Subscription](r.Body, jsonBodyBytesLimit, false)
+	newSubscription, err := readJSONWithLimit[user.Subscription](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -359,7 +359,7 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
 }
 
 func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	updatedSubscription, err := readJSONWithLimit[user.Subscription](r.Body, jsonBodyBytesLimit, false)
+	updatedSubscription, err := readJSONWithLimit[user.Subscription](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -417,7 +417,7 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
 // it is already reserved by someone else.
 func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	u := v.User()
-	req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -532,7 +532,7 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(r *http.Request, v *vi
 
 func (s *Server) handleAccountPhoneNumberVerify(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	u := v.User()
-	req, err := readJSONWithLimit[apiAccountPhoneNumberVerifyRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountPhoneNumberVerifyRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	} else if !phoneNumberRegex.MatchString(req.Number) {
@@ -563,7 +563,7 @@ func (s *Server) handleAccountPhoneNumberVerify(w http.ResponseWriter, r *http.R
 
 func (s *Server) handleAccountPhoneNumberAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	u := v.User()
-	req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -582,7 +582,7 @@ func (s *Server) handleAccountPhoneNumberAdd(w http.ResponseWriter, r *http.Requ
 
 func (s *Server) handleAccountPhoneNumberDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	u := v.User()
-	req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountPhoneNumberAddRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}

+ 9 - 8
server/server_admin.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"errors"
 	"heckel.io/ntfy/v2/user"
 	"net/http"
 )
@@ -38,14 +39,14 @@ func (s *Server) handleUsersGet(w http.ResponseWriter, r *http.Request, v *visit
 }
 
 func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiUserAddRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiUserAddRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	} else if !user.AllowedUsername(req.Username) || req.Password == "" {
 		return errHTTPBadRequest.Wrap("username invalid, or password missing")
 	}
 	u, err := s.userManager.User(req.Username)
-	if err != nil && err != user.ErrUserNotFound {
+	if err != nil && !errors.Is(err, user.ErrUserNotFound) {
 		return err
 	} else if u != nil {
 		return errHTTPConflictUserExists
@@ -53,7 +54,7 @@ func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visit
 	var tier *user.Tier
 	if req.Tier != "" {
 		tier, err = s.userManager.Tier(req.Tier)
-		if err == user.ErrTierNotFound {
+		if errors.Is(err, user.ErrTierNotFound) {
 			return errHTTPBadRequestTierInvalid
 		} else if err != nil {
 			return err
@@ -71,12 +72,12 @@ func (s *Server) handleUsersAdd(w http.ResponseWriter, r *http.Request, v *visit
 }
 
 func (s *Server) handleUsersDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiUserDeleteRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiUserDeleteRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
 	u, err := s.userManager.User(req.Username)
-	if err == user.ErrUserNotFound {
+	if errors.Is(err, user.ErrUserNotFound) {
 		return errHTTPBadRequestUserNotFound
 	} else if err != nil {
 		return err
@@ -93,12 +94,12 @@ func (s *Server) handleUsersDelete(w http.ResponseWriter, r *http.Request, v *vi
 }
 
 func (s *Server) handleAccessAllow(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiAccessAllowRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccessAllowRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
 	_, err = s.userManager.User(req.Username)
-	if err == user.ErrUserNotFound {
+	if errors.Is(err, user.ErrUserNotFound) {
 		return errHTTPBadRequestUserNotFound
 	} else if err != nil {
 		return err
@@ -114,7 +115,7 @@ func (s *Server) handleAccessAllow(w http.ResponseWriter, r *http.Request, v *vi
 }
 
 func (s *Server) handleAccessReset(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiAccessResetRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccessResetRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}

+ 3 - 3
server/server_payments.go

@@ -115,7 +115,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
 	if u.Billing.StripeSubscriptionID != "" {
 		return errHTTPBadRequestBillingSubscriptionExists
 	}
-	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -245,7 +245,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
 	if u.Billing.StripeSubscriptionID == "" {
 		return errNoBillingSubscription
 	}
-	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil {
 		return err
 	}
@@ -342,7 +342,7 @@ func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Requ
 	if stripeSignature == "" {
 		return errHTTPBadRequestBillingRequestInvalid
 	}
-	body, err := util.Peek(r.Body, jsonBodyBytesLimit)
+	body, err := util.Peek(r.Body, httpBodyBytesLimit)
 	if err != nil {
 		return err
 	} else if body.LimitReached {

File diff suppressed because it is too large
+ 10 - 7
server/server_test.go


+ 2 - 2
server/server_webpush.go

@@ -38,7 +38,7 @@ func init() {
 }
 
 func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil || req.Endpoint == "" || req.P256dh == "" || req.Auth == "" {
 		return errHTTPBadRequestWebPushSubscriptionInvalid
 	} else if !webPushAllowedEndpointsRegex.MatchString(req.Endpoint) {
@@ -66,7 +66,7 @@ func (s *Server) handleWebPushUpdate(w http.ResponseWriter, r *http.Request, v *
 }
 
 func (s *Server) handleWebPushDelete(w http.ResponseWriter, r *http.Request, _ *visitor) error {
-	req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, jsonBodyBytesLimit, false)
+	req, err := readJSONWithLimit[apiWebPushUpdateSubscriptionRequest](r.Body, httpBodyBytesLimit, false)
 	if err != nil || req.Endpoint == "" {
 		return errHTTPBadRequestWebPushSubscriptionInvalid
 	}

+ 3 - 2
server/util.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"heckel.io/ntfy/v2/util"
 	"io"
@@ -104,9 +105,9 @@ func extractIPAddress(r *http.Request, behindProxy bool) netip.Addr {
 
 func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, error) {
 	obj, err := util.UnmarshalJSONWithLimit[T](r, limit, allowEmpty)
-	if err == util.ErrUnmarshalJSON {
+	if errors.Is(err, util.ErrUnmarshalJSON) {
 		return nil, errHTTPBadRequestJSONInvalid
-	} else if err == util.ErrTooLargeJSON {
+	} else if errors.Is(err, util.ErrTooLargeJSON) {
 		return nil, errHTTPEntityTooLargeJSONBody
 	} else if err != nil {
 		return nil, err

+ 1 - 1
test/server.go

@@ -16,7 +16,7 @@ func StartServer(t *testing.T) (*server.Server, int) {
 
 // StartServerWithConfig starts a server.Server with a random port and waits for the server to be up
 func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, int) {
-	port := 10000 + rand.Intn(20000)
+	port := 10000 + rand.Intn(30000)
 	conf.ListenHTTP = fmt.Sprintf(":%d", port)
 	conf.AttachmentCacheDir = t.TempDir()
 	conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")

+ 3 - 2
util/peek.go

@@ -2,6 +2,7 @@ package util
 
 import (
 	"bytes"
+	"errors"
 	"io"
 	"strings"
 )
@@ -26,7 +27,7 @@ func Peek(underlying io.ReadCloser, limit int) (*PeekedReadCloser, error) {
 	}
 	peeked := make([]byte, limit)
 	read, err := io.ReadFull(underlying, peeked)
-	if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
+	if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) && err != io.EOF {
 		return nil, err
 	}
 	return &PeekedReadCloser{
@@ -44,7 +45,7 @@ func (r *PeekedReadCloser) Read(p []byte) (n int, err error) {
 		return 0, io.EOF
 	}
 	n, err = r.peeked.Read(p)
-	if err == io.EOF {
+	if errors.Is(err, io.EOF) {
 		return r.underlying.Read(p)
 	} else if err != nil {
 		return 0, err

Some files were not shown because too many files changed in this diff