Browse Source

Merge pull request #95 from binwiederhier/websockets

Websockets; working
Philipp C. Heckel 4 years ago
parent
commit
828a286809
17 changed files with 591 additions and 238 deletions
  1. 20 8
      client/client.go
  2. 1 3
      cmd/app_test.go
  3. 36 0
      cmd/publish_test.go
  4. 68 0
      cmd/serve_test.go
  5. 6 2
      docs/config.md
  6. 53 4
      docs/subscribe/api.md
  7. 2 0
      go.mod
  8. 3 0
      go.sum
  9. 1 1
      server/config.go
  10. 49 0
      server/errors.go
  11. 155 160
      server/server.go
  12. 3 2
      server/server.yml
  13. 0 58
      server/server_test.go
  14. 70 0
      server/types.go
  15. 55 0
      server/util.go
  16. 66 0
      server/util_test.go
  17. 3 0
      test/server.go

+ 20 - 8
client/client.go

@@ -36,14 +36,16 @@ type Client struct {
 
 
 // Message is a struct that represents a ntfy message
 // Message is a struct that represents a ntfy message
 type Message struct { // TODO combine with server.message
 type Message struct { // TODO combine with server.message
-	ID       string
-	Event    string
-	Time     int64
-	Topic    string
-	Message  string
-	Title    string
-	Priority int
-	Tags     []string
+	ID         string
+	Event      string
+	Time       int64
+	Topic      string
+	Message    string
+	Title      string
+	Priority   int
+	Tags       []string
+	Click      string
+	Attachment *Attachment
 
 
 	// Additional fields
 	// Additional fields
 	TopicURL       string
 	TopicURL       string
@@ -51,6 +53,16 @@ type Message struct { // TODO combine with server.message
 	Raw            string
 	Raw            string
 }
 }
 
 
+// Attachment represents a message attachment
+type Attachment struct {
+	Name    string `json:"name"`
+	Type    string `json:"type,omitempty"`
+	Size    int64  `json:"size,omitempty"`
+	Expires int64  `json:"expires,omitempty"`
+	URL     string `json:"url"`
+	Owner   string `json:"-"` // IP address of uploader, used for rate limiting
+}
+
 type subscription struct {
 type subscription struct {
 	ID       string
 	ID       string
 	topicURL string
 	topicURL string

+ 1 - 3
cmd/app_test.go

@@ -5,8 +5,6 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"github.com/urfave/cli/v2"
 	"github.com/urfave/cli/v2"
 	"heckel.io/ntfy/client"
 	"heckel.io/ntfy/client"
-	"io"
-	"log"
 	"os"
 	"os"
 	"strings"
 	"strings"
 	"testing"
 	"testing"
@@ -15,7 +13,7 @@ import (
 // This only contains helpers so far
 // This only contains helpers so far
 
 
 func TestMain(m *testing.M) {
 func TestMain(m *testing.M) {
-	log.SetOutput(io.Discard)
+	// log.SetOutput(io.Discard)
 	os.Exit(m.Run())
 	os.Exit(m.Run())
 }
 }
 
 

+ 36 - 0
cmd/publish_test.go

@@ -34,3 +34,39 @@ func TestCLI_Publish_Subscribe_Poll(t *testing.T) {
 	m = toMessage(t, stdout.String())
 	m = toMessage(t, stdout.String())
 	require.Equal(t, "some message", m.Message)
 	require.Equal(t, "some message", m.Message)
 }
 }
+
+func TestCLI_Publish_All_The_Things(t *testing.T) {
+	s, port := test.StartServer(t)
+	defer test.StopServer(t, s, port)
+	topic := fmt.Sprintf("http://127.0.0.1:%d/mytopic", port)
+
+	app, _, stdout, _ := newTestApp()
+	require.Nil(t, app.Run([]string{
+		"ntfy", "publish",
+		"--title", "this is a title",
+		"--priority", "high",
+		"--tags", "tag1,tag2",
+		// No --delay, --email
+		"--click", "https://ntfy.sh",
+		"--attach", "https://f-droid.org/F-Droid.apk",
+		"--filename", "fdroid.apk",
+		"--no-cache",
+		"--no-firebase",
+		topic,
+		"some message",
+	}))
+	m := toMessage(t, stdout.String())
+	require.Equal(t, "message", m.Event)
+	require.Equal(t, "mytopic", m.Topic)
+	require.Equal(t, "some message", m.Message)
+	require.Equal(t, "this is a title", m.Title)
+	require.Equal(t, 4, m.Priority)
+	require.Equal(t, []string{"tag1", "tag2"}, m.Tags)
+	require.Equal(t, "https://ntfy.sh", m.Click)
+	require.Equal(t, "https://f-droid.org/F-Droid.apk", m.Attachment.URL)
+	require.Equal(t, "fdroid.apk", m.Attachment.Name)
+	require.Equal(t, int64(0), m.Attachment.Size)
+	require.Equal(t, "", m.Attachment.Owner)
+	require.Equal(t, int64(0), m.Attachment.Expires)
+	require.Equal(t, "", m.Attachment.Type)
+}

+ 68 - 0
cmd/serve_test.go

@@ -0,0 +1,68 @@
+package cmd
+
+import (
+	"fmt"
+	"github.com/gorilla/websocket"
+	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/client"
+	"heckel.io/ntfy/test"
+	"heckel.io/ntfy/util"
+	"math/rand"
+	"os/exec"
+	"path/filepath"
+	"testing"
+	"time"
+)
+
+func init() {
+	rand.Seed(time.Now().UnixMilli())
+}
+
+func TestCLI_Serve_Unix_Curl(t *testing.T) {
+	sockFile := filepath.Join(t.TempDir(), "ntfy.sock")
+	go func() {
+		app, _, _, _ := newTestApp()
+		err := app.Run([]string{"ntfy", "serve", "--listen-http=-", "--listen-unix=" + sockFile})
+		require.Nil(t, err)
+	}()
+	for i := 0; i < 40 && !util.FileExists(sockFile); i++ {
+		time.Sleep(50 * time.Millisecond)
+	}
+	require.True(t, util.FileExists(sockFile))
+
+	cmd := exec.Command("curl", "-s", "--unix-socket", sockFile, "-d", "this is a message", "localhost/mytopic")
+	out, err := cmd.Output()
+	require.Nil(t, err)
+	m := toMessage(t, string(out))
+	require.Equal(t, "this is a message", m.Message)
+}
+
+func TestCLI_Serve_WebSocket(t *testing.T) {
+	port := 10000 + rand.Intn(20000)
+	go func() {
+		app, _, _, _ := newTestApp()
+		err := app.Run([]string{"ntfy", "serve", fmt.Sprintf("--listen-http=:%d", port)})
+		require.Nil(t, err)
+	}()
+	test.WaitForPortUp(t, port)
+
+	ws, _, err := websocket.DefaultDialer.Dial(fmt.Sprintf("ws://127.0.0.1:%d/mytopic/ws", port), nil)
+	require.Nil(t, err)
+
+	messageType, data, err := ws.ReadMessage()
+	require.Nil(t, err)
+	require.Equal(t, websocket.TextMessage, messageType)
+	require.Equal(t, "open", toMessage(t, string(data)).Event)
+
+	c := client.New(client.NewConfig())
+	_, err = c.Publish(fmt.Sprintf("http://127.0.0.1:%d/mytopic", port), "my message")
+	require.Nil(t, err)
+
+	messageType, data, err = ws.ReadMessage()
+	require.Nil(t, err)
+	require.Equal(t, websocket.TextMessage, messageType)
+
+	m := toMessage(t, string(data))
+	require.Equal(t, "my message", m.Message)
+	require.Equal(t, "mytopic", m.Topic)
+}

+ 6 - 2
docs/config.md

@@ -243,6 +243,8 @@ or the root domain:
         proxy_redirect off;
         proxy_redirect off;
      
      
         proxy_set_header Host $http_host;
         proxy_set_header Host $http_host;
+        proxy_set_header Upgrade $http_upgrade;
+        proxy_set_header Connection "upgrade";
         proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
         proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
     
     
         proxy_connect_timeout 3m;
         proxy_connect_timeout 3m;
@@ -274,6 +276,8 @@ or the root domain:
         proxy_redirect off;
         proxy_redirect off;
      
      
         proxy_set_header Host $http_host;
         proxy_set_header Host $http_host;
+        proxy_set_header Upgrade $http_upgrade;
+        proxy_set_header Connection "upgrade";
         proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
         proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
     
     
         proxy_connect_timeout 3m;
         proxy_connect_timeout 3m;
@@ -549,7 +553,7 @@ variable before running the `ntfy` command (e.g. `export NTFY_LISTEN_HTTP=:80`).
 | `smtp-server-listen`                       | `NTFY_SMTP_SERVER_LISTEN`                       | `[ip]:port`      | -       | Defines the IP address and port the SMTP server will listen on, e.g. `:25` or `1.2.3.4:25`                                                                                                                                      |
 | `smtp-server-listen`                       | `NTFY_SMTP_SERVER_LISTEN`                       | `[ip]:port`      | -       | Defines the IP address and port the SMTP server will listen on, e.g. `:25` or `1.2.3.4:25`                                                                                                                                      |
 | `smtp-server-domain`                       | `NTFY_SMTP_SERVER_DOMAIN`                       | *domain name*    | -       | SMTP server e-mail domain, e.g. `ntfy.sh`                                                                                                                                                                                       |
 | `smtp-server-domain`                       | `NTFY_SMTP_SERVER_DOMAIN`                       | *domain name*    | -       | SMTP server e-mail domain, e.g. `ntfy.sh`                                                                                                                                                                                       |
 | `smtp-server-addr-prefix`                  | `NTFY_SMTP_SERVER_ADDR_PREFIX`                  | `[ip]:port`      | -       | Optional prefix for the e-mail addresses to prevent spam, e.g. `ntfy-`                                                                                                                                                          |
 | `smtp-server-addr-prefix`                  | `NTFY_SMTP_SERVER_ADDR_PREFIX`                  | `[ip]:port`      | -       | Optional prefix for the e-mail addresses to prevent spam, e.g. `ntfy-`                                                                                                                                                          |
-| `keepalive-interval`                       | `NTFY_KEEPALIVE_INTERVAL`                       | *duration*       | 55s     | Interval in which keepalive messages are sent to the client. This is to prevent intermediaries closing the connection for inactivity. Note that the Android app has a hardcoded timeout at 77s, so it should be less than that. |
+| `keepalive-interval`                       | `NTFY_KEEPALIVE_INTERVAL`                       | *duration*       | 45s     | Interval in which keepalive messages are sent to the client. This is to prevent intermediaries closing the connection for inactivity. Note that the Android app has a hardcoded timeout at 77s, so it should be less than that. |
 | `manager-interval`                         | `$NTFY_MANAGER_INTERVAL`                        | *duration*       | 1m      | Interval in which the manager prunes old messages, deletes topics and prints the stats.                                                                                                                                         |
 | `manager-interval`                         | `$NTFY_MANAGER_INTERVAL`                        | *duration*       | 1m      | Interval in which the manager prunes old messages, deletes topics and prints the stats.                                                                                                                                         |
 | `global-topic-limit`                       | `NTFY_GLOBAL_TOPIC_LIMIT`                       | *number*         | 15,000  | Rate limiting: Total number of topics before the server rejects new topics.                                                                                                                                                     |
 | `global-topic-limit`                       | `NTFY_GLOBAL_TOPIC_LIMIT`                       | *number*         | 15,000  | Rate limiting: Total number of topics before the server rejects new topics.                                                                                                                                                     |
 | `visitor-subscription-limit`               | `NTFY_VISITOR_SUBSCRIPTION_LIMIT`               | *number*         | 30      | Rate limiting: Number of subscriptions per visitor (IP address)                                                                                                                                                                 |
 | `visitor-subscription-limit`               | `NTFY_VISITOR_SUBSCRIPTION_LIMIT`               | *number*         | 30      | Rate limiting: Number of subscriptions per visitor (IP address)                                                                                                                                                                 |
@@ -597,7 +601,7 @@ OPTIONS:
    --attachment-total-size-limit value, -A value     limit of the on-disk attachment cache (default: 5G) [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT]
    --attachment-total-size-limit value, -A value     limit of the on-disk attachment cache (default: 5G) [$NTFY_ATTACHMENT_TOTAL_SIZE_LIMIT]
    --attachment-file-size-limit value, -Y value      per-file attachment size limit (e.g. 300k, 2M, 100M) (default: 15M) [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT]
    --attachment-file-size-limit value, -Y value      per-file attachment size limit (e.g. 300k, 2M, 100M) (default: 15M) [$NTFY_ATTACHMENT_FILE_SIZE_LIMIT]
    --attachment-expiry-duration value, -X value      duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: 3h) [$NTFY_ATTACHMENT_EXPIRY_DURATION]
    --attachment-expiry-duration value, -X value      duration after which uploaded attachments will be deleted (e.g. 3h, 20h) (default: 3h) [$NTFY_ATTACHMENT_EXPIRY_DURATION]
-   --keepalive-interval value, -k value              interval of keepalive messages (default: 55s) [$NTFY_KEEPALIVE_INTERVAL]
+   --keepalive-interval value, -k value              interval of keepalive messages (default: 45s) [$NTFY_KEEPALIVE_INTERVAL]
    --manager-interval value, -m value                interval of for message pruning and stats printing (default: 1m0s) [$NTFY_MANAGER_INTERVAL]
    --manager-interval value, -m value                interval of for message pruning and stats printing (default: 1m0s) [$NTFY_MANAGER_INTERVAL]
    --smtp-sender-addr value                          SMTP server address (host:port) for outgoing emails [$NTFY_SMTP_SENDER_ADDR]
    --smtp-sender-addr value                          SMTP server address (host:port) for outgoing emails [$NTFY_SMTP_SENDER_ADDR]
    --smtp-sender-user value                          SMTP user (if e-mail sending is enabled) [$NTFY_SMTP_SENDER_USER]
    --smtp-sender-user value                          SMTP user (if e-mail sending is enabled) [$NTFY_SMTP_SENDER_USER]

+ 53 - 4
docs/subscribe/api.md

@@ -3,7 +3,11 @@ You can create and subscribe to a topic in the [web UI](web.md), via the [phone
 or in your own app or script by subscribing the API. This page describes how to subscribe via API. You may also want to 
 or in your own app or script by subscribing the API. This page describes how to subscribe via API. You may also want to 
 check out the page that describes how to [publish messages](../publish.md).
 check out the page that describes how to [publish messages](../publish.md).
 
 
-The subscription API relies on a simple HTTP GET request with a streaming HTTP response, i.e **you open a GET request and
+You can consume the subscription API as either a **[simple HTTP stream (JSON, SSE or raw)](#http-stream)**, or 
+**[via WebSockets](#websockets)**. Both are incredibly simple to use.
+
+## HTTP stream
+The HTTP stream-based API relies on a simple GET request with a streaming HTTP response, i.e **you open a GET request and
 the connection stays open forever**, sending messages back as they come in. There are three different API endpoints, which 
 the connection stays open forever**, sending messages back as they come in. There are three different API endpoints, which 
 only differ in the response format:
 only differ in the response format:
 
 
@@ -12,7 +16,7 @@ only differ in the response format:
   can be used with [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource)
   can be used with [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource)
 * [Raw stream](#subscribe-as-raw-stream): `<topic>/raw` returns messages as raw text, with one line per message
 * [Raw stream](#subscribe-as-raw-stream): `<topic>/raw` returns messages as raw text, with one line per message
 
 
-## Subscribe as JSON stream
+### Subscribe as JSON stream
 Here are a few examples of how to consume the JSON endpoint (`<topic>/json`). For almost all languages, **this is the 
 Here are a few examples of how to consume the JSON endpoint (`<topic>/json`). For almost all languages, **this is the 
 recommended way to subscribe to a topic**. The notable exception is JavaScript, for which the 
 recommended way to subscribe to a topic**. The notable exception is JavaScript, for which the 
 [SSE/EventSource stream](#subscribe-as-sse-stream) is much easier to work with.
 [SSE/EventSource stream](#subscribe-as-sse-stream) is much easier to work with.
@@ -80,7 +84,7 @@ recommended way to subscribe to a topic**. The notable exception is JavaScript,
     fclose($fp);
     fclose($fp);
     ```
     ```
 
 
-## Subscribe as SSE stream
+### Subscribe as SSE stream
 Using [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) in JavaScript, you can consume
 Using [EventSource](https://developer.mozilla.org/en-US/docs/Web/API/EventSource) in JavaScript, you can consume
 notifications via a [Server-Sent Events (SSE)](https://en.wikipedia.org/wiki/Server-sent_events) stream. It's incredibly 
 notifications via a [Server-Sent Events (SSE)](https://en.wikipedia.org/wiki/Server-sent_events) stream. It's incredibly 
 easy to use. Here's what it looks like. You may also want to check out the [live example](/example.html).
 easy to use. Here's what it looks like. You may also want to check out the [live example](/example.html).
@@ -125,7 +129,7 @@ easy to use. Here's what it looks like. You may also want to check out the [live
     };
     };
     ```
     ```
 
 
-## Subscribe as raw stream
+### Subscribe as raw stream
 The `/raw` endpoint will output one line per message, and **will only include the message body**. It's useful for extremely
 The `/raw` endpoint will output one line per message, and **will only include the message body**. It's useful for extremely
 simple scripts, and doesn't include all the data. Additional fields such as [priority](../publish.md#message-priority), 
 simple scripts, and doesn't include all the data. Additional fields such as [priority](../publish.md#message-priority), 
 [tags](../publish.md#tags--emojis--) or [message title](../publish.md#message-title) are not included in this output 
 [tags](../publish.md#tags--emojis--) or [message title](../publish.md#message-title) are not included in this output 
@@ -184,6 +188,51 @@ format. Keepalive messages are sent as empty lines.
     fclose($fp);
     fclose($fp);
     ```
     ```
 
 
+## WebSockets
+You may also subscribe to topics via [WebSockets](https://en.wikipedia.org/wiki/WebSocket), which is also widely 
+supported in many languages. Most notably, WebSockets are natively supported in JavaScript. On the command line, 
+I recommend [websocat](https://github.com/vi/websocat), a fantastic tool similar to `socat` or `curl`, but specifically
+for WebSockets.  
+
+The WebSockets endpoint is available at `<topic>/ws` and returns messages as JSON objects similar to the 
+[JSON stream endpoint](#subscribe-as-json-stream). 
+
+=== "Command line (websocat)"
+    ```
+    $ websocat wss://ntfy.sh/mytopic/ws
+    {"id":"qRHUCCvjj8","time":1642307388,"event":"open","topic":"mytopic"}
+    {"id":"eOWoUBJ14x","time":1642307754,"event":"message","topic":"mytopic","message":"hi there"}
+    ```
+
+=== "HTTP"
+    ``` http
+    GET /disk-alerts/ws HTTP/1.1
+    Host: ntfy.sh
+    Upgrade: websocket
+    Connection: Upgrade
+
+    HTTP/1.1 101 Switching Protocols
+    Upgrade: websocket
+    Connection: Upgrade
+    ...
+    ```
+
+=== "Go"
+    ``` go
+    import "github.com/gorilla/websocket"
+	ws, _, _ := websocket.DefaultDialer.Dial("wss://ntfy.sh/mytopic/ws", nil)
+	messageType, data, err := ws.ReadMessage()
+    ...
+    ```
+
+=== "JavaScript"
+    ``` javascript
+    const socket = new WebSocket('wss://ntfy.sh/mytopic/ws');
+    socket.addEventListener('message', function (event) {
+        console.log(event.data);
+    });
+    ```
+
 ## Advanced features
 ## Advanced features
 
 
 ### Poll for messages
 ### Poll for messages

+ 2 - 0
go.mod

@@ -35,11 +35,13 @@ require (
 	github.com/golang/protobuf v1.5.2 // indirect
 	github.com/golang/protobuf v1.5.2 // indirect
 	github.com/google/go-cmp v0.5.6 // indirect
 	github.com/google/go-cmp v0.5.6 // indirect
 	github.com/googleapis/gax-go/v2 v2.1.1 // indirect
 	github.com/googleapis/gax-go/v2 v2.1.1 // indirect
+	github.com/gorilla/websocket v1.4.2 // indirect
 	github.com/pkg/errors v0.9.1 // indirect
 	github.com/pkg/errors v0.9.1 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/russross/blackfriday/v2 v2.1.0 // indirect
 	github.com/russross/blackfriday/v2 v2.1.0 // indirect
 	go.opencensus.io v0.23.0 // indirect
 	go.opencensus.io v0.23.0 // indirect
 	golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d // indirect
 	golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d // indirect
+	golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
 	golang.org/x/sys v0.0.0-20211210111614-af8b64212486 // indirect
 	golang.org/x/sys v0.0.0-20211210111614-af8b64212486 // indirect
 	golang.org/x/text v0.3.7 // indirect
 	golang.org/x/text v0.3.7 // indirect
 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
 	golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect

+ 3 - 0
go.sum

@@ -189,6 +189,8 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m
 github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
 github.com/googleapis/gax-go/v2 v2.1.0/go.mod h1:Q3nei7sK6ybPYH7twZdmQpAd1MKb7pfu6SK+H1/DsU0=
 github.com/googleapis/gax-go/v2 v2.1.1 h1:dp3bWCh+PPO1zjRRiCSczJav13sBvG4UhNyVTa1KqdU=
 github.com/googleapis/gax-go/v2 v2.1.1 h1:dp3bWCh+PPO1zjRRiCSczJav13sBvG4UhNyVTa1KqdU=
 github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
 github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0eJc8R6ouapiM=
+github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
+github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
 github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
 github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
 github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
 github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
@@ -356,6 +358,7 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ=
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

+ 1 - 1
server/config.go

@@ -8,7 +8,7 @@ import (
 const (
 const (
 	DefaultListenHTTP                = ":80"
 	DefaultListenHTTP                = ":80"
 	DefaultCacheDuration             = 12 * time.Hour
 	DefaultCacheDuration             = 12 * time.Hour
-	DefaultKeepaliveInterval         = 55 * time.Second // Not too frequently to save battery (Android read timeout is 77s!)
+	DefaultKeepaliveInterval         = 45 * time.Second // Not too frequently to save battery (Android read timeout used to be 77s!)
 	DefaultManagerInterval           = time.Minute
 	DefaultManagerInterval           = time.Minute
 	DefaultAtSenderInterval          = 10 * time.Second
 	DefaultAtSenderInterval          = 10 * time.Second
 	DefaultMinDelay                  = 10 * time.Second
 	DefaultMinDelay                  = 10 * time.Second

+ 49 - 0
server/errors.go

@@ -0,0 +1,49 @@
+package server
+
+import (
+	"encoding/json"
+	"net/http"
+)
+
+// errHTTP is a generic HTTP error for any non-200 HTTP error
+type errHTTP struct {
+	Code     int    `json:"code,omitempty"`
+	HTTPCode int    `json:"http"`
+	Message  string `json:"error"`
+	Link     string `json:"link,omitempty"`
+}
+
+func (e errHTTP) Error() string {
+	return e.Message
+}
+
+func (e errHTTP) JSON() string {
+	b, _ := json.Marshal(&e)
+	return string(b)
+}
+
+var (
+	errHTTPBadRequestEmailDisabled                   = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
+	errHTTPBadRequestDelayNoCache                    = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
+	errHTTPBadRequestDelayNoEmail                    = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
+	errHTTPBadRequestDelayCannotParse                = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+	errHTTPBadRequestDelayTooSmall                   = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+	errHTTPBadRequestDelayTooLarge                   = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
+	errHTTPBadRequestPriorityInvalid                 = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
+	errHTTPBadRequestSinceInvalid                    = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
+	errHTTPBadRequestTopicInvalid                    = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
+	errHTTPBadRequestTopicDisallowed                 = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
+	errHTTPBadRequestMessageNotUTF8                  = &errHTTP{40011, http.StatusBadRequest, "invalid message: message must be UTF-8 encoded", ""}
+	errHTTPBadRequestAttachmentTooLarge              = &errHTTP{40012, http.StatusBadRequest, "invalid request: attachment too large, or bandwidth limit reached", ""}
+	errHTTPBadRequestAttachmentURLInvalid            = &errHTTP{40013, http.StatusBadRequest, "invalid request: attachment URL is invalid", ""}
+	errHTTPBadRequestAttachmentsDisallowed           = &errHTTP{40014, http.StatusBadRequest, "invalid request: attachments not allowed", ""}
+	errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
+	errHTTPNotFound                                  = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
+	errHTTPTooManyRequestsLimitRequests              = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitEmails                = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitSubscriptions         = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitTotalTopics           = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsAttachmentBandwidthLimit   = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPInternalError                             = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
+	errHTTPInternalErrorInvalidFilePath              = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
+)

+ 155 - 160
server/server.go

@@ -10,6 +10,8 @@ import (
 	"firebase.google.com/go/messaging"
 	"firebase.google.com/go/messaging"
 	"fmt"
 	"fmt"
 	"github.com/emersion/go-smtp"
 	"github.com/emersion/go-smtp"
+	"github.com/gorilla/websocket"
+	"golang.org/x/sync/errgroup"
 	"google.golang.org/api/option"
 	"google.golang.org/api/option"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 	"html/template"
 	"html/template"
@@ -30,9 +32,6 @@ import (
 	"unicode/utf8"
 	"unicode/utf8"
 )
 )
 
 
-// TODO add "max messages in a topic" limit
-// TODO implement "since=<ID>"
-
 // Server is the main server, providing the UI and API for ntfy
 // Server is the main server, providing the UI and API for ntfy
 type Server struct {
 type Server struct {
 	config       *Config
 	config       *Config
@@ -52,53 +51,18 @@ type Server struct {
 	mu           sync.Mutex
 	mu           sync.Mutex
 }
 }
 
 
-// errHTTP is a generic HTTP error for any non-200 HTTP error
-type errHTTP struct {
-	Code     int    `json:"code,omitempty"`
-	HTTPCode int    `json:"http"`
-	Message  string `json:"error"`
-	Link     string `json:"link,omitempty"`
-}
-
-func (e errHTTP) Error() string {
-	return e.Message
-}
-
-func (e errHTTP) JSON() string {
-	b, _ := json.Marshal(&e)
-	return string(b)
-}
-
 type indexPage struct {
 type indexPage struct {
 	Topic         string
 	Topic         string
 	CacheDuration time.Duration
 	CacheDuration time.Duration
 }
 }
 
 
-type sinceTime time.Time
-
-func (t sinceTime) IsAll() bool {
-	return t == sinceAllMessages
-}
-
-func (t sinceTime) IsNone() bool {
-	return t == sinceNoMessages
-}
-
-func (t sinceTime) Time() time.Time {
-	return time.Time(t)
-}
-
-var (
-	sinceAllMessages = sinceTime(time.Unix(0, 0))
-	sinceNoMessages  = sinceTime(time.Unix(1, 0))
-)
-
 var (
 var (
 	topicRegex       = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No /!
 	topicRegex       = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No /!
 	topicPathRegex   = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
 	topicPathRegex   = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
 	jsonPathRegex    = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
 	jsonPathRegex    = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
 	ssePathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
 	ssePathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
 	rawPathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
 	rawPathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
+	wsPathRegex      = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
 	publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
 	publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
 
 
 	staticRegex      = regexp.MustCompile(`^/static/.+`)
 	staticRegex      = regexp.MustCompile(`^/static/.+`)
@@ -125,37 +89,20 @@ var (
 	//go:embed docs
 	//go:embed docs
 	docsStaticFs     embed.FS
 	docsStaticFs     embed.FS
 	docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
 	docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
-
-	errHTTPBadRequestEmailDisabled                   = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
-	errHTTPBadRequestDelayNoCache                    = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
-	errHTTPBadRequestDelayNoEmail                    = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
-	errHTTPBadRequestDelayCannotParse                = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
-	errHTTPBadRequestDelayTooSmall                   = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
-	errHTTPBadRequestDelayTooLarge                   = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
-	errHTTPBadRequestPriorityInvalid                 = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
-	errHTTPBadRequestSinceInvalid                    = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
-	errHTTPBadRequestTopicInvalid                    = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
-	errHTTPBadRequestTopicDisallowed                 = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
-	errHTTPBadRequestMessageNotUTF8                  = &errHTTP{40011, http.StatusBadRequest, "invalid message: message must be UTF-8 encoded", ""}
-	errHTTPBadRequestAttachmentTooLarge              = &errHTTP{40012, http.StatusBadRequest, "invalid request: attachment too large, or bandwidth limit reached", ""}
-	errHTTPBadRequestAttachmentURLInvalid            = &errHTTP{40013, http.StatusBadRequest, "invalid request: attachment URL is invalid", ""}
-	errHTTPBadRequestAttachmentsDisallowed           = &errHTTP{40014, http.StatusBadRequest, "invalid request: attachments not allowed", ""}
-	errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
-	errHTTPNotFound                                  = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
-	errHTTPTooManyRequestsLimitRequests              = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsLimitEmails                = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsLimitSubscriptions         = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsLimitTotalTopics           = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsAttachmentBandwidthLimit   = &errHTTP{42905, http.StatusTooManyRequests, "too many requests: daily bandwidth limit reached", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPInternalError                             = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
-	errHTTPInternalErrorInvalidFilePath              = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid file path", ""}
 )
 )
 
 
 const (
 const (
 	firebaseControlTopic     = "~control"                // See Android if changed
 	firebaseControlTopic     = "~control"                // See Android if changed
 	emptyMessageBody         = "triggered"               // Used if message body is empty
 	emptyMessageBody         = "triggered"               // Used if message body is empty
 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
-	fcmMessageLimit          = 4000                      // see maybeTruncateFCMMessage for details
+)
+
+// WebSocket constants
+const (
+	wsWriteWait  = 2 * time.Second
+	wsBufferSize = 1024
+	wsReadLimit  = 64 // We only ever receive PINGs
+	wsPongWait   = 15 * time.Second
 )
 )
 
 
 // New instantiates a new Server. It creates the cache and adds a Firebase
 // New instantiates a new Server. It creates the cache and adds a Firebase
@@ -262,25 +209,6 @@ func createFirebaseSubscriber(conf *Config) (subscriber, error) {
 	}, nil
 	}, nil
 }
 }
 
 
-// maybeTruncateFCMMessage performs best-effort truncation of FCM messages.
-// The docs say the limit is 4000 characters, but during testing it wasn't quite clear
-// what fields matter; so we're just capping the serialized JSON to 4000 bytes.
-func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
-	s, err := json.Marshal(m)
-	if err != nil {
-		return m
-	}
-	if len(s) > fcmMessageLimit {
-		over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ...
-		message, ok := m.Data["message"]
-		if ok && len(message) > over {
-			m.Data["truncated"] = "1"
-			m.Data["message"] = message[:len(message)-over]
-		}
-	}
-	return m
-}
-
 // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
 // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
 // a manager go routine to print stats and prune messages.
 // a manager go routine to print stats and prune messages.
 func (s *Server) Run() error {
 func (s *Server) Run() error {
@@ -364,16 +292,19 @@ func (s *Server) Stop() {
 
 
 func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 	if err := s.handleInternal(w, r); err != nil {
 	if err := s.handleInternal(w, r); err != nil {
-		var e *errHTTP
-		var ok bool
-		if e, ok = err.(*errHTTP); !ok {
-			e = errHTTPInternalError
+		if websocket.IsWebSocketUpgrade(r) {
+			log.Printf("[%s] WS %s %s - %s", r.RemoteAddr, r.Method, r.URL.Path, err.Error())
+			return // Do not attempt to write to upgraded connection
 		}
 		}
-		log.Printf("[%s] %s - %d - %d - %s", r.RemoteAddr, r.Method, e.HTTPCode, e.Code, err.Error())
+		httpErr, ok := err.(*errHTTP)
+		if !ok {
+			httpErr = errHTTPInternalError
+		}
+		log.Printf("[%s] HTTP %s %s - %d - %d - %s", r.RemoteAddr, r.Method, r.URL.Path, httpErr.HTTPCode, httpErr.Code, err.Error())
 		w.Header().Set("Content-Type", "application/json")
 		w.Header().Set("Content-Type", "application/json")
 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
-		w.WriteHeader(e.HTTPCode)
-		io.WriteString(w, e.JSON()+"\n")
+		w.WriteHeader(httpErr.HTTPCode)
+		io.WriteString(w, httpErr.JSON()+"\n")
 	}
 	}
 }
 }
 
 
@@ -404,6 +335,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
 		return s.withRateLimit(w, r, s.handleSubscribeSSE)
 		return s.withRateLimit(w, r, s.handleSubscribeSSE)
 	} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
 		return s.withRateLimit(w, r, s.handleSubscribeRaw)
 		return s.withRateLimit(w, r, s.handleSubscribeRaw)
+	} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
+		return s.withRateLimit(w, r, s.handleSubscribeWS)
 	}
 	}
 	return errHTTPNotFound
 	return errHTTPNotFound
 }
 }
@@ -416,7 +349,7 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
 }
 }
 
 
 func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error {
 func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error {
-	unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see PUT/POST too!
+	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
 	if unifiedpush {
 	if unifiedpush {
 		w.Header().Set("Content-Type", "application/json")
 		w.Header().Set("Content-Type", "application/json")
 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
@@ -522,13 +455,15 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 	if err := json.NewEncoder(w).Encode(m); err != nil {
 	if err := json.NewEncoder(w).Encode(m); err != nil {
 		return err
 		return err
 	}
 	}
-	s.inc(&s.messages)
+	s.mu.Lock()
+	s.messages++
+	s.mu.Unlock()
 	return nil
 	return nil
 }
 }
 
 
 func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, err error) {
 func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, err error) {
-	cache = readParam(r, "x-cache", "cache") != "no"
-	firebase = readParam(r, "x-firebase", "firebase") != "no"
+	cache = readBoolParam(r, true, "x-cache", "cache")
+	firebase = readBoolParam(r, true, "x-firebase", "firebase")
 	m.Title = readParam(r, "x-title", "title", "t")
 	m.Title = readParam(r, "x-title", "title", "t")
 	m.Click = readParam(r, "x-click", "click")
 	m.Click = readParam(r, "x-click", "click")
 	filename := readParam(r, "x-filename", "filename", "file", "f")
 	filename := readParam(r, "x-filename", "filename", "file", "f")
@@ -599,29 +534,13 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 		}
 		}
 		m.Time = delay.Unix()
 		m.Time = delay.Unix()
 	}
 	}
-	unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see GET too!
+	unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
 	if unifiedpush {
 	if unifiedpush {
 		firebase = false
 		firebase = false
 	}
 	}
 	return cache, firebase, email, nil
 	return cache, firebase, email, nil
 }
 }
 
 
-func readParam(r *http.Request, names ...string) string {
-	for _, name := range names {
-		value := r.Header.Get(name)
-		if value != "" {
-			return strings.TrimSpace(value)
-		}
-	}
-	for _, name := range names {
-		value := r.URL.Query().Get(strings.ToLower(name))
-		if value != "" {
-			return strings.TrimSpace(value)
-		}
-	}
-	return ""
-}
-
 // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
 // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
 //
 //
 // 1. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
 // 1. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
@@ -705,7 +624,7 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *
 		}
 		}
 		return buf.String(), nil
 		return buf.String(), nil
 	}
 	}
-	return s.handleSubscribe(w, r, v, "json", "application/x-ndjson", encoder)
+	return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
 }
 }
 
 
 func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
 func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -719,7 +638,7 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *v
 		}
 		}
 		return fmt.Sprintf("data: %s\n", buf.String()), nil
 		return fmt.Sprintf("data: %s\n", buf.String()), nil
 	}
 	}
-	return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder)
+	return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
 }
 }
 
 
 func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
 func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -729,33 +648,25 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v
 		}
 		}
 		return "\n", nil // "keepalive" and "open" events just send an empty line
 		return "\n", nil // "keepalive" and "open" events just send an empty line
 	}
 	}
-	return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder)
+	return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
 }
 }
 
 
-func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error {
+func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
 	if err := v.SubscriptionAllowed(); err != nil {
 	if err := v.SubscriptionAllowed(); err != nil {
 		return errHTTPTooManyRequestsLimitSubscriptions
 		return errHTTPTooManyRequestsLimitSubscriptions
 	}
 	}
 	defer v.RemoveSubscription()
 	defer v.RemoveSubscription()
-	topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack
-	topicIDs := util.SplitNoEmpty(topicsStr, ",")
-	topics, err := s.topicsFromIDs(topicIDs...)
+	topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	poll := readParam(r, "x-poll", "poll", "po") == "1"
-	scheduled := readParam(r, "x-scheduled", "scheduled", "sched") == "1"
-	since, err := parseSince(r, poll)
-	if err != nil {
-		return err
-	}
-	messageFilter, titleFilter, priorityFilter, tagsFilter, err := parseQueryFilters(r)
+	poll, since, scheduled, filters, err := parseSubscribeParams(r)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	var wlock sync.Mutex
 	var wlock sync.Mutex
 	sub := func(msg *message) error {
 	sub := func(msg *message) error {
-		if !passesQueryFilter(msg, messageFilter, titleFilter, priorityFilter, tagsFilter) {
+		if !filters.Pass(msg) {
 			return nil
 			return nil
 		}
 		}
 		m, err := encoder(msg)
 		m, err := encoder(msg)
@@ -805,42 +716,119 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visi
 	}
 	}
 }
 }
 
 
-func parseQueryFilters(r *http.Request) (messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string, err error) {
-	messageFilter = readParam(r, "x-message", "message", "m")
-	titleFilter = readParam(r, "x-title", "title", "t")
-	tagsFilter = util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",")
-	priorityFilter = make([]int, 0)
-	for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") {
-		priority, err := util.ParsePriority(p)
-		if err != nil {
-			return "", "", nil, nil, err
+func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	if err := v.SubscriptionAllowed(); err != nil {
+		return errHTTPTooManyRequestsLimitSubscriptions
+	}
+	defer v.RemoveSubscription()
+	topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
+	if err != nil {
+		return err
+	}
+	poll, since, scheduled, filters, err := parseSubscribeParams(r)
+	if err != nil {
+		return err
+	}
+	upgrader := &websocket.Upgrader{
+		ReadBufferSize:  wsBufferSize,
+		WriteBufferSize: wsBufferSize,
+		CheckOrigin: func(r *http.Request) bool {
+			return true // We're open for business!
+		},
+	}
+	conn, err := upgrader.Upgrade(w, r, nil)
+	if err != nil {
+		return err
+	}
+	defer conn.Close()
+	var wlock sync.Mutex
+	g, ctx := errgroup.WithContext(context.Background())
+	g.Go(func() error {
+		pongWait := s.config.KeepaliveInterval + wsPongWait
+		conn.SetReadLimit(wsReadLimit)
+		if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
+			return err
+		}
+		conn.SetPongHandler(func(appData string) error {
+			return conn.SetReadDeadline(time.Now().Add(pongWait))
+		})
+		for {
+			_, _, err := conn.NextReader()
+			if err != nil {
+				return err
+			}
+		}
+	})
+	g.Go(func() error {
+		ping := func() error {
+			wlock.Lock()
+			defer wlock.Unlock()
+			if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
+				return err
+			}
+			return conn.WriteMessage(websocket.PingMessage, nil)
 		}
 		}
-		priorityFilter = append(priorityFilter, priority)
+		for {
+			select {
+			case <-ctx.Done():
+				return nil
+			case <-time.After(s.config.KeepaliveInterval):
+				v.Keepalive()
+				if err := ping(); err != nil {
+					return err
+				}
+			}
+		}
+	})
+	sub := func(msg *message) error {
+		if !filters.Pass(msg) {
+			return nil
+		}
+		wlock.Lock()
+		defer wlock.Unlock()
+		if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
+			return err
+		}
+		return conn.WriteJSON(msg)
 	}
 	}
-	return
-}
-
-func passesQueryFilter(msg *message, messageFilter string, titleFilter string, priorityFilter []int, tagsFilter []string) bool {
-	if msg.Event != messageEvent {
-		return true // filters only apply to messages
+	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+	if poll {
+		return s.sendOldMessages(topics, since, scheduled, sub)
 	}
 	}
-	if messageFilter != "" && msg.Message != messageFilter {
-		return false
+	subscriberIDs := make([]int, 0)
+	for _, t := range topics {
+		subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
 	}
 	}
-	if titleFilter != "" && msg.Title != titleFilter {
-		return false
+	defer func() {
+		for i, subscriberID := range subscriberIDs {
+			topics[i].Unsubscribe(subscriberID) // Order!
+		}
+	}()
+	if err := sub(newOpenMessage(topicsStr)); err != nil { // Send out open message
+		return err
+	}
+	if err := s.sendOldMessages(topics, since, scheduled, sub); err != nil {
+		return err
 	}
 	}
-	messagePriority := msg.Priority
-	if messagePriority == 0 {
-		messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0)
+	err = g.Wait()
+	if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
+		return nil // Normal closures are not errors
 	}
 	}
-	if len(priorityFilter) > 0 && !util.InIntList(priorityFilter, messagePriority) {
-		return false
+	return err
+}
+
+func parseSubscribeParams(r *http.Request) (poll bool, since sinceTime, scheduled bool, filters *queryFilter, err error) {
+	poll = readBoolParam(r, false, "x-poll", "poll", "po")
+	scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
+	since, err = parseSince(r, poll)
+	if err != nil {
+		return
 	}
 	}
-	if len(tagsFilter) > 0 && !util.InStringListAll(msg.Tags, tagsFilter) {
-		return false
+	filters, err = parseQueryFilters(r)
+	if err != nil {
+		return
 	}
 	}
-	return true
+	return
 }
 }
 
 
 func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error {
 func (s *Server) sendOldMessages(topics []*topic, since sinceTime, scheduled bool, sub subscriber) error {
@@ -901,6 +889,19 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
 	return topics[0], nil
 	return topics[0], nil
 }
 }
 
 
+func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
+	parts := strings.Split(path, "/")
+	if len(parts) < 2 {
+		return nil, "", errHTTPBadRequestTopicInvalid
+	}
+	topicIDs := util.SplitNoEmpty(parts[1], ",")
+	topics, err := s.topicsFromIDs(topicIDs...)
+	if err != nil {
+		return nil, "", errHTTPBadRequestTopicInvalid
+	}
+	return topics, parts[1], nil
+}
+
 func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
 func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
 	s.mu.Lock()
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	defer s.mu.Unlock()
@@ -1101,9 +1102,3 @@ func (s *Server) visitor(r *http.Request) *visitor {
 	v.Keepalive()
 	v.Keepalive()
 	return v
 	return v
 }
 }
-
-func (s *Server) inc(counter *int64) {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-	*counter++
-}

+ 3 - 2
server/server.yml

@@ -6,8 +6,9 @@
 # base-url:
 # base-url:
 
 
 # Listen address for the HTTP & HTTPS web server. If "listen-https" is set, you must also
 # Listen address for the HTTP & HTTPS web server. If "listen-https" is set, you must also
-# set "key-file" and "cert-file". Format: <hostname>:<port>
+# set "key-file" and "cert-file". Format: [<ip>]:<port>, e.g. "1.2.3.4:8080".
 #
 #
+# To listen on all interfaces, you may omit the IP address, e.g. ":443".
 # To disable HTTP, set "listen-http" to "-".
 # To disable HTTP, set "listen-http" to "-".
 #
 #
 # listen-http: ":80"
 # listen-http: ":80"
@@ -98,7 +99,7 @@
 #
 #
 # Note that the Android app has a hardcoded timeout at 77s, so it should be less than that.
 # Note that the Android app has a hardcoded timeout at 77s, so it should be less than that.
 #
 #
-# keepalive-interval: "30s"
+# keepalive-interval: "45s"
 
 
 # Interval in which the manager prunes old messages, deletes topics
 # Interval in which the manager prunes old messages, deletes topics
 # and prints the stats.
 # and prints the stats.

+ 0 - 58
server/server_test.go

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"bufio"
 	"context"
 	"context"
 	"encoding/json"
 	"encoding/json"
-	"firebase.google.com/go/messaging"
 	"fmt"
 	"fmt"
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/require"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
@@ -624,63 +623,6 @@ func TestServer_UnifiedPushDiscovery(t *testing.T) {
 	require.Equal(t, `{"unifiedpush":{"version":1}}`+"\n", response.Body.String())
 	require.Equal(t, `{"unifiedpush":{"version":1}}`+"\n", response.Body.String())
 }
 }
 
 
-func TestServer_MaybeTruncateFCMMessage(t *testing.T) {
-	origMessage := strings.Repeat("this is a long string", 300)
-	origFCMMessage := &messaging.Message{
-		Topic: "mytopic",
-		Data: map[string]string{
-			"id":       "abcdefg",
-			"time":     "1641324761",
-			"event":    "message",
-			"topic":    "mytopic",
-			"priority": "0",
-			"tags":     "",
-			"title":    "",
-			"message":  origMessage,
-		},
-		Android: &messaging.AndroidConfig{
-			Priority: "high",
-		},
-	}
-	origMessageLength := len(origFCMMessage.Data["message"])
-	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
-	require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
-
-	truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
-	truncatedMessageLength := len(truncatedFCMMessage.Data["message"])
-	serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage)
-	require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage))
-	require.Equal(t, "1", truncatedFCMMessage.Data["truncated"])
-	require.NotEqual(t, origMessageLength, truncatedMessageLength)
-}
-
-func TestServer_MaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
-	origMessage := "not really a long string"
-	origFCMMessage := &messaging.Message{
-		Topic: "mytopic",
-		Data: map[string]string{
-			"id":       "abcdefg",
-			"time":     "1641324761",
-			"event":    "message",
-			"topic":    "mytopic",
-			"priority": "0",
-			"tags":     "",
-			"title":    "",
-			"message":  origMessage,
-		},
-	}
-	origMessageLength := len(origFCMMessage.Data["message"])
-	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
-	require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
-
-	notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
-	notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"])
-	serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage)
-	require.Equal(t, origMessageLength, notTruncatedMessageLength)
-	require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
-	require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
-}
-
 func TestServer_PublishAttachment(t *testing.T) {
 func TestServer_PublishAttachment(t *testing.T) {
 	content := util.RandomString(5000) // > 4096
 	content := util.RandomString(5000) // > 4096
 	s := newTestServer(t, newTestConfig(t))
 	s := newTestServer(t, newTestConfig(t))

+ 70 - 0
server/message.go → server/types.go

@@ -2,6 +2,7 @@ package server
 
 
 import (
 import (
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
+	"net/http"
 	"time"
 	"time"
 )
 )
 
 
@@ -70,3 +71,72 @@ func newKeepaliveMessage(topic string) *message {
 func newDefaultMessage(topic, msg string) *message {
 func newDefaultMessage(topic, msg string) *message {
 	return newMessage(messageEvent, topic, msg)
 	return newMessage(messageEvent, topic, msg)
 }
 }
+
+type sinceTime time.Time
+
+func (t sinceTime) IsAll() bool {
+	return t == sinceAllMessages
+}
+
+func (t sinceTime) IsNone() bool {
+	return t == sinceNoMessages
+}
+
+func (t sinceTime) Time() time.Time {
+	return time.Time(t)
+}
+
+var (
+	sinceAllMessages = sinceTime(time.Unix(0, 0))
+	sinceNoMessages  = sinceTime(time.Unix(1, 0))
+)
+
+type queryFilter struct {
+	Message  string
+	Title    string
+	Tags     []string
+	Priority []int
+}
+
+func parseQueryFilters(r *http.Request) (*queryFilter, error) {
+	messageFilter := readParam(r, "x-message", "message", "m")
+	titleFilter := readParam(r, "x-title", "title", "t")
+	tagsFilter := util.SplitNoEmpty(readParam(r, "x-tags", "tags", "tag", "ta"), ",")
+	priorityFilter := make([]int, 0)
+	for _, p := range util.SplitNoEmpty(readParam(r, "x-priority", "priority", "prio", "p"), ",") {
+		priority, err := util.ParsePriority(p)
+		if err != nil {
+			return nil, err
+		}
+		priorityFilter = append(priorityFilter, priority)
+	}
+	return &queryFilter{
+		Message:  messageFilter,
+		Title:    titleFilter,
+		Tags:     tagsFilter,
+		Priority: priorityFilter,
+	}, nil
+}
+
+func (q *queryFilter) Pass(msg *message) bool {
+	if msg.Event != messageEvent {
+		return true // filters only apply to messages
+	}
+	if q.Message != "" && msg.Message != q.Message {
+		return false
+	}
+	if q.Title != "" && msg.Title != q.Title {
+		return false
+	}
+	messagePriority := msg.Priority
+	if messagePriority == 0 {
+		messagePriority = 3 // For query filters, default priority (3) is the same as "not set" (0)
+	}
+	if len(q.Priority) > 0 && !util.InIntList(q.Priority, messagePriority) {
+		return false
+	}
+	if len(q.Tags) > 0 && !util.InStringListAll(msg.Tags, q.Tags) {
+		return false
+	}
+	return true
+}

+ 55 - 0
server/util.go

@@ -0,0 +1,55 @@
+package server
+
+import (
+	"encoding/json"
+	"firebase.google.com/go/messaging"
+	"net/http"
+	"strings"
+)
+
+const (
+	fcmMessageLimit = 4000
+)
+
+// maybeTruncateFCMMessage performs best-effort truncation of FCM messages.
+// The docs say the limit is 4000 characters, but during testing it wasn't quite clear
+// what fields matter; so we're just capping the serialized JSON to 4000 bytes.
+func maybeTruncateFCMMessage(m *messaging.Message) *messaging.Message {
+	s, err := json.Marshal(m)
+	if err != nil {
+		return m
+	}
+	if len(s) > fcmMessageLimit {
+		over := len(s) - fcmMessageLimit + 16 // = len("truncated":"1",), sigh ...
+		message, ok := m.Data["message"]
+		if ok && len(message) > over {
+			m.Data["truncated"] = "1"
+			m.Data["message"] = message[:len(message)-over]
+		}
+	}
+	return m
+}
+
+func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
+	value := strings.ToLower(readParam(r, names...))
+	if value == "" {
+		return defaultValue
+	}
+	return value == "1" || value == "yes" || value == "true"
+}
+
+func readParam(r *http.Request, names ...string) string {
+	for _, name := range names {
+		value := r.Header.Get(name)
+		if value != "" {
+			return strings.TrimSpace(value)
+		}
+	}
+	for _, name := range names {
+		value := r.URL.Query().Get(strings.ToLower(name))
+		if value != "" {
+			return strings.TrimSpace(value)
+		}
+	}
+	return ""
+}

+ 66 - 0
server/util_test.go

@@ -0,0 +1,66 @@
+package server
+
+import (
+	"encoding/json"
+	"firebase.google.com/go/messaging"
+	"github.com/stretchr/testify/require"
+	"strings"
+	"testing"
+)
+
+func TestMaybeTruncateFCMMessage(t *testing.T) {
+	origMessage := strings.Repeat("this is a long string", 300)
+	origFCMMessage := &messaging.Message{
+		Topic: "mytopic",
+		Data: map[string]string{
+			"id":       "abcdefg",
+			"time":     "1641324761",
+			"event":    "message",
+			"topic":    "mytopic",
+			"priority": "0",
+			"tags":     "",
+			"title":    "",
+			"message":  origMessage,
+		},
+		Android: &messaging.AndroidConfig{
+			Priority: "high",
+		},
+	}
+	origMessageLength := len(origFCMMessage.Data["message"])
+	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
+	require.Greater(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
+
+	truncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
+	truncatedMessageLength := len(truncatedFCMMessage.Data["message"])
+	serializedTruncatedFCMMessage, _ := json.Marshal(truncatedFCMMessage)
+	require.Equal(t, fcmMessageLimit, len(serializedTruncatedFCMMessage))
+	require.Equal(t, "1", truncatedFCMMessage.Data["truncated"])
+	require.NotEqual(t, origMessageLength, truncatedMessageLength)
+}
+
+func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
+	origMessage := "not really a long string"
+	origFCMMessage := &messaging.Message{
+		Topic: "mytopic",
+		Data: map[string]string{
+			"id":       "abcdefg",
+			"time":     "1641324761",
+			"event":    "message",
+			"topic":    "mytopic",
+			"priority": "0",
+			"tags":     "",
+			"title":    "",
+			"message":  origMessage,
+		},
+	}
+	origMessageLength := len(origFCMMessage.Data["message"])
+	serializedOrigFCMMessage, _ := json.Marshal(origFCMMessage)
+	require.LessOrEqual(t, len(serializedOrigFCMMessage), fcmMessageLimit) // Pre-condition
+
+	notTruncatedFCMMessage := maybeTruncateFCMMessage(origFCMMessage)
+	notTruncatedMessageLength := len(notTruncatedFCMMessage.Data["message"])
+	serializedNotTruncatedFCMMessage, _ := json.Marshal(notTruncatedFCMMessage)
+	require.Equal(t, origMessageLength, notTruncatedMessageLength)
+	require.Equal(t, len(serializedOrigFCMMessage), len(serializedNotTruncatedFCMMessage))
+	require.Equal(t, "", notTruncatedFCMMessage.Data["truncated"])
+}

+ 3 - 0
test/server.go

@@ -5,6 +5,7 @@ import (
 	"heckel.io/ntfy/server"
 	"heckel.io/ntfy/server"
 	"math/rand"
 	"math/rand"
 	"net/http"
 	"net/http"
+	"path/filepath"
 	"testing"
 	"testing"
 	"time"
 	"time"
 )
 )
@@ -22,6 +23,8 @@ func StartServer(t *testing.T) (*server.Server, int) {
 func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, int) {
 func StartServerWithConfig(t *testing.T, conf *server.Config) (*server.Server, int) {
 	port := 10000 + rand.Intn(20000)
 	port := 10000 + rand.Intn(20000)
 	conf.ListenHTTP = fmt.Sprintf(":%d", port)
 	conf.ListenHTTP = fmt.Sprintf(":%d", port)
+	conf.AttachmentCacheDir = t.TempDir()
+	conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
 	s, err := server.New(conf)
 	s, err := server.New(conf)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)