Просмотр исходного кода

Make sure tokens are updated instead of deleted/re-added

binwiederhier 6 месяцев назад
Родитель
Сommit
9f987e66fa
3 измененных файлов с 186 добавлено и 89 удалено
  1. 4 1
      cmd/webpush_test.go
  2. 159 85
      user/manager.go
  3. 23 3
      user/manager_test.go

+ 4 - 1
cmd/webpush_test.go

@@ -1,6 +1,7 @@
 package cmd
 
 import (
+	"path/filepath"
 	"testing"
 
 	"github.com/stretchr/testify/require"
@@ -15,10 +16,12 @@ func TestCLI_WebPush_GenerateKeys(t *testing.T) {
 }
 
 func TestCLI_WebPush_WriteKeysToFile(t *testing.T) {
+	tempDir := t.TempDir()
+	t.Chdir(tempDir)
 	app, _, _, stderr := newTestApp()
 	require.Nil(t, runWebPushCommand(app, server.NewConfig(), "keys", "--output-file=key-file.yaml"))
 	require.Contains(t, stderr.String(), "Web Push keys written to key-file.yaml")
-	require.FileExists(t, "key-file.yaml")
+	require.FileExists(t, filepath.Join(tempDir, "key-file.yaml"))
 }
 
 func runWebPushCommand(app *cli.App, conf *server.Config, args ...string) error {

+ 159 - 85
user/manager.go

@@ -13,6 +13,7 @@ import (
 	"heckel.io/ntfy/v2/util"
 	"net/netip"
 	"path/filepath"
+	"slices"
 	"strings"
 	"sync"
 	"time"
@@ -258,23 +259,24 @@ const (
 	   	  AND topic = ?
   	`
 
-	selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
-	selectTokensQuery     = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
-	selectTokenQuery      = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
-	upsertTokenQuery      = `
+	selectTokenCountQuery           = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
+	selectTokensQuery               = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
+	selectTokenQuery                = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
+	selectAllProvisionedTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE provisioned = 1`
+	upsertTokenQuery                = `
 		INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
 		VALUES (?, ?, ?, ?, ?, ?, ?)
 		ON CONFLICT (user_id, token)
 		DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
 	`
-	updateTokenExpiryQuery       = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
-	updateTokenLabelQuery        = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
-	updateTokenLastAccessQuery   = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
-	deleteTokenQuery             = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
-	deleteAllTokenQuery          = `DELETE FROM user_token WHERE user_id = ?`
-	deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1`
-	deleteExpiredTokensQuery     = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
-	deleteExcessTokensQuery      = `
+	updateTokenExpiryQuery      = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
+	updateTokenLabelQuery       = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
+	updateTokenLastAccessQuery  = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
+	deleteTokenQuery            = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
+	deleteProvisionedTokenQuery = `DELETE FROM user_token WHERE token = ?`
+	deleteAllTokenQuery         = `DELETE FROM user_token WHERE user_id = ?`
+	deleteExpiredTokensQuery    = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
+	deleteExcessTokensQuery     = `
 		DELETE FROM user_token
 		WHERE user_id = ?
 		  AND (user_id, token) NOT IN (
@@ -711,6 +713,25 @@ func (a *Manager) Tokens(userID string) ([]*Token, error) {
 	return tokens, nil
 }
 
+func (a *Manager) allProvisionedTokens() ([]*Token, error) {
+	rows, err := a.db.Query(selectAllProvisionedTokensQuery)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+	tokens := make([]*Token, 0)
+	for {
+		token, err := a.readToken(rows)
+		if errors.Is(err, ErrTokenNotFound) {
+			break
+		} else if err != nil {
+			return nil, err
+		}
+		tokens = append(tokens, token)
+	}
+	return tokens, nil
+}
+
 // Token returns a specific token for a user
 func (a *Manager) Token(userID, token string) (*Token, error) {
 	rows, err := a.db.Query(selectTokenQuery, userID, token)
@@ -775,10 +796,16 @@ func (a *Manager) ChangeToken(userID, token string, label *string, expires *time
 
 // RemoveToken deletes the token defined in User.Token
 func (a *Manager) RemoveToken(userID, token string) error {
+	return execTx(a.db, func(tx *sql.Tx) error {
+		return a.removeTokenTx(tx, userID, token)
+	})
+}
+
+func (a *Manager) removeTokenTx(tx *sql.Tx, userID, token string) error {
 	if token == "" {
 		return errNoTokenProvided
 	}
-	if _, err := a.db.Exec(deleteTokenQuery, userID, token); err != nil {
+	if _, err := tx.Exec(deleteTokenQuery, userID, token); err != nil {
 		return err
 	}
 	return nil
@@ -952,13 +979,20 @@ func (a *Manager) writeTokenUpdateQueue() error {
 	log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
 	for tokenID, update := range tokenQueue {
 		log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
-		if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
+		if err := a.updateTokenLastAccessTx(tx, tokenID, update.LastAccess.Unix(), update.LastOrigin.String()); err != nil {
 			return err
 		}
 	}
 	return tx.Commit()
 }
 
+func (a *Manager) updateTokenLastAccessTx(tx *sql.Tx, token string, lastAccess int64, lastOrigin string) error {
+	if _, err := tx.Exec(updateTokenLastAccessQuery, lastAccess, lastOrigin, token); err != nil {
+		return err
+	}
+	return nil
+}
+
 // Authorize returns nil if the given user has access to the given topic using the desired
 // permission. The user param may be nil to signal an anonymous user.
 func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
@@ -1706,7 +1740,7 @@ func (a *Manager) maybeProvisionUsersAndAccess() error {
 	if !a.config.ProvisionEnabled {
 		return nil
 	}
-	users, err := a.Users()
+	existingUsers, err := a.Users()
 	if err != nil {
 		return err
 	}
@@ -1714,92 +1748,132 @@ func (a *Manager) maybeProvisionUsersAndAccess() error {
 		return u.Name
 	})
 	return execTx(a.db, func(tx *sql.Tx) error {
-		// Remove users that are provisioned, but not in the config anymore
-		for _, user := range users {
-			if user.Name == Everyone {
-				continue
-			} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
-				if err := a.removeUserTx(tx, user.Name); err != nil {
-					return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
-				}
+		if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
+			return fmt.Errorf("failed to provision users: %v", err)
+		}
+		if err := a.maybeProvisionGrants(tx); err != nil {
+			return fmt.Errorf("failed to provision grants: %v", err)
+		}
+		if err := a.maybeProvisionTokens(tx, provisionUsernames); err != nil {
+			return fmt.Errorf("failed to provision tokens: %v", err)
+		}
+		return nil
+	})
+}
+
+// maybeProvisionUsers checks if the users in the config are provisioned, and adds or updates them.
+// It also removes users that are provisioned, but not in the config anymore.
+func (a *Manager) maybeProvisionUsers(tx *sql.Tx, provisionUsernames []string, existingUsers []*User) error {
+	// Remove users that are provisioned, but not in the config anymore
+	for _, user := range existingUsers {
+		if user.Name == Everyone {
+			continue
+		} else if user.Provisioned && !util.Contains(provisionUsernames, user.Name) {
+			if err := a.removeUserTx(tx, user.Name); err != nil {
+				return fmt.Errorf("failed to remove provisioned user %s: %v", user.Name, err)
 			}
 		}
-		// Add or update provisioned users
-		for _, user := range a.config.Users {
-			if user.Name == Everyone {
-				continue
+	}
+	// Add or update provisioned users
+	for _, user := range a.config.Users {
+		if user.Name == Everyone {
+			continue
+		}
+		existingUser, exists := util.Find(existingUsers, func(u *User) bool {
+			return u.Name == user.Name
+		})
+		if !exists {
+			if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
+				return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
 			}
-			existingUser, exists := util.Find(users, func(u *User) bool {
-				return u.Name == user.Name
-			})
-			if !exists {
-				if err := a.addUserTx(tx, user.Name, user.Hash, user.Role, true, true); err != nil && !errors.Is(err, ErrUserExists) {
-					return fmt.Errorf("failed to add provisioned user %s: %v", user.Name, err)
+		} else {
+			if !existingUser.Provisioned {
+				if err := a.changeProvisionedTx(tx, user.Name, true); err != nil {
+					return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
 				}
-			} else {
-				if !existingUser.Provisioned {
-					if err := a.changeProvisionedTx(tx, user.Name, true); err != nil {
-						return fmt.Errorf("failed to change provisioned status for user %s: %v", user.Name, err)
-					}
-				}
-				if existingUser.Hash != user.Hash {
-					if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil {
-						return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
-					}
+			}
+			if existingUser.Hash != user.Hash {
+				if err := a.changePasswordTx(tx, user.Name, user.Hash, true); err != nil {
+					return fmt.Errorf("failed to change password for provisioned user %s: %v", user.Name, err)
 				}
-				if existingUser.Role != user.Role {
-					if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
-						return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
-					}
+			}
+			if existingUser.Role != user.Role {
+				if err := a.changeRoleTx(tx, user.Name, user.Role); err != nil {
+					return fmt.Errorf("failed to change role for provisioned user %s: %v", user.Name, err)
 				}
 			}
 		}
-		// Remove and (re-)add provisioned grants
-		if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil {
-			return err
+	}
+	return nil
+}
+
+// maybyProvisionGrants removes all provisioned grants, and (re-)adds the grants from the config.
+//
+// Unlike users and tokens, grants can be just re-added, because they do not carry any state (such as last
+// access time) or do not have dependent resources (such as grants or tokens).
+func (a *Manager) maybeProvisionGrants(tx *sql.Tx) error {
+	// Remove all provisioned grants
+	if _, err := tx.Exec(deleteUserAccessProvisionedQuery); err != nil {
+		return err
+	}
+	// (Re-)add provisioned grants
+	for username, grants := range a.config.Access {
+		user, exists := util.Find(a.config.Users, func(u *User) bool {
+			return u.Name == username
+		})
+		if !exists && username != Everyone {
+			return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username)
+		} else if user != nil && user.Role == RoleAdmin {
+			return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
 		}
-		for username, grants := range a.config.Access {
-			user, exists := util.Find(a.config.Users, func(u *User) bool {
-				return u.Name == username
-			})
-			if !exists && username != Everyone {
-				return fmt.Errorf("user %s is not a provisioned user, refusing to add ACL entry", username)
-			} else if user != nil && user.Role == RoleAdmin {
-				return fmt.Errorf("adding access control entries is not allowed for admin roles for user %s", username)
+		for _, grant := range grants {
+			if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil {
+				return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
 			}
-			for _, grant := range grants {
-				if err := a.resetAccessTx(tx, username, grant.TopicPattern); err != nil {
-					return fmt.Errorf("failed to reset access for user %s and topic %s: %v", username, grant.TopicPattern, err)
-				}
-				if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
-					return err
-				}
+			if err := a.allowAccessTx(tx, username, grant.TopicPattern, grant.Permission, true); err != nil {
+				return err
 			}
 		}
-		// Remove and (re-)add provisioned tokens
-		if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil {
-			return err
+	}
+	return nil
+}
+
+func (a *Manager) maybeProvisionTokens(tx *sql.Tx, provisionUsernames []string) error {
+	// Remove tokens that are provisioned, but not in the config anymore
+	existingTokens, err := a.allProvisionedTokens()
+	if err != nil {
+		return fmt.Errorf("failed to retrieve existing provisioned tokens: %v", err)
+	}
+	var provisionTokens []string
+	for _, userTokens := range a.config.Tokens {
+		for _, token := range userTokens {
+			provisionTokens = append(provisionTokens, token.Value)
 		}
-		for username, tokens := range a.config.Tokens {
-			_, exists := util.Find(a.config.Users, func(u *User) bool {
-				return u.Name == username
-			})
-			if !exists && username != Everyone {
-				return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
-			}
-			var userID string
-			row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
-			if err := row.Scan(&userID); err != nil {
-				return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
+	}
+	for _, existingToken := range existingTokens {
+		if !slices.Contains(provisionTokens, existingToken.Value) {
+			if _, err := tx.Exec(deleteProvisionedTokenQuery, existingToken.Value); err != nil {
+				return fmt.Errorf("failed to remove provisioned token %s: %v", existingToken.Value, err)
 			}
-			for _, token := range tokens {
-				if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
-					return err
-				}
+		}
+	}
+	// (Re-)add provisioned tokens
+	for username, tokens := range a.config.Tokens {
+		if !slices.Contains(provisionUsernames, username) && username != Everyone {
+			return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
+		}
+		var userID string
+		row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
+		if err := row.Scan(&userID); err != nil {
+			return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
+		}
+		for _, token := range tokens {
+			if _, err := a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
+				return err
 			}
 		}
-		return nil
-	})
+	}
+	return nil
 }
 
 // toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,

+ 23 - 3
user/manager_test.go

@@ -1152,6 +1152,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 	require.Equal(t, "Alerts token", tokens[0].Label)
 	require.True(t, tokens[0].Provisioned)
 
+	// Update the token last access time and origin (so we can check that it is persisted)
+	lastAccessTime := time.Now().Add(time.Hour)
+	lastOrigin := netip.MustParseAddr("1.1.9.9")
+	err = execTx(a.db, func(tx *sql.Tx) error {
+		return a.updateTokenLastAccessTx(tx, tokens[0].Value, lastAccessTime.Unix(), lastOrigin.String())
+	})
+	require.Nil(t, err)
+
 	// Re-open the DB (second app start)
 	require.Nil(t, a.db.Close())
 	conf.Users = []*User{
@@ -1165,7 +1173,8 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 	}
 	conf.Tokens = map[string][]*Token{
 		"philuser": {
-			{Value: "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", Label: "Alerts token updated"},
+			{Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Alerts token updated"},
+			{Value: "tk_u48wqendnkx9er21pqqcadlytbutx", Label: "Another token"},
 		},
 	}
 	a, err = NewManager(conf)
@@ -1191,10 +1200,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 
 	tokens, err = a.Tokens(provisionedUserID)
 	require.Nil(t, err)
-	require.Equal(t, 1, len(tokens))
-	require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", tokens[0].Value)
+	require.Equal(t, 2, len(tokens))
+	require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc37lo", tokens[0].Value)
 	require.Equal(t, "Alerts token updated", tokens[0].Label)
+	require.Equal(t, lastAccessTime.Unix(), tokens[0].LastAccess.Unix())
+	require.Equal(t, lastOrigin, tokens[0].LastOrigin)
 	require.True(t, tokens[0].Provisioned)
+	require.Equal(t, "tk_u48wqendnkx9er21pqqcadlytbutx", tokens[1].Value)
+	require.Equal(t, "Another token", tokens[1].Label)
 
 	// Re-open the DB again (third app start)
 	require.Nil(t, a.db.Close())
@@ -1220,6 +1233,13 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 	tokens, err = a.Tokens(provisionedUserID)
 	require.Nil(t, err)
 	require.Equal(t, 0, len(tokens))
+
+	var count int
+	a.db.QueryRow("SELECT COUNT(*) FROM user WHERE provisioned = 1").Scan(&count)
+	require.Equal(t, 0, count)
+	a.db.QueryRow("SELECT COUNT(*) FROM user_grant WHERE provisioned = 1").Scan(&count)
+	require.Equal(t, 0, count)
+	a.db.QueryRow("SELECT COUNT(*) FROM user_token WHERE provisioned = 1").Scan(&count)
 }
 
 func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {