Просмотр исходного кода

Deleting account deletes subscription

binwiederhier 3 лет назад
Родитель
Сommit
45b97c7054
5 измененных файлов с 67 добавлено и 1 удалено
  1. 3 1
      server/server.go
  2. 10 0
      server/server_account.go
  3. 5 0
      server/server_payments.go
  4. 47 0
      server/server_payments_test.go
  5. 2 0
      server/visitor.go

+ 3 - 1
server/server.go

@@ -36,8 +36,10 @@ import (
 
 /*
 	TODO
+		races:
+		- v.user --> see publishSyncEventAsync() test
+
 		payments:
-		- delete subscription when account deleted
 		- delete messages + reserved topics on ResetTier
 
 		Limits & rate limiting:

+ 10 - 0
server/server_account.go

@@ -119,6 +119,16 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
 }
 
 func (s *Server) handleAccountDelete(w http.ResponseWriter, _ *http.Request, v *visitor) error {
+	if v.user.Billing.StripeCustomerID != "" {
+		log.Info("Deleting user %s (billing customer: %s, billing subscription: %s)", v.user.Name, v.user.Billing.StripeCustomerID, v.user.Billing.StripeSubscriptionID)
+		if v.user.Billing.StripeSubscriptionID != "" {
+			if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
+				return err
+			}
+		}
+	} else {
+		log.Info("Deleting user %s", v.user.Name)
+	}
 	if err := s.userManager.RemoveUser(v.user.Name); err != nil {
 		return err
 	}

+ 5 - 0
server/server_payments.go

@@ -359,6 +359,7 @@ type stripeAPI interface {
 	GetSession(id string) (*stripe.CheckoutSession, error)
 	GetSubscription(id string) (*stripe.Subscription, error)
 	UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
+	CancelSubscription(id string) (*stripe.Subscription, error)
 	ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
 }
 
@@ -407,6 +408,10 @@ func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.Subscriptio
 	return subscription.Update(id, params)
 }
 
+func (s *realStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
+	return subscription.Cancel(id, nil)
+}
+
 func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
 	return webhook.ConstructEvent(payload, header, secret)
 }

+ 47 - 0
server/server_payments_test.go

@@ -83,6 +83,48 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
 	require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
 }
 
+func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.EnableSignup = true
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Define how the mock should react
+	stripeMock.
+		On("CancelSubscription", "sub_123").
+		Return(&stripe.Subscription{}, nil)
+
+	// Create tier and user
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code:          "pro",
+		StripePriceID: "price_123",
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
+
+	u, err := s.userManager.User("phil")
+	require.Nil(t, err)
+
+	u.Billing.StripeCustomerID = "acct_123"
+	u.Billing.StripeSubscriptionID = "sub_123"
+	require.Nil(t, s.userManager.ChangeBilling(u))
+
+	// Delete account
+	rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, rr.Code)
+
+	rr = request(t, s, "GET", "/v1/account", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "mypass"),
+	})
+	require.Equal(t, 401, rr.Code)
+}
+
 type testStripeAPI struct {
 	mock.Mock
 }
@@ -122,6 +164,11 @@ func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.Subscriptio
 	return args.Get(0).(*stripe.Subscription), args.Error(1)
 }
 
+func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
+	args := s.Called(id)
+	return args.Get(0).(*stripe.Subscription), args.Error(1)
+}
+
 func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
 	args := s.Called(payload, header, secret)
 	return args.Get(0).(stripe.Event), args.Error(1)

+ 2 - 0
server/visitor.go

@@ -213,6 +213,8 @@ func (v *visitor) ResetStats() {
 }
 
 func (v *visitor) Limits() *visitorLimits {
+	v.mu.Lock()
+	defer v.mu.Unlock()
 	limits := defaultVisitorLimits(v.config)
 	if v.user != nil && v.user.Tier != nil {
 		limits.Basis = visitorLimitBasisTier