Explorar el Código

ACLs and underscores, resolves #840

binwiederhier hace 2 años
padre
commit
a5f0670f7f
Se han modificado 2 ficheros con 248 adiciones y 17 borrados
  1. 47 13
      user/manager.go
  2. 201 4
      user/manager_test.go

+ 47 - 13
user/manager.go

@@ -160,7 +160,7 @@ const (
 		SELECT read, write
 		SELECT read, write
 		FROM user_access a
 		FROM user_access a
 		JOIN user u ON u.id = a.user_id
 		JOIN user u ON u.id = a.user_id
-		WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic
+		WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
 		ORDER BY u.user DESC
 		ORDER BY u.user DESC
 	`
 	`
 
 
@@ -235,7 +235,7 @@ const (
 	selectOtherAccessCountQuery = `
 	selectOtherAccessCountQuery = `
 		SELECT COUNT(*)
 		SELECT COUNT(*)
 		FROM user_access
 		FROM user_access
-		WHERE (topic = ? OR ? LIKE topic)
+		WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
 		  AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
 		  AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
 	`
 	`
 	deleteAllAccessQuery  = `DELETE FROM user_access`
 	deleteAllAccessQuery  = `DELETE FROM user_access`
@@ -312,7 +312,7 @@ const (
 
 
 // Schema management queries
 // Schema management queries
 const (
 const (
-	currentSchemaVersion     = 4
+	currentSchemaVersion     = 5
 	insertSchemaVersion      = `INSERT INTO schemaVersion VALUES (1, ?)`
 	insertSchemaVersion      = `INSERT INTO schemaVersion VALUES (1, ?)`
 	updateSchemaVersion      = `UPDATE schemaVersion SET version = ? WHERE id = 1`
 	updateSchemaVersion      = `UPDATE schemaVersion SET version = ? WHERE id = 1`
 	selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
 	selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
@@ -422,6 +422,11 @@ const (
 			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
 			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
 		);
 		);
 	`
 	`
+
+	// 4 -> 5
+	migrate4To5UpdateQueries = `
+		UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
+	`
 )
 )
 
 
 var (
 var (
@@ -429,6 +434,7 @@ var (
 		1: migrateFrom1,
 		1: migrateFrom1,
 		2: migrateFrom2,
 		2: migrateFrom2,
 		3: migrateFrom3,
 		3: migrateFrom3,
+		4: migrateFrom4,
 	}
 	}
 )
 )
 
 
@@ -1123,7 +1129,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
 			return nil, err
 			return nil, err
 		}
 		}
 		reservations = append(reservations, Reservation{
 		reservations = append(reservations, Reservation{
-			Topic:    topic,
+			Topic:    unescapeUnderscore(topic),
 			Owner:    NewPermission(ownerRead, ownerWrite),
 			Owner:    NewPermission(ownerRead, ownerWrite),
 			Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
 			Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
 		})
 		})
@@ -1133,7 +1139,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
 
 
 // HasReservation returns true if the given topic access is owned by the user
 // HasReservation returns true if the given topic access is owned by the user
 func (a *Manager) HasReservation(username, topic string) (bool, error) {
 func (a *Manager) HasReservation(username, topic string) (bool, error) {
-	rows, err := a.db.Query(selectUserHasReservationQuery, username, topic)
+	rows, err := a.db.Query(selectUserHasReservationQuery, username, escapeUnderscore(topic))
 	if err != nil {
 	if err != nil {
 		return false, err
 		return false, err
 	}
 	}
@@ -1168,7 +1174,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
 // ReservationOwner returns user ID of the user that owns this topic, or an
 // ReservationOwner returns user ID of the user that owns this topic, or an
 // empty string if it's not owned by anyone
 // empty string if it's not owned by anyone
 func (a *Manager) ReservationOwner(topic string) (string, error) {
 func (a *Manager) ReservationOwner(topic string) (string, error) {
-	rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic)
+	rows, err := a.db.Query(selectUserReservationsOwnerQuery, escapeUnderscore(topic))
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
@@ -1263,7 +1269,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
 	if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
 	if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
 		return ErrInvalidArgument
 		return ErrInvalidArgument
 	}
 	}
-	rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username)
+	rows, err := a.db.Query(selectOtherAccessCountQuery, escapeUnderscore(topic), escapeUnderscore(topic), username)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -1328,10 +1334,10 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
 		return err
 		return err
 	}
 	}
 	defer tx.Rollback()
 	defer tx.Rollback()
-	if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil {
+	if _, err := tx.Exec(upsertUserAccessQuery, username, escapeUnderscore(topic), true, true, username, username); err != nil {
 		return err
 		return err
 	}
 	}
-	if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
+	if _, err := tx.Exec(upsertUserAccessQuery, Everyone, escapeUnderscore(topic), everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
 		return err
 		return err
 	}
 	}
 	return tx.Commit()
 	return tx.Commit()
@@ -1354,10 +1360,10 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
 	}
 	}
 	defer tx.Rollback()
 	defer tx.Rollback()
 	for _, topic := range topics {
 	for _, topic := range topics {
-		if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil {
+		if _, err := tx.Exec(deleteTopicAccessQuery, username, username, escapeUnderscore(topic)); err != nil {
 			return err
 			return err
 		}
 		}
-		if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil {
+		if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, escapeUnderscore(topic)); err != nil {
 			return err
 			return err
 		}
 		}
 	}
 	}
@@ -1484,12 +1490,24 @@ func (a *Manager) Close() error {
 	return a.db.Close()
 	return a.db.Close()
 }
 }
 
 
+// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
+// and escapes '_', assuming '\' as escape character.
 func toSQLWildcard(s string) string {
 func toSQLWildcard(s string) string {
-	return strings.ReplaceAll(s, "*", "%")
+	return escapeUnderscore(strings.ReplaceAll(s, "*", "%"))
 }
 }
 
 
+// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*',
+// and removes the '\_' escape character.
 func fromSQLWildcard(s string) string {
 func fromSQLWildcard(s string) string {
-	return strings.ReplaceAll(s, "%", "*")
+	return strings.ReplaceAll(unescapeUnderscore(s), "%", "*")
+}
+
+func escapeUnderscore(s string) string {
+	return strings.ReplaceAll(s, "_", "\\_")
+}
+
+func unescapeUnderscore(s string) string {
+	return strings.ReplaceAll(s, "\\_", "_")
 }
 }
 
 
 func runStartupQueries(db *sql.DB, startupQueries string) error {
 func runStartupQueries(db *sql.DB, startupQueries string) error {
@@ -1627,6 +1645,22 @@ func migrateFrom3(db *sql.DB) error {
 	return tx.Commit()
 	return tx.Commit()
 }
 }
 
 
+func migrateFrom4(db *sql.DB) error {
+	log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
+	tx, err := db.Begin()
+	if err != nil {
+		return err
+	}
+	defer tx.Rollback()
+	if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
+		return err
+	}
+	if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
+		return err
+	}
+	return tx.Commit()
+}
+
 func nullString(s string) sql.NullString {
 func nullString(s string) sql.NullString {
 	if s == "" {
 	if s == "" {
 		return sql.NullString{}
 		return sql.NullString{}

+ 201 - 4
user/manager_test.go

@@ -330,7 +330,7 @@ func TestManager_Reservations(t *testing.T) {
 	a := newTestManager(t, PermissionDenyAll)
 	a := newTestManager(t, PermissionDenyAll)
 	require.Nil(t, a.AddUser("phil", "phil", RoleUser))
 	require.Nil(t, a.AddUser("phil", "phil", RoleUser))
 	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
 	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
-	require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
+	require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
 	require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
 	require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
 	require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
 	require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
 
 
@@ -343,7 +343,7 @@ func TestManager_Reservations(t *testing.T) {
 		Everyone: PermissionRead,
 		Everyone: PermissionRead,
 	}, reservations[0])
 	}, reservations[0])
 	require.Equal(t, Reservation{
 	require.Equal(t, Reservation{
-		Topic:    "ztopic",
+		Topic:    "ztopic_",
 		Owner:    PermissionReadWrite,
 		Owner:    PermissionReadWrite,
 		Everyone: PermissionDenyAll,
 		Everyone: PermissionDenyAll,
 	}, reservations[1])
 	}, reservations[1])
@@ -352,6 +352,14 @@ func TestManager_Reservations(t *testing.T) {
 	require.Nil(t, err)
 	require.Nil(t, err)
 	require.True(t, b)
 	require.True(t, b)
 
 
+	b, err = a.HasReservation("ben", "ztopic_")
+	require.Nil(t, err)
+	require.True(t, b)
+
+	b, err = a.HasReservation("ben", "ztopicX") // _ != X (used to be a SQL wildcard issue)
+	require.Nil(t, err)
+	require.False(t, b)
+
 	b, err = a.HasReservation("notben", "readme")
 	b, err = a.HasReservation("notben", "readme")
 	require.Nil(t, err)
 	require.Nil(t, err)
 	require.False(t, b)
 	require.False(t, b)
@@ -371,11 +379,17 @@ func TestManager_Reservations(t *testing.T) {
 	err = a.AllowReservation("phil", "readme")
 	err = a.AllowReservation("phil", "readme")
 	require.Equal(t, errTopicOwnedByOthers, err)
 	require.Equal(t, errTopicOwnedByOthers, err)
 
 
+	err = a.AllowReservation("phil", "ztopic_")
+	require.Equal(t, errTopicOwnedByOthers, err)
+
+	err = a.AllowReservation("phil", "ztopicX")
+	require.Nil(t, err)
+
 	err = a.AllowReservation("phil", "not-reserved")
 	err = a.AllowReservation("phil", "not-reserved")
 	require.Nil(t, err)
 	require.Nil(t, err)
 
 
 	// Now remove them again
 	// Now remove them again
-	require.Nil(t, a.RemoveReservations("ben", "ztopic", "readme"))
+	require.Nil(t, a.RemoveReservations("ben", "ztopic_", "readme"))
 
 
 	count, err = a.ReservationsCount("ben")
 	count, err = a.ReservationsCount("ben")
 	require.Nil(t, err)
 	require.Nil(t, err)
@@ -978,7 +992,44 @@ func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) {
 	require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890"))
 	require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890"))
 }
 }
 
 
-func TestSqliteCache_Migration_From1(t *testing.T) {
+func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) {
+	f := filepath.Join(t.TempDir(), "user.db")
+	a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
+	require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead))
+	require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead))
+	require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead))
+	require.Nil(t, a.Authorize(nil, "__allowed_", PermissionRead))
+	require.Nil(t, a.Authorize(nil, "_allowed_", PermissionRead)) // The "%" in "%\_" matches the first "_"
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "notallowed", PermissionRead))
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "_notallowed", PermissionRead))
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "__notallowed", PermissionRead))
+}
+
+func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) {
+	f := filepath.Join(t.TempDir(), "user.db")
+	a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
+	require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite))
+	require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead))
+	require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite))
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead))
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionWrite))
+}
+
+func TestToFromSQLWildcard(t *testing.T) {
+	require.Equal(t, "up%", toSQLWildcard("up*"))
+	require.Equal(t, "up\\_%", toSQLWildcard("up_*"))
+	require.Equal(t, "foo", toSQLWildcard("foo"))
+
+	require.Equal(t, "up*", fromSQLWildcard("up%"))
+	require.Equal(t, "up_*", fromSQLWildcard("up\\_%"))
+	require.Equal(t, "foo", fromSQLWildcard("foo"))
+
+	require.Equal(t, "up*", fromSQLWildcard(toSQLWildcard("up*")))
+	require.Equal(t, "up_*", fromSQLWildcard(toSQLWildcard("up_*")))
+	require.Equal(t, "foo", fromSQLWildcard(toSQLWildcard("foo")))
+}
+
+func TestMigrationFrom1(t *testing.T) {
 	filename := filepath.Join(t.TempDir(), "user.db")
 	filename := filepath.Join(t.TempDir(), "user.db")
 	db, err := sql.Open("sqlite3", filename)
 	db, err := sql.Open("sqlite3", filename)
 	require.Nil(t, err)
 	require.Nil(t, err)
@@ -1063,6 +1114,152 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
 	require.Equal(t, PermissionRead, everyoneGrants[0].Allow)
 	require.Equal(t, PermissionRead, everyoneGrants[0].Allow)
 }
 }
 
 
+func TestMigrationFrom4(t *testing.T) {
+	filename := filepath.Join(t.TempDir(), "user.db")
+	db, err := sql.Open("sqlite3", filename)
+	require.Nil(t, err)
+
+	// Create "version 4" schema
+	_, err = db.Exec(`
+		BEGIN;
+		CREATE TABLE IF NOT EXISTS tier (
+			id TEXT PRIMARY KEY,
+			code TEXT NOT NULL,
+			name TEXT NOT NULL,
+			messages_limit INT NOT NULL,
+			messages_expiry_duration INT NOT NULL,
+			emails_limit INT NOT NULL,
+			calls_limit INT NOT NULL,
+			reservations_limit INT NOT NULL,
+			attachment_file_size_limit INT NOT NULL,
+			attachment_total_size_limit INT NOT NULL,
+			attachment_expiry_duration INT NOT NULL,
+			attachment_bandwidth_limit INT NOT NULL,
+			stripe_monthly_price_id TEXT,
+			stripe_yearly_price_id TEXT
+		);
+		CREATE UNIQUE INDEX idx_tier_code ON tier (code);
+		CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
+		CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
+		CREATE TABLE IF NOT EXISTS user (
+		    id TEXT PRIMARY KEY,
+			tier_id TEXT,
+			user TEXT NOT NULL,
+			pass TEXT NOT NULL,
+			role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
+			prefs JSON NOT NULL DEFAULT '{}',
+			sync_topic TEXT NOT NULL,
+			stats_messages INT NOT NULL DEFAULT (0),
+			stats_emails INT NOT NULL DEFAULT (0),
+			stats_calls INT NOT NULL DEFAULT (0),
+			stripe_customer_id TEXT,
+			stripe_subscription_id TEXT,
+			stripe_subscription_status TEXT,
+			stripe_subscription_interval TEXT,
+			stripe_subscription_paid_until INT,
+			stripe_subscription_cancel_at INT,
+			created INT NOT NULL,
+			deleted INT,
+		    FOREIGN KEY (tier_id) REFERENCES tier (id)
+		);
+		CREATE UNIQUE INDEX idx_user ON user (user);
+		CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
+		CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
+		CREATE TABLE IF NOT EXISTS user_access (
+			user_id TEXT NOT NULL,
+			topic TEXT NOT NULL,
+			read INT NOT NULL,
+			write INT NOT NULL,
+			owner_user_id INT,
+			PRIMARY KEY (user_id, topic),
+			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
+		    FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
+		);
+		CREATE TABLE IF NOT EXISTS user_token (
+			user_id TEXT NOT NULL,
+			token TEXT NOT NULL,
+			label TEXT NOT NULL,
+			last_access INT NOT NULL,
+			last_origin TEXT NOT NULL,
+			expires INT NOT NULL,
+			PRIMARY KEY (user_id, token),
+			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
+		);
+		CREATE TABLE IF NOT EXISTS user_phone (
+			user_id TEXT NOT NULL,
+			phone_number TEXT NOT NULL,
+			PRIMARY KEY (user_id, phone_number),
+			FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
+		);
+		CREATE TABLE IF NOT EXISTS schemaVersion (
+			id INT PRIMARY KEY,
+			version INT NOT NULL
+		);
+		INSERT INTO user (id, user, pass, role, sync_topic, created)
+		VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
+		ON CONFLICT (id) DO NOTHING;
+		INSERT INTO schemaVersion (id, version) VALUES (1, 4);		
+		COMMIT;
+	`)
+	require.Nil(t, err)
+
+	// Insert a few ACL entries
+	_, err = db.Exec(`
+		BEGIN;
+		INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'mytopic_', 1, 1);
+		INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'up%', 1, 1);
+		INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'down_%', 1, 1);
+		COMMIT;	
+	`)
+	require.Nil(t, err)
+
+	// Create manager to trigger migration
+	a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
+	checkSchemaVersion(t, a.db)
+
+	// Add another
+	require.Nil(t, a.AllowAccess(Everyone, "left_*", PermissionReadWrite))
+
+	// Check "external view" of grants
+	everyoneGrants, err := a.Grants(Everyone)
+	require.Nil(t, err)
+
+	require.Equal(t, 4, len(everyoneGrants))
+	require.Equal(t, "down_*", everyoneGrants[0].TopicPattern)
+	require.Equal(t, "left_*", everyoneGrants[1].TopicPattern)
+	require.Equal(t, "mytopic_", everyoneGrants[2].TopicPattern)
+	require.Equal(t, "up*", everyoneGrants[3].TopicPattern)
+
+	// Check they are stored correctly in the database
+	rows, err := db.Query(`SELECT topic FROM user_access WHERE user_id = 'u_everyone' ORDER BY topic`)
+	require.Nil(t, err)
+	topicPatterns := make([]string, 0)
+	for rows.Next() {
+		var topicPattern string
+		require.Nil(t, rows.Scan(&topicPattern))
+		topicPatterns = append(topicPatterns, topicPattern)
+	}
+	require.Nil(t, rows.Close())
+	require.Equal(t, 4, len(topicPatterns))
+	require.Equal(t, "down\\_%", topicPatterns[0])
+	require.Equal(t, "left\\_%", topicPatterns[1])
+	require.Equal(t, "mytopic\\_", topicPatterns[2])
+	require.Equal(t, "up%", topicPatterns[3])
+
+	// Check that ACL works as excepted
+	require.Nil(t, a.Authorize(nil, "down_123", PermissionRead))
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "downX123", PermissionRead))
+
+	require.Nil(t, a.Authorize(nil, "left_abc", PermissionRead))
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "leftX123", PermissionRead))
+
+	require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead))
+	require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead))
+
+	require.Nil(t, a.Authorize(nil, "up123", PermissionRead))
+	require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters
+}
+
 func checkSchemaVersion(t *testing.T, db *sql.DB) {
 func checkSchemaVersion(t *testing.T, db *sql.DB) {
 	rows, err := db.Query(`SELECT version FROM schemaVersion`)
 	rows, err := db.Query(`SELECT version FROM schemaVersion`)
 	require.Nil(t, err)
 	require.Nil(t, err)