Browse Source

Allow mocking the Stripe API

binwiederhier 3 years ago
parent
commit
4e51a715c1
6 changed files with 224 additions and 29 deletions
  1. 1 0
      go.mod
  2. 1 0
      go.sum
  3. 13 8
      server/server.go
  4. 1 1
      server/server_middleware.go
  5. 78 20
      server/server_payments.go
  6. 130 0
      server/server_payments_test.go

+ 1 - 0
go.mod

@@ -49,6 +49,7 @@ require (
 	github.com/googleapis/gax-go/v2 v2.7.0 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/russross/blackfriday/v2 v2.1.0 // indirect
+	github.com/stretchr/objx v0.5.0 // indirect
 	github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
 	go.opencensus.io v0.24.0 // indirect
 	golang.org/x/net v0.4.0 // indirect

+ 1 - 0
go.sum

@@ -94,6 +94,7 @@ github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf
 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=

+ 13 - 8
server/server.go

@@ -37,8 +37,6 @@ import (
 /*
 	TODO
 		payments:
-		- send dunning emails when overdue
-		- payment methods
 		- delete subscription when account deleted
 		- delete messages + reserved topics on ResetTier
 
@@ -76,9 +74,10 @@ type Server struct {
 	visitors          map[string]*visitor // ip:<ip> or user:<user>
 	firebaseClient    *firebaseClient
 	messages          int64
-	userManager       *user.Manager // Might be nil!
-	messageCache      *messageCache
-	fileCache         *fileCache
+	userManager       *user.Manager                        // Might be nil!
+	messageCache      *messageCache                        // Database that stores the messages
+	fileCache         *fileCache                           // File system based cache that stores attachments
+	stripe            stripeAPI                            // Stripe API, can be replaced with a mock
 	priceCache        *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
 	closeChan         chan bool
 	mu                sync.Mutex
@@ -160,6 +159,10 @@ func New(conf *Config) (*Server, error) {
 	if conf.SMTPSenderAddr != "" {
 		mailer = &smtpSender{config: conf}
 	}
+	var stripe stripeAPI
+	if conf.StripeSecretKey != "" {
+		stripe = newStripeAPI()
+	}
 	messageCache, err := createMessageCache(conf)
 	if err != nil {
 		return nil, err
@@ -190,7 +193,7 @@ func New(conf *Config) (*Server, error) {
 		}
 		firebaseClient = newFirebaseClient(sender, userManager)
 	}
-	return &Server{
+	s := &Server{
 		config:         conf,
 		messageCache:   messageCache,
 		fileCache:      fileCache,
@@ -199,8 +202,10 @@ func New(conf *Config) (*Server, error) {
 		topics:         topics,
 		userManager:    userManager,
 		visitors:       make(map[string]*visitor),
-		priceCache:     util.NewLookupCache(fetchStripePrices, conf.StripePriceCacheDuration),
-	}, nil
+		stripe:         stripe,
+	}
+	s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
+	return s, nil
 }
 
 func createMessageCache(conf *Config) (*messageCache, error) {

+ 1 - 1
server/server_middleware.go

@@ -33,7 +33,7 @@ func (s *Server) ensureUser(next handleFunc) handleFunc {
 
 func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if s.config.StripeSecretKey == "" {
+		if s.config.StripeSecretKey == "" || s.stripe == nil {
 			return errHTTPNotFound
 		}
 		return next(w, r, v)

+ 78 - 20
server/server_payments.go

@@ -96,7 +96,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
 	var stripeCustomerID *string
 	if v.user.Billing.StripeCustomerID != "" {
 		stripeCustomerID = &v.user.Billing.StripeCustomerID
-		stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil)
+		stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
 		if err != nil {
 			return err
 		} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
@@ -120,7 +120,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
 			Enabled: stripe.Bool(true),
 		},*/
 	}
-	sess, err := session.New(params)
+	sess, err := s.stripe.NewCheckoutSession(params)
 	if err != nil {
 		return err
 	}
@@ -137,14 +137,14 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
 		return errHTTPInternalErrorInvalidPath
 	}
 	sessionID := matches[1]
-	sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this?
+	sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
 	if err != nil {
 		log.Warn("Stripe: %s", err)
 		return errHTTPBadRequestBillingRequestInvalid
 	} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
 		return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
 	}
-	sub, err := subscription.Get(sess.Subscription.ID, nil)
+	sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
 	if err != nil {
 		return err
 	} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
@@ -180,7 +180,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
 		return err
 	}
 	log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
-	sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
+	sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
 	if err != nil {
 		return err
 	}
@@ -194,7 +194,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
 			},
 		},
 	}
-	_, err = subscription.Update(sub.ID, params)
+	_, err = s.stripe.UpdateSubscription(sub.ID, params)
 	if err != nil {
 		return err
 	}
@@ -208,7 +208,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
 		params := &stripe.SubscriptionParams{
 			CancelAtPeriodEnd: stripe.Bool(true),
 		}
-		_, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params)
+		_, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
 		if err != nil {
 			return err
 		}
@@ -224,7 +224,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
 		Customer:  stripe.String(v.user.Billing.StripeCustomerID),
 		ReturnURL: stripe.String(s.config.BaseURL),
 	}
-	ps, err := portalsession.New(params)
+	ps, err := s.stripe.NewPortalSession(params)
 	if err != nil {
 		return err
 	}
@@ -248,7 +248,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
 	} else if body.LimitReached {
 		return errHTTPEntityTooLargeJSONBody
 	}
-	event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
+	event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
 	if err != nil {
 		return errHTTPBadRequestBillingRequestInvalid
 	} else if event.Data == nil || event.Data.Raw == nil {
@@ -331,24 +331,82 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
 
 // fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
 // in memory, and ultimately for the web app to display the price table.
-func fetchStripePrices() (map[string]string, error) {
+func (s *Server) fetchStripePrices() (map[string]string, error) {
 	log.Debug("Caching prices from Stripe API")
-	prices := make(map[string]string)
-	iter := price.List(&stripe.PriceListParams{
-		Active: stripe.Bool(true),
-	})
-	for iter.Next() {
-		p := iter.Price()
+	priceMap := make(map[string]string)
+	prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
+	if err != nil {
+		log.Warn("Fetching Stripe prices failed: %s", err.Error())
+		return nil, err
+	}
+	for _, p := range prices {
 		if p.UnitAmount%100 == 0 {
-			prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
+			priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
 		} else {
-			prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
+			priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
 		}
-		log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
+		log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
+	}
+	return priceMap, nil
+}
+
+// stripeAPI is a small interface to facilitate mocking of the Stripe API
+type stripeAPI interface {
+	NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error)
+	NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error)
+	ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error)
+	GetCustomer(id string) (*stripe.Customer, error)
+	GetSession(id string) (*stripe.CheckoutSession, error)
+	GetSubscription(id string) (*stripe.Subscription, error)
+	UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
+	ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
+}
+
+// realStripeAPI is a thin shim around the Stripe functions to facilitate mocking
+type realStripeAPI struct{}
+
+var _ stripeAPI = (*realStripeAPI)(nil)
+
+func newStripeAPI() stripeAPI {
+	return &realStripeAPI{}
+}
+
+func (s *realStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
+	return session.New(params)
+}
+
+func (s *realStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
+	return portalsession.New(params)
+}
+
+func (s *realStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
+	prices := make([]*stripe.Price, 0)
+	iter := price.List(params)
+	for iter.Next() {
+		prices = append(prices, iter.Price())
 	}
 	if iter.Err() != nil {
-		log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
 		return nil, iter.Err()
 	}
 	return prices, nil
 }
+
+func (s *realStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
+	return customer.Get(id, nil)
+}
+
+func (s *realStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
+	return session.Get(id, nil)
+}
+
+func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
+	return subscription.Get(id, nil)
+}
+
+func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
+	return subscription.Update(id, params)
+}
+
+func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
+	return webhook.ConstructEvent(payload, header, secret)
+}

+ 130 - 0
server/server_payments_test.go

@@ -0,0 +1,130 @@
+package server
+
+import (
+	"github.com/stretchr/testify/mock"
+	"github.com/stretchr/testify/require"
+	"github.com/stripe/stripe-go/v74"
+	"heckel.io/ntfy/user"
+	"heckel.io/ntfy/util"
+	"io"
+	"testing"
+)
+
+func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Define how the mock should react
+	stripeMock.
+		On("NewCheckoutSession", mock.Anything).
+		Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
+
+	// Create tier and user
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code:          "pro",
+		StripePriceID: "price_123",
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
+
+	// Create subscription
+	response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, response.Code)
+	redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
+	require.Nil(t, err)
+	require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
+}
+
+func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Define how the mock should react
+	stripeMock.
+		On("GetCustomer", "acct_123").
+		Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
+	stripeMock.
+		On("NewCheckoutSession", mock.Anything).
+		Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
+
+	// Create tier and user
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code:          "pro",
+		StripePriceID: "price_123",
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
+
+	u, err := s.userManager.User("phil")
+	require.Nil(t, err)
+
+	u.Billing.StripeCustomerID = "acct_123"
+	require.Nil(t, s.userManager.ChangeBilling(u))
+
+	// Create subscription
+	response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, response.Code)
+	redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
+	require.Nil(t, err)
+	require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
+}
+
+type testStripeAPI struct {
+	mock.Mock
+}
+
+func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
+	args := s.Called(params)
+	return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
+}
+
+func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
+	args := s.Called(params)
+	return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
+}
+
+func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
+	args := s.Called(params)
+	return args.Get(0).([]*stripe.Price), args.Error(1)
+}
+
+func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
+	args := s.Called(id)
+	return args.Get(0).(*stripe.Customer), args.Error(1)
+}
+
+func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
+	args := s.Called(id)
+	return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
+}
+
+func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
+	args := s.Called(id)
+	return args.Get(0).(*stripe.Subscription), args.Error(1)
+}
+
+func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
+	args := s.Called(id)
+	return args.Get(0).(*stripe.Subscription), args.Error(1)
+}
+
+func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
+	args := s.Called(payload, header, secret)
+	return args.Get(0).(stripe.Event), args.Error(1)
+}
+
+var _ stripeAPI = (*testStripeAPI)(nil)