Преглед изворни кода

Merge pull request #713 from dropdevrahul/issue-712

fix: removes an issue with topic.Subscribe function not checking dupl…
Philipp C. Heckel пре 2 година
родитељ
комит
9eb94a565d
2 измењених фајлова са 47 додато и 3 уклоњено
  1. 17 2
      server/topic.go
  2. 30 1
      server/topic_test.go

+ 17 - 2
server/topic.go

@@ -1,11 +1,12 @@
 package server
 package server
 
 
 import (
 import (
-	"heckel.io/ntfy/log"
-	"heckel.io/ntfy/util"
 	"math/rand"
 	"math/rand"
 	"sync"
 	"sync"
 	"time"
 	"time"
+
+	"heckel.io/ntfy/log"
+	"heckel.io/ntfy/util"
 )
 )
 
 
 const (
 const (
@@ -45,9 +46,23 @@ func newTopic(id string) *topic {
 
 
 // Subscribe subscribes to this topic
 // Subscribe subscribes to this topic
 func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
 func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
+	max_retries := 5
+	retries := 1
 	t.mu.Lock()
 	t.mu.Lock()
 	defer t.mu.Unlock()
 	defer t.mu.Unlock()
+
 	subscriberID := rand.Int()
 	subscriberID := rand.Int()
+	// simple check for existing id in maps
+	for {
+		_, ok := t.subscribers[subscriberID]
+		if ok && retries <= max_retries {
+			subscriberID = rand.Int()
+			retries++
+		} else {
+			break
+		}
+	}
+
 	t.subscribers[subscriberID] = &topicSubscriber{
 	t.subscribers[subscriberID] = &topicSubscriber{
 		userID:     userID, // May be empty
 		userID:     userID, // May be empty
 		subscriber: s,
 		subscriber: s,

+ 30 - 1
server/topic_test.go

@@ -1,10 +1,12 @@
 package server
 package server
 
 
 import (
 import (
-	"github.com/stretchr/testify/require"
+	"math/rand"
 	"sync/atomic"
 	"sync/atomic"
 	"testing"
 	"testing"
 	"time"
 	"time"
+
+	"github.com/stretchr/testify/require"
 )
 )
 
 
 func TestTopic_CancelSubscribers(t *testing.T) {
 func TestTopic_CancelSubscribers(t *testing.T) {
@@ -39,3 +41,30 @@ func TestTopic_Keepalive(t *testing.T) {
 	require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
 	require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
 	require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
 	require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
 }
 }
+
+func TestTopic_Subscribe_duplicateID(t *testing.T) {
+	t.Parallel()
+
+	to := newTopic("mytopic")
+
+	// fix random seed to force same number generation
+	rand.Seed(1)
+	a := rand.Int()
+	to.subscribers[a] = &topicSubscriber{
+		userID:     "a",
+		subscriber: nil,
+		cancel:     func() {},
+	}
+
+	subFn := func(v *visitor, msg *message) error {
+		return nil
+	}
+
+	// force rand.Int to generate the same id once more
+	rand.Seed(1)
+	id := to.Subscribe(subFn, "b", func() {})
+	res := to.subscribers[id]
+
+	require.False(t, id == a)
+	require.True(t, res.userID == "b")
+}