Jelajahi Sumber

Firebase quota limit

Philipp Heckel 3 tahun lalu
induk
melakukan
8283b6be97

+ 12 - 9
server/config.go

@@ -6,15 +6,16 @@ import (
 
 // Defines default config settings (excluding limits, see below)
 const (
-	DefaultListenHTTP                = ":80"
-	DefaultCacheDuration             = 12 * time.Hour
-	DefaultKeepaliveInterval         = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!)
-	DefaultManagerInterval           = time.Minute
-	DefaultAtSenderInterval          = 10 * time.Second
-	DefaultMinDelay                  = 10 * time.Second
-	DefaultMaxDelay                  = 3 * 24 * time.Hour
-	DefaultFirebaseKeepaliveInterval = 3 * time.Hour    // ~control topic (Android), not too frequently to save battery
-	DefaultFirebasePollInterval      = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
+	DefaultListenHTTP                        = ":80"
+	DefaultCacheDuration                     = 12 * time.Hour
+	DefaultKeepaliveInterval                 = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!)
+	DefaultManagerInterval                   = time.Minute
+	DefaultAtSenderInterval                  = 10 * time.Second
+	DefaultMinDelay                          = 10 * time.Second
+	DefaultMaxDelay                          = 3 * 24 * time.Hour
+	DefaultFirebaseKeepaliveInterval         = 3 * time.Hour    // ~control topic (Android), not too frequently to save battery
+	DefaultFirebasePollInterval              = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
+	DefaultFirebaseQuotaLimitPenaltyDuration = 10 * time.Minute
 )
 
 // Defines all global and per-visitor limits
@@ -69,6 +70,7 @@ type Config struct {
 	AtSenderInterval                     time.Duration
 	FirebaseKeepaliveInterval            time.Duration
 	FirebasePollInterval                 time.Duration
+	FirebaseQuotaLimitPenaltyDuration    time.Duration
 	UpstreamBaseURL                      string
 	SMTPSenderAddr                       string
 	SMTPSenderUser                       string
@@ -121,6 +123,7 @@ func NewConfig() *Config {
 		AtSenderInterval:                     DefaultAtSenderInterval,
 		FirebaseKeepaliveInterval:            DefaultFirebaseKeepaliveInterval,
 		FirebasePollInterval:                 DefaultFirebasePollInterval,
+		FirebaseQuotaLimitPenaltyDuration:    DefaultFirebaseQuotaLimitPenaltyDuration,
 		TotalTopicLimit:                      DefaultTotalTopicLimit,
 		VisitorSubscriptionLimit:             DefaultVisitorSubscriptionLimit,
 		VisitorAttachmentTotalSizeLimit:      DefaultVisitorAttachmentTotalSizeLimit,

+ 1 - 0
server/errors.go

@@ -59,6 +59,7 @@ var (
 	errHTTPTooManyRequestsLimitSubscriptions         = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPTooManyRequestsLimitTotalTopics           = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPTooManyRequestsAttachmentBandwidthLimit   = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsFirebaseQuotaReached       = &errHTTP{42906, http.StatusTooManyRequests, "too many requests: Firebase quota for topic reached", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPInternalError                             = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
 	errHTTPInternalErrorInvalidFilePath              = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
 )

+ 46 - 51
server/server.go

@@ -7,13 +7,11 @@ import (
 	"embed"
 	"encoding/base64"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"io"
 	"log"
 	"net"
 	"net/http"
-	"net/http/httptest"
 	"net/url"
 	"os"
 	"path"
@@ -221,7 +219,7 @@ func (s *Server) Run() error {
 	}
 	s.mu.Unlock()
 	go s.runManager()
-	go s.runAtSender()
+	go s.runDelayedSender()
 	go s.runFirebaseKeepaliver()
 
 	return <-errChan
@@ -435,7 +433,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 	}
 	delayed := m.Time > time.Now().Unix()
 	if !delayed {
-		if err := t.Publish(m); err != nil {
+		if err := t.Publish(v, m); err != nil {
 			return err
 		}
 	}
@@ -465,7 +463,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 }
 
 func (s *Server) sendToFirebase(v *visitor, m *message) {
-	if err := s.firebase(m); err != nil {
+	if err := s.firebase(v, m); err != nil {
 		log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
 	}
 }
@@ -731,7 +729,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 		return err
 	}
 	var wlock sync.Mutex
-	sub := func(msg *message) error {
+	sub := func(v *visitor, msg *message) error {
 		if !filters.Pass(msg) {
 			return nil
 		}
@@ -752,7 +750,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 	w.Header().Set("Access-Control-Allow-Origin", "*")            // CORS, allow cross-origin requests
 	w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
 	if poll {
-		return s.sendOldMessages(topics, since, scheduled, sub)
+		return s.sendOldMessages(topics, since, scheduled, v, sub)
 	}
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
@@ -763,10 +761,10 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 			topics[i].Unsubscribe(subscriberID) // Order!
 		}
 	}()
-	if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
+	if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
 		return err
 	}
-	if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
+	if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
 		return err
 	}
 	for {
@@ -775,7 +773,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 			return nil
 		case <-time.After(s.config.KeepaliveInterval):
 			v.Keepalive()
-			if err := sub(newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
+			if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
 				return err
 			}
 		}
@@ -849,7 +847,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 			}
 		}
 	})
-	sub := func(msg *message) error {
+	sub := func(v *visitor, msg *message) error {
 		if !filters.Pass(msg) {
 			return nil
 		}
@@ -862,7 +860,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	}
 	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
 	if poll {
-		return s.sendOldMessages(topics, since, scheduled, sub)
+		return s.sendOldMessages(topics, since, scheduled, v, sub)
 	}
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
@@ -873,10 +871,10 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 			topics[i].Unsubscribe(subscriberID) // Order!
 		}
 	}()
-	if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
+	if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
 		return err
 	}
-	if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
+	if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
 		return err
 	}
 	err = g.Wait()
@@ -900,7 +898,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
 	return
 }
 
-func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, sub subscriber) error {
+func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
 	if since.IsNone() {
 		return nil
 	}
@@ -910,7 +908,7 @@ func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled b
 			return err
 		}
 		for _, m := range messages {
-			if err := sub(m); err != nil {
+			if err := sub(v, m); err != nil {
 				return err
 			}
 		}
@@ -1057,23 +1055,7 @@ func (s *Server) updateStatsAndPrune() {
 }
 
 func (s *Server) runSMTPServer() error {
-	sub := func(m *message) error {
-		url := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
-		req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
-		if err != nil {
-			return err
-		}
-		if m.Title != "" {
-			req.Header.Set("Title", m.Title)
-		}
-		rr := httptest.NewRecorder()
-		s.handle(rr, req)
-		if rr.Code != http.StatusOK {
-			return errors.New("error: " + rr.Body.String())
-		}
-		return nil
-	}
-	s.smtpBackend = newMailBackend(s.config, sub)
+	s.smtpBackend = newMailBackend(s.config, s.handle)
 	s.smtpServer = smtp.NewServer(s.smtpBackend)
 	s.smtpServer.Addr = s.config.SMTPServerListen
 	s.smtpServer.Domain = s.config.SMTPServerDomain
@@ -1096,7 +1078,7 @@ func (s *Server) runManager() {
 	}
 }
 
-func (s *Server) runAtSender() {
+func (s *Server) runDelayedSender() {
 	for {
 		select {
 		case <-time.After(s.config.AtSenderInterval):
@@ -1113,14 +1095,15 @@ func (s *Server) runFirebaseKeepaliver() {
 	if s.firebase == nil {
 		return
 	}
+	v := newVisitor(s.config, s.messageCache, "0.0.0.0")
 	for {
 		select {
 		case <-time.After(s.config.FirebaseKeepaliveInterval):
-			if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
+			if err := s.firebase(v, newKeepaliveMessage(firebaseControlTopic)); err != nil {
 				log.Printf("error sending Firebase keepalive message to %s: %s", firebaseControlTopic, err.Error())
 			}
 		case <-time.After(s.config.FirebasePollInterval):
-			if err := s.firebase(newKeepaliveMessage(firebasePollTopic)); err != nil {
+			if err := s.firebase(v, newKeepaliveMessage(firebasePollTopic)); err != nil {
 				log.Printf("error sending Firebase keepalive message to %s: %s", firebasePollTopic, err.Error())
 			}
 		case <-s.closeChan:
@@ -1130,28 +1113,36 @@ func (s *Server) runFirebaseKeepaliver() {
 }
 
 func (s *Server) sendDelayedMessages() error {
-	s.mu.Lock()
-	defer s.mu.Unlock()
 	messages, err := s.messageCache.MessagesDue()
 	if err != nil {
 		return err
 	}
 	for _, m := range messages {
-		t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
-		if ok {
-			if err := t.Publish(m); err != nil {
-				log.Printf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
-			}
+		v := s.visitorFromIP("0.0.0.0") // FIXME: get message owner!!
+		if err := s.sendDelayedMessage(v, m); err != nil {
+			log.Printf("error sending delayed message: %s", err.Error())
 		}
-		if s.firebase != nil { // Firebase subscribers may not show up in topics map
-			if err := s.firebase(m); err != nil {
-				log.Printf("unable to publish to Firebase: %v", err.Error())
-			}
+	}
+	return nil
+}
+
+func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
+	if ok {
+		if err := t.Publish(v, m); err != nil {
+			return fmt.Errorf("unable to publish message %s to topic %s: %v", m.ID, m.Topic, err.Error())
 		}
-		if err := s.messageCache.MarkPublished(m); err != nil {
-			return err
+	}
+	if s.firebase != nil { // Firebase subscribers may not show up in topics map
+		if err := s.firebase(v, m); err != nil {
+			return fmt.Errorf("unable to publish to Firebase: %v", err.Error())
 		}
 	}
+	if err := s.messageCache.MarkPublished(m); err != nil {
+		return err
+	}
 	return nil
 }
 
@@ -1290,8 +1281,6 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
 // visitor creates or retrieves a rate.Limiter for the given visitor.
 // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
 func (s *Server) visitor(r *http.Request) *visitor {
-	s.mu.Lock()
-	defer s.mu.Unlock()
 	remoteAddr := r.RemoteAddr
 	ip, _, err := net.SplitHostPort(remoteAddr)
 	if err != nil {
@@ -1300,6 +1289,12 @@ func (s *Server) visitor(r *http.Request) *visitor {
 	if s.config.BehindProxy && r.Header.Get("X-Forwarded-For") != "" {
 		ip = r.Header.Get("X-Forwarded-For")
 	}
+	return s.visitorFromIP(ip)
+}
+
+func (s *Server) visitorFromIP(ip string) *visitor {
+	s.mu.Lock()
+	defer s.mu.Unlock()
 	v, exists := s.visitors[ip]
 	if !exists {
 		s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)

+ 10 - 1
server/server_firebase.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"log"
 	"strings"
 
 	firebase "firebase.google.com/go/v4"
@@ -26,12 +27,20 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc
 	if err != nil {
 		return nil, err
 	}
-	return func(m *message) error {
+	return func(v *visitor, m *message) error {
+		if err := v.FirebaseAllowed(); err != nil {
+			return errHTTPTooManyRequestsFirebaseQuotaReached
+		}
 		fbm, err := toFirebaseMessage(m, auther)
 		if err != nil {
 			return err
 		}
 		_, err = msg.Send(context.Background(), fbm)
+		if err != nil && messaging.IsQuotaExceeded(err) {
+			log.Printf("[%s] FB quota exceeded when trying to publish to topic %s, temporarily denying FB access", v.ip, m.Topic)
+			v.FirebaseTemporarilyDeny()
+			return errHTTPTooManyRequestsFirebaseQuotaReached
+		}
 		return err
 	}, nil
 }

+ 2 - 1
server/server_test.go

@@ -469,7 +469,8 @@ func TestServer_PublishFirebase(t *testing.T) {
 	require.NotEmpty(t, msg.ID)
 
 	// Keepalive message
-	require.Nil(t, s.firebase(newKeepaliveMessage(firebaseControlTopic)))
+	v := newVisitor(s.config, s.messageCache, "1.2.3.4")
+	require.Nil(t, s.firebase(v, newKeepaliveMessage(firebaseControlTopic)))
 
 	time.Sleep(500 * time.Millisecond) // Time for sends
 }

+ 32 - 10
server/smtp_server.go

@@ -3,10 +3,13 @@ package server
 import (
 	"bytes"
 	"errors"
+	"fmt"
 	"github.com/emersion/go-smtp"
 	"io"
 	"mime"
 	"mime/multipart"
+	"net/http"
+	"net/http/httptest"
 	"net/mail"
 	"strings"
 	"sync"
@@ -23,25 +26,25 @@ var (
 // smtpBackend implements SMTP server methods.
 type smtpBackend struct {
 	config  *Config
-	sub     subscriber
+	handler func(http.ResponseWriter, *http.Request)
 	success int64
 	failure int64
 	mu      sync.Mutex
 }
 
-func newMailBackend(conf *Config, sub subscriber) *smtpBackend {
+func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
 	return &smtpBackend{
-		config: conf,
-		sub:    sub,
+		config:  conf,
+		handler: handler,
 	}
 }
 
 func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
-	return &smtpSession{backend: b}, nil
+	return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
 }
 
 func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
-	return &smtpSession{backend: b}, nil
+	return &smtpSession{backend: b, remoteAddr: state.RemoteAddr.String()}, nil
 }
 
 func (b *smtpBackend) Counts() (success int64, failure int64) {
@@ -52,9 +55,10 @@ func (b *smtpBackend) Counts() (success int64, failure int64) {
 
 // smtpSession is returned after EHLO.
 type smtpSession struct {
-	backend *smtpBackend
-	topic   string
-	mu      sync.Mutex
+	backend    *smtpBackend
+	remoteAddr string
+	topic      string
+	mu         sync.Mutex
 }
 
 func (s *smtpSession) AuthPlain(username, password string) error {
@@ -128,7 +132,7 @@ func (s *smtpSession) Data(r io.Reader) error {
 			m.Message = m.Title // Flip them, this makes more sense
 			m.Title = ""
 		}
-		if err := s.backend.sub(m); err != nil {
+		if err := s.publishMessage(m); err != nil {
 			return err
 		}
 		s.backend.mu.Lock()
@@ -138,6 +142,24 @@ func (s *smtpSession) Data(r io.Reader) error {
 	})
 }
 
+func (s *smtpSession) publishMessage(m *message) error {
+	url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
+	req, err := http.NewRequest("PUT", url, strings.NewReader(m.Message))
+	req.RemoteAddr = s.remoteAddr // rate limiting!!
+	if err != nil {
+		return err
+	}
+	if m.Title != "" {
+		req.Header.Set("Title", m.Title)
+	}
+	rr := httptest.NewRecorder()
+	s.backend.handler(rr, req)
+	if rr.Code != http.StatusOK {
+		return errors.New("error: " + rr.Body.String())
+	}
+	return nil
+}
+
 func (s *smtpSession) Reset() {
 	s.mu.Lock()
 	s.topic = ""

+ 57 - 40
server/smtp_server_test.go

@@ -3,6 +3,9 @@ package server
 import (
 	"github.com/emersion/go-smtp"
 	"github.com/stretchr/testify/require"
+	"io"
+	"net"
+	"net/http"
 	"strings"
 	"testing"
 )
@@ -27,13 +30,12 @@ Content-Type: text/html; charset="UTF-8"
 <div dir="ltr">what&#39;s up<br clear="all"><div><br></div></div>
 
 --000000000000f3320b05d42915c9--`
-	_, backend := newTestBackend(t, func(m *message) error {
-		require.Equal(t, "mytopic", m.Topic)
-		require.Equal(t, "and one more", m.Title)
-		require.Equal(t, "what's up", m.Message)
-		return nil
+	_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
+		require.Equal(t, "/mytopic", r.URL.Path)
+		require.Equal(t, "and one more", r.Header.Get("Title"))
+		require.Equal(t, "what's up", readAll(t, r.Body))
 	})
-	session, _ := backend.AnonymousLogin(nil)
+	session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
 	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
 	require.Nil(t, session.Data(strings.NewReader(email)))
@@ -59,13 +61,12 @@ Content-Type: text/html; charset="UTF-8"
 <div dir="ltr"><br></div>
 
 --000000000000bcf4a405d429f8d4--`
-	_, backend := newTestBackend(t, func(m *message) error {
-		require.Equal(t, "emailtest", m.Topic)
-		require.Equal(t, "", m.Title) // We flipped message and body
-		require.Equal(t, "This email has a subject but no body", m.Message)
-		return nil
+	_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
+		require.Equal(t, "/emailtest", r.URL.Path)
+		require.Equal(t, "", r.Header.Get("Title")) // We flipped message and body
+		require.Equal(t, "This email has a subject but no body", readAll(t, r.Body))
 	})
-	session, _ := backend.AnonymousLogin(nil)
+	session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
 	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("ntfy-emailtest@ntfy.sh"))
 	require.Nil(t, session.Data(strings.NewReader(email)))
@@ -81,14 +82,13 @@ Content-Type: text/plain; charset="UTF-8"
 
 what's up
 `
-	conf, backend := newTestBackend(t, func(m *message) error {
-		require.Equal(t, "mytopic", m.Topic)
-		require.Equal(t, "and one more", m.Title)
-		require.Equal(t, "what's up", m.Message)
-		return nil
+	conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
+		require.Equal(t, "/mytopic", r.URL.Path)
+		require.Equal(t, "and one more", r.Header.Get("Title"))
+		require.Equal(t, "what's up", readAll(t, r.Body))
 	})
 	conf.SMTPServerAddrPrefix = ""
-	session, _ := backend.AnonymousLogin(nil)
+	session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
 	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
 	require.Nil(t, session.Data(strings.NewReader(email)))
@@ -99,14 +99,13 @@ func TestSmtpBackend_Plaintext_No_ContentType(t *testing.T) {
 
 what's up
 `
-	conf, backend := newTestBackend(t, func(m *message) error {
-		require.Equal(t, "mytopic", m.Topic)
-		require.Equal(t, "Very short mail", m.Title)
-		require.Equal(t, "what's up", m.Message)
-		return nil
+	conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
+		require.Equal(t, "/mytopic", r.URL.Path)
+		require.Equal(t, "Very short mail", r.Header.Get("Title"))
+		require.Equal(t, "what's up", readAll(t, r.Body))
 	})
 	conf.SMTPServerAddrPrefix = ""
-	session, _ := backend.AnonymousLogin(nil)
+	session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
 	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
 	require.Nil(t, session.Data(strings.NewReader(email)))
@@ -121,11 +120,10 @@ Content-Type: text/plain; charset="UTF-8"
 
 what's up
 `
-	_, backend := newTestBackend(t, func(m *message) error {
-		require.Equal(t, "Three santas 🎅🎅🎅", m.Title)
-		return nil
+	_, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
+		require.Equal(t, "Three santas 🎅🎅🎅", r.Header.Get("Title"))
 	})
-	session, _ := backend.AnonymousLogin(nil)
+	session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
 	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
 	require.Nil(t, session.Data(strings.NewReader(email)))
@@ -140,7 +138,7 @@ To: mytopic@ntfy.sh
 Content-Type: text/plain; charset="UTF-8"
 
 you know this is a string.
-it's a long string. 
+it's a long string.
 it's supposed to be longer than the max message length
 which is 4096 bytes,
 it used to be 512 bytes, but I increased that for the UnifiedPush support
@@ -204,9 +202,9 @@ BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
 BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
 that should do it
 `
-	conf, backend := newTestBackend(t, func(m *message) error {
+	conf, backend := newTestBackend(t, func(w http.ResponseWriter, r *http.Request) {
 		expected := `you know this is a string.
-it's a long string. 
+it's a long string.
 it's supposed to be longer than the max message length
 which is 4096 bytes,
 it used to be 512 bytes, but I increased that for the UnifiedPush support
@@ -266,13 +264,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
 ......................................................................
 ......................................................................
 and with BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
-BBBBBBBBBBBBBBBBBBBBBBBB`
+BBBBBBBBBBBBBBBBBBBBBBBBB`
 		require.Equal(t, 4096, len(expected)) // Sanity check
-		require.Equal(t, expected, m.Message)
-		return nil
+		require.Equal(t, expected, readAll(t, r.Body))
 	})
 	conf.SMTPServerAddrPrefix = ""
-	session, _ := backend.AnonymousLogin(nil)
+	session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
 	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
 	require.Nil(t, session.Data(strings.NewReader(email)))
@@ -288,21 +285,41 @@ Content-Type: text/SOMETHINGELSE
 
 what's up
 `
-	conf, backend := newTestBackend(t, func(m *message) error {
-		return nil
+	conf, backend := newTestBackend(t, func(http.ResponseWriter, *http.Request) {
+		// Nothing.
 	})
 	conf.SMTPServerAddrPrefix = ""
-	session, _ := backend.Login(nil, "user", "pass")
+	session, _ := backend.Login(fakeConnState(t, "1.2.3.4"), "user", "pass")
 	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("mytopic@ntfy.sh"))
 	require.Equal(t, errUnsupportedContentType, session.Data(strings.NewReader(email)))
 }
 
-func newTestBackend(t *testing.T, sub subscriber) (*Config, *smtpBackend) {
+func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Request)) (*Config, *smtpBackend) {
 	conf := newTestConfig(t)
 	conf.SMTPServerListen = ":25"
 	conf.SMTPServerDomain = "ntfy.sh"
 	conf.SMTPServerAddrPrefix = "ntfy-"
-	backend := newMailBackend(conf, sub)
+	backend := newMailBackend(conf, handler)
 	return conf, backend
 }
+
+func readAll(t *testing.T, rc io.ReadCloser) string {
+	b, err := io.ReadAll(rc)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return string(b)
+}
+
+func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
+	ip, err := net.ResolveIPAddr("ip", remoteAddr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return &smtp.ConnectionState{
+		Hostname:   "myhostname",
+		LocalAddr:  ip,
+		RemoteAddr: ip,
+	}
+}

+ 3 - 3
server/topic.go

@@ -15,7 +15,7 @@ type topic struct {
 }
 
 // subscriber is a function that is called for every new message on a topic
-type subscriber func(msg *message) error
+type subscriber func(v *visitor, msg *message) error
 
 // newTopic creates a new topic
 func newTopic(id string) *topic {
@@ -42,12 +42,12 @@ func (t *topic) Unsubscribe(id int) {
 }
 
 // Publish asynchronously publishes to all subscribers
-func (t *topic) Publish(m *message) error {
+func (t *topic) Publish(v *visitor, m *message) error {
 	go func() {
 		t.mu.Lock()
 		defer t.mu.Unlock()
 		for _, s := range t.subscribers {
-			if err := s(m); err != nil {
+			if err := s(v, m); err != nil {
 				log.Printf("error publishing message to subscriber")
 			}
 		}

+ 17 - 4
server/visitor.go

@@ -28,6 +28,7 @@ type visitor struct {
 	emails        *rate.Limiter
 	subscriptions util.Limiter
 	bandwidth     util.Limiter
+	firebase      time.Time // Next allowed Firebase message
 	seen          time.Time
 	mu            sync.Mutex
 }
@@ -48,14 +49,11 @@ func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
 		emails:        rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
 		subscriptions: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
 		bandwidth:     util.NewBytesLimiter(conf.VisitorAttachmentDailyBandwidthLimit, 24*time.Hour),
+		firebase:      time.Unix(0, 0),
 		seen:          time.Now(),
 	}
 }
 
-func (v *visitor) IP() string {
-	return v.ip
-}
-
 func (v *visitor) RequestAllowed() error {
 	if !v.requests.Allow() {
 		return errVisitorLimitReached
@@ -63,6 +61,21 @@ func (v *visitor) RequestAllowed() error {
 	return nil
 }
 
+func (v *visitor) FirebaseAllowed() error {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	if time.Now().Before(v.firebase) {
+		return errVisitorLimitReached
+	}
+	return nil
+}
+
+func (v *visitor) FirebaseTemporarilyDeny() {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	v.firebase = time.Now().Add(v.config.FirebaseQuotaLimitPenaltyDuration)
+}
+
 func (v *visitor) EmailAllowed() error {
 	if !v.emails.Allow() {
 		return errVisitorLimitReached