binwiederhier пре 3 година
родитељ
комит
e1a4a74905

+ 2 - 0
cmd/access_test.go

@@ -79,7 +79,9 @@ user * (role: anonymous, tier: none)
 func runAccessCommand(app *cli.App, conf *server.Config, args ...string) error {
 	userArgs := []string{
 		"ntfy",
+		"--log-level=ERROR",
 		"access",
+		"--config=" + conf.File, // Dummy config file to avoid lookups of real file
 		"--auth-file=" + conf.AuthFile,
 		"--auth-default-access=" + conf.AuthDefault.String(),
 	}

+ 1 - 0
cmd/serve.go

@@ -253,6 +253,7 @@ func execServe(c *cli.Context) error {
 
 	// Run server
 	conf := server.NewConfig()
+	conf.File = config
 	conf.BaseURL = baseURL
 	conf.ListenHTTP = listenHTTP
 	conf.ListenHTTPS = listenHTTPS

+ 2 - 0
cmd/tier_test.go

@@ -38,7 +38,9 @@ func TestCLI_Tier_AddListChangeDelete(t *testing.T) {
 func runTierCommand(app *cli.App, conf *server.Config, args ...string) error {
 	userArgs := []string{
 		"ntfy",
+		"--log-level=ERROR",
 		"tier",
+		"--config=" + conf.File, // Dummy config file to avoid lookups of real file
 		"--auth-file=" + conf.AuthFile,
 		"--auth-default-access=" + conf.AuthDefault.String(),
 	}

+ 2 - 0
cmd/token_test.go

@@ -41,7 +41,9 @@ func TestCLI_Token_AddListRemove(t *testing.T) {
 func runTokenCommand(app *cli.App, conf *server.Config, args ...string) error {
 	userArgs := []string{
 		"ntfy",
+		"--log-level=ERROR",
 		"token",
+		"--config=" + conf.File, // Dummy config file to avoid lookups of real file
 		"--auth-file=" + conf.AuthFile,
 	}
 	return app.Run(append(userArgs, args...))

+ 6 - 0
cmd/user_test.go

@@ -6,6 +6,7 @@ import (
 	"heckel.io/ntfy/server"
 	"heckel.io/ntfy/test"
 	"heckel.io/ntfy/user"
+	"os"
 	"path/filepath"
 	"testing"
 )
@@ -113,7 +114,10 @@ func TestCLI_User_Delete(t *testing.T) {
 }
 
 func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config, port int) {
+	configFile := filepath.Join(t.TempDir(), "server-dummy.yml")
+	require.Nil(t, os.WriteFile(configFile, []byte(""), 0600)) // Dummy config file to avoid lookup of real server.yml
 	conf = server.NewConfig()
+	conf.File = configFile
 	conf.AuthFile = filepath.Join(t.TempDir(), "user.db")
 	conf.AuthDefault = user.PermissionDenyAll
 	s, port = test.StartServerWithConfig(t, conf)
@@ -123,7 +127,9 @@ func newTestServerWithAuth(t *testing.T) (s *server.Server, conf *server.Config,
 func runUserCommand(app *cli.App, conf *server.Config, args ...string) error {
 	userArgs := []string{
 		"ntfy",
+		"--log-level=ERROR",
 		"user",
+		"--config=" + conf.File, // Dummy config file to avoid lookups of real file
 		"--auth-file=" + conf.AuthFile,
 		"--auth-default-access=" + conf.AuthDefault.String(),
 	}

+ 4 - 2
log/event.go

@@ -82,8 +82,10 @@ func (e *Event) Time(t time.Time) *Event {
 
 // Err adds an "error" field to the log event
 func (e *Event) Err(err error) *Event {
-	if c, ok := err.(Contexter); ok {
-		return e.Fields(c.Context())
+	if err == nil {
+		return e
+	} else if c, ok := err.(Contexter); ok {
+		return e.With(c)
 	}
 	return e.Field(errorField, err.Error())
 }

+ 8 - 0
server/config.go

@@ -49,6 +49,8 @@ const (
 	DefaultVisitorEmailLimitReplenish           = time.Hour
 	DefaultVisitorAccountCreationLimitBurst     = 3
 	DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
+	DefaultVisitorAuthFailureLimitBurst         = 10
+	DefaultVisitorAuthFailureLimitReplenish     = time.Minute
 	DefaultVisitorAttachmentTotalSizeLimit      = 100 * 1024 * 1024 // 100 MB
 	DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
 )
@@ -60,6 +62,7 @@ var (
 
 // Config is the main config struct for the application. Use New to instantiate a default config struct.
 type Config struct {
+	File                                 string // Config file, only used for testing
 	BaseURL                              string
 	ListenHTTP                           string
 	ListenHTTPS                          string
@@ -113,6 +116,8 @@ type Config struct {
 	VisitorEmailLimitReplenish           time.Duration
 	VisitorAccountCreationLimitBurst     int
 	VisitorAccountCreationLimitReplenish time.Duration
+	VisitorAuthFailureLimitBurst         int
+	VisitorAuthFailureLimitReplenish     time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	BehindProxy                          bool
 	StripeSecretKey                      string
@@ -129,6 +134,7 @@ type Config struct {
 // NewConfig instantiates a default new server config
 func NewConfig() *Config {
 	return &Config{
+		File:                                 "", // Only used for testing
 		BaseURL:                              "",
 		ListenHTTP:                           DefaultListenHTTP,
 		ListenHTTPS:                          "",
@@ -182,6 +188,8 @@ func NewConfig() *Config {
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
 		VisitorAccountCreationLimitBurst:     DefaultVisitorAccountCreationLimitBurst,
 		VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
+		VisitorAuthFailureLimitBurst:         DefaultVisitorAuthFailureLimitBurst,
+		VisitorAuthFailureLimitReplenish:     DefaultVisitorAuthFailureLimitReplenish,
 		VisitorStatsResetTime:                DefaultVisitorStatsResetTime,
 		BehindProxy:                          false,
 		StripeSecretKey:                      "",

+ 1 - 0
server/errors.go

@@ -87,6 +87,7 @@ var (
 	errHTTPTooManyRequestsLimitAccountCreation       = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
 	errHTTPTooManyRequestsLimitReservations          = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
 	errHTTPTooManyRequestsLimitMessages              = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitAuthFailure           = &errHTTP{42909, http.StatusTooManyRequests, "limit reached: too many auth failures", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
 	errHTTPInternalError                             = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
 	errHTTPInternalErrorInvalidPath                  = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
 	errHTTPInternalErrorMissingBaseURL               = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}

+ 47 - 36
server/server.go

@@ -34,9 +34,9 @@ import (
 
 /*
 
-- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
-- HIGH Account limit creation triggers when account is taken!
 - HIGH Docs
+  - tiers
+  - api
 - HIGH Self-review
 - MEDIUM: Test for expiring messages after reservation removal
 - MEDIUM: Test new token endpoints & never-expiring token
@@ -1540,18 +1540,6 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
 	return nil
 }
 
-func (s *Server) limitRequests(next handleFunc) handleFunc {
-	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
-			return next(w, r, v)
-		} else if err := v.RequestAllowed(); err != nil {
-			logvr(v, r).Err(err).Trace("Request not allowed by rate limiter")
-			return errHTTPTooManyRequestsLimitRequests
-		}
-		return next(w, r, v)
-	}
-}
-
 // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
 // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
 func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
@@ -1648,43 +1636,65 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
 	}
 }
 
-// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
-// Note that this function will always return a visitor, even if an error occurs.
-func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
+// maybeAuthenticate reads the "Authorization" header and will try to authenticate the user
+// if it is set.
+//
+//   - If the header is not set, an IP-based visitor is returned
+//   - If the header is set, authenticate will be called to check the username/password (Basic auth),
+//     or the token (Bearer auth), and read the user from the database
+//
+// This function will ALWAYS return a visitor, even if an error occurs (e.g. unauthorized), so
+// that subsequent logging calls still have a visitor context.
+func (s *Server) maybeAuthenticate(r *http.Request) (*visitor, error) {
+	// Read "Authorization" header value, and exit out early if it's not set
 	ip := extractIPAddress(r, s.config.BehindProxy)
-	var u *user.User // may stay nil if no auth header!
-	if u, err = s.authenticate(r); err != nil {
-		logr(r).Err(err).Debug("Authentication failed: %s", err.Error())
-		err = errHTTPUnauthorized // Always return visitor, even when error occurs!
+	vip := s.visitor(ip, nil)
+	header, err := readAuthHeader(r)
+	if err != nil {
+		return vip, err
+	} else if header == "" {
+		return vip, nil
+	} else if s.userManager == nil {
+		return vip, errHTTPUnauthorized
 	}
-	v = s.visitor(ip, u)
-	v.SetUser(u)  // Update visitor user with latest from database!
-	return v, err // Always return visitor, even when error occurs!
+	// If we're trying to auth, check the rate limiter first
+	if !vip.AuthAllowed() {
+		return vip, errHTTPTooManyRequestsLimitAuthFailure // Always return visitor, even when error occurs!
+	}
+	u, err := s.authenticate(r, header)
+	if err != nil {
+		vip.AuthFailed()
+		logr(r).Err(err).Debug("Authentication failed")
+		return vip, errHTTPUnauthorized // Always return visitor, even when error occurs!
+	}
+	// Authentication with user was successful
+	return s.visitor(ip, u), nil
 }
 
 // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
 // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
 // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
-// query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
-func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
+// query param is effectively doubly base64 encoded. Its format is base64(Basic base64(user:pass)).
+func (s *Server) authenticate(r *http.Request, header string) (user *user.User, err error) {
+	if strings.HasPrefix(header, "Bearer") {
+		return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(header, "Bearer")))
+	}
+	return s.authenticateBasicAuth(r, header)
+}
+
+// readAuthHeader reads the raw value of the Authorization header, either from the actual HTTP header,
+// or from the ?auth... query parameter
+func readAuthHeader(r *http.Request) (string, error) {
 	value := strings.TrimSpace(r.Header.Get("Authorization"))
 	queryParam := readQueryParam(r, "authorization", "auth")
 	if queryParam != "" {
 		a, err := base64.RawURLEncoding.DecodeString(queryParam)
 		if err != nil {
-			return nil, err
+			return "", err
 		}
 		value = strings.TrimSpace(string(a))
 	}
-	if value == "" {
-		return nil, nil
-	} else if s.userManager == nil {
-		return nil, errHTTPUnauthorized
-	}
-	if strings.HasPrefix(value, "Bearer") {
-		return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer")))
-	}
-	return s.authenticateBasicAuth(r, value)
+	return value, nil
 }
 
 func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
@@ -1721,6 +1731,7 @@ func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
 		return s.visitors[id]
 	}
 	v.Keepalive()
+	v.SetUser(user) // Always update with the latest user, may be nil!
 	return v
 }
 

+ 1 - 0
server/server_account.go

@@ -41,6 +41,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
 	if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil {
 		return err
 	}
+	v.AccountCreated()
 	return s.writeJSON(w, newSuccessResponse())
 }
 

+ 1 - 1
server/server_firebase.go

@@ -39,7 +39,7 @@ func newFirebaseClient(sender firebaseSender, auther user.Auther) *firebaseClien
 }
 
 func (c *firebaseClient) Send(v *visitor, m *message) error {
-	if err := v.FirebaseAllowed(); err != nil {
+	if !v.FirebaseAllowed() {
 		return errFirebaseTemporarilyBanned
 	}
 	fbm, err := toFirebaseMessage(m, c.auther)

+ 12 - 0
server/server_middleware.go

@@ -1,9 +1,21 @@
 package server
 
 import (
+	"heckel.io/ntfy/util"
 	"net/http"
 )
 
+func (s *Server) limitRequests(next handleFunc) handleFunc {
+	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
+			return next(w, r, v)
+		} else if !v.RequestAllowed() {
+			return errHTTPTooManyRequestsLimitRequests
+		}
+		return next(w, r, v)
+	}
+}
+
 func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
 		if !s.config.EnableWeb {

+ 4 - 4
server/server_payments_test.go

@@ -374,13 +374,13 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
 	var wg sync.WaitGroup
 	for i := 0; i < 209; i++ {
 		wg.Add(1)
-		go func() {
+		go func(i int) {
+			defer wg.Done()
 			rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
 				"Authorization": util.BasicAuth("phil", "phil"),
 			})
-			require.Equal(t, 200, rr.Code)
-			wg.Done()
-		}()
+			require.Equal(t, 200, rr.Code, "Failed on %d", i)
+		}(i)
 	}
 	wg.Wait()
 	rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{

+ 18 - 0
server/server_test.go

@@ -733,6 +733,24 @@ func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
 	require.Equal(t, 403, response.Code) // Anonymous read not allowed
 }
 
+func TestServer_Auth_Fail_Rate_Limiting(t *testing.T) {
+	c := newTestConfigWithAuthFile(t)
+	s := newTestServer(t, c)
+
+	for i := 0; i < 10; i++ {
+		response := request(t, s, "PUT", "/announcements", "test", map[string]string{
+			"Authorization": util.BasicAuth("phil", "phil"),
+		})
+		require.Equal(t, 401, response.Code)
+	}
+
+	response := request(t, s, "PUT", "/announcements", "test", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 429, response.Code)
+	require.Equal(t, 42909, toHTTPError(t, response.Body.String()).Code)
+}
+
 func TestServer_Auth_ViaQuery(t *testing.T) {
 	c := newTestConfigWithAuthFile(t)
 	c.AuthDefault = user.PermissionDenyAll

+ 42 - 17
server/visitor.go

@@ -64,6 +64,7 @@ type visitor struct {
 	subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
 	bandwidthLimiter    *util.RateLimiter  // Limiter for attachment bandwidth downloads
 	accountLimiter      *rate.Limiter      // Rate limiter for account creation, may be nil
+	authLimiter         *rate.Limiter      // Limiter for incorrect login attempts
 	firebase            time.Time          // Next allowed Firebase message
 	seen                time.Time          // Last seen time of this visitor (needed for removal of stale visitors)
 	mu                  sync.Mutex
@@ -130,6 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
 		emailsLimiter:       nil, // Set in resetLimiters
 		bandwidthLimiter:    nil, // Set in resetLimiters
 		accountLimiter:      nil, // Set in resetLimiters, may be nil
+		authLimiter:         nil, // Set in resetLimiters, may be nil
 	}
 	v.resetLimitersNoLock(messages, emails, false)
 	return v
@@ -154,6 +156,10 @@ func (v *visitor) contextNoLock() log.Context {
 		"visitor_request_limiter_limit":  v.requestLimiter.Limit(),
 		"visitor_request_limiter_tokens": v.requestLimiter.Tokens(),
 	}
+	if v.authLimiter != nil {
+		fields["visitor_auth_limiter_limit"] = v.authLimiter.Limit()
+		fields["visitor_auth_limiter_tokens"] = v.authLimiter.Tokens()
+	}
 	if v.user != nil {
 		fields["user_id"] = v.user.ID
 		fields["user_name"] = v.user.Name
@@ -182,28 +188,16 @@ func visitorExtendedInfoContext(info *visitorInfo) log.Context {
 	}
 
 }
-func (v *visitor) RequestAllowed() error {
-	v.mu.Lock() // limiters could be replaced!
-	defer v.mu.Unlock()
-	if !v.requestLimiter.Allow() {
-		return errVisitorLimitReached
-	}
-	return nil
-}
-
-func (v *visitor) RequestLimiter() *rate.Limiter {
+func (v *visitor) RequestAllowed() bool {
 	v.mu.Lock() // limiters could be replaced!
 	defer v.mu.Unlock()
-	return v.requestLimiter
+	return v.requestLimiter.Allow()
 }
 
-func (v *visitor) FirebaseAllowed() error {
+func (v *visitor) FirebaseAllowed() bool {
 	v.mu.Lock()
 	defer v.mu.Unlock()
-	if time.Now().Before(v.firebase) {
-		return errVisitorLimitReached
-	}
-	return nil
+	return !time.Now().Before(v.firebase)
 }
 
 func (v *visitor) FirebaseTemporarilyDeny() {
@@ -230,15 +224,44 @@ func (v *visitor) SubscriptionAllowed() bool {
 	return v.subscriptionLimiter.Allow()
 }
 
+// AuthAllowed returns true if an auth request can be attempted (> 1 token available)
+func (v *visitor) AuthAllowed() bool {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
+	if v.authLimiter == nil {
+		return true
+	}
+	return v.authLimiter.Tokens() > 1
+}
+
+// AuthFailed records an auth failure
+func (v *visitor) AuthFailed() {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
+	if v.authLimiter != nil {
+		v.authLimiter.Allow()
+	}
+}
+
+// AccountCreationAllowed returns true if a new account can be created
 func (v *visitor) AccountCreationAllowed() bool {
 	v.mu.Lock() // limiters could be replaced!
 	defer v.mu.Unlock()
-	if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
+	if v.accountLimiter == nil || (v.accountLimiter != nil && v.accountLimiter.Tokens() < 1) {
 		return false
 	}
 	return true
 }
 
+// AccountCreated decreases the account limiter. This is to be called after an account was created.
+func (v *visitor) AccountCreated() {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
+	if v.accountLimiter != nil {
+		v.accountLimiter.Allow()
+	}
+}
+
 func (v *visitor) BandwidthAllowed(bytes int64) bool {
 	v.mu.Lock() // limiters could be replaced!
 	defer v.mu.Unlock()
@@ -336,8 +359,10 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool
 	v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
 	if v.user == nil {
 		v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
+		v.authLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAuthFailureLimitReplenish), v.config.VisitorAuthFailureLimitBurst)
 	} else {
 		v.accountLimiter = nil // Users cannot create accounts when logged in
+		v.authLimiter = nil    // Users are already logged in, no need to limit requests
 	}
 	if enqueueUpdate && v.user != nil {
 		go v.userManager.EnqueueStats(v.user.ID, &user.Stats{

+ 1 - 0
user/manager.go

@@ -372,6 +372,7 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
 	}
 	user, err := a.userByToken(token)
 	if err != nil {
+		log.Tag(tagManager).Field("token", token).Err(err).Trace("Authentication of token failed")
 		return nil, ErrUnauthenticated
 	}
 	user.Token = token