binwiederhier 6 месяцев назад
Родитель
Сommit
23ec7702fc
10 измененных файлов с 263 добавлено и 88 удалено
  1. 49 7
      cmd/serve.go
  2. 23 6
      cmd/token.go
  3. 1 0
      server/config.go
  4. 1 0
      server/server.go
  5. 2 0
      server/server.yml
  6. 1 1
      server/server_account.go
  7. 1 1
      server/server_account_test.go
  8. 117 48
      user/manager.go
  9. 49 20
      user/manager_test.go
  10. 19 5
      user/types.go

+ 49 - 7
cmd/serve.go

@@ -50,6 +50,7 @@ var flagsServe = append(
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "auth-default-access", Aliases: []string{"auth_default_access", "p"}, EnvVars: []string{"NTFY_AUTH_DEFAULT_ACCESS"}, Value: "read-write", Usage: "default permissions if no matching entries in the auth database are found"}),
 	altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-users", Aliases: []string{"auth_users"}, EnvVars: []string{"NTFY_AUTH_USERS"}, Usage: "pre-provisioned declarative users"}),
 	altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-access", Aliases: []string{"auth_access"}, EnvVars: []string{"NTFY_AUTH_ACCESS"}, Usage: "pre-provisioned declarative access control entries"}),
+	altsrc.NewStringSliceFlag(&cli.StringSliceFlag{Name: "auth-tokens", Aliases: []string{"auth_tokens"}, EnvVars: []string{"NTFY_AUTH_TOKENS"}, Usage: "pre-provisioned declarative access tokens"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-cache-dir", Aliases: []string{"attachment_cache_dir"}, EnvVars: []string{"NTFY_ATTACHMENT_CACHE_DIR"}, Usage: "cache directory for attached files"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-total-size-limit", Aliases: []string{"attachment_total_size_limit", "A"}, EnvVars: []string{"NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentTotalSizeLimit), Usage: "limit of the on-disk attachment cache"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "attachment-file-size-limit", Aliases: []string{"attachment_file_size_limit", "Y"}, EnvVars: []string{"NTFY_ATTACHMENT_FILE_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultAttachmentFileSizeLimit), Usage: "per-file attachment size limit (e.g. 300k, 2M, 100M)"}),
@@ -158,6 +159,7 @@ func execServe(c *cli.Context) error {
 	authDefaultAccess := c.String("auth-default-access")
 	authUsersRaw := c.StringSlice("auth-users")
 	authAccessRaw := c.StringSlice("auth-access")
+	authTokensRaw := c.StringSlice("auth-tokens")
 	attachmentCacheDir := c.String("attachment-cache-dir")
 	attachmentTotalSizeLimitStr := c.String("attachment-total-size-limit")
 	attachmentFileSizeLimitStr := c.String("attachment-file-size-limit")
@@ -361,6 +363,10 @@ func execServe(c *cli.Context) error {
 	if err != nil {
 		return err
 	}
+	authTokens, err := parseTokens(authUsers, authTokensRaw)
+	if err != nil {
+		return err
+	}
 
 	// Special case: Unset default
 	if listenHTTP == "-" {
@@ -418,6 +424,7 @@ func execServe(c *cli.Context) error {
 	conf.AuthDefault = authDefault
 	conf.AuthUsers = authUsers
 	conf.AuthAccess = authAccess
+	conf.AuthTokens = authTokens
 	conf.AttachmentCacheDir = attachmentCacheDir
 	conf.AttachmentTotalSizeLimit = attachmentTotalSizeLimit
 	conf.AttachmentFileSizeLimit = attachmentFileSizeLimit
@@ -532,7 +539,7 @@ func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
 }
 
 func parseUsers(usersRaw []string) ([]*user.User, error) {
-	provisionUsers := make([]*user.User, 0)
+	users := make([]*user.User, 0)
 	for _, userLine := range usersRaw {
 		parts := strings.Split(userLine, ":")
 		if len(parts) != 3 {
@@ -548,19 +555,19 @@ func parseUsers(usersRaw []string) ([]*user.User, error) {
 		} else if !user.AllowedRole(role) {
 			return nil, fmt.Errorf("invalid auth-users: %s, role %s is not allowed, allowed roles are 'admin' or 'user'", userLine, role)
 		}
-		provisionUsers = append(provisionUsers, &user.User{
+		users = append(users, &user.User{
 			Name:        username,
 			Hash:        passwordHash,
 			Role:        role,
 			Provisioned: true,
 		})
 	}
-	return provisionUsers, nil
+	return users, nil
 }
 
-func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[string][]*user.Grant, error) {
+func parseAccess(users []*user.User, accessRaw []string) (map[string][]*user.Grant, error) {
 	access := make(map[string][]*user.Grant)
-	for _, accessLine := range provisionAccessRaw {
+	for _, accessLine := range accessRaw {
 		parts := strings.Split(accessLine, ":")
 		if len(parts) != 3 {
 			return nil, fmt.Errorf("invalid auth-access: %s, expected format: 'user:topic:permission'", accessLine)
@@ -569,7 +576,7 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
 		if username == userEveryone {
 			username = user.Everyone
 		}
-		provisionUser, exists := util.Find(provisionUsers, func(u *user.User) bool {
+		u, exists := util.Find(users, func(u *user.User) bool {
 			return u.Name == username
 		})
 		if username != user.Everyone {
@@ -577,7 +584,7 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
 				return nil, fmt.Errorf("invalid auth-access: %s, user %s is not provisioned", accessLine, username)
 			} else if !user.AllowedUsername(username) {
 				return nil, fmt.Errorf("invalid auth-access: %s, username %s invalid", accessLine, username)
-			} else if provisionUser.Role != user.RoleUser {
+			} else if u.Role != user.RoleUser {
 				return nil, fmt.Errorf("invalid auth-access: %s, user %s is not a regular user, only regular users can have ACL entries", accessLine, username)
 			}
 		}
@@ -601,6 +608,41 @@ func parseAccess(provisionUsers []*user.User, provisionAccessRaw []string) (map[
 	return access, nil
 }
 
+func parseTokens(users []*user.User, tokensRaw []string) (map[string][]*user.Token, error) {
+	tokens := make(map[string][]*user.Token)
+	for _, tokenLine := range tokensRaw {
+		parts := strings.Split(tokenLine, ":")
+		if len(parts) < 2 || len(parts) > 3 {
+			return nil, fmt.Errorf("invalid auth-tokens: %s, expected format: 'user:token[:label]'", tokenLine)
+		}
+		username := strings.TrimSpace(parts[0])
+		_, exists := util.Find(users, func(u *user.User) bool {
+			return u.Name == username
+		})
+		if !exists {
+			return nil, fmt.Errorf("invalid auth-tokens: %s, user %s is not provisioned", tokenLine, username)
+		} else if !user.AllowedUsername(username) {
+			return nil, fmt.Errorf("invalid auth-tokens: %s, username %s invalid", tokenLine, username)
+		}
+		token := strings.TrimSpace(parts[1])
+		if !user.AllowedToken(token) {
+			return nil, fmt.Errorf("invalid auth-tokens: %s, token %s invalid, use 'ntfy token generate' to generate a random token", tokenLine, token)
+		}
+		var label string
+		if len(parts) > 2 {
+			label = parts[2]
+		}
+		if _, exists := tokens[username]; !exists {
+			tokens[username] = make([]*user.Token, 0)
+		}
+		tokens[username] = append(tokens[username], &user.Token{
+			Value: token,
+			Label: label,
+		})
+	}
+	return tokens, nil
+}
+
 func reloadLogLevel(inputSource altsrc.InputSourceContext) error {
 	newLevelStr, err := inputSource.String("log-level")
 	if err != nil {

+ 23 - 6
cmd/token.go

@@ -72,6 +72,15 @@ Example:
 This is a server-only command. It directly reads from user.db as defined in the server config
 file server.yml. The command only works if 'auth-file' is properly defined.`,
 		},
+		{
+			Name:   "generate",
+			Usage:  "Generates a random token",
+			Action: execTokenGenerate,
+			Description: `Randomly generate a token to be used in provisioned tokens.
+
+This command only generates the token value, but does not persist it anywhere.
+The output can be used in the 'auth-tokens' config option.`,
+		},
 	},
 	Description: `Manage access tokens for individual users.
 
@@ -112,12 +121,12 @@ func execTokenAdd(c *cli.Context) error {
 		return err
 	}
 	u, err := manager.User(username)
-	if err == user.ErrUserNotFound {
+	if errors.Is(err, user.ErrUserNotFound) {
 		return fmt.Errorf("user %s does not exist", username)
 	} else if err != nil {
 		return err
 	}
-	token, err := manager.CreateToken(u.ID, label, expires, netip.IPv4Unspecified())
+	token, err := manager.CreateToken(u.ID, label, expires, netip.IPv4Unspecified(), false)
 	if err != nil {
 		return err
 	}
@@ -141,7 +150,7 @@ func execTokenDel(c *cli.Context) error {
 		return err
 	}
 	u, err := manager.User(username)
-	if err == user.ErrUserNotFound {
+	if errors.Is(err, user.ErrUserNotFound) {
 		return fmt.Errorf("user %s does not exist", username)
 	} else if err != nil {
 		return err
@@ -165,7 +174,7 @@ func execTokenList(c *cli.Context) error {
 	var users []*user.User
 	if username != "" {
 		u, err := manager.User(username)
-		if err == user.ErrUserNotFound {
+		if errors.Is(err, user.ErrUserNotFound) {
 			return fmt.Errorf("user %s does not exist", username)
 		} else if err != nil {
 			return err
@@ -191,7 +200,7 @@ func execTokenList(c *cli.Context) error {
 		usersWithTokens++
 		fmt.Fprintf(c.App.ErrWriter, "user %s\n", u.Name)
 		for _, t := range tokens {
-			var label, expires string
+			var label, expires, provisioned string
 			if t.Label != "" {
 				label = fmt.Sprintf(" (%s)", t.Label)
 			}
@@ -200,7 +209,10 @@ func execTokenList(c *cli.Context) error {
 			} else {
 				expires = fmt.Sprintf("expires %s", t.Expires.Format(time.RFC822))
 			}
-			fmt.Fprintf(c.App.ErrWriter, "- %s%s, %s, accessed from %s at %s\n", t.Value, label, expires, t.LastOrigin.String(), t.LastAccess.Format(time.RFC822))
+			if t.Provisioned {
+				provisioned = " (server config)"
+			}
+			fmt.Fprintf(c.App.ErrWriter, "- %s%s, %s, accessed from %s at %s%s\n", t.Value, label, expires, t.LastOrigin.String(), t.LastAccess.Format(time.RFC822), provisioned)
 		}
 	}
 	if usersWithTokens == 0 {
@@ -208,3 +220,8 @@ func execTokenList(c *cli.Context) error {
 	}
 	return nil
 }
+
+func execTokenGenerate(c *cli.Context) error {
+	fmt.Println(user.GenerateToken())
+	return nil
+}

+ 1 - 0
server/config.go

@@ -97,6 +97,7 @@ type Config struct {
 	AuthDefault                          user.Permission
 	AuthUsers                            []*user.User
 	AuthAccess                           map[string][]*user.Grant
+	AuthTokens                           map[string][]*user.Token
 	AuthBcryptCost                       int
 	AuthStatsQueueWriterInterval         time.Duration
 	AttachmentCacheDir                   string

+ 1 - 0
server/server.go

@@ -203,6 +203,7 @@ func New(conf *Config) (*Server, error) {
 			ProvisionEnabled:    true, // Enable provisioning of users and access
 			Users:               conf.AuthUsers,
 			Access:              conf.AuthAccess,
+			Tokens:              conf.AuthTokens,
 			BcryptCost:          conf.AuthBcryptCost,
 			QueueWriterInterval: conf.AuthStatsQueueWriterInterval,
 		}

+ 2 - 0
server/server.yml

@@ -86,6 +86,8 @@
 #   Each entry is in the format "<username>:<password-hash>:<role>", e.g. "phil:$2a$10$YLiO8U21sX1uhZamTLJXHuxgVC0Z/GKISibrKCLohPgtG7yIxSk4C:user"
 # - auth-access is a list of access control entries that are automatically created when the server starts.
 #   Each entry is in the format "<username>:<topic-pattern>:<access>", e.g. "phil:mytopic:rw" or "phil:phil-*:rw".
+# - auth-tokens is a list of access tokens that are automatically created when the server starts.
+#   Each entry is in the format "<username>:<token>[:<label>]", e.g. "phil:tk_1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef:My token".
 #
 # Debian/RPM package users:
 #   Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package

+ 1 - 1
server/server_account.go

@@ -234,7 +234,7 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request
 			"token_expires": expires,
 		}).
 		Debug("Creating token for user %s", u.Name)
-	token, err := s.userManager.CreateToken(u.ID, label, expires, v.IP())
+	token, err := s.userManager.CreateToken(u.ID, label, expires, v.IP(), false)
 	if err != nil {
 		return err
 	}

+ 1 - 1
server/server_account_test.go

@@ -176,7 +176,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
 
 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
 	u, _ := s.userManager.User("phil")
-	token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
+	token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified(), false)
 
 	rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{
 		"Authorization": util.BasicAuth("phil", "phil"),

+ 117 - 48
user/manager.go

@@ -111,9 +111,11 @@ const (
 			last_access INT NOT NULL,
 			last_origin TEXT NOT NULL,
 			expires INT NOT NULL,
+			provisioned INT NOT NULL,
 			PRIMARY KEY (user_id, token),
 			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
 		);
+		CREATE UNIQUE INDEX idx_user_token ON user_token (token);
 		CREATE TABLE IF NOT EXISTS user_phone (
 			user_id TEXT NOT NULL,
 			phone_number TEXT NOT NULL,
@@ -181,16 +183,17 @@ const (
 				ELSE 2
 			END, user
 	`
-	selectUserCountQuery         = `SELECT COUNT(*) FROM user`
-	updateUserPassQuery          = `UPDATE user SET pass = ? WHERE user = ?`
-	updateUserRoleQuery          = `UPDATE user SET role = ? WHERE user = ?`
-	updateUserProvisionedQuery   = `UPDATE user SET provisioned = ? WHERE user = ?`
-	updateUserPrefsQuery         = `UPDATE user SET prefs = ? WHERE id = ?`
-	updateUserStatsQuery         = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
-	updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
-	updateUserDeletedQuery       = `UPDATE user SET deleted = ? WHERE id = ?`
-	deleteUsersMarkedQuery       = `DELETE FROM user WHERE deleted < ?`
-	deleteUserQuery              = `DELETE FROM user WHERE user = ?`
+	selectUserCountQuery          = `SELECT COUNT(*) FROM user`
+	selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
+	updateUserPassQuery           = `UPDATE user SET pass = ? WHERE user = ?`
+	updateUserRoleQuery           = `UPDATE user SET role = ? WHERE user = ?`
+	updateUserProvisionedQuery    = `UPDATE user SET provisioned = ? WHERE user = ?`
+	updateUserPrefsQuery          = `UPDATE user SET prefs = ? WHERE id = ?`
+	updateUserStatsQuery          = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
+	updateUserStatsResetAllQuery  = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
+	updateUserDeletedQuery        = `UPDATE user SET deleted = ? WHERE id = ?`
+	deleteUsersMarkedQuery        = `DELETE FROM user WHERE deleted < ?`
+	deleteUserQuery               = `DELETE FROM user WHERE user = ?`
 
 	upsertUserAccessQuery = `
 		INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
@@ -220,7 +223,7 @@ const (
 	selectUserReservationsCountQuery = `
 		SELECT COUNT(*)
 		FROM user_access
-		WHERE user_id = owner_user_id 
+		WHERE user_id = owner_user_id
 		  AND owner_user_id = (SELECT id FROM user WHERE user = ?)
 	`
 	selectUserReservationsOwnerQuery = `
@@ -255,17 +258,23 @@ const (
 	   	  AND topic = ?
   	`
 
-	selectTokenCountQuery      = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
-	selectTokensQuery          = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
-	selectTokenQuery           = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
-	insertTokenQuery           = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
-	updateTokenExpiryQuery     = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
-	updateTokenLabelQuery      = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
-	updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
-	deleteTokenQuery           = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
-	deleteAllTokenQuery        = `DELETE FROM user_token WHERE user_id = ?`
-	deleteExpiredTokensQuery   = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
-	deleteExcessTokensQuery    = `
+	selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
+	selectTokensQuery     = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
+	selectTokenQuery      = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
+	upsertTokenQuery      = `
+		INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
+		VALUES (?, ?, ?, ?, ?, ?, ?)
+		ON CONFLICT (user_id, token)
+		DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
+	`
+	updateTokenExpiryQuery       = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
+	updateTokenLabelQuery        = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
+	updateTokenLastAccessQuery   = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
+	deleteTokenQuery             = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
+	deleteAllTokenQuery          = `DELETE FROM user_token WHERE user_id = ?`
+	deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1`
+	deleteExpiredTokensQuery     = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
+	deleteExcessTokensQuery      = `
 		DELETE FROM user_token
 		WHERE user_id = ?
 		  AND (user_id, token) NOT IN (
@@ -470,7 +479,7 @@ const (
 		    role,
 		    prefs,
 		    sync_topic,
-		    0,
+		    0, -- provisioned
 		    stats_messages,
 		    stats_emails,
 		    stats_calls,
@@ -480,7 +489,8 @@ const (
 		    stripe_subscription_interval,
 		    stripe_subscription_paid_until,
 		    stripe_subscription_cancel_at,
-		    created, deleted
+		    created,
+		    deleted
 		FROM user_old;
 		DROP TABLE user_old;
 
@@ -500,10 +510,27 @@ const (
 		INSERT INTO user_access SELECT *, 0 FROM user_access_old;
 		DROP TABLE user_access_old;
 
+		-- Alter user_token table: Add provisioned column
+		ALTER TABLE user_token RENAME TO user_token_old;
+		CREATE TABLE IF NOT EXISTS user_token (
+			user_id TEXT NOT NULL,
+			token TEXT NOT NULL,
+			label TEXT NOT NULL,
+			last_access INT NOT NULL,
+			last_origin TEXT NOT NULL,
+			expires INT NOT NULL,
+			provisioned INT NOT NULL,
+			PRIMARY KEY (user_id, token),
+			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
+		);
+		INSERT INTO user_token SELECT *, 0 FROM user_token_old;
+		DROP TABLE user_token_old;
+
 		-- Recreate indices
 		CREATE UNIQUE INDEX idx_user ON user (user);
 		CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
 		CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
+		CREATE UNIQUE INDEX idx_user_token ON user_token (token);
 
 		-- Re-enable foreign keys
 		PRAGMA foreign_keys=on;
@@ -537,7 +564,8 @@ type Config struct {
 	DefaultAccess       Permission          // Default permission if no ACL matches
 	ProvisionEnabled    bool                // Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
 	Users               []*User             // Predefined users to create on startup
-	Access              map[string][]*Grant // Predefined access grants to create on startup
+	Access              map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
+	Tokens              map[string][]*Token // Predefined users to create on startup (username -> []*Token)
 	QueueWriterInterval time.Duration       // Interval for the async queue writer to flush stats and token updates to the database
 	BcryptCost          int                 // Cost of generated passwords; lowering makes testing faster
 }
@@ -623,15 +651,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
 // CreateToken generates a random token for the given user and returns it. The token expires
 // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
 // given user, if there are too many of them.
-func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
-	token := util.RandomLowerStringPrefix(tokenPrefix, tokenLength) // Lowercase only to support "<topic>+<token>@<domain>" email addresses
-	tx, err := a.db.Begin()
-	if err != nil {
-		return nil, err
-	}
-	defer tx.Rollback()
+func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
+	return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
+		return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned)
+	})
+}
+
+func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
 	access := time.Now()
-	if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
+	if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil {
 		return nil, err
 	}
 	rows, err := tx.Query(selectTokenCountQuery, userID)
@@ -653,15 +681,13 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time, origin ne
 			return nil, err
 		}
 	}
-	if err := tx.Commit(); err != nil {
-		return nil, err
-	}
 	return &Token{
-		Value:      token,
-		Label:      label,
-		LastAccess: access,
-		LastOrigin: origin,
-		Expires:    expires,
+		Value:       token,
+		Label:       label,
+		LastAccess:  access,
+		LastOrigin:  origin,
+		Expires:     expires,
+		Provisioned: provisioned,
 	}, nil
 }
 
@@ -698,10 +724,11 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
 func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
 	var token, label, lastOrigin string
 	var lastAccess, expires int64
+	var provisioned bool
 	if !rows.Next() {
 		return nil, ErrTokenNotFound
 	}
-	if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
+	if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
 		return nil, err
 	} else if err := rows.Err(); err != nil {
 		return nil, err
@@ -711,11 +738,12 @@ func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
 		lastOriginIP = netip.IPv4Unspecified()
 	}
 	return &Token{
-		Value:      token,
-		Label:      label,
-		LastAccess: time.Unix(lastAccess, 0),
-		LastOrigin: lastOriginIP,
-		Expires:    time.Unix(expires, 0),
+		Value:       token,
+		Label:       label,
+		LastAccess:  time.Unix(lastAccess, 0),
+		LastOrigin:  lastOriginIP,
+		Expires:     time.Unix(expires, 0),
+		Provisioned: provisioned,
 	}, nil
 }
 
@@ -774,7 +802,7 @@ func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
 	phoneNumbers := make([]string, 0)
 	for {
 		phoneNumber, err := a.readPhoneNumber(rows)
-		if err == ErrPhoneNumberNotFound {
+		if errors.Is(err, ErrPhoneNumberNotFound) {
 			break
 		} else if err != nil {
 			return nil, err
@@ -1757,6 +1785,28 @@ func (a *Manager) maybeProvisionUsersAndAccess() error {
 				}
 			}
 		}
+		// Remove and (re-)add provisioned tokens
+		if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil {
+			return err
+		}
+		for username, tokens := range a.config.Tokens {
+			_, exists := util.Find(a.config.Users, func(u *User) bool {
+				return u.Name == username
+			})
+			if !exists && username != Everyone {
+				return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
+			}
+			var userID string
+			row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
+			if err := row.Scan(&userID); err != nil {
+				return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
+			}
+			for _, token := range tokens {
+				if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
+					return err
+				}
+			}
+		}
 		return nil
 	})
 }
@@ -1974,3 +2024,22 @@ func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
 	}
 	return tx.Commit()
 }
+
+// queryTx executes a function in a transaction and returns the result. If the function
+// returns an error, the transaction is rolled back.
+func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
+	tx, err := db.Begin()
+	if err != nil {
+		var zero T
+		return zero, err
+	}
+	defer tx.Rollback()
+	t, err := f(tx)
+	if err != nil {
+		return t, err
+	}
+	if err := tx.Commit(); err != nil {
+		return t, err
+	}
+	return t, nil
+}

+ 49 - 20
user/manager_test.go

@@ -194,7 +194,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
 	require.Nil(t, err)
 	require.False(t, u.Deleted)
 
-	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
+	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 
 	u, err = a.Authenticate("user", "pass")
@@ -241,7 +241,7 @@ func TestManager_CreateToken_Only_Lower(t *testing.T) {
 	u, err := a.User("user")
 	require.Nil(t, err)
 
-	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
+	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 	require.Equal(t, token.Value, strings.ToLower(token.Value))
 }
@@ -523,7 +523,7 @@ func TestManager_Token_Valid(t *testing.T) {
 	require.Nil(t, err)
 
 	// Create token for user
-	token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
+	token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 	require.NotEmpty(t, token.Value)
 	require.Equal(t, "some label", token.Label)
@@ -586,12 +586,12 @@ func TestManager_Token_Expire(t *testing.T) {
 	require.Nil(t, err)
 
 	// Create tokens for user
-	token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
+	token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 	require.NotEmpty(t, token1.Value)
 	require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix())
 
-	token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
+	token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 	require.NotEmpty(t, token2.Value)
 	require.NotEqual(t, token1.Value, token2.Value)
@@ -638,7 +638,7 @@ func TestManager_Token_Extend(t *testing.T) {
 	require.Equal(t, errNoTokenProvided, err)
 
 	// Create token for user
-	token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
+	token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 	require.NotEmpty(t, token.Value)
 
@@ -668,12 +668,12 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
 
 	// Create 2 tokens for phil
 	philTokens := make([]string, 0)
-	token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
+	token, err := a.CreateToken(phil.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 	require.NotEmpty(t, token.Value)
 	philTokens = append(philTokens, token.Value)
 
-	token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
+	token, err = a.CreateToken(phil.ID, "", time.Unix(0, 0), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 	require.NotEmpty(t, token.Value)
 	philTokens = append(philTokens, token.Value)
@@ -682,7 +682,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
 	baseTime := time.Now().Add(24 * time.Hour)
 	benTokens := make([]string, 0)
 	for i := 0; i < 62; i++ { //
-		token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
+		token, err := a.CreateToken(ben.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified(), false)
 		require.Nil(t, err)
 		require.NotEmpty(t, token.Value)
 		benTokens = append(benTokens, token.Value)
@@ -795,7 +795,7 @@ func TestManager_EnqueueTokenUpdate(t *testing.T) {
 	u, err := a.User("ben")
 	require.Nil(t, err)
 
-	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
+	token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified(), false)
 	require.Nil(t, err)
 
 	// Queue token update
@@ -1112,6 +1112,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 				{TopicPattern: "secret", Permission: PermissionRead},
 			},
 		},
+		Tokens: map[string][]*Token{
+			"philuser": {
+				{Value: "tk_op56p8lz5bf3cxkz9je99v9oc37lo", Label: "Alerts token"},
+			},
+		},
 	}
 	a, err := NewManager(conf)
 	require.Nil(t, err)
@@ -1123,24 +1128,29 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 	users, err := a.Users()
 	require.Nil(t, err)
 	require.Len(t, users, 4)
-
 	require.Equal(t, "philadmin", users[0].Name)
 	require.Equal(t, RoleAdmin, users[0].Role)
-
 	require.Equal(t, "philmanual", users[1].Name)
 	require.Equal(t, RoleUser, users[1].Role)
+	require.Equal(t, "philuser", users[2].Name)
+	require.Equal(t, RoleUser, users[2].Role)
+	require.Equal(t, "*", users[3].Name)
+	provisionedUserID := users[2].ID // "philuser" is the provisioned user
 
 	grants, err := a.Grants("philuser")
 	require.Nil(t, err)
-	require.Equal(t, "philuser", users[2].Name)
-	require.Equal(t, RoleUser, users[2].Role)
 	require.Equal(t, 2, len(grants))
 	require.Equal(t, "secret", grants[0].TopicPattern)
 	require.Equal(t, PermissionRead, grants[0].Permission)
 	require.Equal(t, "stats", grants[1].TopicPattern)
 	require.Equal(t, PermissionReadWrite, grants[1].Permission)
 
-	require.Equal(t, "*", users[3].Name)
+	tokens, err := a.Tokens(provisionedUserID)
+	require.Nil(t, err)
+	require.Equal(t, 1, len(tokens))
+	require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc37lo", tokens[0].Value)
+	require.Equal(t, "Alerts token", tokens[0].Label)
+	require.True(t, tokens[0].Provisioned)
 
 	// Re-open the DB (second app start)
 	require.Nil(t, a.db.Close())
@@ -1153,6 +1163,11 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 			{TopicPattern: "secret12", Permission: PermissionRead},
 		},
 	}
+	conf.Tokens = map[string][]*Token{
+		"philuser": {
+			{Value: "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", Label: "Alerts token updated"},
+		},
+	}
 	a, err = NewManager(conf)
 	require.Nil(t, err)
 
@@ -1160,30 +1175,36 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 	users, err = a.Users()
 	require.Nil(t, err)
 	require.Len(t, users, 3)
-
 	require.Equal(t, "philmanual", users[0].Name)
+	require.Equal(t, "philuser", users[1].Name)
+	require.Equal(t, RoleUser, users[1].Role)
 	require.Equal(t, RoleUser, users[0].Role)
+	require.Equal(t, "*", users[2].Name)
 
 	grants, err = a.Grants("philuser")
 	require.Nil(t, err)
-	require.Equal(t, "philuser", users[1].Name)
-	require.Equal(t, RoleUser, users[1].Role)
 	require.Equal(t, 2, len(grants))
 	require.Equal(t, "secret12", grants[0].TopicPattern)
 	require.Equal(t, PermissionRead, grants[0].Permission)
 	require.Equal(t, "stats12", grants[1].TopicPattern)
 	require.Equal(t, PermissionReadWrite, grants[1].Permission)
 
-	require.Equal(t, "*", users[2].Name)
+	tokens, err = a.Tokens(provisionedUserID)
+	require.Nil(t, err)
+	require.Equal(t, 1, len(tokens))
+	require.Equal(t, "tk_op56p8lz5bf3cxkz9je99v9oc3XXX", tokens[0].Value)
+	require.Equal(t, "Alerts token updated", tokens[0].Label)
+	require.True(t, tokens[0].Provisioned)
 
 	// Re-open the DB again (third app start)
 	require.Nil(t, a.db.Close())
 	conf.Users = []*User{}
 	conf.Access = map[string][]*Grant{}
+	conf.Tokens = map[string][]*Token{}
 	a, err = NewManager(conf)
 	require.Nil(t, err)
 
-	// Check that the provisioned users are there
+	// Check that the provisioned users are all gone
 	users, err = a.Users()
 	require.Nil(t, err)
 	require.Len(t, users, 2)
@@ -1191,6 +1212,14 @@ func TestManager_WithProvisionedUsers(t *testing.T) {
 	require.Equal(t, "philmanual", users[0].Name)
 	require.Equal(t, RoleUser, users[0].Role)
 	require.Equal(t, "*", users[1].Name)
+
+	grants, err = a.Grants("philuser")
+	require.Nil(t, err)
+	require.Equal(t, 0, len(grants))
+
+	tokens, err = a.Tokens(provisionedUserID)
+	require.Nil(t, err)
+	require.Equal(t, 0, len(tokens))
 }
 
 func TestManager_UpdateNonProvisionedUsersToProvisionedUsers(t *testing.T) {

+ 19 - 5
user/types.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 	"github.com/stripe/stripe-go/v74"
 	"heckel.io/ntfy/v2/log"
+	"heckel.io/ntfy/v2/util"
 	"net/netip"
 	"regexp"
 	"strings"
@@ -59,11 +60,12 @@ type Auther interface {
 
 // Token represents a user token, including expiry date
 type Token struct {
-	Value      string
-	Label      string
-	LastAccess time.Time
-	LastOrigin netip.Addr
-	Expires    time.Time
+	Value       string
+	Label       string
+	LastAccess  time.Time
+	LastOrigin  netip.Addr
+	Expires     time.Time
+	Provisioned bool
 }
 
 // TokenUpdate holds information about the last access time and origin IP address of a token
@@ -247,6 +249,7 @@ var (
 	allowedTopicRegex        = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No '*'
 	allowedTopicPatternRegex = regexp.MustCompile(`^[-_*A-Za-z0-9]{1,64}$`) // Adds '*' for wildcards!
 	allowedTierRegex         = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)
+	allowedTokenRegex        = regexp.MustCompile(`^tk_[-_A-Za-z0-9]{29}$`) // Must be tokenLength-len(tokenPrefix)
 )
 
 // AllowedRole returns true if the given role can be used for new users
@@ -282,6 +285,17 @@ func AllowedPasswordHash(hash string) error {
 	return nil
 }
 
+// AllowedToken returns true if the given token matches the naming convention
+func AllowedToken(token string) bool {
+	return allowedTokenRegex.MatchString(token)
+}
+
+// GenerateToken generates a new token with a prefix and a fixed length
+// Lowercase only to support "<topic>+<token>@<domain>" email addresses
+func GenerateToken() string {
+	return util.RandomLowerStringPrefix(tokenPrefix, tokenLength)
+}
+
 // Error constants used by the package
 var (
 	ErrUnauthenticated     = errors.New("unauthenticated")