|
|
@@ -1,6 +1,7 @@
|
|
|
package server
|
|
|
|
|
|
import (
|
|
|
+ "bytes"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
@@ -11,19 +12,15 @@ import (
|
|
|
"github.com/stripe/stripe-go/v74/price"
|
|
|
"github.com/stripe/stripe-go/v74/subscription"
|
|
|
"github.com/stripe/stripe-go/v74/webhook"
|
|
|
- "github.com/tidwall/gjson"
|
|
|
"heckel.io/ntfy/log"
|
|
|
"heckel.io/ntfy/user"
|
|
|
"heckel.io/ntfy/util"
|
|
|
+ "io"
|
|
|
"net/http"
|
|
|
"net/netip"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
-const (
|
|
|
- stripeBodyBytesLimit = 16384
|
|
|
-)
|
|
|
-
|
|
|
var (
|
|
|
errNotAPaidTier = errors.New("tier does not have billing price identifier")
|
|
|
errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
|
|
|
@@ -52,23 +49,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|
|
},
|
|
|
},
|
|
|
}
|
|
|
+ prices, err := s.priceCache.Value()
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
for _, tier := range tiers {
|
|
|
- if tier.StripePriceID == "" {
|
|
|
+ priceStr, ok := prices[tier.StripePriceID]
|
|
|
+ if tier.StripePriceID == "" || !ok {
|
|
|
continue
|
|
|
}
|
|
|
- priceStr, ok := s.priceCache[tier.StripePriceID]
|
|
|
- if !ok {
|
|
|
- p, err := price.Get(tier.StripePriceID, nil)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- if p.UnitAmount%100 == 0 {
|
|
|
- priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
|
|
|
- } else {
|
|
|
- priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
|
|
- }
|
|
|
- s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
|
|
|
- }
|
|
|
response = append(response, &apiAccountBillingTier{
|
|
|
Code: tier.Code,
|
|
|
Name: tier.Name,
|
|
|
@@ -84,12 +73,7 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
|
|
|
},
|
|
|
})
|
|
|
}
|
|
|
- w.Header().Set("Content-Type", "application/json")
|
|
|
- w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
|
- if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
+ return s.writeJSON(w, response)
|
|
|
}
|
|
|
|
|
|
// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
|
|
|
@@ -143,12 +127,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|
|
response := &apiAccountBillingSubscriptionCreateResponse{
|
|
|
RedirectURL: sess.URL,
|
|
|
}
|
|
|
- w.Header().Set("Content-Type", "application/json")
|
|
|
- w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
|
- if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
+ return s.writeJSON(w, response)
|
|
|
}
|
|
|
|
|
|
func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
|
|
@@ -219,12 +198,7 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- w.Header().Set("Content-Type", "application/json")
|
|
|
- w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
|
- if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
+ return s.writeJSON(w, newSuccessResponse())
|
|
|
}
|
|
|
|
|
|
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
|
|
@@ -239,12 +213,7 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
- w.Header().Set("Content-Type", "application/json")
|
|
|
- w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
|
- if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
+ return s.writeJSON(w, newSuccessResponse())
|
|
|
}
|
|
|
|
|
|
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
|
|
@@ -262,12 +231,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|
|
response := &apiAccountBillingPortalRedirectResponse{
|
|
|
RedirectURL: ps.URL,
|
|
|
}
|
|
|
- w.Header().Set("Content-Type", "application/json")
|
|
|
- w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
|
- if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
+ return s.writeJSON(w, response)
|
|
|
}
|
|
|
|
|
|
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
|
|
@@ -278,7 +242,7 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|
|
if stripeSignature == "" {
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
}
|
|
|
- body, err := util.Peek(r.Body, stripeBodyBytesLimit)
|
|
|
+ body, err := util.Peek(r.Body, jsonBodyBytesLimit)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
} else if body.LimitReached {
|
|
|
@@ -302,25 +266,23 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
|
|
|
}
|
|
|
|
|
|
func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
|
|
- subscriptionID := gjson.GetBytes(event, "id")
|
|
|
- customerID := gjson.GetBytes(event, "customer")
|
|
|
- status := gjson.GetBytes(event, "status")
|
|
|
- currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
|
|
|
- cancelAt := gjson.GetBytes(event, "cancel_at")
|
|
|
- priceID := gjson.GetBytes(event, "items.data.0.price.id")
|
|
|
- if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
|
|
|
+ r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ } else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
}
|
|
|
- log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
|
|
|
- u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
|
|
+ subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
|
|
|
+ log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
|
|
|
+ u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- tier, err := s.userManager.TierByStripePrice(priceID.String())
|
|
|
+ tier, err := s.userManager.TierByStripePrice(priceID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
|
|
|
+ if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
|
|
@@ -328,16 +290,18 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
|
|
|
}
|
|
|
|
|
|
func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
|
|
- customerID := gjson.GetBytes(event, "customer")
|
|
|
- if !customerID.Exists() {
|
|
|
+ r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ } else if r.Customer == "" {
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
}
|
|
|
- log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
|
|
|
- u, err := s.userManager.UserByStripeCustomer(customerID.String())
|
|
|
+ log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
|
|
|
+ u, err := s.userManager.UserByStripeCustomer(r.Customer)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
|
|
|
+ if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
|
|
|
@@ -364,3 +328,27 @@ func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptio
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
+
|
|
|
+// fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
|
|
|
+// in memory, and ultimately for the web app to display the price table.
|
|
|
+func fetchStripePrices() (map[string]string, error) {
|
|
|
+ log.Debug("Caching prices from Stripe API")
|
|
|
+ prices := make(map[string]string)
|
|
|
+ iter := price.List(&stripe.PriceListParams{
|
|
|
+ Active: stripe.Bool(true),
|
|
|
+ })
|
|
|
+ for iter.Next() {
|
|
|
+ p := iter.Price()
|
|
|
+ if p.UnitAmount%100 == 0 {
|
|
|
+ prices[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
|
|
|
+ } else {
|
|
|
+ prices[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
|
|
|
+ }
|
|
|
+ log.Trace("- Caching price %s = %v", p.ID, prices[p.ID])
|
|
|
+ }
|
|
|
+ if iter.Err() != nil {
|
|
|
+ log.Warn("Fetching Stripe prices failed: %s", iter.Err().Error())
|
|
|
+ return nil, iter.Err()
|
|
|
+ }
|
|
|
+ return prices, nil
|
|
|
+}
|