server_payments_test.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package server
  2. import (
  3. "github.com/stretchr/testify/mock"
  4. "github.com/stretchr/testify/require"
  5. "github.com/stripe/stripe-go/v74"
  6. "heckel.io/ntfy/user"
  7. "heckel.io/ntfy/util"
  8. "io"
  9. "testing"
  10. )
  11. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  12. stripeMock := &testStripeAPI{}
  13. defer stripeMock.AssertExpectations(t)
  14. c := newTestConfigWithAuthFile(t)
  15. c.StripeSecretKey = "secret key"
  16. c.StripeWebhookKey = "webhook key"
  17. s := newTestServer(t, c)
  18. s.stripe = stripeMock
  19. // Define how the mock should react
  20. stripeMock.
  21. On("NewCheckoutSession", mock.Anything).
  22. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  23. // Create tier and user
  24. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  25. Code: "pro",
  26. StripePriceID: "price_123",
  27. }))
  28. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  29. // Create subscription
  30. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  31. "Authorization": util.BasicAuth("phil", "phil"),
  32. })
  33. require.Equal(t, 200, response.Code)
  34. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  35. require.Nil(t, err)
  36. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  37. }
  38. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  39. stripeMock := &testStripeAPI{}
  40. defer stripeMock.AssertExpectations(t)
  41. c := newTestConfigWithAuthFile(t)
  42. c.StripeSecretKey = "secret key"
  43. c.StripeWebhookKey = "webhook key"
  44. s := newTestServer(t, c)
  45. s.stripe = stripeMock
  46. // Define how the mock should react
  47. stripeMock.
  48. On("GetCustomer", "acct_123").
  49. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  50. stripeMock.
  51. On("NewCheckoutSession", mock.Anything).
  52. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  53. // Create tier and user
  54. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  55. Code: "pro",
  56. StripePriceID: "price_123",
  57. }))
  58. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  59. u, err := s.userManager.User("phil")
  60. require.Nil(t, err)
  61. u.Billing.StripeCustomerID = "acct_123"
  62. require.Nil(t, s.userManager.ChangeBilling(u))
  63. // Create subscription
  64. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  65. "Authorization": util.BasicAuth("phil", "phil"),
  66. })
  67. require.Equal(t, 200, response.Code)
  68. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  69. require.Nil(t, err)
  70. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  71. }
  72. type testStripeAPI struct {
  73. mock.Mock
  74. }
  75. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  76. args := s.Called(params)
  77. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  78. }
  79. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  80. args := s.Called(params)
  81. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  82. }
  83. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  84. args := s.Called(params)
  85. return args.Get(0).([]*stripe.Price), args.Error(1)
  86. }
  87. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  88. args := s.Called(id)
  89. return args.Get(0).(*stripe.Customer), args.Error(1)
  90. }
  91. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  92. args := s.Called(id)
  93. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  94. }
  95. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  96. args := s.Called(id)
  97. return args.Get(0).(*stripe.Subscription), args.Error(1)
  98. }
  99. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  100. args := s.Called(id)
  101. return args.Get(0).(*stripe.Subscription), args.Error(1)
  102. }
  103. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  104. args := s.Called(payload, header, secret)
  105. return args.Get(0).(stripe.Event), args.Error(1)
  106. }
  107. var _ stripeAPI = (*testStripeAPI)(nil)