Kaynağa Gözat

Tests, client tests WIP

Philipp Heckel 4 yıl önce
ebeveyn
işleme
6a7e9071b6

+ 42 - 0
client/client_test.go

@@ -0,0 +1,42 @@
+package client_test
+
+import (
+	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/client"
+	"heckel.io/ntfy/server"
+	"net/http"
+	"testing"
+	"time"
+)
+
+func TestClient_Publish(t *testing.T) {
+	s := startTestServer(t)
+	defer s.Stop()
+	c := client.New(newTestConfig())
+
+	time.Sleep(time.Second) // FIXME Wait for port up
+
+	_, err := c.Publish("mytopic", "some message")
+	require.Nil(t, err)
+}
+
+func newTestConfig() *client.Config {
+	c := client.NewConfig()
+	c.DefaultHost = "http://127.0.0.1:12345"
+	return c
+}
+
+func startTestServer(t *testing.T) *server.Server {
+	conf := server.NewConfig()
+	conf.ListenHTTP = ":12345"
+	s, err := server.New(conf)
+	if err != nil {
+		t.Fatal(err)
+	}
+	go func() {
+		if err := s.Run(); err != nil && err != http.ErrServerClosed {
+			panic(err) // 'go vet' complains about 't.Fatal(err)'
+		}
+	}()
+	return s
+}

+ 3 - 2
client/config_test.go

@@ -1,7 +1,8 @@
-package client
+package client_test
 
 import (
 	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/client"
 	"os"
 	"path/filepath"
 	"testing"
@@ -21,7 +22,7 @@ subscribe:
             priority: high,urgent
 `), 0600))
 
-	conf, err := LoadConfig(filename)
+	conf, err := client.LoadConfig(filename)
 	require.Nil(t, err)
 	require.Equal(t, "http://localhost", conf.DefaultHost)
 	require.Equal(t, 3, len(conf.Subscribe))

+ 2 - 1
cmd/serve.go

@@ -85,7 +85,8 @@ func execServe(c *cli.Context) error {
 	}
 
 	// Run server
-	conf := server.NewConfig(listenHTTP)
+	conf := server.NewConfig()
+	conf.ListenHTTP = listenHTTP
 	conf.ListenHTTPS = listenHTTPS
 	conf.KeyFile = keyFile
 	conf.CertFile = certFile

+ 2 - 2
server/config.go

@@ -52,9 +52,9 @@ type Config struct {
 }
 
 // NewConfig instantiates a default new server config
-func NewConfig(listenHTTP string) *Config {
+func NewConfig() *Config {
 	return &Config{
-		ListenHTTP:                   listenHTTP,
+		ListenHTTP:                   DefaultListenHTTP,
 		ListenHTTPS:                  "",
 		KeyFile:                      "",
 		CertFile:                     "",

+ 3 - 2
server/config_test.go

@@ -7,6 +7,7 @@ import (
 )
 
 func TestConfig_New(t *testing.T) {
-	c := server.NewConfig(":1234")
-	assert.Equal(t, ":1234", c.ListenHTTP)
+	c := server.NewConfig()
+	assert.Equal(t, ":80", c.ListenHTTP)
+	assert.Equal(t, server.DefaultKeepaliveInterval, c.KeepaliveInterval)
 }

+ 51 - 22
server/server.go

@@ -27,13 +27,16 @@ import (
 
 // Server is the main server, providing the UI and API for ntfy
 type Server struct {
-	config   *Config
-	topics   map[string]*topic
-	visitors map[string]*visitor
-	firebase subscriber
-	messages int64
-	cache    cache
-	mu       sync.Mutex
+	config      *Config
+	httpServer  *http.Server
+	httpsServer *http.Server
+	topics      map[string]*topic
+	visitors    map[string]*visitor
+	firebase    subscriber
+	messages    int64
+	cache       cache
+	closeChan   chan bool
+	mu          sync.Mutex
 }
 
 // errHTTP is a generic HTTP error for any non-200 HTTP error
@@ -198,17 +201,35 @@ func (s *Server) Run() error {
 	log.Printf("Listening on %s", listenStr)
 	http.HandleFunc("/", s.handle)
 	errChan := make(chan error)
+	s.mu.Lock()
+	s.closeChan = make(chan bool)
+	s.httpServer = &http.Server{Addr: s.config.ListenHTTP}
 	go func() {
-		errChan <- http.ListenAndServe(s.config.ListenHTTP, nil)
+		errChan <- s.httpServer.ListenAndServe()
 	}()
 	if s.config.ListenHTTPS != "" {
+		s.httpsServer = &http.Server{Addr: s.config.ListenHTTP}
 		go func() {
-			errChan <- http.ListenAndServeTLS(s.config.ListenHTTPS, s.config.CertFile, s.config.KeyFile, nil)
+			errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
 		}()
 	}
+	s.mu.Unlock()
 	return <-errChan
 }
 
+// Stop stops HTTP (+HTTPS) server and all managers
+func (s *Server) Stop() {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.httpServer != nil {
+		s.httpServer.Close()
+	}
+	if s.httpsServer != nil {
+		s.httpsServer.Close()
+	}
+	close(s.closeChan)
+}
+
 func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 	if err := s.handleInternal(w, r); err != nil {
 		if e, ok := err.(*errHTTP); ok {
@@ -635,21 +656,25 @@ func (s *Server) updateStatsAndPrune() {
 }
 
 func (s *Server) runManager() {
-	func() {
-		ticker := time.NewTicker(s.config.ManagerInterval)
-		for {
-			<-ticker.C
+	for {
+		select {
+		case <-time.After(s.config.ManagerInterval):
 			s.updateStatsAndPrune()
+		case <-s.closeChan:
+			return
 		}
-	}()
+	}
 }
 
 func (s *Server) runAtSender() {
-	ticker := time.NewTicker(s.config.AtSenderInterval)
 	for {
-		<-ticker.C
-		if err := s.sendDelayedMessages(); err != nil {
-			log.Printf("error sending scheduled messages: %s", err.Error())
+		select {
+		case <-time.After(s.config.AtSenderInterval):
+			if err := s.sendDelayedMessages(); err != nil {
+				log.Printf("error sending scheduled messages: %s", err.Error())
+			}
+		case <-s.closeChan:
+			return
 		}
 	}
 }
@@ -658,14 +683,18 @@ func (s *Server) runFirebaseKeepliver() {
 	if s.firebase == nil {
 		return
 	}
-	ticker := time.NewTicker(s.config.FirebaseKeepaliveInterval)
 	for {
-		<-ticker.C
-		if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
-			log.Printf("error sending Firebase keepalive message: %s", err.Error())
+		select {
+		case <-time.After(s.config.FirebaseKeepaliveInterval):
+			if err := s.firebase(newKeepaliveMessage(firebaseControlTopic)); err != nil {
+				log.Printf("error sending Firebase keepalive message: %s", err.Error())
+			}
+		case <-s.closeChan:
+			return
 		}
 	}
 }
+
 func (s *Server) sendDelayedMessages() error {
 	s.mu.Lock()
 	defer s.mu.Unlock()

+ 1 - 1
server/server_test.go

@@ -488,7 +488,7 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) {
 }
 
 func newTestConfig(t *testing.T) *Config {
-	conf := NewConfig(":80")
+	conf := NewConfig()
 	conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
 	return conf
 }