|
@@ -6,6 +6,7 @@ import (
|
|
|
"errors"
|
|
"errors"
|
|
|
"fmt"
|
|
"fmt"
|
|
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
|
|
|
|
+ "github.com/stripe/stripe-go/v74"
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
"heckel.io/ntfy/log"
|
|
"heckel.io/ntfy/log"
|
|
|
"heckel.io/ntfy/util"
|
|
"heckel.io/ntfy/util"
|
|
@@ -60,7 +61,9 @@ const (
|
|
|
stats_messages INT NOT NULL DEFAULT (0),
|
|
stats_messages INT NOT NULL DEFAULT (0),
|
|
|
stats_emails INT NOT NULL DEFAULT (0),
|
|
stats_emails INT NOT NULL DEFAULT (0),
|
|
|
stripe_customer_id TEXT,
|
|
stripe_customer_id TEXT,
|
|
|
- stripe_subscription_id TEXT,
|
|
|
|
|
|
|
+ stripe_subscription_id TEXT,
|
|
|
|
|
+ stripe_subscription_status TEXT,
|
|
|
|
|
+ stripe_subscription_paid_until INT,
|
|
|
created_by TEXT NOT NULL,
|
|
created_by TEXT NOT NULL,
|
|
|
created_at INT NOT NULL,
|
|
created_at INT NOT NULL,
|
|
|
last_seen INT NOT NULL,
|
|
last_seen INT NOT NULL,
|
|
@@ -100,20 +103,20 @@ const (
|
|
|
`
|
|
`
|
|
|
|
|
|
|
|
selectUserByNameQuery = `
|
|
selectUserByNameQuery = `
|
|
|
- SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
|
|
|
|
|
|
+ SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
|
|
FROM user u
|
|
FROM user u
|
|
|
LEFT JOIN tier p on p.id = u.tier_id
|
|
LEFT JOIN tier p on p.id = u.tier_id
|
|
|
WHERE user = ?
|
|
WHERE user = ?
|
|
|
`
|
|
`
|
|
|
selectUserByTokenQuery = `
|
|
selectUserByTokenQuery = `
|
|
|
- SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
|
|
|
|
|
|
+ SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
|
|
FROM user u
|
|
FROM user u
|
|
|
JOIN user_token t on u.id = t.user_id
|
|
JOIN user_token t on u.id = t.user_id
|
|
|
LEFT JOIN tier p on p.id = u.tier_id
|
|
LEFT JOIN tier p on p.id = u.tier_id
|
|
|
WHERE t.token = ? AND t.expires >= ?
|
|
WHERE t.token = ? AND t.expires >= ?
|
|
|
`
|
|
`
|
|
|
selectUserByStripeCustomerIDQuery = `
|
|
selectUserByStripeCustomerIDQuery = `
|
|
|
- SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
|
|
|
|
|
|
+ SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
|
|
|
FROM user u
|
|
FROM user u
|
|
|
LEFT JOIN tier p on p.id = u.tier_id
|
|
LEFT JOIN tier p on p.id = u.tier_id
|
|
|
WHERE u.stripe_customer_id = ?
|
|
WHERE u.stripe_customer_id = ?
|
|
@@ -231,7 +234,11 @@ const (
|
|
|
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
|
|
updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
|
|
|
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
|
deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
|
|
|
|
|
|
|
|
- updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
|
|
|
|
|
|
|
+ updateBillingQuery = `
|
|
|
|
|
+ UPDATE user
|
|
|
|
|
+ SET stripe_customer_id = ?, stripe_subscription_id = ?, stripe_subscription_status = ?, stripe_subscription_paid_until = ?
|
|
|
|
|
+ WHERE user = ?
|
|
|
|
|
+ `
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
// Schema management queries
|
|
// Schema management queries
|
|
@@ -597,14 +604,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
|
|
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|
|
defer rows.Close()
|
|
defer rows.Close()
|
|
|
var username, hash, role, prefs, syncTopic string
|
|
var username, hash, role, prefs, syncTopic string
|
|
|
- var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
|
|
|
|
|
|
|
+ var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
|
|
|
var paid sql.NullBool
|
|
var paid sql.NullBool
|
|
|
var messages, emails int64
|
|
var messages, emails int64
|
|
|
- var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
|
|
|
|
|
|
|
+ var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil sql.NullInt64
|
|
|
if !rows.Next() {
|
|
if !rows.Next() {
|
|
|
return nil, ErrUserNotFound
|
|
return nil, ErrUserNotFound
|
|
|
}
|
|
}
|
|
|
- if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
|
|
|
|
|
|
+ if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
} else if err := rows.Err(); err != nil {
|
|
} else if err := rows.Err(); err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
@@ -619,16 +626,16 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|
|
Messages: messages,
|
|
Messages: messages,
|
|
|
Emails: emails,
|
|
Emails: emails,
|
|
|
},
|
|
},
|
|
|
|
|
+ Billing: &Billing{
|
|
|
|
|
+ StripeCustomerID: stripeCustomerID.String, // May be empty
|
|
|
|
|
+ StripeSubscriptionID: stripeSubscriptionID.String, // May be empty
|
|
|
|
|
+ StripeSubscriptionStatus: stripe.SubscriptionStatus(stripeSubscriptionStatus.String), // May be empty
|
|
|
|
|
+ StripeSubscriptionPaidUntil: time.Unix(stripeSubscriptionPaidUntil.Int64, 0), // May be zero
|
|
|
|
|
+ },
|
|
|
}
|
|
}
|
|
|
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
|
if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
|
|
|
return nil, err
|
|
return nil, err
|
|
|
}
|
|
}
|
|
|
- if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
|
|
|
|
|
- user.Billing = &Billing{
|
|
|
|
|
- StripeCustomerID: stripeCustomerID.String,
|
|
|
|
|
- StripeSubscriptionID: stripeSubscriptionID.String,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
if tierCode.Valid {
|
|
if tierCode.Valid {
|
|
|
// See readTier() when this is changed!
|
|
// See readTier() when this is changed!
|
|
|
user.Tier = &Tier{
|
|
user.Tier = &Tier{
|
|
@@ -868,7 +875,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (a *Manager) ChangeBilling(user *User) error {
|
|
func (a *Manager) ChangeBilling(user *User) error {
|
|
|
- if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
|
|
|
|
|
|
|
+ if _, err := a.db.Exec(updateBillingQuery, nullString(user.Billing.StripeCustomerID), nullString(user.Billing.StripeSubscriptionID), nullString(string(user.Billing.StripeSubscriptionStatus)), nullInt64(user.Billing.StripeSubscriptionPaidUntil.Unix()), user.Name); err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
return nil
|
|
return nil
|
|
@@ -1020,3 +1027,17 @@ func migrateFrom1(db *sql.DB) error {
|
|
|
}
|
|
}
|
|
|
return nil // Update this when a new version is added
|
|
return nil // Update this when a new version is added
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+func nullString(s string) sql.NullString {
|
|
|
|
|
+ if s == "" {
|
|
|
|
|
+ return sql.NullString{}
|
|
|
|
|
+ }
|
|
|
|
|
+ return sql.NullString{String: s, Valid: true}
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func nullInt64(v int64) sql.NullInt64 {
|
|
|
|
|
+ if v == 0 {
|
|
|
|
|
+ return sql.NullInt64{}
|
|
|
|
|
+ }
|
|
|
|
|
+ return sql.NullInt64{Int64: v, Valid: true}
|
|
|
|
|
+}
|