server_payments_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. package server
  2. import (
  3. "encoding/json"
  4. "github.com/stretchr/testify/mock"
  5. "github.com/stretchr/testify/require"
  6. "github.com/stripe/stripe-go/v74"
  7. "heckel.io/ntfy/user"
  8. "heckel.io/ntfy/util"
  9. "io"
  10. "path/filepath"
  11. "strings"
  12. "testing"
  13. "time"
  14. )
  15. func TestPayments_Tiers(t *testing.T) {
  16. stripeMock := &testStripeAPI{}
  17. defer stripeMock.AssertExpectations(t)
  18. c := newTestConfigWithAuthFile(t)
  19. c.StripeSecretKey = "secret key"
  20. c.StripeWebhookKey = "webhook key"
  21. c.VisitorRequestLimitReplenish = 12 * time.Hour
  22. c.CacheDuration = 13 * time.Hour
  23. c.AttachmentFileSizeLimit = 111
  24. c.VisitorAttachmentTotalSizeLimit = 222
  25. c.AttachmentExpiryDuration = 123 * time.Second
  26. s := newTestServer(t, c)
  27. s.stripe = stripeMock
  28. // Define how the mock should react
  29. stripeMock.
  30. On("ListPrices", mock.Anything).
  31. Return([]*stripe.Price{
  32. {ID: "price_123", UnitAmount: 500},
  33. {ID: "price_456", UnitAmount: 1000},
  34. {ID: "price_999", UnitAmount: 9999},
  35. }, nil)
  36. // Create tiers
  37. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  38. ID: "ti_1",
  39. Code: "admin",
  40. Name: "Admin",
  41. }))
  42. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  43. ID: "ti_123",
  44. Code: "pro",
  45. Name: "Pro",
  46. MessagesLimit: 1000,
  47. MessagesExpiryDuration: time.Hour,
  48. EmailsLimit: 123,
  49. ReservationsLimit: 777,
  50. AttachmentFileSizeLimit: 999,
  51. AttachmentTotalSizeLimit: 888,
  52. AttachmentExpiryDuration: time.Minute,
  53. StripePriceID: "price_123",
  54. }))
  55. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  56. ID: "ti_444",
  57. Code: "business",
  58. Name: "Business",
  59. MessagesLimit: 2000,
  60. MessagesExpiryDuration: 10 * time.Hour,
  61. EmailsLimit: 123123,
  62. ReservationsLimit: 777333,
  63. AttachmentFileSizeLimit: 999111,
  64. AttachmentTotalSizeLimit: 888111,
  65. AttachmentExpiryDuration: time.Hour,
  66. StripePriceID: "price_456",
  67. }))
  68. response := request(t, s, "GET", "/v1/tiers", "", nil)
  69. require.Equal(t, 200, response.Code)
  70. var tiers []apiAccountBillingTier
  71. require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
  72. require.Equal(t, 3, len(tiers))
  73. // Free tier
  74. tier := tiers[0]
  75. require.Equal(t, "", tier.Code)
  76. require.Equal(t, "", tier.Name)
  77. require.Equal(t, "ip", tier.Limits.Basis)
  78. require.Equal(t, int64(0), tier.Limits.Reservations)
  79. require.Equal(t, int64(2), tier.Limits.Messages) // :-(
  80. require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
  81. require.Equal(t, int64(24), tier.Limits.Emails)
  82. require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
  83. require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
  84. require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
  85. // Admin tier is not included, because it is not paid!
  86. tier = tiers[1]
  87. require.Equal(t, "pro", tier.Code)
  88. require.Equal(t, "Pro", tier.Name)
  89. require.Equal(t, "tier", tier.Limits.Basis)
  90. require.Equal(t, int64(777), tier.Limits.Reservations)
  91. require.Equal(t, int64(1000), tier.Limits.Messages)
  92. require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
  93. require.Equal(t, int64(123), tier.Limits.Emails)
  94. require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
  95. require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
  96. require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
  97. tier = tiers[2]
  98. require.Equal(t, "business", tier.Code)
  99. require.Equal(t, "Business", tier.Name)
  100. require.Equal(t, "tier", tier.Limits.Basis)
  101. require.Equal(t, int64(777333), tier.Limits.Reservations)
  102. require.Equal(t, int64(2000), tier.Limits.Messages)
  103. require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
  104. require.Equal(t, int64(123123), tier.Limits.Emails)
  105. require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
  106. require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
  107. require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
  108. }
  109. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  110. stripeMock := &testStripeAPI{}
  111. defer stripeMock.AssertExpectations(t)
  112. c := newTestConfigWithAuthFile(t)
  113. c.StripeSecretKey = "secret key"
  114. c.StripeWebhookKey = "webhook key"
  115. s := newTestServer(t, c)
  116. s.stripe = stripeMock
  117. // Define how the mock should react
  118. stripeMock.
  119. On("NewCheckoutSession", mock.Anything).
  120. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  121. // Create tier and user
  122. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  123. Code: "pro",
  124. StripePriceID: "price_123",
  125. }))
  126. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  127. // Create subscription
  128. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  129. "Authorization": util.BasicAuth("phil", "phil"),
  130. })
  131. require.Equal(t, 200, response.Code)
  132. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  133. require.Nil(t, err)
  134. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  135. }
  136. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  137. stripeMock := &testStripeAPI{}
  138. defer stripeMock.AssertExpectations(t)
  139. c := newTestConfigWithAuthFile(t)
  140. c.StripeSecretKey = "secret key"
  141. c.StripeWebhookKey = "webhook key"
  142. s := newTestServer(t, c)
  143. s.stripe = stripeMock
  144. // Define how the mock should react
  145. stripeMock.
  146. On("GetCustomer", "acct_123").
  147. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  148. stripeMock.
  149. On("NewCheckoutSession", mock.Anything).
  150. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  151. // Create tier and user
  152. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  153. Code: "pro",
  154. StripePriceID: "price_123",
  155. }))
  156. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  157. u, err := s.userManager.User("phil")
  158. require.Nil(t, err)
  159. billing := &user.Billing{
  160. StripeCustomerID: "acct_123",
  161. }
  162. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  163. // Create subscription
  164. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  165. "Authorization": util.BasicAuth("phil", "phil"),
  166. })
  167. require.Equal(t, 200, response.Code)
  168. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  169. require.Nil(t, err)
  170. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  171. }
  172. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  173. stripeMock := &testStripeAPI{}
  174. defer stripeMock.AssertExpectations(t)
  175. c := newTestConfigWithAuthFile(t)
  176. c.EnableSignup = true
  177. c.StripeSecretKey = "secret key"
  178. c.StripeWebhookKey = "webhook key"
  179. s := newTestServer(t, c)
  180. s.stripe = stripeMock
  181. // Define how the mock should react
  182. stripeMock.
  183. On("CancelSubscription", "sub_123").
  184. Return(&stripe.Subscription{}, nil)
  185. // Create tier and user
  186. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  187. Code: "pro",
  188. StripePriceID: "price_123",
  189. }))
  190. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  191. u, err := s.userManager.User("phil")
  192. require.Nil(t, err)
  193. billing := &user.Billing{
  194. StripeCustomerID: "acct_123",
  195. StripeSubscriptionID: "sub_123",
  196. }
  197. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  198. // Delete account
  199. rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
  200. "Authorization": util.BasicAuth("phil", "phil"),
  201. })
  202. require.Equal(t, 200, rr.Code)
  203. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  204. "Authorization": util.BasicAuth("phil", "mypass"),
  205. })
  206. require.Equal(t, 401, rr.Code)
  207. }
  208. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  209. // This tests incoming webhooks from Stripe to update a subscription:
  210. // - All Stripe columns are updated in the user table
  211. // - When downgrading, excess reservations are deleted, including messages and attachments in
  212. // the corresponding topics
  213. stripeMock := &testStripeAPI{}
  214. defer stripeMock.AssertExpectations(t)
  215. c := newTestConfigWithAuthFile(t)
  216. c.StripeSecretKey = "secret key"
  217. c.StripeWebhookKey = "webhook key"
  218. s := newTestServer(t, c)
  219. s.stripe = stripeMock
  220. // Define how the mock should react
  221. stripeMock.
  222. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  223. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  224. // Create a user with a Stripe subscription and 3 reservations
  225. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  226. Code: "starter",
  227. StripePriceID: "price_1234", // !
  228. ReservationsLimit: 1, // !
  229. MessagesLimit: 100,
  230. MessagesExpiryDuration: time.Hour,
  231. AttachmentExpiryDuration: time.Hour,
  232. AttachmentFileSizeLimit: 1000000,
  233. AttachmentTotalSizeLimit: 1000000,
  234. }))
  235. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  236. Code: "pro",
  237. StripePriceID: "price_1111", // !
  238. ReservationsLimit: 3, // !
  239. MessagesLimit: 200,
  240. MessagesExpiryDuration: time.Hour,
  241. AttachmentExpiryDuration: time.Hour,
  242. AttachmentFileSizeLimit: 1000000,
  243. AttachmentTotalSizeLimit: 1000000,
  244. }))
  245. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  246. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  247. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  248. require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
  249. // Add billing details
  250. u, err := s.userManager.User("phil")
  251. require.Nil(t, err)
  252. billing := &user.Billing{
  253. StripeCustomerID: "acct_5555",
  254. StripeSubscriptionID: "sub_1234",
  255. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  256. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  257. StripeSubscriptionCancelAt: time.Unix(456, 0),
  258. }
  259. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  260. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  261. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  262. "Authorization": util.BasicAuth("phil", "phil"),
  263. })
  264. require.Equal(t, 200, rr.Code)
  265. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  266. "Authorization": util.BasicAuth("phil", "phil"),
  267. })
  268. require.Equal(t, 200, rr.Code)
  269. a2 := toMessage(t, rr.Body.String())
  270. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  271. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  272. "Authorization": util.BasicAuth("phil", "phil"),
  273. })
  274. require.Equal(t, 200, rr.Code)
  275. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  276. "Authorization": util.BasicAuth("phil", "phil"),
  277. })
  278. require.Equal(t, 200, rr.Code)
  279. z2 := toMessage(t, rr.Body.String())
  280. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  281. // Call the webhook: This does all the magic
  282. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  283. "Stripe-Signature": "stripe signature",
  284. })
  285. require.Equal(t, 200, rr.Code)
  286. // Verify that database columns were updated
  287. u, err = s.userManager.User("phil")
  288. require.Nil(t, err)
  289. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  290. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  291. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  292. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
  293. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  294. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  295. // Verify that reservations were deleted
  296. r, err := s.userManager.Reservations("phil")
  297. require.Nil(t, err)
  298. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  299. require.Equal(t, "atopic", r[0].Topic)
  300. // Verify that messages and attachments were deleted
  301. time.Sleep(time.Second)
  302. s.execManager()
  303. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  304. require.Nil(t, err)
  305. require.Equal(t, 2, len(ms))
  306. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  307. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  308. require.Nil(t, err)
  309. require.Equal(t, 0, len(ms))
  310. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  311. }
  312. type testStripeAPI struct {
  313. mock.Mock
  314. }
  315. var _ stripeAPI = (*testStripeAPI)(nil)
  316. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  317. args := s.Called(params)
  318. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  319. }
  320. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  321. args := s.Called(params)
  322. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  323. }
  324. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  325. args := s.Called(params)
  326. return args.Get(0).([]*stripe.Price), args.Error(1)
  327. }
  328. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  329. args := s.Called(id)
  330. return args.Get(0).(*stripe.Customer), args.Error(1)
  331. }
  332. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  333. args := s.Called(id)
  334. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  335. }
  336. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  337. args := s.Called(id)
  338. return args.Get(0).(*stripe.Subscription), args.Error(1)
  339. }
  340. func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
  341. args := s.Called(id)
  342. return args.Get(0).(*stripe.Customer), args.Error(1)
  343. }
  344. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  345. args := s.Called(id)
  346. return args.Get(0).(*stripe.Subscription), args.Error(1)
  347. }
  348. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  349. args := s.Called(id)
  350. return args.Get(0).(*stripe.Subscription), args.Error(1)
  351. }
  352. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  353. args := s.Called(payload, header, secret)
  354. return args.Get(0).(stripe.Event), args.Error(1)
  355. }
  356. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  357. var e stripe.Event
  358. if err := json.Unmarshal([]byte(v), &e); err != nil {
  359. t.Fatal(err)
  360. }
  361. return e
  362. }
  363. const subscriptionUpdatedEventJSON = `
  364. {
  365. "type": "customer.subscription.updated",
  366. "data": {
  367. "object": {
  368. "id": "sub_1234",
  369. "customer": "acct_5555",
  370. "status": "active",
  371. "current_period_end": 1674268231,
  372. "cancel_at": 1674299999,
  373. "items": {
  374. "data": [
  375. {
  376. "price": {
  377. "id": "price_1234"
  378. }
  379. }
  380. ]
  381. }
  382. }
  383. }
  384. }`