ソースを参照

Code review (round 1)

binwiederhier 3 年 前
コミット
b37cf02a6e
12 ファイル変更80 行追加25 行削除
  1. 2 2
      cmd/tier.go
  2. 6 2
      log/event.go
  3. 1 1
      log/log.go
  4. 6 0
      log/log_test.go
  5. 3 1
      log/types.go
  6. 2 2
      server/config.go
  7. 1 0
      server/server.go
  8. 25 7
      server/server.yml
  9. 1 1
      user/types.go
  10. 13 0
      user/types_test.go
  11. 13 9
      util/lookup_cache.go
  12. 7 0
      util/util_test.go

+ 2 - 2
cmd/tier.go

@@ -33,7 +33,7 @@ var (
 var cmdTier = &cli.Command{
 var cmdTier = &cli.Command{
 	Name:      "tier",
 	Name:      "tier",
 	Usage:     "Manage/show tiers",
 	Usage:     "Manage/show tiers",
-	UsageText: "ntfy tier [list|add|remove] ...",
+	UsageText: "ntfy tier [list|add|change|remove] ...",
 	Flags:     flagsTier,
 	Flags:     flagsTier,
 	Before:    initConfigFileInputSourceFunc("config", flagsUser, initLogFunc),
 	Before:    initConfigFileInputSourceFunc("config", flagsUser, initLogFunc),
 	Category:  categoryServer,
 	Category:  categoryServer,
@@ -58,7 +58,7 @@ var cmdTier = &cli.Command{
 			},
 			},
 			Description: `Add a new tier to the ntfy user database.
 			Description: `Add a new tier to the ntfy user database.
 
 
-Tiers can be used to grant users higher limits based, such as daily message limits, attachment size, or
+Tiers can be used to grant users higher limits, such as daily message limits, attachment size, or
 make it possible for users to reserve topics.
 make it possible for users to reserve topics.
 
 
 This is a server-only command. It directly reads from the user.db as defined in the server config
 This is a server-only command. It directly reads from the user.db as defined in the server config

+ 6 - 2
log/event.go

@@ -197,8 +197,12 @@ func (e *Event) globalLevelWithOverride() Level {
 	}
 	}
 	for field, override := range ov {
 	for field, override := range ov {
 		value, exists := e.fields[field]
 		value, exists := e.fields[field]
-		if exists && value == override.value {
-			return override.level
+		if exists {
+			if value == override.value {
+				return override.level
+			} else if fmt.Sprintf("%v", value) == override.value {
+				return override.level
+			}
 		}
 		}
 	}
 	}
 	return l
 	return l

+ 1 - 1
log/log.go

@@ -93,7 +93,7 @@ func SetLevel(newLevel Level) {
 }
 }
 
 
 // SetLevelOverride adds a log override for the given field
 // SetLevelOverride adds a log override for the given field
-func SetLevelOverride(field string, value any, level Level) {
+func SetLevelOverride(field string, value string, level Level) {
 	mu.Lock()
 	mu.Lock()
 	defer mu.Unlock()
 	defer mu.Unlock()
 	overrides[field] = &levelOverride{value: value, level: level}
 	overrides[field] = &levelOverride{value: value, level: level}

+ 6 - 0
log/log_test.go

@@ -29,6 +29,7 @@ func TestLog_TagContextFieldFields(t *testing.T) {
 	SetOutput(&out)
 	SetOutput(&out)
 	SetFormat(JSONFormat)
 	SetFormat(JSONFormat)
 	SetLevelOverride("tag", "stripe", DebugLevel)
 	SetLevelOverride("tag", "stripe", DebugLevel)
+	SetLevelOverride("number", "5", DebugLevel)
 
 
 	Tag("mytag").
 	Tag("mytag").
 		Field("field2", 123).
 		Field("field2", 123).
@@ -49,8 +50,13 @@ func TestLog_TagContextFieldFields(t *testing.T) {
 		Time(time.Unix(456, 123000000).UTC()).
 		Time(time.Unix(456, 123000000).UTC()).
 		Debug("Subscription status %s", "active")
 		Debug("Subscription status %s", "active")
 
 
+	Field("number", 5).
+		Time(time.Unix(777, 001000000).UTC()).
+		Debug("The number 5 is an int, but the level override is a string")
+
 	expected := `{"time":"1970-01-01T00:02:03.999Z","level":"INFO","message":"hi there phil","field1":"value1","field2":123,"tag":"mytag"}
 	expected := `{"time":"1970-01-01T00:02:03.999Z","level":"INFO","message":"hi there phil","field1":"value1","field2":123,"tag":"mytag"}
 {"time":"1970-01-01T00:07:36.123Z","level":"DEBUG","message":"Subscription status active","error":"some error","error_code":123,"stripe_customer_id":"acct_123","stripe_subscription_id":"sub_123","tag":"stripe","user_id":"u_abc","visitor_ip":"1.2.3.4"}
 {"time":"1970-01-01T00:07:36.123Z","level":"DEBUG","message":"Subscription status active","error":"some error","error_code":123,"stripe_customer_id":"acct_123","stripe_subscription_id":"sub_123","tag":"stripe","user_id":"u_abc","visitor_ip":"1.2.3.4"}
+{"time":"1970-01-01T00:12:57Z","level":"DEBUG","message":"The number 5 is an int, but the level override is a string","number":5}
 `
 `
 	require.Equal(t, expected, out.String())
 	require.Equal(t, expected, out.String())
 }
 }

+ 3 - 1
log/types.go

@@ -55,6 +55,8 @@ func ToLevel(s string) Level {
 		return WarnLevel
 		return WarnLevel
 	case "ERROR":
 	case "ERROR":
 		return ErrorLevel
 		return ErrorLevel
+	case "FATAL":
+		return FatalLevel
 	default:
 	default:
 		return InfoLevel
 		return InfoLevel
 	}
 	}
@@ -101,6 +103,6 @@ type Contexter interface {
 type Context map[string]any
 type Context map[string]any
 
 
 type levelOverride struct {
 type levelOverride struct {
-	value any
+	value string
 	level Level
 	level Level
 }
 }

+ 2 - 2
server/config.go

@@ -19,7 +19,7 @@ const (
 	DefaultFirebaseKeepaliveInterval            = 3 * time.Hour    // ~control topic (Android), not too frequently to save battery
 	DefaultFirebaseKeepaliveInterval            = 3 * time.Hour    // ~control topic (Android), not too frequently to save battery
 	DefaultFirebasePollInterval                 = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
 	DefaultFirebasePollInterval                 = 20 * time.Minute // ~poll topic (iOS), max. 2-3 times per hour (see docs)
 	DefaultFirebaseQuotaExceededPenaltyDuration = 10 * time.Minute // Time that over-users are locked out of Firebase if it returns "quota exceeded"
 	DefaultFirebaseQuotaExceededPenaltyDuration = 10 * time.Minute // Time that over-users are locked out of Firebase if it returns "quota exceeded"
-	DefaultStripePriceCacheDuration             = time.Hour        // Time to keep Stripe prices cached in memory before a refresh is needed
+	DefaultStripePriceCacheDuration             = 3 * time.Hour    // Time to keep Stripe prices cached in memory before a refresh is needed
 )
 )
 
 
 // Defines all global and per-visitor limits
 // Defines all global and per-visitor limits
@@ -150,7 +150,7 @@ func NewConfig() *Config {
 		CacheBatchTimeout:                    0,
 		CacheBatchTimeout:                    0,
 		AuthFile:                             "",
 		AuthFile:                             "",
 		AuthStartupQueries:                   "",
 		AuthStartupQueries:                   "",
-		AuthDefault:                          user.NewPermission(true, true),
+		AuthDefault:                          user.PermissionReadWrite,
 		AuthBcryptCost:                       user.DefaultUserPasswordBcryptCost,
 		AuthBcryptCost:                       user.DefaultUserPasswordBcryptCost,
 		AuthStatsQueueWriterInterval:         user.DefaultUserStatsQueueWriterInterval,
 		AuthStatsQueueWriterInterval:         user.DefaultUserStatsQueueWriterInterval,
 		AttachmentCacheDir:                   "",
 		AttachmentCacheDir:                   "",

+ 1 - 0
server/server.go

@@ -39,6 +39,7 @@ import (
   - api
   - api
 - HIGH Self-review
 - HIGH Self-review
 - MEDIUM: Test for expiring messages after reservation removal
 - MEDIUM: Test for expiring messages after reservation removal
+- MEDIUM: disallowed-topics
 - MEDIUM: Test new token endpoints & never-expiring token
 - MEDIUM: Test new token endpoints & never-expiring token
 - LOW: UI: Flickering upgrade banner when logging in
 - LOW: UI: Flickering upgrade banner when logging in
 
 

+ 25 - 7
server/server.yml

@@ -233,13 +233,31 @@
 # stripe-secret-key:
 # stripe-secret-key:
 # stripe-webhook-key:
 # stripe-webhook-key:
 
 
-# Log level, can be "trace", "debug", "info", "warn" or "error"
-# This option can be hot-reloaded by calling "kill -HUP $pid" or "systemctl reload ntfy".
-#
-# FIXME
-#
-# Be aware that "debug" (and particularly "trace"") can be VERY CHATTY. Only turn them on for
-# debugging purposes, or your disk will fill up quickly.
+# Logging options
+#
+# By default, ntfy logs to the console (stderr), with a "info" log level, and in a human-readable text format.
+# ntfy supports five different log levels, can also write to a file, log as JSON, and even supports granular
+# log level overrides for easier debugging. Some options (log-level and log-level-overrides) can be hot reloaded
+# by calling "kill -HUP $pid" or "systemctl reload ntfy".
+#
+# - log-format defines the output format, can be "text" (default) or "json"
+# - log-file is a filename to write logs to. If this is not set, ntfy logs to stderr.
+# - log-level defines the default log level, can be one of "trace", "debug", "info" (default), "warn" or "error".
+#   Be aware that "debug" (and particularly "trace") can be VERY CHATTY. Only turn them on briefly for debugging purposes.
+# - log-level-overrides lets you override the log level if certain fields match. This is incredibly powerful
+#   for debugging certain parts of the system (e.g. only the account management, or only a certain visitor).
+#   This is an array of strings in the format "field=value -> level", e.g. "tag=manager -> trace".
+#   Warning: Using log-level-overrides has a performance penalty. Only use it for temporary debugging.
+#
+# Example (good for production):
+#   log-level: info
+#   log-format: json
+#   log-file: /var/log/ntfy.log
+#
+# Example level overrides (for debugging, only use temporarily):
+#   log-level-overrides:
+#      - "tag=manager -> trace"
+#      - "visitor_ip=1.2.3.4 -> debug"
 #
 #
 # log-level: info
 # log-level: info
 # log-level-overrides:
 # log-level-overrides:

+ 1 - 1
user/types.go

@@ -40,7 +40,7 @@ func (u *User) Admin() bool {
 
 
 // User returns true if the user is a regular user, not an admin
 // User returns true if the user is a regular user, not an admin
 func (u *User) User() bool {
 func (u *User) User() bool {
-	return !u.Admin()
+	return u != nil && u.Role == RoleUser
 }
 }
 
 
 // Auther is an interface for authentication and authorization
 // Auther is an interface for authentication and authorization

+ 13 - 0
user/types_test.go

@@ -0,0 +1,13 @@
+package user
+
+import (
+	"github.com/stretchr/testify/require"
+	"testing"
+)
+
+func TestPermission(t *testing.T) {
+	require.Equal(t, PermissionReadWrite, NewPermission(true, true))
+	require.Equal(t, PermissionRead, NewPermission(true, false))
+	require.Equal(t, PermissionWrite, NewPermission(false, true))
+	require.Equal(t, PermissionDenyAll, NewPermission(false, false))
+}

+ 13 - 9
util/lookup_cache.go

@@ -10,14 +10,14 @@ import (
 //
 //
 // Example:
 // Example:
 //
 //
-//	    lookup := func() (string, error) {
-//		   r, _ := http.Get("...")
-//		   s, _ := io.ReadAll(r.Body)
-//		   return string(s), nil
-//		}
-//		c := NewLookupCache[string](lookup, time.Hour)
-//		fmt.Println(c.Get()) // Fetches the string via HTTP
-//		fmt.Println(c.Get()) // Uses cached value
+//	lookup := func() (string, error) {
+//	   r, _ := http.Get("...")
+//	   s, _ := io.ReadAll(r.Body)
+//	   return string(s), nil
+//	}
+//	c := NewLookupCache[string](lookup, time.Hour)
+//	fmt.Println(c.Get()) // Fetches the string via HTTP
+//	fmt.Println(c.Get()) // Uses cached value
 type LookupCache[T any] struct {
 type LookupCache[T any] struct {
 	value   *T
 	value   *T
 	lookup  func() (T, error)
 	lookup  func() (T, error)
@@ -26,8 +26,12 @@ type LookupCache[T any] struct {
 	mu      sync.Mutex
 	mu      sync.Mutex
 }
 }
 
 
+// LookupFunc is a function that is called by the LookupCache if the underlying
+// value is out-of-date. It returns the new value, or an error.
+type LookupFunc[T any] func() (T, error)
+
 // NewLookupCache creates a new LookupCache with a given time-to-live (TTL)
 // NewLookupCache creates a new LookupCache with a given time-to-live (TTL)
-func NewLookupCache[T any](lookup func() (T, error), ttl time.Duration) *LookupCache[T] {
+func NewLookupCache[T any](lookup LookupFunc[T], ttl time.Duration) *LookupCache[T] {
 	return &LookupCache[T]{
 	return &LookupCache[T]{
 		value:  nil,
 		value:  nil,
 		lookup: lookup,
 		lookup: lookup,

+ 7 - 0
util/util_test.go

@@ -2,6 +2,7 @@ package util
 
 
 import (
 import (
 	"errors"
 	"errors"
+	"golang.org/x/time/rate"
 	"io"
 	"io"
 	"net/netip"
 	"net/netip"
 	"os"
 	"os"
@@ -245,6 +246,12 @@ func TestMinMax(t *testing.T) {
 	require.Equal(t, 50, MinMax(50, 10, 99))
 	require.Equal(t, 50, MinMax(50, 10, 99))
 }
 }
 
 
+func TestMax(t *testing.T) {
+	require.Equal(t, 9, Max(1, 9))
+	require.Equal(t, 9, Max(9, 1))
+	require.Equal(t, rate.Every(time.Minute), Max(rate.Every(time.Hour), rate.Every(time.Minute)))
+}
+
 func TestPointerFunctions(t *testing.T) {
 func TestPointerFunctions(t *testing.T) {
 	i, s, ti := Int(99), String("abc"), Time(time.Unix(99, 0))
 	i, s, ti := Int(99), String("abc"), Time(time.Unix(99, 0))
 	require.Equal(t, 99, *i)
 	require.Equal(t, 99, *i)