commit 83f515c6e1b3207f028c7547e1d45b079aa2dcfa
parent ea9a1c56d36dbadbcd5e321fa2b40c1e6d9ee55b
Author: Jacob R. Edwards <jacob@jacobedwards.org>
Date: Sat, 31 Aug 2024 15:22:41 -0700
Add active subscription verification
This implements an interface to Stripe's webhook system (under the
stripe endpoint) which currently supports the
customer.subscription.created, updated, and deleted events. This
also required adding a webhook key configuration parameter to
StripeConfig.
It also adds a subscription table to the database with all the
required subscription information.
And finally, it introduces payment verification middleware for
select endpoints to verify there is an active subscription for the
user before giving access.
Diffstat:
7 files changed, 208 insertions(+), 0 deletions(-)
diff --git a/cmd/api/main.go b/cmd/api/main.go
@@ -81,6 +81,10 @@ func main() {
}
func setRoutes(env *Env, r *gin.RouterGroup) {
+ // This endpoint uses a separate authorization scheme
+ // enforced within stripeEventHandler
+ r.POST("/stripe", env.StripeEventHandler)
+
r.POST("/tokens", env.Auth.LoginHandler)
r.GET("/tokens", env.Auth.RefreshHandler)
r.GET("/settings", env.GetSettings)
@@ -107,6 +111,11 @@ func setAuthenticatedRoutes(env *Env, r *gin.RouterGroup) {
email.POST("/code", env.VerifyUserEmailCode)
email.GET("/verified", env.VerifiedUserEmail)
+ payed := r.Group("", env.VerifyPayment)
+ setPayedRoutes(env, payed)
+}
+
+func setPayedRoutes(env *Env, r *gin.RouterGroup) {
fp := r.Group("/floorplans/:user")
fp.GET("", env.GetFloorplans)
fp.POST("", env.CreateFloorplan)
@@ -122,6 +131,18 @@ func setAuthenticatedRoutes(env *Env, r *gin.RouterGroup) {
fpdata.PATCH("", env.PatchFloorplanData)
}
+func (e *Env) VerifyPayment(c *gin.Context) {
+ service, err := e.backend.UserService(c.Param("user"))
+ if err != nil {
+ RespondError(c, 500, "Unable to get subscription status")
+ } else if service == nil {
+ RespondError(c, 401, "You must be subscribed to access this resource")
+ } else {
+ c.Set("service_id", service)
+ c.Next()
+ }
+}
+
func noRoute(c *gin.Context) {
RespondError(c, http.StatusNotFound, "Endpoint does not exist")
}
diff --git a/cmd/api/migration/2024-08-31T17:08:02.sql b/cmd/api/migration/2024-08-31T17:08:02.sql
@@ -0,0 +1,18 @@
+CREATE TABLE spaceplanner.subscriptions (
+ -- Stripe subscription id
+ id varchar PRIMARY KEY,
+
+ -- Stripe customer id
+ customer_id varchar
+ REFERENCES spaceplanner.users(stripe_customer_id)
+ ON DELETE CASCADE
+ NOT NULL,
+
+ -- Stripe price id
+ price_id varchar NOT NULL,
+
+ -- Stripe product id
+ product_id varchar NOT NULL,
+
+ active boolean NOT NULL
+);
diff --git a/cmd/api/stripe.go b/cmd/api/stripe.go
@@ -0,0 +1,31 @@
+package main
+
+import (
+ "github.com/gin-gonic/gin"
+ "github.com/stripe/stripe-go/v72"
+ "github.com/stripe/stripe-go/v72/webhook"
+)
+
+func (e *Env) StripeEventHandler(c *gin.Context) {
+ ev := stripe.Event{}
+
+ payload, err := c.GetRawData()
+ if err != nil {
+ RespondError(c, 500, "%s: Unable to read payload", err.Error())
+ return
+ }
+
+ ev, err = webhook.ConstructEvent(payload, c.GetHeader("Stripe-Signature"),
+ e.backend.Config.Stripe.WebhookKey)
+ if err != nil {
+ RespondError(c, 400, "%s: Unable to validate event", err.Error())
+ return
+ }
+
+ if err := e.backend.StripeEventHandler(&ev); err != nil {
+ RespondError(c, 400, "%s: Unable to process event", err.Error())
+ return
+ }
+
+ Respond(c, 200, nil)
+}
diff --git a/internal/backend/config.go b/internal/backend/config.go
@@ -10,6 +10,8 @@ type Config struct {
type StripeConfig struct {
// Stripe API key
Key string `json:"key" binding:"required"`
+ // Stripe webhook signing key
+ WebhookKey string `json:"webhook_key" binding:"required"`
}
type SMTPConfig struct {
diff --git a/internal/backend/env.go b/internal/backend/env.go
@@ -13,6 +13,7 @@ type Env struct {
Stripe *client.API
SMTPAuth smtp.Auth
// Private
+ stripeProcessedEvents map[string]bool
stmts map[string]*sql.Stmt
settingDefs map[string]SettingDef
}
@@ -29,6 +30,7 @@ func NewEnv(c Config) (*Env, error) {
}
e.SMTPAuth = smtp.PlainAuth("", c.SMTP.User, c.SMTP.Password, c.SMTP.Server)
e.Stripe = client.New(c.Stripe.Key, nil)
+ e.stripeProcessedEvents = make(map[string]bool)
e.stmts = make(map[string]*sql.Stmt)
e.settingDefs = make(map[string]SettingDef)
e.Config = c
diff --git a/internal/backend/stripe.go b/internal/backend/stripe.go
@@ -0,0 +1,112 @@
+package backend
+
+import (
+ "encoding/json"
+ "errors"
+ "log"
+ "github.com/stripe/stripe-go/v72"
+)
+
+/*
+ * Don't need customer.subscription.paused without free trial like stuff
+ * <https://docs.stripe.com/billing/subscriptions/pause-payment>
+ */
+func (e *Env) StripeEventHandler(ev *stripe.Event) error {
+ // Not sure whether I need to store in persistent database,
+ // maybe I should
+ v, exists := e.stripeProcessedEvents[ev.ID]
+ if exists && v {
+ log.Printf("StripeEventHandler: %s event already processed, ignoring (and returning success)\n", ev.ID)
+ return nil
+ }
+ e.stripeProcessedEvents[ev.ID] = true
+
+ var err error
+ switch ev.Type {
+ case "customer.subscription.created":
+ err = stripeHandle(e.stripeCreateSub, ev)
+ case "customer.subscription.deleted":
+ err = stripeHandle(e.stripeDeleteSub, ev)
+ case "customer.subscription.updated":
+ err = stripeHandle(e.stripeUpdateSub, ev)
+ default:
+ err = errors.New(ev.Type + ": Event type is not being processed")
+ }
+
+ if err == nil {
+ log.Printf("StripeEventHandler: Handled %s\n", ev.ID)
+ } else {
+ log.Printf("StripeEventHandler: %s: error: %s\n", ev.ID, err.Error())
+ }
+ return err
+}
+
+func (e *Env) stripeCreateSub(s *stripe.Subscription) error {
+ add, err := e.CacheStmt("create_sub",
+ `INSERT INTO spaceplanner.subscriptions (id, customer_id, price_id, product_id, active)
+ VALUES ($1, $2, $3, $4, $5)`)
+ if err != nil {
+ return err
+ }
+
+ price, err := e.subPrice(s)
+ if err != nil {
+ return err
+ }
+
+ log.Printf("stripeHandleCreateSub sub %s for %s (active: %v)\n", s.ID, s.Customer.ID, subIsActive(s))
+ _, err = add.Exec(s.ID, s.Customer.ID, price.ID, price.Product.ID, subIsActive(s))
+ return err
+}
+
+func (e *Env) stripeUpdateSub(s *stripe.Subscription) error {
+ update, err := e.CacheStmt("update_sub",
+ `UPDATE spaceplanner.subscriptions SET (price_id, product_id, active) = ($2, $3, $4)
+ WHERE id = $1`)
+ if err != nil {
+ return err
+ }
+
+ price, err := e.subPrice(s)
+ if err != nil {
+ return err
+ }
+
+ log.Printf("stripeHandleUpdateSub sub %s for %s (active: %v)\n", s.ID, s.Customer.ID, subIsActive(s))
+ _, err = update.Exec(s.ID, price.ID, price.Product.ID, subIsActive(s))
+ return err
+}
+
+func (e *Env) stripeDeleteSub(s *stripe.Subscription) error {
+ delete, err := e.CacheStmt("delete_sub",
+ `DELETE FROM spaceplanner.subscriptions
+ WHERE id = $1`)
+ if err != nil {
+ return err
+ }
+
+ _, err = delete.Exec(s.ID)
+ return err
+
+}
+
+func stripeHandle[T any](handler func(*T) error, ev *stripe.Event) error {
+ var d T
+ if err := json.Unmarshal(ev.Data.Raw, &d); err != nil {
+ return err
+ }
+
+ return handler(&d)
+}
+
+func subIsActive(s *stripe.Subscription) bool {
+ log.Printf("!!! Sub status %v\n", s.Status)
+ return s.Status == stripe.SubscriptionStatusActive
+}
+
+func (e *Env) subPrice(s *stripe.Subscription) (*stripe.Price, error) {
+ if len(s.Items.Data) != 1 {
+ return nil, errors.New("Expected exactly one subscription item")
+ }
+ return s.Items.Data[0].Price, nil
+}
diff --git a/internal/backend/user.go b/internal/backend/user.go
@@ -111,6 +111,28 @@ func (e *Env) LoginUser(username string, password string) (User, error) {
return user, nil;
}
+// In the future this may return a Service
+func (e *Env) UserService(username string) (*string, error) {
+ service, err := e.CacheStmt("get_user_service",
+ `SELECT product_id
+ FROM spaceplanner.subscriptions
+ WHERE active = true AND customer_id = (
+ SELECT stripe_customer_id
+ FROM spaceplanner.users
+ WHERE name = $1
+ )`)
+ if err != nil {
+ return nil, err
+ }
+
+ var prod *string
+ err = service.QueryRow(username).Scan(&prod)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return prod, err
+}
+
func (e *Env) VerifyEmail(username string, code *string) (bool, error) {
emails, err := e.UserEmails(username)
if err != nil {