binwiederhier 3 роки тому
батько
коміт
00af52411c

+ 4 - 4
docs/config.md

@@ -504,7 +504,7 @@ or the root domain:
         proxy_send_timeout 3m;
         proxy_read_timeout 3m;
 
-        client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml
+        client_max_body_size 0; # Stream request body to backend
       }
     }
     
@@ -540,7 +540,7 @@ or the root domain:
         proxy_send_timeout 3m;
         proxy_read_timeout 3m;
         
-        client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml
+        client_max_body_size 0; # Stream request body to backend
       }
     }
     ```
@@ -571,7 +571,7 @@ or the root domain:
         proxy_send_timeout 3m;
         proxy_read_timeout 3m;
 
-        client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml
+        client_max_body_size 0; # Stream request body to backend
       }
     }
     
@@ -603,7 +603,7 @@ or the root domain:
         proxy_send_timeout 3m;
         proxy_read_timeout 3m;
 
-        client_max_body_size 20m; # Must be >= attachment-file-size-limit in /etc/ntfy/server.yml
+        client_max_body_size 0; # Stream request body to backend
       }
     }
     ```

+ 4 - 4
server/server.go

@@ -38,7 +38,6 @@ import (
 
 - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
 - HIGH Docs
-- Large uploads for higher tiers (nginx config!)
 - MEDIUM: Test new token endpoints & never-expiring token
 - MEDIUM: Make sure account endpoints make sense for admins
 - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
@@ -1641,7 +1640,7 @@ func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
 		return nil, errHTTPUnauthorized
 	}
 	if strings.HasPrefix(value, "Bearer") {
-		return s.authenticateBearerAuth(r, value)
+		return s.authenticateBearerAuth(r, strings.TrimSpace(strings.TrimPrefix(value, "Bearer")))
 	}
 	return s.authenticateBasicAuth(r, value)
 }
@@ -1651,12 +1650,13 @@ func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *use
 	username, password, ok := r.BasicAuth()
 	if !ok {
 		return nil, errors.New("invalid basic auth")
+	} else if username == "" {
+		return s.authenticateBearerAuth(r, password) // Treat password as token
 	}
 	return s.userManager.Authenticate(username, password)
 }
 
-func (s *Server) authenticateBearerAuth(r *http.Request, value string) (*user.User, error) {
-	token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
+func (s *Server) authenticateBearerAuth(r *http.Request, token string) (*user.User, error) {
 	u, err := s.userManager.AuthenticateToken(token)
 	if err != nil {
 		return nil, err

+ 19 - 1
server/server_account_test.go

@@ -41,6 +41,13 @@ func TestAccount_Signup_Success(t *testing.T) {
 	account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
 	require.Equal(t, "phil", account.Username)
 	require.Equal(t, "user", account.Role)
+
+	rr = request(t, s, "GET", "/v1/account", "", map[string]string{
+		"Authorization": util.BasicAuth("", token.Token), // We allow a fake basic auth to make curl-ing easier (curl -u :<token>)
+	})
+	require.Equal(t, 200, rr.Code)
+	account, _ = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
+	require.Equal(t, "phil", account.Username)
 }
 
 func TestAccount_Signup_UserExists(t *testing.T) {
@@ -247,7 +254,18 @@ func TestAccount_ChangePassword(t *testing.T) {
 
 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
 
-	rr := request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
+	rr := request(t, s, "POST", "/v1/account/password", `{"password": "WRONG", "new_password": ""}`, map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 400, rr.Code)
+
+	rr = request(t, s, "POST", "/v1/account/password", `{"password": "WRONG", "new_password": "new password"}`, map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 400, rr.Code)
+	require.Equal(t, 40030, toHTTPError(t, rr.Body.String()).Code)
+
+	rr = request(t, s, "POST", "/v1/account/password", `{"password": "phil", "new_password": "new password"}`, map[string]string{
 		"Authorization": util.BasicAuth("phil", "phil"),
 	})
 	require.Equal(t, 200, rr.Code)

+ 2 - 0
server/server_payments.go

@@ -229,6 +229,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
 	sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
 	if err != nil {
 		return err
+	} else if sub.Items == nil || len(sub.Items.Data) != 1 {
+		return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "no items, or more than one item")
 	}
 	params := &stripe.SubscriptionParams{
 		CancelAtPeriodEnd: stripe.Bool(false),

+ 139 - 3
server/server_payments_test.go

@@ -304,7 +304,14 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
 			},
 		}, nil)
 	stripeMock.
-		On("UpdateCustomer", mock.Anything).
+		On("UpdateCustomer", "acct_5555", &stripe.CustomerParams{
+			Params: stripe.Params{
+				Metadata: map[string]string{
+					"user_id":   u.ID,
+					"user_name": u.Name,
+				},
+			},
+		}).
 		Return(&stripe.Customer{}, nil)
 
 	// Send messages until rate limit of free tier is hit
@@ -517,6 +524,135 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
 	require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
 }
 
+func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Define how the mock should react
+	stripeMock.
+		On("GetSubscription", "sub_123").
+		Return(&stripe.Subscription{
+			ID: "sub_123",
+			Items: &stripe.SubscriptionItemList{
+				Data: []*stripe.SubscriptionItem{
+					{
+						ID:    "someid_123",
+						Price: &stripe.Price{ID: "price_123"},
+					},
+				},
+			},
+		}, nil)
+	stripeMock.
+		On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
+			CancelAtPeriodEnd: stripe.Bool(false),
+			ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
+			Items: []*stripe.SubscriptionItemsParams{
+				{
+					ID:    stripe.String("someid_123"),
+					Price: stripe.String("price_456"),
+				},
+			},
+		}).
+		Return(&stripe.Subscription{}, nil)
+
+	// Create tier and user
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_123",
+		Code:          "pro",
+		StripePriceID: "price_123",
+	}))
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_456",
+		Code:          "business",
+		StripePriceID: "price_456",
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
+	require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
+		StripeCustomerID:     "acct_123",
+		StripeSubscriptionID: "sub_123",
+	}))
+
+	// Call endpoint to change subscription
+	rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business"}`, map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, rr.Code)
+}
+
+func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Define how the mock should react
+	stripeMock.
+		On("UpdateSubscription", "sub_123", mock.MatchedBy(func(s *stripe.SubscriptionParams) bool {
+			return *s.CancelAtPeriodEnd // Is true
+		})).
+		Return(&stripe.Subscription{}, nil)
+
+	// Create user
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
+		StripeCustomerID:     "acct_123",
+		StripeSubscriptionID: "sub_123",
+	}))
+
+	// Delete subscription
+	rr := request(t, s, "DELETE", "/v1/account/billing/subscription", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, rr.Code)
+}
+
+func TestPayments_CreatePortalSession(t *testing.T) {
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Define how the mock should react
+	stripeMock.
+		On("NewPortalSession", &stripe.BillingPortalSessionParams{
+			Customer:  stripe.String("acct_123"),
+			ReturnURL: stripe.String(s.config.BaseURL),
+		}).
+		Return(&stripe.BillingPortalSession{
+			URL: "https://billing.stripe.com/blablabla",
+		}, nil)
+
+	// Create user
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
+		StripeCustomerID:     "acct_123",
+		StripeSubscriptionID: "sub_123",
+	}))
+
+	// Create portal session
+	rr := request(t, s, "POST", "/v1/account/billing/portal", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, rr.Code)
+	ps, _ := util.UnmarshalJSON[apiAccountBillingPortalRedirectResponse](io.NopCloser(rr.Body))
+	require.Equal(t, "https://billing.stripe.com/blablabla", ps.RedirectURL)
+}
+
 type testStripeAPI struct {
 	mock.Mock
 }
@@ -554,12 +690,12 @@ func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error)
 }
 
 func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
-	args := s.Called(id)
+	args := s.Called(id, params)
 	return args.Get(0).(*stripe.Customer), args.Error(1)
 }
 
 func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
-	args := s.Called(id)
+	args := s.Called(id, params)
 	return args.Get(0).(*stripe.Subscription), args.Error(1)
 }