|
|
@@ -8,7 +8,6 @@ import (
|
|
|
"fmt"
|
|
|
"heckel.io/ntfy/user"
|
|
|
"io"
|
|
|
- "log"
|
|
|
"math/rand"
|
|
|
"net/http"
|
|
|
"net/http/httptest"
|
|
|
@@ -22,9 +21,14 @@ import (
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
+ "heckel.io/ntfy/log"
|
|
|
"heckel.io/ntfy/util"
|
|
|
)
|
|
|
|
|
|
+func init() {
|
|
|
+ // log.SetLevel(log.DebugLevel)
|
|
|
+}
|
|
|
+
|
|
|
func TestServer_PublishAndPoll(t *testing.T) {
|
|
|
s := newTestServer(t, newTestConfig(t))
|
|
|
|
|
|
@@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
|
|
|
require.Equal(t, 401, response.Code)
|
|
|
}
|
|
|
|
|
|
-func TestServer_StatsResetter(t *testing.T) {
|
|
|
+func TestServer_StatsResetter_User_Without_Tier(t *testing.T) {
|
|
|
+ // This tests the stats resetter for
|
|
|
+ // - an anonymous user
|
|
|
+ // - a user without a tier (treated like the same as the anonymous user)
|
|
|
+ // - a user with a tier
|
|
|
+
|
|
|
c := newTestConfigWithAuthFile(t)
|
|
|
- c.AuthDefault = user.PermissionDenyAll
|
|
|
c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
|
|
|
s := newTestServer(t, c)
|
|
|
go s.runStatsResetter()
|
|
|
|
|
|
+ // Create user with tier (tieruser) and user without tier (phil)
|
|
|
+ require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
|
|
+ Code: "test",
|
|
|
+ MessageLimit: 5,
|
|
|
+ MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
|
|
+ }))
|
|
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
|
|
- require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
|
|
|
+ require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
|
|
|
+ require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
|
|
|
+
|
|
|
+ // Send an anonymous message
|
|
|
+ response := request(t, s, "PUT", "/mytopic", "test", nil)
|
|
|
|
|
|
+ // Send messages from user without tier (phil)
|
|
|
for i := 0; i < 5; i++ {
|
|
|
response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
|
|
"Authorization": util.BasicAuth("phil", "phil"),
|
|
|
@@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) {
|
|
|
require.Equal(t, 200, response.Code)
|
|
|
}
|
|
|
|
|
|
- response := request(t, s, "GET", "/v1/account", "", map[string]string{
|
|
|
+ // Send messages from user with tier
|
|
|
+ for i := 0; i < 2; i++ {
|
|
|
+ response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
|
|
|
+ "Authorization": util.BasicAuth("tieruser", "tieruser"),
|
|
|
+ })
|
|
|
+ require.Equal(t, 200, response.Code)
|
|
|
+ }
|
|
|
+
|
|
|
+ // User stats show 6 messages (for user without tier)
|
|
|
+ response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
|
|
"Authorization": util.BasicAuth("phil", "phil"),
|
|
|
})
|
|
|
require.Equal(t, 200, response.Code)
|
|
|
+ account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
|
|
+ require.Nil(t, err)
|
|
|
+ require.Equal(t, int64(6), account.Stats.Messages)
|
|
|
+
|
|
|
+ // User stats show 6 messages (for anonymous visitor)
|
|
|
+ response = request(t, s, "GET", "/v1/account", "", nil)
|
|
|
+ require.Equal(t, 200, response.Code)
|
|
|
+ account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
|
|
+ require.Nil(t, err)
|
|
|
+ require.Equal(t, int64(6), account.Stats.Messages)
|
|
|
|
|
|
- // User stats show 10 messages
|
|
|
+ // User stats show 2 messages (for user with tier)
|
|
|
response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
|
|
- "Authorization": util.BasicAuth("phil", "phil"),
|
|
|
+ "Authorization": util.BasicAuth("tieruser", "tieruser"),
|
|
|
})
|
|
|
require.Equal(t, 200, response.Code)
|
|
|
- account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
|
|
+ account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
|
|
require.Nil(t, err)
|
|
|
- require.Equal(t, int64(5), account.Stats.Messages)
|
|
|
+ require.Equal(t, int64(2), account.Stats.Messages)
|
|
|
|
|
|
// Wait for stats resetter to run
|
|
|
time.Sleep(2200 * time.Millisecond)
|
|
|
|
|
|
// User stats show 0 messages now!
|
|
|
+ response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
|
|
+ "Authorization": util.BasicAuth("phil", "phil"),
|
|
|
+ })
|
|
|
+ require.Equal(t, 200, response.Code)
|
|
|
+ account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
|
|
+ require.Nil(t, err)
|
|
|
+ require.Equal(t, int64(0), account.Stats.Messages)
|
|
|
+
|
|
|
+ // Since this is a user without a tier, the anonymous user should have the same stats
|
|
|
response = request(t, s, "GET", "/v1/account", "", nil)
|
|
|
require.Equal(t, 200, response.Code)
|
|
|
account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
|
|
require.Nil(t, err)
|
|
|
require.Equal(t, int64(0), account.Stats.Messages)
|
|
|
|
|
|
+ // User stats show 0 messages (for user with tier)
|
|
|
+ response = request(t, s, "GET", "/v1/account", "", map[string]string{
|
|
|
+ "Authorization": util.BasicAuth("tieruser", "tieruser"),
|
|
|
+ })
|
|
|
+ require.Equal(t, 200, response.Code)
|
|
|
+ account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
|
|
|
+ require.Nil(t, err)
|
|
|
+ require.Equal(t, int64(0), account.Stats.Messages)
|
|
|
}
|
|
|
|
|
|
type testMailer struct {
|
|
|
@@ -1133,9 +1188,9 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
|
|
|
|
|
|
// Create tier with certain limits
|
|
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
|
|
- Code: "test",
|
|
|
- MessagesLimit: 5,
|
|
|
- MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
|
|
|
+ Code: "test",
|
|
|
+ MessageLimit: 5,
|
|
|
+ MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
|
|
|
}))
|
|
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
|
|
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
|
|
@@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
|
|
|
sevenDays := time.Duration(604800) * time.Second
|
|
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
|
|
Code: "test",
|
|
|
- MessagesLimit: 10,
|
|
|
- MessagesExpiryDuration: sevenDays,
|
|
|
+ MessageLimit: 10,
|
|
|
+ MessageExpiryDuration: sevenDays,
|
|
|
AttachmentFileSizeLimit: 50_000,
|
|
|
AttachmentTotalSizeLimit: 200_000,
|
|
|
AttachmentExpiryDuration: sevenDays, // 7 days
|
|
|
@@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
|
|
|
// Create tier with certain limits
|
|
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
|
|
Code: "test",
|
|
|
- MessagesLimit: 10,
|
|
|
- MessagesExpiryDuration: time.Hour,
|
|
|
+ MessageLimit: 10,
|
|
|
+ MessageExpiryDuration: time.Hour,
|
|
|
AttachmentFileSizeLimit: 50_000,
|
|
|
AttachmentTotalSizeLimit: 200_000,
|
|
|
AttachmentExpiryDuration: time.Hour,
|
|
|
@@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
|
|
|
// Create tier with certain limits
|
|
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
|
|
Code: "test",
|
|
|
- MessagesLimit: 100,
|
|
|
+ MessageLimit: 100,
|
|
|
AttachmentFileSizeLimit: 50_000,
|
|
|
AttachmentTotalSizeLimit: 200_000,
|
|
|
AttachmentExpiryDuration: 30 * time.Second,
|
|
|
@@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
|
|
|
r, _ := http.NewRequest("GET", "/bla", nil)
|
|
|
r.RemoteAddr = "8.9.10.11"
|
|
|
r.Header.Set("X-Forwarded-For", " ") // Spaces, not empty!
|
|
|
- v, err := s.visitor(r)
|
|
|
+ v, err := s.maybeAuthenticate(r)
|
|
|
require.Nil(t, err)
|
|
|
require.Equal(t, "8.9.10.11", v.ip.String())
|
|
|
}
|
|
|
@@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
|
|
|
r, _ := http.NewRequest("GET", "/bla", nil)
|
|
|
r.RemoteAddr = "8.9.10.11"
|
|
|
r.Header.Set("X-Forwarded-For", "1.1.1.1")
|
|
|
- v, err := s.visitor(r)
|
|
|
+ v, err := s.maybeAuthenticate(r)
|
|
|
require.Nil(t, err)
|
|
|
require.Equal(t, "1.1.1.1", v.ip.String())
|
|
|
}
|
|
|
@@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
|
|
|
r, _ := http.NewRequest("GET", "/bla", nil)
|
|
|
r.RemoteAddr = "8.9.10.11"
|
|
|
r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
|
|
|
- v, err := s.visitor(r)
|
|
|
+ v, err := s.maybeAuthenticate(r)
|
|
|
require.Nil(t, err)
|
|
|
require.Equal(t, "234.5.2.1", v.ip.String())
|
|
|
}
|
|
|
@@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
|
|
s := newTestServer(t, c)
|
|
|
|
|
|
// Add lots of messages
|
|
|
- log.Printf("Adding %d messages", count)
|
|
|
+ log.Info("Adding %d messages", count)
|
|
|
start := time.Now()
|
|
|
messages := make([]*message, 0)
|
|
|
for i := 0; i < count; i++ {
|
|
|
@@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
|
|
|
messages = append(messages, newDefaultMessage(topicID, "some message"))
|
|
|
}
|
|
|
require.Nil(t, s.messageCache.addMessages(messages))
|
|
|
- log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
|
|
+ log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
|
|
|
|
|
|
// Update stats
|
|
|
statsChan := make(chan bool)
|
|
|
go func() {
|
|
|
- log.Printf("Updating stats")
|
|
|
+ log.Info("Updating stats")
|
|
|
start := time.Now()
|
|
|
s.execManager()
|
|
|
- log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
|
|
+ log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
|
|
|
statsChan <- true
|
|
|
}()
|
|
|
time.Sleep(50 * time.Millisecond) // Make sure it starts first
|
|
|
|
|
|
// Publish message (during stats update)
|
|
|
- log.Printf("Publishing message")
|
|
|
+ log.Info("Publishing message")
|
|
|
start = time.Now()
|
|
|
response := request(t, s, "PUT", "/mytopic", "some body", nil)
|
|
|
m := toMessage(t, response.Body.String())
|
|
|
assert.Equal(t, "some body", m.Message)
|
|
|
assert.True(t, time.Since(start) < 100*time.Millisecond)
|
|
|
- log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
|
|
+ log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
|
|
|
|
|
|
// Wait for all goroutines
|
|
|
<-statsChan
|
|
|
- log.Printf("Done: Waiting for all locks")
|
|
|
+ log.Info("Done: Waiting for all locks")
|
|
|
}
|
|
|
|
|
|
func newTestConfig(t *testing.T) *Config {
|