Переглянути джерело

Add /auth endpoint and tests

Philipp Heckel 4 роки тому
батько
коміт
e61a0c2f78
2 змінених файлів з 115 додано та 2 видалено
  1. 12 2
      server/server.go
  2. 103 0
      server/server_test.go

+ 12 - 2
server/server.go

@@ -69,6 +69,7 @@ var (
 	ssePathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
 	ssePathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
 	rawPathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
 	rawPathRegex     = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
 	wsPathRegex      = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
 	wsPathRegex      = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
+	authPathRegex    = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
 	publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
 	publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`)
 
 
 	staticRegex      = regexp.MustCompile(`^/static/.+`)
 	staticRegex      = regexp.MustCompile(`^/static/.+`)
@@ -331,7 +332,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
 	} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
 	} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
 		return s.handleExample(w, r)
 		return s.handleExample(w, r)
 	} else if r.Method == http.MethodHead && r.URL.Path == "/" {
 	} else if r.Method == http.MethodHead && r.URL.Path == "/" {
-		return s.handleEmpty(w, r)
+		return s.handleEmpty(w, r, v)
 	} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
 		return s.handleStatic(w, r)
 		return s.handleStatic(w, r)
 	} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
@@ -354,6 +355,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
 		return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
 		return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
 	} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
 		return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
 		return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
+	} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
+		return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v)
 	}
 	}
 	return errHTTPNotFound
 	return errHTTPNotFound
 }
 }
@@ -376,10 +379,17 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error {
 	return s.handleHome(w, r)
 	return s.handleHome(w, r)
 }
 }
 
 
-func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request) error {
+func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
 	return nil
 	return nil
 }
 }
 
 
+func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
+	w.Header().Set("Content-Type", "application/json")
+	w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
+	_, err := io.WriteString(w, `{"success":true}`+"\n")
+	return err
+}
+
 func (s *Server) handleExample(w http.ResponseWriter, _ *http.Request) error {
 func (s *Server) handleExample(w http.ResponseWriter, _ *http.Request) error {
 	_, err := io.WriteString(w, exampleSource)
 	_, err := io.WriteString(w, exampleSource)
 	return err
 	return err

+ 103 - 0
server/server_test.go

@@ -7,6 +7,7 @@ import (
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"github.com/stretchr/testify/require"
 	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/auth"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 	"math/rand"
 	"math/rand"
 	"net/http"
 	"net/http"
@@ -524,6 +525,104 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) {
 	require.Equal(t, keepaliveEvent, messages[2].Event)
 	require.Equal(t, keepaliveEvent, messages[2].Event)
 }
 }
 
 
+func TestServer_Auth_Success_Admin(t *testing.T) {
+	c := newTestConfig(t)
+	c.AuthFile = filepath.Join(t.TempDir(), "user.db")
+	s := newTestServer(t, c)
+
+	manager := s.auth.(auth.Manager)
+	require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin))
+
+	response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
+		"Authorization": basicAuth("phil:phil"),
+	})
+	require.Equal(t, 200, response.Code)
+	require.Equal(t, `{"success":true}`+"\n", response.Body.String())
+}
+
+func TestServer_Auth_Success_User(t *testing.T) {
+	c := newTestConfig(t)
+	c.AuthFile = filepath.Join(t.TempDir(), "user.db")
+	c.AuthDefaultRead = false
+	c.AuthDefaultWrite = false
+	s := newTestServer(t, c)
+
+	manager := s.auth.(auth.Manager)
+	require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser))
+	require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true)) // Not mytopic!
+
+	response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
+		"Authorization": basicAuth("ben:ben"),
+	})
+	require.Equal(t, 200, response.Code)
+}
+
+func TestServer_Auth_Fail_InvalidPass(t *testing.T) {
+	c := newTestConfig(t)
+	c.AuthFile = filepath.Join(t.TempDir(), "user.db")
+	c.AuthDefaultRead = false
+	c.AuthDefaultWrite = false
+	s := newTestServer(t, c)
+
+	manager := s.auth.(auth.Manager)
+	require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin))
+
+	response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
+		"Authorization": basicAuth("phil:INVALID"),
+	})
+	require.Equal(t, 401, response.Code)
+}
+
+func TestServer_Auth_Fail_Unauthorized(t *testing.T) {
+	c := newTestConfig(t)
+	c.AuthFile = filepath.Join(t.TempDir(), "user.db")
+	c.AuthDefaultRead = false
+	c.AuthDefaultWrite = false
+	s := newTestServer(t, c)
+
+	manager := s.auth.(auth.Manager)
+	require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser))
+	require.Nil(t, manager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic!
+
+	response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{
+		"Authorization": basicAuth("ben:ben"),
+	})
+	require.Equal(t, 403, response.Code)
+}
+
+func TestServer_Auth_Fail_CannotPublish(t *testing.T) {
+	c := newTestConfig(t)
+	c.AuthFile = filepath.Join(t.TempDir(), "user.db")
+	c.AuthDefaultRead = true  // Open by default
+	c.AuthDefaultWrite = true // Open by default
+	s := newTestServer(t, c)
+
+	manager := s.auth.(auth.Manager)
+	require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin))
+	require.Nil(t, manager.AllowAccess(auth.Everyone, "private", false, false))
+	require.Nil(t, manager.AllowAccess(auth.Everyone, "announcements", true, false))
+
+	response := request(t, s, "PUT", "/mytopic", "test", nil)
+	require.Equal(t, 200, response.Code)
+
+	response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
+	require.Equal(t, 200, response.Code)
+
+	response = request(t, s, "PUT", "/announcements", "test", nil)
+	require.Equal(t, 403, response.Code) // Cannot write as anonymous
+
+	response = request(t, s, "PUT", "/announcements", "test", map[string]string{
+		"Authorization": basicAuth("phil:phil"),
+	})
+	require.Equal(t, 200, response.Code)
+
+	response = request(t, s, "GET", "/announcements/json?poll=1", "", nil)
+	require.Equal(t, 200, response.Code) // Anonymous read allowed
+
+	response = request(t, s, "GET", "/private/json?poll=1", "", nil)
+	require.Equal(t, 403, response.Code) // Anonymous read not allowed
+}
+
 /*
 /*
 func TestServer_Curl_Publish_Poll(t *testing.T) {
 func TestServer_Curl_Publish_Poll(t *testing.T) {
 	s, port := test.StartServer(t)
 	s, port := test.StartServer(t)
@@ -988,3 +1087,7 @@ func firebaseServiceAccountFile(t *testing.T) string {
 	t.SkipNow()
 	t.SkipNow()
 	return ""
 	return ""
 }
 }
+
+func basicAuth(s string) string {
+	return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s)))
+}