Browse Source

Fix all the tests

binwiederhier 3 years ago
parent
commit
a2e474c375
4 changed files with 44 additions and 35 deletions
  1. 1 2
      server/server.go
  2. 38 28
      user/manager.go
  3. 4 4
      user/manager_test.go
  4. 1 1
      user/types.go

+ 1 - 2
server/server.go

@@ -40,8 +40,6 @@ import (
 		reserve topics
 		purge accounts that were not logged into in X
 		reset daily limits for users
-		"user list" shows * twice
-		"ntfy access everyone user4topic <bla>" twice -> UNIQUE constraint error
 		Account usage not updated "in real time"
 		Attachment expiration based on plan
 		Plan: Keep 10000 messages or keep X days?
@@ -66,6 +64,7 @@ import (
 		- Expire tokens
 		- userManager can be nil
 		- visitor with/without user
+		- userManager.<NEWSTUFF>
 */
 
 // Server is the main server, providing the UI and API for ntfy

+ 38 - 28
user/manager.go

@@ -93,16 +93,30 @@ const (
 
 // Manager-related queries
 const (
-	insertUserQuery         = `INSERT INTO user (user, pass, role) VALUES (?, ?, ?)`
-	selectUsernamesQuery    = `SELECT user FROM user ORDER BY role, user`
+	insertUserQuery      = `INSERT INTO user (user, pass, role) VALUES (?, ?, ?)`
+	selectUsernamesQuery = `
+		SELECT user 
+		FROM user 
+		ORDER BY
+			CASE role
+				WHEN 'admin' THEN 1
+				WHEN 'anonymous' THEN 3
+				ELSE 2
+			END, user
+	`
 	updateUserPassQuery     = `UPDATE user SET pass = ? WHERE user = ?`
 	updateUserRoleQuery     = `UPDATE user SET role = ? WHERE user = ?`
 	updateUserSettingsQuery = `UPDATE user SET settings = ? WHERE user = ?`
 	updateUserStatsQuery    = `UPDATE user SET messages = ?, emails = ? WHERE user = ?`
 	deleteUserQuery         = `DELETE FROM user WHERE user = ?`
 
-	upsertUserAccessQuery  = `INSERT INTO user_access (user_id, topic, read, write) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?)`
-	selectUserAccessQuery  = `SELECT topic, read, write FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
+	upsertUserAccessQuery = `
+		INSERT INTO user_access (user_id, topic, read, write) 
+		VALUES ((SELECT id FROM user WHERE user = ?), ?, ?, ?)
+		ON CONFLICT (user_id, topic) 
+		DO UPDATE SET read=excluded.read, write=excluded.write
+	`
+	selectUserAccessQuery  = `SELECT topic, read, write FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) ORDER BY write DESC, read DESC, topic`
 	deleteAllAccessQuery   = `DELETE FROM user_access`
 	deleteUserAccessQuery  = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
 	deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
@@ -152,7 +166,7 @@ func NewManager(filename string, defaultRead, defaultWrite bool) (*Manager, erro
 	return manager, nil
 }
 
-// Authenticate checks username and password and returns a user if correct. The method
+// Authenticate checks username and password and returns a User if correct. The method
 // returns in constant-ish time, regardless of whether the user exists or the password is
 // correct or incorrect.
 func (a *Manager) Authenticate(username, password string) (*User, error) {
@@ -171,6 +185,8 @@ func (a *Manager) Authenticate(username, password string) (*User, error) {
 	return user, nil
 }
 
+// AuthenticateToken checks if the token exists and returns the associated User if it does.
+// The method sets the User.Token value to the token that was used for authentication.
 func (a *Manager) AuthenticateToken(token string) (*User, error) {
 	user, err := a.userByToken(token)
 	if err != nil {
@@ -180,9 +196,10 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
 	return user, nil
 }
 
+// CreateToken generates a random token for the given user and returns it. The token expires
+// after a fixed duration unless ExtendToken is called.
 func (a *Manager) CreateToken(user *User) (*Token, error) {
-	token := util.RandomString(tokenLength)
-	expires := time.Now().Add(userTokenExpiryDuration)
+	token, expires := util.RandomString(tokenLength), time.Now().Add(userTokenExpiryDuration)
 	if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires.Unix()); err != nil {
 		return nil, err
 	}
@@ -192,6 +209,7 @@ func (a *Manager) CreateToken(user *User) (*Token, error) {
 	}, nil
 }
 
+// ExtendToken sets the new expiry date for a token, thereby extending its use further into the future.
 func (a *Manager) ExtendToken(user *User) (*Token, error) {
 	newExpires := time.Now().Add(userTokenExpiryDuration)
 	if _, err := a.db.Exec(updateTokenExpiryQuery, newExpires.Unix(), user.Name, user.Token); err != nil {
@@ -203,6 +221,7 @@ func (a *Manager) ExtendToken(user *User) (*Token, error) {
 	}, nil
 }
 
+// RemoveToken deletes the token defined in User.Token
 func (a *Manager) RemoveToken(user *User) error {
 	if user.Token == "" {
 		return ErrUnauthorized
@@ -213,6 +232,7 @@ func (a *Manager) RemoveToken(user *User) error {
 	return nil
 }
 
+// RemoveExpiredTokens deletes all expired tokens from the database
 func (a *Manager) RemoveExpiredTokens() error {
 	if _, err := a.db.Exec(deleteExpiredTokensQuery, time.Now().Unix()); err != nil {
 		return err
@@ -370,20 +390,23 @@ func (a *Manager) Users() ([]*User, error) {
 		}
 		users = append(users, user)
 	}
-	everyone, err := a.everyoneUser()
-	if err != nil {
-		return nil, err
-	}
-	users = append(users, everyone)
+	/*sort.Slice(users, func(i, j int) bool {
+		if users[i].Role != users[j].Role {
+			return true
+		}
+		if users[i].Name == Everyone || users[j].Name == Everyone {
+			return users[i].Name != Everyone
+		} else if string(users[i].Role) < string(users[j].Role) {
+			return true
+		}
+		return users[i].Name < users[j].Name
+	})*/
 	return users, nil
 }
 
 // User returns the user with the given username if it exists, or ErrNotFound otherwise.
 // You may also pass Everyone to retrieve the anonymous user and its Grant list.
 func (a *Manager) User(username string) (*User, error) {
-	if username == Everyone {
-		return a.everyoneUser()
-	}
 	rows, err := a.db.Query(selectUserByNameQuery, username)
 	if err != nil {
 		return nil, err
@@ -446,19 +469,6 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	return user, nil
 }
 
-func (a *Manager) everyoneUser() (*User, error) {
-	grants, err := a.readGrants(Everyone)
-	if err != nil {
-		return nil, err
-	}
-	return &User{
-		Name:   Everyone,
-		Hash:   "",
-		Role:   RoleAnonymous,
-		Grants: grants,
-	}, nil
-}
-
 func (a *Manager) readGrants(username string) ([]Grant, error) {
 	rows, err := a.db.Query(selectUserAccessQuery, username)
 	if err != nil {

+ 4 - 4
user/manager_test.go

@@ -37,8 +37,8 @@ func TestSQLiteAuth_FullScenario_Default_DenyAll(t *testing.T) {
 	require.Equal(t, user.RoleUser, ben.Role)
 	require.Equal(t, []user.Grant{
 		{"mytopic", true, true},
-		{"readme", true, false},
 		{"writeme", false, true},
+		{"readme", true, false},
 		{"everyonewrite", false, false},
 	}, ben.Grants)
 
@@ -146,8 +146,8 @@ func TestSQLiteAuth_UserManagement(t *testing.T) {
 	require.Equal(t, user.RoleUser, ben.Role)
 	require.Equal(t, []user.Grant{
 		{"mytopic", true, true},
-		{"readme", true, false},
 		{"writeme", false, true},
+		{"readme", true, false},
 		{"everyonewrite", false, false},
 	}, ben.Grants)
 
@@ -157,12 +157,12 @@ func TestSQLiteAuth_UserManagement(t *testing.T) {
 	require.Equal(t, "", everyone.Hash)
 	require.Equal(t, user.RoleAnonymous, everyone.Role)
 	require.Equal(t, []user.Grant{
-		{"announcements", true, false},
 		{"everyonewrite", true, true},
+		{"announcements", true, false},
 	}, everyone.Grants)
 
 	// Ben: Before revoking
-	require.Nil(t, a.AllowAccess("ben", "mytopic", true, true))
+	require.Nil(t, a.AllowAccess("ben", "mytopic", true, true)) // Overwrite!
 	require.Nil(t, a.AllowAccess("ben", "readme", true, false))
 	require.Nil(t, a.AllowAccess("ben", "writeme", false, true))
 	require.Nil(t, a.Authorize(ben, "mytopic", user.PermissionRead))

+ 1 - 1
user/types.go

@@ -96,7 +96,7 @@ type Role string
 
 // User roles
 const (
-	RoleAdmin     = Role("admin")
+	RoleAdmin     = Role("admin") // Some queries have these values hardcoded!
 	RoleUser      = Role("user")
 	RoleAnonymous = Role("anonymous")
 )