|
|
@@ -111,9 +111,11 @@ const (
|
|
|
last_access INT NOT NULL,
|
|
|
last_origin TEXT NOT NULL,
|
|
|
expires INT NOT NULL,
|
|
|
+ provisioned INT NOT NULL,
|
|
|
PRIMARY KEY (user_id, token),
|
|
|
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
|
|
);
|
|
|
+ CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
|
|
CREATE TABLE IF NOT EXISTS user_phone (
|
|
|
user_id TEXT NOT NULL,
|
|
|
phone_number TEXT NOT NULL,
|
|
|
@@ -181,16 +183,17 @@ const (
|
|
|
ELSE 2
|
|
|
END, user
|
|
|
`
|
|
|
- selectUserCountQuery = `SELECT COUNT(*) FROM user`
|
|
|
- updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
|
|
- updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
|
|
- updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
|
|
- updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
|
|
- updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
|
|
- updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
|
|
- updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
|
|
- deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
|
|
- deleteUserQuery = `DELETE FROM user WHERE user = ?`
|
|
|
+ selectUserCountQuery = `SELECT COUNT(*) FROM user`
|
|
|
+ selectUserIDFromUsernameQuery = `SELECT id FROM user WHERE user = ?`
|
|
|
+ updateUserPassQuery = `UPDATE user SET pass = ? WHERE user = ?`
|
|
|
+ updateUserRoleQuery = `UPDATE user SET role = ? WHERE user = ?`
|
|
|
+ updateUserProvisionedQuery = `UPDATE user SET provisioned = ? WHERE user = ?`
|
|
|
+ updateUserPrefsQuery = `UPDATE user SET prefs = ? WHERE id = ?`
|
|
|
+ updateUserStatsQuery = `UPDATE user SET stats_messages = ?, stats_emails = ?, stats_calls = ? WHERE id = ?`
|
|
|
+ updateUserStatsResetAllQuery = `UPDATE user SET stats_messages = 0, stats_emails = 0, stats_calls = 0`
|
|
|
+ updateUserDeletedQuery = `UPDATE user SET deleted = ? WHERE id = ?`
|
|
|
+ deleteUsersMarkedQuery = `DELETE FROM user WHERE deleted < ?`
|
|
|
+ deleteUserQuery = `DELETE FROM user WHERE user = ?`
|
|
|
|
|
|
upsertUserAccessQuery = `
|
|
|
INSERT INTO user_access (user_id, topic, read, write, owner_user_id, provisioned)
|
|
|
@@ -220,7 +223,7 @@ const (
|
|
|
selectUserReservationsCountQuery = `
|
|
|
SELECT COUNT(*)
|
|
|
FROM user_access
|
|
|
- WHERE user_id = owner_user_id
|
|
|
+ WHERE user_id = owner_user_id
|
|
|
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
|
|
`
|
|
|
selectUserReservationsOwnerQuery = `
|
|
|
@@ -255,17 +258,23 @@ const (
|
|
|
AND topic = ?
|
|
|
`
|
|
|
|
|
|
- selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
|
|
- selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
|
|
|
- selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
|
|
|
- insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
|
|
|
- updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
|
|
- updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
|
|
- updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
|
|
- deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
|
|
- deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
|
|
- deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
|
|
- deleteExcessTokensQuery = `
|
|
|
+ selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
|
|
+ selectTokensQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ?`
|
|
|
+ selectTokenQuery = `SELECT token, label, last_access, last_origin, expires, provisioned FROM user_token WHERE user_id = ? AND token = ?`
|
|
|
+ upsertTokenQuery = `
|
|
|
+ INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires, provisioned)
|
|
|
+ VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
|
+ ON CONFLICT (user_id, token)
|
|
|
+ DO UPDATE SET label = excluded.label, expires = excluded.expires, provisioned = excluded.provisioned;
|
|
|
+ `
|
|
|
+ updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
|
|
+ updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
|
|
+ updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
|
|
+ deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
|
|
+ deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
|
|
+ deleteTokensProvisionedQuery = `DELETE FROM user_token WHERE provisioned = 1`
|
|
|
+ deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
|
|
+ deleteExcessTokensQuery = `
|
|
|
DELETE FROM user_token
|
|
|
WHERE user_id = ?
|
|
|
AND (user_id, token) NOT IN (
|
|
|
@@ -470,7 +479,7 @@ const (
|
|
|
role,
|
|
|
prefs,
|
|
|
sync_topic,
|
|
|
- 0,
|
|
|
+ 0, -- provisioned
|
|
|
stats_messages,
|
|
|
stats_emails,
|
|
|
stats_calls,
|
|
|
@@ -480,7 +489,8 @@ const (
|
|
|
stripe_subscription_interval,
|
|
|
stripe_subscription_paid_until,
|
|
|
stripe_subscription_cancel_at,
|
|
|
- created, deleted
|
|
|
+ created,
|
|
|
+ deleted
|
|
|
FROM user_old;
|
|
|
DROP TABLE user_old;
|
|
|
|
|
|
@@ -500,10 +510,27 @@ const (
|
|
|
INSERT INTO user_access SELECT *, 0 FROM user_access_old;
|
|
|
DROP TABLE user_access_old;
|
|
|
|
|
|
+ -- Alter user_token table: Add provisioned column
|
|
|
+ ALTER TABLE user_token RENAME TO user_token_old;
|
|
|
+ CREATE TABLE IF NOT EXISTS user_token (
|
|
|
+ user_id TEXT NOT NULL,
|
|
|
+ token TEXT NOT NULL,
|
|
|
+ label TEXT NOT NULL,
|
|
|
+ last_access INT NOT NULL,
|
|
|
+ last_origin TEXT NOT NULL,
|
|
|
+ expires INT NOT NULL,
|
|
|
+ provisioned INT NOT NULL,
|
|
|
+ PRIMARY KEY (user_id, token),
|
|
|
+ FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
|
|
+ );
|
|
|
+ INSERT INTO user_token SELECT *, 0 FROM user_token_old;
|
|
|
+ DROP TABLE user_token_old;
|
|
|
+
|
|
|
-- Recreate indices
|
|
|
CREATE UNIQUE INDEX idx_user ON user (user);
|
|
|
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
|
|
|
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
|
|
|
+ CREATE UNIQUE INDEX idx_user_token ON user_token (token);
|
|
|
|
|
|
-- Re-enable foreign keys
|
|
|
PRAGMA foreign_keys=on;
|
|
|
@@ -537,7 +564,8 @@ type Config struct {
|
|
|
DefaultAccess Permission // Default permission if no ACL matches
|
|
|
ProvisionEnabled bool // Enable auto-provisioning of users and access grants, disabled for "ntfy user" commands
|
|
|
Users []*User // Predefined users to create on startup
|
|
|
- Access map[string][]*Grant // Predefined access grants to create on startup
|
|
|
+ Access map[string][]*Grant // Predefined access grants to create on startup (username -> []*Grant)
|
|
|
+ Tokens map[string][]*Token // Predefined users to create on startup (username -> []*Token)
|
|
|
QueueWriterInterval time.Duration // Interval for the async queue writer to flush stats and token updates to the database
|
|
|
BcryptCost int // Cost of generated passwords; lowering makes testing faster
|
|
|
}
|
|
|
@@ -623,15 +651,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
|
|
// CreateToken generates a random token for the given user and returns it. The token expires
|
|
|
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
|
|
// given user, if there are too many of them.
|
|
|
-func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
|
|
|
- token := util.RandomLowerStringPrefix(tokenPrefix, tokenLength) // Lowercase only to support "<topic>+<token>@<domain>" email addresses
|
|
|
- tx, err := a.db.Begin()
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- defer tx.Rollback()
|
|
|
+func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
|
|
+ return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
|
|
+ return a.createTokenTx(tx, userID, GenerateToken(), label, expires, origin, provisioned)
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func (a *Manager) createTokenTx(tx *sql.Tx, userID, token, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
|
|
access := time.Now()
|
|
|
- if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
|
|
|
+ if _, err := tx.Exec(upsertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix(), provisioned); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
rows, err := tx.Query(selectTokenCountQuery, userID)
|
|
|
@@ -653,15 +681,13 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time, origin ne
|
|
|
return nil, err
|
|
|
}
|
|
|
}
|
|
|
- if err := tx.Commit(); err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
return &Token{
|
|
|
- Value: token,
|
|
|
- Label: label,
|
|
|
- LastAccess: access,
|
|
|
- LastOrigin: origin,
|
|
|
- Expires: expires,
|
|
|
+ Value: token,
|
|
|
+ Label: label,
|
|
|
+ LastAccess: access,
|
|
|
+ LastOrigin: origin,
|
|
|
+ Expires: expires,
|
|
|
+ Provisioned: provisioned,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
@@ -698,10 +724,11 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
|
|
|
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
|
|
var token, label, lastOrigin string
|
|
|
var lastAccess, expires int64
|
|
|
+ var provisioned bool
|
|
|
if !rows.Next() {
|
|
|
return nil, ErrTokenNotFound
|
|
|
}
|
|
|
- if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
|
|
|
+ if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires, &provisioned); err != nil {
|
|
|
return nil, err
|
|
|
} else if err := rows.Err(); err != nil {
|
|
|
return nil, err
|
|
|
@@ -711,11 +738,12 @@ func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
|
|
lastOriginIP = netip.IPv4Unspecified()
|
|
|
}
|
|
|
return &Token{
|
|
|
- Value: token,
|
|
|
- Label: label,
|
|
|
- LastAccess: time.Unix(lastAccess, 0),
|
|
|
- LastOrigin: lastOriginIP,
|
|
|
- Expires: time.Unix(expires, 0),
|
|
|
+ Value: token,
|
|
|
+ Label: label,
|
|
|
+ LastAccess: time.Unix(lastAccess, 0),
|
|
|
+ LastOrigin: lastOriginIP,
|
|
|
+ Expires: time.Unix(expires, 0),
|
|
|
+ Provisioned: provisioned,
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
@@ -774,7 +802,7 @@ func (a *Manager) PhoneNumbers(userID string) ([]string, error) {
|
|
|
phoneNumbers := make([]string, 0)
|
|
|
for {
|
|
|
phoneNumber, err := a.readPhoneNumber(rows)
|
|
|
- if err == ErrPhoneNumberNotFound {
|
|
|
+ if errors.Is(err, ErrPhoneNumberNotFound) {
|
|
|
break
|
|
|
} else if err != nil {
|
|
|
return nil, err
|
|
|
@@ -1757,6 +1785,28 @@ func (a *Manager) maybeProvisionUsersAndAccess() error {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+ // Remove and (re-)add provisioned tokens
|
|
|
+ if _, err := tx.Exec(deleteTokensProvisionedQuery); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ for username, tokens := range a.config.Tokens {
|
|
|
+ _, exists := util.Find(a.config.Users, func(u *User) bool {
|
|
|
+ return u.Name == username
|
|
|
+ })
|
|
|
+ if !exists && username != Everyone {
|
|
|
+ return fmt.Errorf("user %s is not a provisioned user, refusing to add tokens", username)
|
|
|
+ }
|
|
|
+ var userID string
|
|
|
+ row := tx.QueryRow(selectUserIDFromUsernameQuery, username)
|
|
|
+ if err := row.Scan(&userID); err != nil {
|
|
|
+ return fmt.Errorf("failed to find provisioned user %s for provisioned tokens", username)
|
|
|
+ }
|
|
|
+ for _, token := range tokens {
|
|
|
+ if _, err = a.createTokenTx(tx, userID, token.Value, token.Label, time.Unix(0, 0), netip.IPv4Unspecified(), true); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
return nil
|
|
|
})
|
|
|
}
|
|
|
@@ -1974,3 +2024,22 @@ func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
|
|
}
|
|
|
return tx.Commit()
|
|
|
}
|
|
|
+
|
|
|
+// queryTx executes a function in a transaction and returns the result. If the function
|
|
|
+// returns an error, the transaction is rolled back.
|
|
|
+func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
|
|
+ tx, err := db.Begin()
|
|
|
+ if err != nil {
|
|
|
+ var zero T
|
|
|
+ return zero, err
|
|
|
+ }
|
|
|
+ defer tx.Rollback()
|
|
|
+ t, err := f(tx)
|
|
|
+ if err != nil {
|
|
|
+ return t, err
|
|
|
+ }
|
|
|
+ if err := tx.Commit(); err != nil {
|
|
|
+ return t, err
|
|
|
+ }
|
|
|
+ return t, nil
|
|
|
+}
|