|
|
@@ -9,12 +9,6 @@ import (
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
- "github.com/emersion/go-smtp"
|
|
|
- "github.com/gorilla/websocket"
|
|
|
- "golang.org/x/sync/errgroup"
|
|
|
- "heckel.io/ntfy/log"
|
|
|
- "heckel.io/ntfy/user"
|
|
|
- "heckel.io/ntfy/util"
|
|
|
"io"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
@@ -30,6 +24,13 @@ import (
|
|
|
"sync"
|
|
|
"time"
|
|
|
"unicode/utf8"
|
|
|
+
|
|
|
+ "github.com/emersion/go-smtp"
|
|
|
+ "github.com/gorilla/websocket"
|
|
|
+ "golang.org/x/sync/errgroup"
|
|
|
+ "heckel.io/ntfy/log"
|
|
|
+ "heckel.io/ntfy/user"
|
|
|
+ "heckel.io/ntfy/util"
|
|
|
)
|
|
|
|
|
|
/*
|
|
|
@@ -605,23 +606,23 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
- v_old := v
|
|
|
- if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) {
|
|
|
- v = t.getBillee()
|
|
|
- if v == nil {
|
|
|
- return nil, errHTTPWontStoreMessage
|
|
|
- }
|
|
|
+ vRate := v
|
|
|
+ if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
|
|
|
+ vRate = topicCountsAgainst
|
|
|
}
|
|
|
|
|
|
- if !v.MessageAllowed() {
|
|
|
- return nil, errHTTPTooManyRequestsLimitMessages
|
|
|
+ if !vRate.MessageAllowed() {
|
|
|
+ vRate = v
|
|
|
+ if !v.MessageAllowed() {
|
|
|
+ return nil, errHTTPTooManyRequestsLimitMessages
|
|
|
+ }
|
|
|
}
|
|
|
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
m := newDefaultMessage(t.ID, "")
|
|
|
- cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
|
|
|
+ cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -630,7 +631,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|
|
}
|
|
|
m.Sender = v.IP()
|
|
|
m.User = v.MaybeUserID()
|
|
|
- m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
|
|
|
+ m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix()
|
|
|
if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -638,18 +639,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|
|
m.Message = emptyMessageBody
|
|
|
}
|
|
|
delayed := m.Time > time.Now().Unix()
|
|
|
- logvrm(v, r, m).
|
|
|
+ logvrm(vRate, r, m).
|
|
|
Tag(tagPublish).
|
|
|
Fields(log.Context{
|
|
|
- "message_delayed": delayed,
|
|
|
- "message_firebase": firebase,
|
|
|
- "message_unifiedpush": unifiedpush,
|
|
|
- "message_email": email,
|
|
|
+ "message_delayed": delayed,
|
|
|
+ "message_firebase": firebase,
|
|
|
+ "message_unifiedpush": unifiedpush,
|
|
|
+ "message_email": email,
|
|
|
+ "message_subscriber_rate_limited": vRate != v,
|
|
|
}).
|
|
|
Debug("Received message")
|
|
|
- //Where should I log the original visitor vs the billing visitor
|
|
|
if log.IsTrace() {
|
|
|
- logvrm(v_old, r, m).
|
|
|
+ logvrm(vRate, r, m).
|
|
|
Tag(tagPublish).
|
|
|
Field("message_body", util.MaybeMarshalJSON(m)).
|
|
|
Trace("Message body")
|
|
|
@@ -659,7 +660,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|
|
return nil, err
|
|
|
}
|
|
|
if s.firebaseClient != nil && firebase {
|
|
|
- go s.sendToFirebase(v, m)
|
|
|
+ go s.sendToFirebase(vRate, m)
|
|
|
}
|
|
|
if s.smtpSender != nil && email != "" {
|
|
|
go s.sendEmail(v, m, email)
|
|
|
@@ -745,7 +746,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
|
|
+func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
|
|
cache = readBoolParam(r, true, "x-cache", "cache")
|
|
|
firebase = readBoolParam(r, true, "x-firebase", "firebase")
|
|
|
m.Title = readParam(r, "x-title", "title", "t")
|
|
|
@@ -785,7 +786,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
|
|
|
}
|
|
|
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
|
|
if email != "" {
|
|
|
- if !v.EmailAllowed() {
|
|
|
+ if !vRate.EmailAllowed() {
|
|
|
return false, false, "", false, errHTTPTooManyRequestsLimitEmails
|
|
|
}
|
|
|
}
|
|
|
@@ -800,13 +801,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
|
|
|
if err != nil {
|
|
|
return false, false, "", false, errHTTPBadRequestPriorityInvalid
|
|
|
}
|
|
|
- tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
|
|
|
- if tagsStr != "" {
|
|
|
- m.Tags = make([]string, 0)
|
|
|
- for _, s := range util.SplitNoEmpty(tagsStr, ",") {
|
|
|
- m.Tags = append(m.Tags, strings.TrimSpace(s))
|
|
|
- }
|
|
|
- }
|
|
|
+ m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta")
|
|
|
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
|
|
if delayStr != "" {
|
|
|
if !cache {
|
|
|
@@ -996,7 +991,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- poll, since, scheduled, filters, err := parseSubscribeParams(r)
|
|
|
+ poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -1035,7 +1030,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|
|
defer cancel()
|
|
|
subscriberIDs := make([]int, 0)
|
|
|
for _, t := range topics {
|
|
|
- subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
|
|
+ subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
|
|
|
+ subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
|
|
}
|
|
|
defer func() {
|
|
|
for i, subscriberID := range subscriberIDs {
|
|
|
@@ -1078,7 +1074,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- poll, since, scheduled, filters, err := parseSubscribeParams(r)
|
|
|
+ poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -1167,7 +1163,8 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|
|
}
|
|
|
subscriberIDs := make([]int, 0)
|
|
|
for _, t := range topics {
|
|
|
- subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
|
|
|
+ subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
|
|
|
+ subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
|
|
}
|
|
|
defer func() {
|
|
|
for i, subscriberID := range subscriberIDs {
|
|
|
@@ -1188,7 +1185,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
-func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
|
|
|
+func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) {
|
|
|
poll = readBoolParam(r, false, "x-poll", "poll", "po")
|
|
|
scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
|
|
|
since, err = parseSince(r, poll)
|
|
|
@@ -1199,6 +1196,8 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
|
|
if err != nil {
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
+ subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
|
|
|
return
|
|
|
}
|
|
|
|