Просмотр исходного кода

Only set rate visitor if allowed

binwiederhier 3 лет назад
Родитель
Сommit
bfc3983d06
4 измененных файлов с 149 добавлено и 15 удалено
  1. 55 13
      server/server.go
  2. 67 1
      server/server_test.go
  3. 1 0
      server/visitor.go
  4. 26 1
      user/manager.go

+ 55 - 13
server/server.go

@@ -112,7 +112,6 @@ const (
 	encodingBase64           = "base64"                  // Used mainly for binary UnifiedPush messages
 	jsonBodyBytesLimit       = 16384
 	unifiedPushTopicPrefix   = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
-	rateTopicsWildcard       = "*"  // Allows defining all topics in the request subscriber-rate-limited topics
 )
 
 // WebSocket constants
@@ -977,7 +976,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 		}
 		return nil
 	}
-	registerRateVisitors(topics, rateTopics, v)
+	if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
+		return err
+	}
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset!
 	if poll {
@@ -1113,7 +1114,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		}
 		return conn.WriteJSON(msg)
 	}
-	registerRateVisitors(topics, rateTopics, v)
+	if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
+		return err
+	}
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	if poll {
 		return s.sendOldMessages(topics, since, scheduled, v, sub)
@@ -1156,23 +1159,62 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
 	return
 }
 
-// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic
-// will be rate limited against the rate visitor instead of the publishing visitor.
+// maybeSetRateVisitors sets the rate visitor on a topic (v.SetRateVisitor), indicating that all messages published
+// to that topic will be rate limited against the rate visitor instead of the publishing visitor.
+//
+// Setting the rate visitor is ony allowed if
+// - auth-file is not set (everything is open by default)
+// - the topic is reserved, and v.user is the owner
+// - the topic is not reserved, and v.user has write access
 //
 // Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
 // until the Android app will send the "Rate-Topics" header.
-func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) {
-	if len(rateTopics) == 1 && rateTopics[0] == rateTopicsWildcard {
-		for _, t := range topics {
-			t.SetRateVisitor(v)
+func (s *Server) maybeSetRateVisitors(r *http.Request, v *visitor, topics []*topic, rateTopics []string) error {
+	// Make a list of topics that we'll actually set the RateVisitor on
+	eligibleRateTopics := make([]*topic, 0)
+	for _, t := range topics {
+		if strings.HasPrefix(t.ID, unifiedPushTopicPrefix) || util.Contains(rateTopics, t.ID) {
+			eligibleRateTopics = append(eligibleRateTopics, t)
 		}
-	} else {
-		for _, t := range topics {
-			if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) {
-				t.SetRateVisitor(v)
+	}
+	if len(eligibleRateTopics) == 0 {
+		return nil
+	}
+
+	// If access controls are turned off, v has access to everything, and we can set the rate visitor
+	if s.userManager == nil {
+		return s.setRateVisitors(r, v, eligibleRateTopics)
+	}
+
+	// If access controls are enabled, only set rate visitor if
+	// - topic is reserved, and v.user is the owner
+	// - topic is not reserved, and v.user has write access
+	writableRateTopics := make([]*topic, 0)
+	for _, t := range topics {
+		ownerUserID, err := s.userManager.ReservationOwner(t.ID)
+		if err != nil {
+			return err
+		}
+		if ownerUserID == "" {
+			if err := s.userManager.Authorize(v.User(), t.ID, user.PermissionWrite); err == nil {
+				writableRateTopics = append(writableRateTopics, t)
 			}
+		} else if ownerUserID == v.MaybeUserID() {
+			writableRateTopics = append(writableRateTopics, t)
 		}
 	}
+	return s.setRateVisitors(r, v, writableRateTopics)
+}
+
+func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topic) error {
+	for _, t := range rateTopics {
+		logvr(v, r).
+			Tag(tagSubscribe).
+			Field("message_topic", t.ID).
+			Debug("Setting visitor as rate visitor for topic %s", t.ID)
+		t.SetRateVisitor(v)
+	}
+	return nil
 }
 
 // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the

+ 67 - 1
server/server_test.go

@@ -2040,7 +2040,7 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
 		r.RemoteAddr = "1.2.3.4"
 	}
 	rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{
-		"rate-topics": "*",
+		"rate-topics": "mytopic",
 	}, subscriberFn)
 	require.Equal(t, 200, rr.Code)
 	require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String())
@@ -2065,6 +2065,72 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
 	require.Nil(t, s.visitors["ip:1.2.3.4"])
 }
 
+func TestServer_SubscriberRateLimiting_ProtectedTopics(t *testing.T) {
+	c := newTestConfigWithAuthFile(t)
+	c.AuthDefault = user.PermissionDenyAll
+	s := newTestServer(t, c)
+
+	// Create some ACLs
+	require.Nil(t, s.userManager.AddTier(&user.Tier{
+		Code:         "test",
+		MessageLimit: 5,
+	}))
+	require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeTier("ben", "test"))
+	require.Nil(t, s.userManager.AllowAccess("ben", "announcements", user.PermissionReadWrite))
+	require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
+	require.Nil(t, s.userManager.AllowAccess(user.Everyone, "public_topic", user.PermissionReadWrite))
+
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeTier("phil", "test"))
+	require.Nil(t, s.userManager.AddReservation("phil", "reserved-for-phil", user.PermissionReadWrite))
+
+	// Set rate visitor as user "phil" on topic
+	// - "reserved-for-phil": Allowed, because I am the owner
+	// - "public_topic": Allowed, because it has read-write permissions for everyone
+	// - "announcements": NOT allowed, because it has read-only permissions for everyone
+	rr := request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+		"Rate-Topics":   "reserved-for-phil,public_topic,announcements",
+	})
+	require.Equal(t, 200, rr.Code)
+	require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
+	require.Equal(t, "phil", s.topics["public_topic"].rateVisitor.user.Name)
+	require.Nil(t, s.topics["announcements"].rateVisitor)
+
+	// Set rate visitor as user "ben" on topic
+	// - "reserved-for-phil": NOT allowed, because I am not the owner
+	// - "public_topic": Allowed, because it has read-write permissions for everyone
+	// - "announcements": Allowed, because I have read-write permissions
+	rr = request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
+		"Authorization": util.BasicAuth("ben", "ben"),
+		"Rate-Topics":   "reserved-for-phil,public_topic,announcements",
+	})
+	require.Equal(t, 200, rr.Code)
+	require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
+	require.Equal(t, "ben", s.topics["public_topic"].rateVisitor.user.Name)
+	require.Equal(t, "ben", s.topics["announcements"].rateVisitor.user.Name)
+}
+
+func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *testing.T) {
+	c := newTestConfigWithAuthFile(t)
+	c.AuthDefault = user.PermissionReadWrite
+	s := newTestServer(t, c)
+
+	// Create some ACLs
+	require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
+
+	// Set rate visitor as ip:1.2.3.4 on topic
+	// - "up1234": Allowed, because no ACLs and nobody owns the topic
+	// - "announcements": NOT allowed, because it has read-only permissions for everyone
+	rr := request(t, s, "GET", "/up1234,announcements/json?poll=1", "", nil, func(r *http.Request) {
+		r.RemoteAddr = "1.2.3.4"
+	})
+	require.Equal(t, 200, rr.Code)
+	require.Equal(t, "1.2.3.4", s.topics["up1234"].rateVisitor.ip.String())
+	require.Nil(t, s.topics["announcements"].rateVisitor)
+}
+
 func newTestConfig(t *testing.T) *Config {
 	conf := NewConfig()
 	conf.BaseURL = "http://127.0.0.1:12345"

+ 1 - 0
server/visitor.go

@@ -141,6 +141,7 @@ func (v *visitor) Context() log.Context {
 func (v *visitor) contextNoLock() log.Context {
 	info := v.infoLightNoLock()
 	fields := log.Context{
+		"visitor_id":                     visitorID(v.ip, v.user),
 		"visitor_ip":                     v.ip.String(),
 		"visitor_messages":               info.Stats.Messages,
 		"visitor_messages_limit":         info.Limits.MessageLimit,

+ 26 - 1
user/manager.go

@@ -201,7 +201,14 @@ const (
 	selectUserReservationsCountQuery = `
 		SELECT COUNT(*)
 		FROM user_access
-		WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
+		WHERE user_id = owner_user_id 
+		  AND owner_user_id = (SELECT id FROM user WHERE user = ?)
+	`
+	selectUserReservationsOwnerQuery = `
+		SELECT owner_user_id
+		FROM user_access
+		WHERE topic = ?
+		  AND user_id = owner_user_id
 	`
 	selectUserHasReservationQuery = `
 		SELECT COUNT(*)
@@ -1025,6 +1032,24 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
 	return count, nil
 }
 
+// ReservationOwner returns user ID of the user that owns this topic, or an
+// empty string if it's not owned by anyone
+func (a *Manager) ReservationOwner(topic string) (string, error) {
+	rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic)
+	if err != nil {
+		return "", err
+	}
+	defer rows.Close()
+	if !rows.Next() {
+		return "", nil
+	}
+	var ownerUserID string
+	if err := rows.Scan(&ownerUserID); err != nil {
+		return "", err
+	}
+	return ownerUserID, nil
+}
+
 // ChangePassword changes a user's password
 func (a *Manager) ChangePassword(username, password string) error {
 	hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)