Explorar o código

WIP: iOS poll_request forwarder

Philipp Heckel %!s(int64=3) %!d(string=hai) anos
pai
achega
6a43c1a126
Modificáronse 5 ficheiros con 85 adicións e 11 borrados
  1. 5 0
      cmd/serve.go
  2. 1 0
      server/config.go
  3. 37 1
      server/server.go
  4. 27 0
      server/server_firebase.go
  5. 15 10
      server/types.go

+ 5 - 0
cmd/serve.go

@@ -41,6 +41,7 @@ var flagsServe = []cli.Flag{
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "keepalive-interval", Aliases: []string{"keepalive_interval", "k"}, EnvVars: []string{"NTFY_KEEPALIVE_INTERVAL"}, Value: server.DefaultKeepaliveInterval, Usage: "interval of keepalive messages"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "manager-interval", Aliases: []string{"manager_interval", "m"}, EnvVars: []string{"NTFY_MANAGER_INTERVAL"}, Value: server.DefaultManagerInterval, Usage: "interval of for message pruning and stats printing"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "web-root", Aliases: []string{"web_root"}, EnvVars: []string{"NTFY_WEB_ROOT"}, Value: "app", Usage: "sets web root to landing page (home), web app (app) or disabled (disable)"}),
+	altsrc.NewStringFlag(&cli.StringFlag{Name: "forward-poll-url", Aliases: []string{"forward_poll_url"}, EnvVars: []string{"NTFY_FORWARD_POLL_URL"}, Value: "", Usage: ""}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-pass", Aliases: []string{"smtp_sender_pass"}, EnvVars: []string{"NTFY_SMTP_SENDER_PASS"}, Usage: "SMTP password (if e-mail sending is enabled)"}),
@@ -102,6 +103,7 @@ func execServe(c *cli.Context) error {
 	keepaliveInterval := c.Duration("keepalive-interval")
 	managerInterval := c.Duration("manager-interval")
 	webRoot := c.String("web-root")
+	forwardPollURL := c.String("forward-poll-url")
 	smtpSenderAddr := c.String("smtp-sender-addr")
 	smtpSenderUser := c.String("smtp-sender-user")
 	smtpSenderPass := c.String("smtp-sender-pass")
@@ -147,6 +149,8 @@ func execServe(c *cli.Context) error {
 		return errors.New("if set, auth-default-access must start set to 'read-write', 'read-only', 'write-only' or 'deny-all'")
 	} else if !util.InStringList([]string{"app", "home", "disable"}, webRoot) {
 		return errors.New("if set, web-root must be 'home' or 'app'")
+	} else if forwardPollURL != "" && !strings.HasPrefix(forwardPollURL, "http://") && !strings.HasPrefix(forwardPollURL, "https://") {
+		return errors.New("if set, forward-poll-url must start with http:// or https://")
 	}
 
 	webRootIsApp := webRoot == "app"
@@ -215,6 +219,7 @@ func execServe(c *cli.Context) error {
 	conf.KeepaliveInterval = keepaliveInterval
 	conf.ManagerInterval = managerInterval
 	conf.WebRootIsApp = webRootIsApp
+	conf.ForwardPollURL = forwardPollURL
 	conf.SMTPSenderAddr = smtpSenderAddr
 	conf.SMTPSenderUser = smtpSenderUser
 	conf.SMTPSenderPass = smtpSenderPass

+ 1 - 0
server/config.go

@@ -69,6 +69,7 @@ type Config struct {
 	AtSenderInterval                     time.Duration
 	FirebaseKeepaliveInterval            time.Duration
 	FirebasePollInterval                 time.Duration
+	ForwardPollURL                       string
 	SMTPSenderAddr                       string
 	SMTPSenderUser                       string
 	SMTPSenderPass                       string

+ 37 - 1
server/server.go

@@ -3,6 +3,7 @@ package server
 import (
 	"bytes"
 	"context"
+	"crypto/sha256"
 	"embed"
 	"encoding/base64"
 	"encoding/json"
@@ -93,6 +94,7 @@ const (
 	firebaseControlTopic     = "~control"                // See Android if changed
 	firebasePollTopic        = "~poll"                   // See iOS if changed
 	emptyMessageBody         = "triggered"               // Used if message body is empty
+	newMessageBody           = "New message"             // Used in poll requests as generic message
 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
 	encodingBase64           = "base64"
 )
@@ -422,6 +424,9 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 	if err != nil {
 		return err
 	}
+	if m.PollID != "" {
+		m = newPollRequestMessage(t.ID, m.PollID)
+	}
 	if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
 		return err
 	}
@@ -448,6 +453,28 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 			}
 		}()
 	}
+	if s.config.ForwardPollURL != "" {
+		go func() {
+			topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
+			topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
+			forwardURL := fmt.Sprintf("%s/%s", s.config.ForwardPollURL, topicHash)
+			log.Printf("forwarding: topicURL %s, to upstream url %s", topicURL, forwardURL)
+			req, err := http.NewRequest("POST", forwardURL, strings.NewReader(""))
+			if err != nil {
+				log.Printf("[%s] FWD - Unable to forward poll request: %v", v.ip, err.Error())
+				return
+			}
+			req.Header.Set("X-Poll-ID", m.ID)
+			response, err := http.DefaultClient.Do(req)
+			if err != nil {
+				log.Printf("[%s] FWD - Unable to forward poll request: %v", v.ip, err.Error())
+				return
+			} else if response.StatusCode != http.StatusOK {
+				log.Printf("[%s] FWD - Unable to forward poll request, unexpected status: %d", v.ip, response.StatusCode)
+				return
+			}
+		}()
+	}
 	if cache {
 		if err := s.messageCache.AddMessage(m); err != nil {
 			return err
@@ -549,6 +576,12 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 		firebase = false
 		unifiedpush = true
 	}
+	m.PollID = readParam(r, "x-poll-id", "poll-id", "poll")
+	if m.PollID != "" {
+		unifiedpush = false
+		cache = false
+		email = ""
+	}
 	return cache, firebase, email, unifiedpush, nil
 }
 
@@ -565,7 +598,9 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 // 5. curl -T file.txt ntfy.sh/mytopic
 //    If file.txt is > message limit, treat it as an attachment
 func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
-	if unifiedpush {
+	if m.Event == pollRequestEvent {
+		return nil // Ignore body
+	} else if unifiedpush {
 		return s.handleBodyAsMessageAutoDetect(m, body) // Case 1
 	} else if m.Attachment != nil && m.Attachment.URL != "" {
 		return s.handleBodyAsTextMessage(m, body) // Case 2
@@ -710,6 +745,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 	w.Header().Set("Access-Control-Allow-Origin", "*")            // CORS, allow cross-origin requests
 	w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
 	if poll {
+		log.Printf("polling %#v", r.URL)
 		return s.sendOldMessages(topics, since, scheduled, sub)
 	}
 	subscriberIDs := make([]int, 0)

+ 27 - 0
server/server_firebase.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"log"
 	"strings"
 
 	firebase "firebase.google.com/go"
@@ -64,6 +65,7 @@ func createFirebaseSubscriber(credentialsFile string, auther auth.Auther) (subsc
 		if err != nil {
 			return err
 		}
+		log.Printf("Sending %#v %#v", m, fbm)
 		_, err = msg.Send(context.Background(), fbm)
 		return err
 	}, nil
@@ -98,6 +100,31 @@ func toFirebaseMessage(m *message, auther auth.Auther) (*messaging.Message, erro
 				CustomData: apnsData,
 			},
 		}
+	case pollRequestEvent:
+		data = map[string]string{
+			"id":      m.ID,
+			"time":    fmt.Sprintf("%d", m.Time),
+			"event":   m.Event,
+			"topic":   m.Topic,
+			"message": m.Message,
+			"poll_id": m.PollID,
+		}
+		apnsData := make(map[string]interface{})
+		for k, v := range data {
+			apnsData[k] = v
+		}
+		apnsConfig = &messaging.APNSConfig{
+			Payload: &messaging.APNSPayload{
+				CustomData: apnsData,
+				Aps: &messaging.Aps{
+					MutableContent: true,
+					Alert: &messaging.ApsAlert{
+						Title: m.Title,
+						Body:  maybeTruncateAPNSBodyMessage(m.Message),
+					},
+				},
+			},
+		}
 	case messageEvent:
 		allowForward := true
 		if auther != nil {

+ 15 - 10
server/types.go

@@ -24,13 +24,14 @@ type message struct {
 	Time       int64       `json:"time"`  // Unix time in seconds
 	Event      string      `json:"event"` // One of the above
 	Topic      string      `json:"topic"`
+	Title      string      `json:"title,omitempty"`
+	Message    string      `json:"message,omitempty"`
 	Priority   int         `json:"priority,omitempty"`
 	Tags       []string    `json:"tags,omitempty"`
 	Click      string      `json:"click,omitempty"`
 	Actions    []*action   `json:"actions,omitempty"`
 	Attachment *attachment `json:"attachment,omitempty"`
-	Title      string      `json:"title,omitempty"`
-	Message    string      `json:"message,omitempty"`
+	PollID     string      `json:"poll_id,omitempty"`
 	Encoding   string      `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
 }
 
@@ -84,14 +85,11 @@ type messageEncoder func(msg *message) (string, error)
 // newMessage creates a new message with the current timestamp
 func newMessage(event, topic, msg string) *message {
 	return &message{
-		ID:       util.RandomString(messageIDLength),
-		Time:     time.Now().Unix(),
-		Event:    event,
-		Topic:    topic,
-		Priority: 0,
-		Tags:     nil,
-		Title:    "",
-		Message:  msg,
+		ID:      util.RandomString(messageIDLength),
+		Time:    time.Now().Unix(),
+		Event:   event,
+		Topic:   topic,
+		Message: msg,
 	}
 }
 
@@ -110,6 +108,13 @@ func newDefaultMessage(topic, msg string) *message {
 	return newMessage(messageEvent, topic, msg)
 }
 
+// newPollRequestMessage is a convenience method to create a poll request message
+func newPollRequestMessage(topic, pollID string) *message {
+	m := newMessage(pollRequestEvent, topic, newMessageBody)
+	m.PollID = pollID
+	return m
+}
+
 func validMessageID(s string) bool {
 	return util.ValidRandomString(s, messageIDLength)
 }