|
|
@@ -96,7 +96,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|
|
var stripeCustomerID *string
|
|
|
if v.user.Billing.StripeCustomerID != "" {
|
|
|
stripeCustomerID = &v.user.Billing.StripeCustomerID
|
|
|
- stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil)
|
|
|
+ stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
|
|
|
@@ -120,7 +120,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|
|
Enabled: stripe.Bool(true),
|
|
|
},*/
|
|
|
}
|
|
|
- sess, err := session.New(params)
|
|
|
+ sess, err := s.stripe.NewCheckoutSession(params)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -137,14 +137,14 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|
|
return errHTTPInternalErrorInvalidPath
|
|
|
}
|
|
|
sessionID := matches[1]
|
|
|
- sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this?
|
|
|
+ sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
|
|
|
if err != nil {
|
|
|
log.Warn("Stripe: %s", err)
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
|
|
|
return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
|
|
|
}
|
|
|
- sub, err := subscription.Get(sess.Subscription.ID, nil)
|
|
|
+ sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
|
|
|
@@ -180,7 +180,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|
|
return err
|
|
|
}
|
|
|
log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
|
|
|
- sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
|
|
|
+ sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -194,7 +194,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|
|
},
|
|
|
},
|
|
|
}
|
|
|
- _, err = subscription.Update(sub.ID, params)
|
|
|
+ _, err = s.stripe.UpdateSubscription(sub.ID, params)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -208,7 +208,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
|
|
params := &stripe.SubscriptionParams{
|
|
|
CancelAtPeriodEnd: stripe.Bool(true),
|
|
|
}
|
|
|
- _, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params)
|
|
|
+ _, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -224,7 +224,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|
|
Customer: stripe.String(v.user.Billing.StripeCustomerID),
|
|
|
ReturnURL: stripe.String(s.config.BaseURL),
|
|
|
}
|
|
|
- ps, err := portalsession.New(params)
|
|
|
+ ps, err := s.stripe.NewPortalSession(params)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -248,7 +248,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|
|
} else if body.LimitReached {
|
|
|
return errHTTPEntityTooLargeJSONBody
|
|
|
}
|
|
|
- event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
|
|
|
+ event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
|
|
|
if err != nil {
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
} else if event.Data == nil || event.Data.Raw == nil {
|
|
|
@@ -331,24 +331,82 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
|
|
|
|
|
|
// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
|
|
// in memory, and ultimately for the web app to display the price table.
|
|
|
-func fetchStripePrices() (map[string]string, error) {
|
|
|
+func (s *Server) fetchStripePrices() (map[string]string, error) {
|
|
|
log.Debug("Caching prices from Stripe API")
|
|
|
- prices := make(map[string]string)
|
|
|
- iter := price.List(&stripe.PriceListParams{
|
|
|
- Active: stripe.Bool(true),
|
|
|
- })
|
|
|
- for iter.Next() {
|
|
|
- p := iter.Price()
|
|
|
+ priceMap := make(map[string]string)
|
|
|
+ prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
|
|
|
+ if err != nil {
|
|
|
+ log.Warn("Fetching Stripe prices failed: %s", err.Error())
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ for _, p := range prices {
|
|
|
if p.UnitAmount%100 == 0 {
|
|
|
- prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
|
|
+ priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
|
|
} else {
|
|
|
- prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
|
|
+ priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
|
|
}
|
|
|
- log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
|
|
|
+ log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
|
|
|
+ }
|
|
|
+ return priceMap, nil
|
|
|
+}
|
|
|
+
|
|
|
+// stripeAPI is a small interface to facilitate mocking of the Stripe API
|
|
|
+type stripeAPI interface {
|
|
|
+ NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error)
|
|
|
+ NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error)
|
|
|
+ ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error)
|
|
|
+ GetCustomer(id string) (*stripe.Customer, error)
|
|
|
+ GetSession(id string) (*stripe.CheckoutSession, error)
|
|
|
+ GetSubscription(id string) (*stripe.Subscription, error)
|
|
|
+ UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
|
|
|
+ ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
|
|
|
+}
|
|
|
+
|
|
|
+// realStripeAPI is a thin shim around the Stripe functions to facilitate mocking
|
|
|
+type realStripeAPI struct{}
|
|
|
+
|
|
|
+var _ stripeAPI = (*realStripeAPI)(nil)
|
|
|
+
|
|
|
+func newStripeAPI() stripeAPI {
|
|
|
+ return &realStripeAPI{}
|
|
|
+}
|
|
|
+
|
|
|
+func (s *realStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
|
|
|
+ return session.New(params)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *realStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
|
|
|
+ return portalsession.New(params)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *realStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
|
|
|
+ prices := make([]*stripe.Price, 0)
|
|
|
+ iter := price.List(params)
|
|
|
+ for iter.Next() {
|
|
|
+ prices = append(prices, iter.Price())
|
|
|
}
|
|
|
if iter.Err() != nil {
|
|
|
- log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
|
|
|
return nil, iter.Err()
|
|
|
}
|
|
|
return prices, nil
|
|
|
}
|
|
|
+
|
|
|
+func (s *realStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
|
|
|
+ return customer.Get(id, nil)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *realStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
|
|
|
+ return session.Get(id, nil)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
|
|
|
+ return subscription.Get(id, nil)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
|
|
|
+ return subscription.Update(id, params)
|
|
|
+}
|
|
|
+
|
|
|
+func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
|
|
|
+ return webhook.ConstructEvent(payload, header, secret)
|
|
|
+}
|