binwiederhier 3 лет назад
Родитель
Сommit
8bf64d8723
5 измененных файлов с 169 добавлено и 6 удалено
  1. 1 1
      server/server_account.go
  2. 2 2
      user/manager.go
  3. 116 1
      user/manager_test.go
  4. 3 2
      user/types.go
  5. 47 0
      user/types_test.go

+ 1 - 1
server/server_account.go

@@ -447,7 +447,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
 	// Check if we are allowed to reserve this topic
 	if u.IsUser() && u.Tier == nil {
 		return errHTTPUnauthorized
-	} else if err := s.userManager.CheckAllowAccess(u.Name, req.Topic); err != nil {
+	} else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil {
 		return errHTTPConflictTopicReserved
 	} else if u.IsUser() {
 		hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)

+ 2 - 2
user/manager.go

@@ -1017,9 +1017,9 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
 	return nil
 }
 
-// CheckAllowAccess tests if a user may create an access control entry for the given topic.
+// AllowReservation tests if a user may create an access control entry for the given topic.
 // If there are any ACL entries that are not owned by the user, an error is returned.
-func (a *Manager) CheckAllowAccess(username string, topic string) error {
+func (a *Manager) AllowReservation(username string, topic string) error {
 	if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
 		return ErrInvalidArgument
 	}

+ 116 - 1
user/manager_test.go

@@ -106,6 +106,30 @@ func TestManager_AddUser_Timing(t *testing.T) {
 	require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
 }
 
+func TestManager_AddUser_And_Query(t *testing.T) {
+	a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
+	require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
+	require.Nil(t, a.ChangeBilling("user", &Billing{
+		StripeCustomerID:            "acct_123",
+		StripeSubscriptionID:        "sub_123",
+		StripeSubscriptionStatus:    "active",
+		StripeSubscriptionPaidUntil: time.Now().Add(time.Hour),
+		StripeSubscriptionCancelAt:  time.Unix(0, 0),
+	}))
+
+	u, err := a.User("user")
+	require.Nil(t, err)
+	require.Equal(t, "user", u.Name)
+
+	u2, err := a.UserByID(u.ID)
+	require.Nil(t, err)
+	require.Equal(t, u.Name, u2.Name)
+
+	u3, err := a.UserByStripeCustomer("acct_123")
+	require.Nil(t, err)
+	require.Equal(t, u.ID, u3.ID)
+}
+
 func TestManager_Authenticate_Timing(t *testing.T) {
 	a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
 	require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
@@ -311,6 +335,7 @@ func TestManager_ChangeRole(t *testing.T) {
 
 func TestManager_Reservations(t *testing.T) {
 	a := newTestManager(t, PermissionDenyAll)
+	require.Nil(t, a.AddUser("phil", "phil", RoleUser))
 	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
 	require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
 	require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
@@ -329,6 +354,32 @@ func TestManager_Reservations(t *testing.T) {
 		Owner:    PermissionReadWrite,
 		Everyone: PermissionDenyAll,
 	}, reservations[1])
+
+	b, err := a.HasReservation("ben", "readme")
+	require.Nil(t, err)
+	require.True(t, b)
+
+	b, err = a.HasReservation("notben", "readme")
+	require.Nil(t, err)
+	require.False(t, b)
+
+	b, err = a.HasReservation("ben", "something-else")
+	require.Nil(t, err)
+	require.False(t, b)
+
+	count, err := a.ReservationsCount("ben")
+	require.Nil(t, err)
+	require.Equal(t, int64(2), count)
+
+	count, err = a.ReservationsCount("phil")
+	require.Nil(t, err)
+	require.Equal(t, int64(0), count)
+
+	err = a.AllowReservation("phil", "readme")
+	require.Equal(t, errTopicOwnedByOthers, err)
+
+	err = a.AllowReservation("phil", "not-reserved")
+	require.Nil(t, err)
 }
 
 func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
@@ -414,11 +465,24 @@ func TestManager_Token_Valid(t *testing.T) {
 	require.Equal(t, token.Value, token2.Value)
 	require.Equal(t, "some label", token2.Label)
 
+	tokens, err := a.Tokens(u.ID)
+	require.Nil(t, err)
+	require.Equal(t, 1, len(tokens))
+	require.Equal(t, "some label", tokens[0].Label)
+
+	tokens, err = a.Tokens("u_notauser")
+	require.Nil(t, err)
+	require.Equal(t, 0, len(tokens))
+
 	// Remove token and auth again
 	require.Nil(t, a.RemoveToken(u2.ID, u2.Token))
 	u3, err := a.AuthenticateToken(token.Value)
 	require.Equal(t, ErrUnauthenticated, err)
 	require.Nil(t, u3)
+
+	tokens, err = a.Tokens(u.ID)
+	require.Nil(t, err)
+	require.Equal(t, 0, len(tokens))
 }
 
 func TestManager_Token_Invalid(t *testing.T) {
@@ -434,6 +498,12 @@ func TestManager_Token_Invalid(t *testing.T) {
 	require.Equal(t, ErrUnauthenticated, err)
 }
 
+func TestManager_Token_NotFound(t *testing.T) {
+	a := newTestManager(t, PermissionDenyAll)
+	_, err := a.Token("u_bla", "notfound")
+	require.Equal(t, ErrTokenNotFound, err)
+}
+
 func TestManager_Token_Expire(t *testing.T) {
 	a := newTestManager(t, PermissionDenyAll)
 	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
@@ -552,7 +622,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
 	require.Equal(t, 20, count)
 }
 
-func TestManager_EnqueueStats(t *testing.T) {
+func TestManager_EnqueueStats_ResetStats(t *testing.T) {
 	a, err := NewManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, bcrypt.MinCost, 1500*time.Millisecond)
 	require.Nil(t, err)
 	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
@@ -580,6 +650,51 @@ func TestManager_EnqueueStats(t *testing.T) {
 	require.Nil(t, err)
 	require.Equal(t, int64(11), u.Stats.Messages)
 	require.Equal(t, int64(2), u.Stats.Emails)
+
+	// Now reset stats (enqueued stats will be thrown out)
+	a.EnqueueUserStats(u.ID, &Stats{
+		Messages: 99,
+		Emails:   23,
+	})
+	require.Nil(t, a.ResetStats())
+
+	u, err = a.User("ben")
+	require.Nil(t, err)
+	require.Equal(t, int64(0), u.Stats.Messages)
+	require.Equal(t, int64(0), u.Stats.Emails)
+}
+
+func TestManager_EnqueueTokenUpdate(t *testing.T) {
+	a, err := NewManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, bcrypt.MinCost, 500*time.Millisecond)
+	require.Nil(t, err)
+	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
+
+	// Create user and token
+	u, err := a.User("ben")
+	require.Nil(t, err)
+
+	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
+	require.Nil(t, err)
+
+	// Queue token update
+	a.EnqueueTokenUpdate(token.Value, &TokenUpdate{
+		LastAccess: time.Unix(111, 0).UTC(),
+		LastOrigin: netip.MustParseAddr("1.2.3.3"),
+	})
+
+	// Token has not changed yet.
+	token2, err := a.Token(u.ID, token.Value)
+	require.Nil(t, err)
+	require.Equal(t, token.LastAccess.Unix(), token2.LastAccess.Unix())
+	require.Equal(t, token.LastOrigin, token2.LastOrigin)
+
+	// After a second or so they should be persisted
+	time.Sleep(time.Second)
+
+	token3, err := a.Token(u.ID, token.Value)
+	require.Nil(t, err)
+	require.Equal(t, time.Unix(111, 0).UTC().Unix(), token3.LastAccess.Unix())
+	require.Equal(t, netip.MustParseAddr("1.2.3.3"), token3.LastOrigin)
 }
 
 func TestManager_ChangeSettings(t *testing.T) {

+ 3 - 2
user/types.go

@@ -6,6 +6,7 @@ import (
 	"heckel.io/ntfy/log"
 	"net/netip"
 	"regexp"
+	"strings"
 	"time"
 )
 
@@ -97,7 +98,7 @@ type Tier struct {
 func (t *Tier) Context() log.Context {
 	return log.Context{
 		"tier_id":         t.ID,
-		"tier_name":       t.Name,
+		"tier_code":       t.Code,
 		"stripe_price_id": t.StripePriceID,
 	}
 }
@@ -170,7 +171,7 @@ func NewPermission(read, write bool) Permission {
 
 // ParsePermission parses the string representation and returns a Permission
 func ParsePermission(s string) (Permission, error) {
-	switch s {
+	switch strings.ToLower(s) {
 	case "read-write", "rw":
 		return NewPermission(true, true), nil
 	case "read-only", "read", "ro":

+ 47 - 0
user/types_test.go

@@ -10,4 +10,51 @@ func TestPermission(t *testing.T) {
 	require.Equal(t, PermissionRead, NewPermission(true, false))
 	require.Equal(t, PermissionWrite, NewPermission(false, true))
 	require.Equal(t, PermissionDenyAll, NewPermission(false, false))
+	require.True(t, PermissionReadWrite.IsReadWrite())
+	require.True(t, PermissionReadWrite.IsRead())
+	require.True(t, PermissionReadWrite.IsWrite())
+	require.True(t, PermissionRead.IsRead())
+	require.True(t, PermissionWrite.IsWrite())
+}
+
+func TestParsePermission(t *testing.T) {
+	_, err := ParsePermission("no")
+	require.NotNil(t, err)
+
+	p, err := ParsePermission("read-write")
+	require.Nil(t, err)
+	require.Equal(t, PermissionReadWrite, p)
+
+	p, err = ParsePermission("rw")
+	require.Nil(t, err)
+	require.Equal(t, PermissionReadWrite, p)
+
+	p, err = ParsePermission("read-only")
+	require.Nil(t, err)
+	require.Equal(t, PermissionRead, p)
+
+	p, err = ParsePermission("WRITE")
+	require.Nil(t, err)
+	require.Equal(t, PermissionWrite, p)
+
+	p, err = ParsePermission("deny-all")
+	require.Nil(t, err)
+	require.Equal(t, PermissionDenyAll, p)
+}
+
+func TestAllowedTier(t *testing.T) {
+	require.False(t, AllowedTier("  no"))
+	require.True(t, AllowedTier("yes"))
+}
+
+func TestTierContext(t *testing.T) {
+	tier := &Tier{
+		ID:            "ti_abc",
+		Code:          "pro",
+		StripePriceID: "price_123",
+	}
+	context := tier.Context()
+	require.Equal(t, "ti_abc", context["tier_id"])
+	require.Equal(t, "pro", context["tier_code"])
+	require.Equal(t, "price_123", context["stripe_price_id"])
 }