Kaynağa Gözat

Kill existing subscribers when topic is reserved

binwiederhier 3 yıl önce
ebeveyn
işleme
bce71cb196
5 değiştirilmiş dosya ile 168 ekleme ve 35 silme
  1. 43 21
      server/server.go
  2. 9 5
      server/server_account.go
  3. 69 0
      server/server_account_test.go
  4. 35 9
      server/topic.go
  5. 12 0
      server/visitor.go

+ 43 - 21
server/server.go

@@ -38,11 +38,13 @@ import (
 TODO
 --
 
-- Reservation: Kill existing subscribers when topic is reserved (deadcade)
 - Rate limiting: Sensitive endpoints (account/login/change-password/...)
 - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
 - Reservation (UI): Ask for confirmation when removing reservation (deadcade)
 - Reservation icons (UI)
+- reservation table delete button: dialog "keep or delete messages?"
+- UI: Flickering upgrade banner when logging in
+- JS constants
 
 races:
 - v.user --> see publishSyncEventAsync() test
@@ -63,11 +65,6 @@ Limits & rate limiting:
 Make sure account endpoints make sense for admins
 
 
-UI:
--
-- reservation table delete button: dialog "keep or delete messages?"
-- flicker of upgrade banner
-- JS constants
 Sync:
 	- sync problems with "deleteAfter=0" and "displayName="
 
@@ -359,7 +356,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 			log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
 		}
 		w.Header().Set("Content-Type", "application/json")
-		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+		w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 		w.WriteHeader(httpErr.HTTPCode)
 		io.WriteString(w, httpErr.JSON()+"\n")
 	}
@@ -461,7 +458,7 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor)
 	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
 	if unifiedpush {
 		w.Header().Set("Content-Type", "application/json")
-		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+		w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 		_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
 		return err
 	}
@@ -538,7 +535,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
 		}
 	}
 	w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
-	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	if r.Method == http.MethodGet {
 		f, err := os.Open(file)
 		if err != nil {
@@ -969,14 +966,16 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 		}
 		return nil
 	}
-	w.Header().Set("Access-Control-Allow-Origin", "*")            // CORS, allow cross-origin requests
-	w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
+	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
+	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset!
 	if poll {
 		return s.sendOldMessages(topics, since, scheduled, v, sub)
 	}
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
-		subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
+		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
 	}
 	defer func() {
 		for i, subscriberID := range subscriberIDs {
@@ -991,6 +990,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 	}
 	for {
 		select {
+		case <-ctx.Done():
+			return nil
 		case <-r.Context().Done():
 			return nil
 		case <-time.After(s.config.KeepaliveInterval):
@@ -1033,8 +1034,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		return err
 	}
 	defer conn.Close()
+
+	// Subscription connections can be canceled externally, see topic.CancelSubscribers
+	subscriberContext, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	// Use errgroup to run WebSocket reader and writer in Go routines
 	var wlock sync.Mutex
-	g, ctx := errgroup.WithContext(context.Background())
+	g, gctx := errgroup.WithContext(context.Background())
+	g.Go(func() error {
+		<-subscriberContext.Done()
+		log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r))
+		conn.Close()
+		return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
+	})
 	g.Go(func() error {
 		pongWait := s.config.KeepaliveInterval + wsPongWait
 		conn.SetReadLimit(wsReadLimit)
@@ -1050,6 +1063,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 			if err != nil {
 				return err
 			}
+			select {
+			case <-gctx.Done():
+				return nil
+			default:
+			}
 		}
 	})
 	g.Go(func() error {
@@ -1064,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		}
 		for {
 			select {
-			case <-ctx.Done():
+			case <-gctx.Done():
 				return nil
 			case <-time.After(s.config.KeepaliveInterval):
 				v.Keepalive()
@@ -1085,13 +1103,13 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		}
 		return conn.WriteJSON(msg)
 	}
-	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	if poll {
 		return s.sendOldMessages(topics, since, scheduled, v, sub)
 	}
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
-		subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
+		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
 	}
 	defer func() {
 		for i, subscriberID := range subscriberIDs {
@@ -1193,11 +1211,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
 	if len(parts) < 2 {
 		return nil, errHTTPBadRequestTopicInvalid
 	}
-	topics, err := s.topicsFromIDs(parts[1])
-	if err != nil {
-		return nil, err
-	}
-	return topics[0], nil
+	return s.topicFromID(parts[1])
 }
 
 func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
@@ -1232,6 +1246,14 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
 	return topics, nil
 }
 
+func (s *Server) topicFromID(id string) (*topic, error) {
+	topics, err := s.topicsFromIDs(id)
+	if err != nil {
+		return nil, err
+	}
+	return topics[0], nil
+}
+
 func (s *Server) execManager() {
 	log.Debug("Manager: Starting")
 	defer log.Debug("Manager: Finished")

+ 9 - 5
server/server_account.go

@@ -2,7 +2,6 @@ package server
 
 import (
 	"encoding/json"
-	"errors"
 	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
@@ -331,6 +330,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
 	if v.user.Tier == nil {
 		return errHTTPUnauthorized
 	}
+	// CHeck if we are allowed to reserve this topic
 	if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil {
 		return errHTTPConflictTopicReserved
 	}
@@ -346,9 +346,16 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
 			return errHTTPTooManyRequestsLimitReservations
 		}
 	}
+	// Actually add the reservation
 	if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil {
 		return err
 	}
+	// Kill existing subscribers
+	t, err := s.topicFromID(req.Topic)
+	if err != nil {
+		return err
+	}
+	t.CancelSubscribers(v.user.ID)
 	return s.writeJSON(w, newSuccessResponse())
 }
 
@@ -402,13 +409,10 @@ func (s *Server) publishSyncEvent(v *visitor) error {
 		return nil
 	}
 	log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic)
-	topics, err := s.topicsFromIDs(v.user.SyncTopic)
+	syncTopic, err := s.topicFromID(v.user.SyncTopic)
 	if err != nil {
 		return err
-	} else if len(topics) == 0 {
-		return errors.New("cannot retrieve sync topic")
 	}
-	syncTopic := topics[0]
 	messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent})
 	if err != nil {
 		return err

+ 69 - 0
server/server_account_test.go

@@ -496,3 +496,72 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
 	rr = request(t, s, "POST", "/mytopic", `Howdy`, nil)
 	require.Equal(t, 403, rr.Code)
 }
+
+func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
+	conf := newTestConfigWithAuthFile(t)
+	conf.AuthDefault = user.PermissionReadWrite
+	conf.EnableSignup = true
+	s := newTestServer(t, conf)
+
+	// Create user with tier
+	rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
+	require.Equal(t, 200, rr.Code)
+
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code:              "pro",
+		MessagesLimit:     20,
+		ReservationsLimit: 2,
+	}))
+	require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
+
+	// Subscribe anonymously
+	anonCh, userCh := make(chan bool), make(chan bool)
+	go func() {
+		rr := request(t, s, "GET", "/mytopic/json", ``, nil)
+		require.Equal(t, 200, rr.Code)
+		messages := toMessages(t, rr.Body.String())
+		require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message!
+		require.Equal(t, "open", messages[0].Event)
+		require.Equal(t, "message before reservation", messages[1].Message)
+		anonCh <- true
+	}()
+
+	// Subscribe with user
+	go func() {
+		rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{
+			"Authorization": util.BasicAuth("phil", "mypass"),
+		})
+		require.Equal(t, 200, rr.Code)
+		messages := toMessages(t, rr.Body.String())
+		require.Equal(t, 3, len(messages))
+		require.Equal(t, "open", messages[0].Event)
+		require.Equal(t, "message before reservation", messages[1].Message)
+		require.Equal(t, "message after reservation", messages[2].Message)
+		userCh <- true
+	}()
+
+	// Publish message (before reservation)
+	time.Sleep(700 * time.Millisecond) // Wait for subscribers
+	rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
+	require.Equal(t, 200, rr.Code)
+	time.Sleep(700 * time.Millisecond) // Wait for subscribers to receive message
+
+	// Reserve a topic
+	rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
+		"Authorization": util.BasicAuth("phil", "mypass"),
+	})
+	require.Equal(t, 200, rr.Code)
+
+	// Everyone but phil should be killed
+	<-anonCh
+
+	// Publish a message
+	rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{
+		"Authorization": util.BasicAuth("phil", "mypass"),
+	})
+	require.Equal(t, 200, rr.Code)
+
+	// Kill user Go routine
+	s.topics["mytopic"].CancelSubscribers("<invalid>")
+	<-userCh
+}

+ 35 - 9
server/topic.go

@@ -10,10 +10,16 @@ import (
 // can publish a message
 type topic struct {
 	ID          string
-	subscribers map[int]subscriber
+	subscribers map[int]*topicSubscriber
 	mu          sync.Mutex
 }
 
+type topicSubscriber struct {
+	userID     string // User ID associated with this subscription, may be empty
+	subscriber subscriber
+	cancel     func()
+}
+
 // subscriber is a function that is called for every new message on a topic
 type subscriber func(v *visitor, msg *message) error
 
@@ -21,16 +27,20 @@ type subscriber func(v *visitor, msg *message) error
 func newTopic(id string) *topic {
 	return &topic{
 		ID:          id,
-		subscribers: make(map[int]subscriber),
+		subscribers: make(map[int]*topicSubscriber),
 	}
 }
 
 // Subscribe subscribes to this topic
-func (t *topic) Subscribe(s subscriber) int {
+func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
 	t.mu.Lock()
 	defer t.mu.Unlock()
 	subscriberID := rand.Int()
-	t.subscribers[subscriberID] = s
+	t.subscribers[subscriberID] = &topicSubscriber{
+		userID:     userID, // May be empty
+		subscriber: s,
+		cancel:     cancel,
+	}
 	return subscriberID
 }
 
@@ -56,7 +66,7 @@ func (t *topic) Publish(v *visitor, m *message) error {
 					if err := s(v, m); err != nil {
 						log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
 					}
-				}(s)
+				}(s.subscriber)
 			}
 		} else {
 			log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))
@@ -72,13 +82,29 @@ func (t *topic) SubscribersCount() int {
 	return len(t.subscribers)
 }
 
+// CancelSubscribers calls the cancel function for all subscribers, forcing
+func (t *topic) CancelSubscribers(exceptUserID string) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	for _, s := range t.subscribers {
+		if s.userID != exceptUserID {
+			log.Trace("Canceling subscriber %s", s.userID)
+			s.cancel()
+		}
+	}
+}
+
 // subscribersCopy returns a shallow copy of the subscribers map
-func (t *topic) subscribersCopy() map[int]subscriber {
+func (t *topic) subscribersCopy() map[int]*topicSubscriber {
 	t.mu.Lock()
 	defer t.mu.Unlock()
-	subscribers := make(map[int]subscriber)
-	for k, v := range t.subscribers {
-		subscribers[k] = v
+	subscribers := make(map[int]*topicSubscriber)
+	for k, sub := range t.subscribers {
+		subscribers[k] = &topicSubscriber{
+			userID:     sub.userID,
+			subscriber: sub.subscriber,
+			cancel:     sub.cancel,
+		}
 	}
 	return subscribers
 }

+ 12 - 0
server/visitor.go

@@ -228,12 +228,24 @@ func (v *visitor) ResetStats() {
 	}
 }
 
+// SetUser sets the visitors user to the given value
 func (v *visitor) SetUser(u *user.User) {
 	v.mu.Lock()
 	defer v.mu.Unlock()
 	v.user = u
 }
 
+// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
+// an empty string is returned.
+func (v *visitor) MaybeUserID() string {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	if v.user != nil {
+		return v.user.ID
+	}
+	return ""
+}
+
 func (v *visitor) Limits() *visitorLimits {
 	v.mu.Lock()
 	defer v.mu.Unlock()