| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- package server
- import (
- "encoding/json"
- "github.com/stretchr/testify/mock"
- "github.com/stretchr/testify/require"
- "github.com/stripe/stripe-go/v74"
- "heckel.io/ntfy/user"
- "heckel.io/ntfy/util"
- "io"
- "path/filepath"
- "strings"
- "testing"
- "time"
- )
- func TestPayments_Tiers(t *testing.T) {
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- c.VisitorRequestLimitReplenish = 12 * time.Hour
- c.CacheDuration = 13 * time.Hour
- c.AttachmentFileSizeLimit = 111
- c.VisitorAttachmentTotalSizeLimit = 222
- c.AttachmentExpiryDuration = 123 * time.Second
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("ListPrices", mock.Anything).
- Return([]*stripe.Price{
- {ID: "price_123", UnitAmount: 500},
- {ID: "price_456", UnitAmount: 1000},
- {ID: "price_999", UnitAmount: 9999},
- }, nil)
- // Create tiers
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- ID: "ti_1",
- Code: "admin",
- Name: "Admin",
- }))
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- ID: "ti_123",
- Code: "pro",
- Name: "Pro",
- MessagesLimit: 1000,
- MessagesExpiryDuration: time.Hour,
- EmailsLimit: 123,
- ReservationsLimit: 777,
- AttachmentFileSizeLimit: 999,
- AttachmentTotalSizeLimit: 888,
- AttachmentExpiryDuration: time.Minute,
- StripePriceID: "price_123",
- }))
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- ID: "ti_444",
- Code: "business",
- Name: "Business",
- MessagesLimit: 2000,
- MessagesExpiryDuration: 10 * time.Hour,
- EmailsLimit: 123123,
- ReservationsLimit: 777333,
- AttachmentFileSizeLimit: 999111,
- AttachmentTotalSizeLimit: 888111,
- AttachmentExpiryDuration: time.Hour,
- StripePriceID: "price_456",
- }))
- response := request(t, s, "GET", "/v1/tiers", "", nil)
- require.Equal(t, 200, response.Code)
- var tiers []apiAccountBillingTier
- require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
- require.Equal(t, 3, len(tiers))
- // Free tier
- tier := tiers[0]
- require.Equal(t, "", tier.Code)
- require.Equal(t, "", tier.Name)
- require.Equal(t, "ip", tier.Limits.Basis)
- require.Equal(t, int64(0), tier.Limits.Reservations)
- require.Equal(t, int64(2), tier.Limits.Messages) // :-(
- require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
- require.Equal(t, int64(24), tier.Limits.Emails)
- require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
- require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
- require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
- // Admin tier is not included, because it is not paid!
- tier = tiers[1]
- require.Equal(t, "pro", tier.Code)
- require.Equal(t, "Pro", tier.Name)
- require.Equal(t, "tier", tier.Limits.Basis)
- require.Equal(t, int64(777), tier.Limits.Reservations)
- require.Equal(t, int64(1000), tier.Limits.Messages)
- require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
- require.Equal(t, int64(123), tier.Limits.Emails)
- require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
- require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
- require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
- tier = tiers[2]
- require.Equal(t, "business", tier.Code)
- require.Equal(t, "Business", tier.Name)
- require.Equal(t, "tier", tier.Limits.Basis)
- require.Equal(t, int64(777333), tier.Limits.Reservations)
- require.Equal(t, int64(2000), tier.Limits.Messages)
- require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
- require.Equal(t, int64(123123), tier.Limits.Emails)
- require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
- require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
- require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
- }
- func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("NewCheckoutSession", mock.Anything).
- Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
- // Create tier and user
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_123",
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
- // Create subscription
- response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, response.Code)
- redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
- require.Nil(t, err)
- require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
- }
- func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("GetCustomer", "acct_123").
- Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
- stripeMock.
- On("NewCheckoutSession", mock.Anything).
- Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
- // Create tier and user
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_123",
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
- u, err := s.userManager.User("phil")
- require.Nil(t, err)
- billing := &user.Billing{
- StripeCustomerID: "acct_123",
- }
- require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
- // Create subscription
- response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, response.Code)
- redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
- require.Nil(t, err)
- require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
- }
- func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.EnableSignup = true
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("CancelSubscription", "sub_123").
- Return(&stripe.Subscription{}, nil)
- // Create tier and user
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_123",
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
- u, err := s.userManager.User("phil")
- require.Nil(t, err)
- billing := &user.Billing{
- StripeCustomerID: "acct_123",
- StripeSubscriptionID: "sub_123",
- }
- require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
- // Delete account
- rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- rr = request(t, s, "GET", "/v1/account", "", map[string]string{
- "Authorization": util.BasicAuth("phil", "mypass"),
- })
- require.Equal(t, 401, rr.Code)
- }
- func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
- // This tests incoming webhooks from Stripe to update a subscription:
- // - All Stripe columns are updated in the user table
- // - When downgrading, excess reservations are deleted, including messages and attachments in
- // the corresponding topics
- stripeMock := &testStripeAPI{}
- defer stripeMock.AssertExpectations(t)
- c := newTestConfigWithAuthFile(t)
- c.StripeSecretKey = "secret key"
- c.StripeWebhookKey = "webhook key"
- s := newTestServer(t, c)
- s.stripe = stripeMock
- // Define how the mock should react
- stripeMock.
- On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
- Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
- // Create a user with a Stripe subscription and 3 reservations
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "starter",
- StripePriceID: "price_1234", // !
- ReservationsLimit: 1, // !
- MessagesLimit: 100,
- MessagesExpiryDuration: time.Hour,
- AttachmentExpiryDuration: time.Hour,
- AttachmentFileSizeLimit: 1000000,
- AttachmentTotalSizeLimit: 1000000,
- }))
- require.Nil(t, s.userManager.CreateTier(&user.Tier{
- Code: "pro",
- StripePriceID: "price_1111", // !
- ReservationsLimit: 3, // !
- MessagesLimit: 200,
- MessagesExpiryDuration: time.Hour,
- AttachmentExpiryDuration: time.Hour,
- AttachmentFileSizeLimit: 1000000,
- AttachmentTotalSizeLimit: 1000000,
- }))
- require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
- require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
- require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
- require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
- // Add billing details
- u, err := s.userManager.User("phil")
- require.Nil(t, err)
- billing := &user.Billing{
- StripeCustomerID: "acct_5555",
- StripeSubscriptionID: "sub_1234",
- StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
- StripeSubscriptionPaidUntil: time.Unix(123, 0),
- StripeSubscriptionCancelAt: time.Unix(456, 0),
- }
- require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
- // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
- rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- a2 := toMessage(t, rr.Body.String())
- require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
- rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
- "Authorization": util.BasicAuth("phil", "phil"),
- })
- require.Equal(t, 200, rr.Code)
- z2 := toMessage(t, rr.Body.String())
- require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
- // Call the webhook: This does all the magic
- rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
- "Stripe-Signature": "stripe signature",
- })
- require.Equal(t, 200, rr.Code)
- // Verify that database columns were updated
- u, err = s.userManager.User("phil")
- require.Nil(t, err)
- require.Equal(t, "starter", u.Tier.Code) // Not "pro"
- require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
- require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
- require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
- require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
- require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
- // Verify that reservations were deleted
- r, err := s.userManager.Reservations("phil")
- require.Nil(t, err)
- require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
- require.Equal(t, "atopic", r[0].Topic)
- // Verify that messages and attachments were deleted
- time.Sleep(time.Second)
- s.execManager()
- ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
- require.Nil(t, err)
- require.Equal(t, 2, len(ms))
- require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
- ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
- require.Nil(t, err)
- require.Equal(t, 0, len(ms))
- require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
- }
- type testStripeAPI struct {
- mock.Mock
- }
- var _ stripeAPI = (*testStripeAPI)(nil)
- func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
- args := s.Called(params)
- return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
- }
- func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
- args := s.Called(params)
- return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
- }
- func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
- args := s.Called(params)
- return args.Get(0).([]*stripe.Price), args.Error(1)
- }
- func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Customer), args.Error(1)
- }
- func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
- }
- func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Subscription), args.Error(1)
- }
- func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Customer), args.Error(1)
- }
- func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Subscription), args.Error(1)
- }
- func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
- args := s.Called(id)
- return args.Get(0).(*stripe.Subscription), args.Error(1)
- }
- func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
- args := s.Called(payload, header, secret)
- return args.Get(0).(stripe.Event), args.Error(1)
- }
- func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
- var e stripe.Event
- if err := json.Unmarshal([]byte(v), &e); err != nil {
- t.Fatal(err)
- }
- return e
- }
- const subscriptionUpdatedEventJSON = `
- {
- "type": "customer.subscription.updated",
- "data": {
- "object": {
- "id": "sub_1234",
- "customer": "acct_5555",
- "status": "active",
- "current_period_end": 1674268231,
- "cancel_at": 1674299999,
- "items": {
- "data": [
- {
- "price": {
- "id": "price_1234"
- }
- }
- ]
- }
- }
- }
- }`
|