瀏覽代碼

Fix previous fix

binwiederhier 2 年之前
父節點
當前提交
f58c1e4c84
共有 5 個文件被更改,包括 38 次插入43 次删除
  1. 29 33
      client/client.go
  2. 1 1
      client/client_test.go
  3. 0 1
      cmd/app.go
  4. 0 4
      cmd/publish.go
  5. 8 4
      cmd/subscribe.go

+ 29 - 33
client/client.go

@@ -11,23 +11,25 @@ import (
 	"heckel.io/ntfy/util"
 	"io"
 	"net/http"
+	"regexp"
 	"strings"
 	"sync"
 	"time"
 )
 
-// Event type constants
 const (
-	MessageEvent     = "message"
-	KeepaliveEvent   = "keepalive"
-	OpenEvent        = "open"
-	PollRequestEvent = "poll_request"
+	// MessageEvent identifies a message event
+	MessageEvent = "message"
 )
 
 const (
 	maxResponseBytes = 4096
 )
 
+var (
+	topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go
+)
+
 // Client is the ntfy client that can be used to publish and subscribe to ntfy topics
 type Client struct {
 	Messages      chan *Message
@@ -96,7 +98,10 @@ func (c *Client) Publish(topic, message string, options ...PublishOption) (*Mess
 // To pass title, priority and tags, check out WithTitle, WithPriority, WithTagsList, WithDelay, WithNoCache,
 // WithNoFirebase, and the generic WithHeader.
 func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishOption) (*Message, error) {
-	topicURL := c.expandTopicURL(topic)
+	topicURL, err := c.expandTopicURL(topic)
+	if err != nil {
+		return nil, err
+	}
 	req, err := http.NewRequest("POST", topicURL, body)
 	if err != nil {
 		return nil, err
@@ -136,11 +141,14 @@ func (c *Client) PublishReader(topic string, body io.Reader, options ...PublishO
 // By default, all messages will be returned, but you can change this behavior using a SubscribeOption.
 // See WithSince, WithSinceAll, WithSinceUnixTime, WithScheduled, and the generic WithQueryParam.
 func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, error) {
+	topicURL, err := c.expandTopicURL(topic)
+	if err != nil {
+		return nil, err
+	}
 	ctx := context.Background()
 	messages := make([]*Message, 0)
 	msgChan := make(chan *Message)
 	errChan := make(chan error)
-	topicURL := c.expandTopicURL(topic)
 	log.Debug("%s Polling from topic", util.ShortTopicURL(topicURL))
 	options = append(options, WithPoll())
 	go func() {
@@ -169,15 +177,18 @@ func (c *Client) Poll(topic string, options ...SubscribeOption) ([]*Message, err
 // Example:
 //
 //	c := client.New(client.NewConfig())
-//	subscriptionID := c.Subscribe("mytopic")
+//	subscriptionID, _ := c.Subscribe("mytopic")
 //	for m := range c.Messages {
 //	  fmt.Printf("New message: %s", m.Message)
 //	}
-func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
+func (c *Client) Subscribe(topic string, options ...SubscribeOption) (string, error) {
+	topicURL, err := c.expandTopicURL(topic)
+	if err != nil {
+		return "", err
+	}
 	c.mu.Lock()
 	defer c.mu.Unlock()
 	subscriptionID := util.RandomString(10)
-	topicURL := c.expandTopicURL(topic)
 	log.Debug("%s Subscribing to topic", util.ShortTopicURL(topicURL))
 	ctx, cancel := context.WithCancel(context.Background())
 	c.subscriptions[subscriptionID] = &subscription{
@@ -186,7 +197,7 @@ func (c *Client) Subscribe(topic string, options ...SubscribeOption) string {
 		cancel:   cancel,
 	}
 	go handleSubscribeConnLoop(ctx, c.Messages, topicURL, subscriptionID, options...)
-	return subscriptionID
+	return subscriptionID, nil
 }
 
 // Unsubscribe unsubscribes from a topic that has been previously subscribed to using the unique
@@ -202,31 +213,16 @@ func (c *Client) Unsubscribe(subscriptionID string) {
 	sub.cancel()
 }
 
-// UnsubscribeAll unsubscribes from a topic that has been previously subscribed with Subscribe.
-// If there are multiple subscriptions matching the topic, all of them are unsubscribed from.
-//
-// A topic can be either a full URL (e.g. https://myhost.lan/mytopic), a short URL which is then prepended https://
-// (e.g. myhost.lan -> https://myhost.lan), or a short name which is expanded using the default host in the
-// config (e.g. mytopic -> https://ntfy.sh/mytopic).
-func (c *Client) UnsubscribeAll(topic string) {
-	c.mu.Lock()
-	defer c.mu.Unlock()
-	topicURL := c.expandTopicURL(topic)
-	for _, sub := range c.subscriptions {
-		if sub.topicURL == topicURL {
-			delete(c.subscriptions, sub.ID)
-			sub.cancel()
-		}
-	}
-}
-
-func (c *Client) expandTopicURL(topic string) string {
+func (c *Client) expandTopicURL(topic string) (string, error) {
 	if strings.HasPrefix(topic, "http://") || strings.HasPrefix(topic, "https://") {
-		return topic
+		return topic, nil
 	} else if strings.Contains(topic, "/") {
-		return fmt.Sprintf("https://%s", topic)
+		return fmt.Sprintf("https://%s", topic), nil
+	}
+	if !topicRegex.MatchString(topic) {
+		return "", fmt.Errorf("invalid topic name: %s", topic)
 	}
-	return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic)
+	return fmt.Sprintf("%s/%s", c.config.DefaultHost, topic), nil
 }
 
 func handleSubscribeConnLoop(ctx context.Context, msgChan chan *Message, topicURL, subcriptionID string, options ...SubscribeOption) {

+ 1 - 1
client/client_test.go

@@ -21,7 +21,7 @@ func TestClient_Publish_Subscribe(t *testing.T) {
 	defer test.StopServer(t, s, port)
 	c := client.New(newTestConfig(port))
 
-	subscriptionID := c.Subscribe("mytopic")
+	subscriptionID, _ := c.Subscribe("mytopic")
 	time.Sleep(time.Second)
 
 	msg, err := c.Publish("mytopic", "some message")

+ 0 - 1
cmd/app.go

@@ -29,7 +29,6 @@ var flagsDefault = []cli.Flag{
 
 var (
 	logLevelOverrideRegex = regexp.MustCompile(`(?i)^([^=\s]+)(?:\s*=\s*(\S+))?\s*->\s*(TRACE|DEBUG|INFO|WARN|ERROR)$`)
-	topicRegex            = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // Same as in server/server.go
 )
 
 // New creates a new CLI application

+ 0 - 4
cmd/publish.go

@@ -249,10 +249,6 @@ func parseTopicMessageCommand(c *cli.Context) (topic string, message string, com
 	if c.String("message") != "" {
 		message = c.String("message")
 	}
-	if !topicRegex.MatchString(topic) {
-		err = fmt.Errorf("topic %s contains invalid characters", topic)
-		return
-	}
 	return
 }
 

+ 8 - 4
cmd/subscribe.go

@@ -108,8 +108,6 @@ func execSubscribe(c *cli.Context) error {
 	// Checks
 	if user != "" && token != "" {
 		return errors.New("cannot set both --user and --token")
-	} else if !topicRegex.MatchString(topic) {
-		return fmt.Errorf("topic %s contains invalid characters", topic)
 	}
 
 	if !fromConfig {
@@ -196,7 +194,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic,
 			topicOptions = append(topicOptions, auth)
 		}
 
-		subscriptionID := cl.Subscribe(s.Topic, topicOptions...)
+		subscriptionID, err := cl.Subscribe(s.Topic, topicOptions...)
+		if err != nil {
+			return err
+		}
 		if s.Command != "" {
 			cmds[subscriptionID] = s.Command
 		} else if conf.DefaultCommand != "" {
@@ -206,7 +207,10 @@ func doSubscribe(c *cli.Context, cl *client.Client, conf *client.Config, topic,
 		}
 	}
 	if topic != "" {
-		subscriptionID := cl.Subscribe(topic, options...)
+		subscriptionID, err := cl.Subscribe(topic, options...)
+		if err != nil {
+			return err
+		}
 		cmds[subscriptionID] = command
 	}
 	for m := range cl.Messages {