Sfoglia il codice sorgente

Allow HEAD requests for file attachments

Philipp Heckel 3 anni fa
parent
commit
2b42cea1a3
2 ha cambiato i file con 22 aggiunte e 9 eliminazioni
  1. 15 9
      server/server.go
  2. 7 0
      server/server_test.go

+ 15 - 9
server/server.go

@@ -8,7 +8,6 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"fmt"
-	"heckel.io/ntfy/log"
 	"io"
 	"net"
 	"net/http"
@@ -23,6 +22,8 @@ import (
 	"time"
 	"unicode/utf8"
 
+	"heckel.io/ntfy/log"
+
 	"github.com/emersion/go-smtp"
 	"github.com/gorilla/websocket"
 	"golang.org/x/sync/errgroup"
@@ -289,7 +290,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 		return s.ensureWebEnabled(s.handleStatic)(w, r, v)
 	} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
 		return s.ensureWebEnabled(s.handleDocs)(w, r, v)
-	} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
+	} else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
 		return s.limitRequests(s.handleFile)(w, r, v)
 	} else if r.Method == http.MethodOptions {
 		return s.ensureWebEnabled(s.handleOptions)(w, r, v)
@@ -405,18 +406,23 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
 	if err != nil {
 		return errHTTPNotFound
 	}
-	if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
-		return errHTTPTooManyRequestsAttachmentBandwidthLimit
+	if r.Method == http.MethodGet {
+		if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
+			return errHTTPTooManyRequestsAttachmentBandwidthLimit
+		}
 	}
 	w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
 	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
-	f, err := os.Open(file)
-	if err != nil {
+	if r.Method == http.MethodGet {
+		f, err := os.Open(file)
+		if err != nil {
+			return err
+		}
+		defer f.Close()
+		_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
 		return err
 	}
-	defer f.Close()
-	_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
-	return err
+	return nil
 }
 
 func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {

+ 7 - 0
server/server_test.go

@@ -1026,12 +1026,19 @@ func TestServer_PublishAttachment(t *testing.T) {
 	require.Equal(t, "", msg.Sender) // Should never be returned
 	require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
 
+	// GET
 	path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
 	response = request(t, s, "GET", path, "", nil)
 	require.Equal(t, 200, response.Code)
 	require.Equal(t, "5000", response.Header().Get("Content-Length"))
 	require.Equal(t, content, response.Body.String())
 
+	// HEAD
+	response = request(t, s, "HEAD", path, "", nil)
+	require.Equal(t, 200, response.Code)
+	require.Equal(t, "5000", response.Header().Get("Content-Length"))
+	require.Equal(t, "", response.Body.String())
+
 	// Slightly unrelated cross-test: make sure we add an owner for internal attachments
 	size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request()
 	require.Nil(t, err)