commit 8ccf324d76ea9df82b191ba61b76fc3ce016331f
parent 801183b1ad3b2984a19f05dbff20d606a0bb826f
Author: Jacob R. Edwards <jacob@jacobedwards.org>
Date: Sun, 4 Aug 2024 21:43:58 -0700
Allow sql transactions for database-related functions
This gives callers much greater flexibility.
Diffstat:
3 files changed, 34 insertions(+), 36 deletions(-)
diff --git a/cmd/api/users.go b/cmd/api/users.go
@@ -57,7 +57,7 @@ func (e *Env) GetUserSettings(c *gin.Context) {
return
}
- settings, err := e.backend.GetUserSettings(user)
+ settings, err := e.backend.GetUserSettings(nil, user)
if err != nil {
RespondError(c, 400, "Unable to get settings: %s", err.Error())
return
@@ -74,7 +74,7 @@ func (e *Env) GetUserSetting(c *gin.Context) {
return
}
- setting, err := e.backend.GetUserSettings(user, name)
+ setting, err := e.backend.GetUserSettings(nil, user, name)
if err != nil {
RespondError(c, 400, "Unable to get settings: %s", err.Error())
return
@@ -101,12 +101,12 @@ func (e *Env) SetUserSettings(c *gin.Context) {
RespondError(c, 400, "Unable to begin transaction: %s", err.Error())
return
}
- if _, err := e.backend.TxDeleteUserSettings(tx, user); err != nil {
+ if _, err := e.backend.DeleteUserSettings(tx, user); err != nil {
tx.Rollback()
RespondError(c, 400, "Unable to truncate settings: %s", err.Error())
return
}
- if err := e.backend.TxUpdateUserSettings(tx, user, settings); err != nil {
+ if err := e.backend.UpdateUserSettings(tx, user, settings); err != nil {
RespondError(c, 400, "Unable to set settings: %s", err.Error())
return
}
@@ -134,7 +134,7 @@ func (e *Env) SetUserSetting(c *gin.Context) {
setting := make(map[string]backend.Setting)
setting[name] = value
- if err := e.backend.UpdateUserSettings(user, setting);
+ if err := e.backend.UpdateUserSettings(nil, user, setting);
err != nil {
RespondError(c, 400, "Unable to update %q: %s", setting, err.Error())
}
diff --git a/internal/backend/env.go b/internal/backend/env.go
@@ -36,6 +36,19 @@ func (e *Env) CacheStmt(name, sql string) (*sql.Stmt, error) {
return stmt, nil
}
+func (e *Env) CacheTxStmt(tx *sql.Tx, name, sql string) (*sql.Stmt, error) {
+ if tx == nil {
+ return e.CacheStmt(name, sql)
+ }
+
+ stmt, err := e.CacheStmt(name, sql)
+ if err != nil {
+ return nil, err
+ }
+ return tx.Stmt(stmt), nil
+}
+
+
func (e *Env) Free() {
for _, s := range e.stmts {
s.Close()
diff --git a/internal/backend/user.go b/internal/backend/user.go
@@ -80,27 +80,18 @@ func (e *Env) LoginUser(username string, password string) (User, error) {
return user, nil;
}
-func (e *Env) UpdateUserSettings(username string, settings map[string]Setting) error {
- tx, err := e.DB.Begin()
- if err != nil {
- return err
- }
- err = e.TxUpdateUserSettings(tx, username, settings)
- if err != nil {
- tx.Rollback()
- return err
- }
- return tx.Commit()
-}
-func (e *Env) TxUpdateUserSettings(tx *sql.Tx, username string, settings map[string]Setting) error {
+func (e *Env) UpdateUserSettings(tx *sql.Tx, username string, settings map[string]Setting) error {
stmt, err := e.CacheStmt("set_user_settings", `INSERT INTO user_settings VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (username, name) DO UPDATE
SET (strval, numval, boolval) = (EXCLUDED.strval, EXCLUDED.numval, EXCLUDED.boolval)`)
if err != nil {
return err
}
- stmt = tx.Stmt(stmt)
+
+ if tx != nil {
+ stmt = tx.Stmt(stmt)
+ }
for name, setting := range settings {
r := toRawSetting(username, name, setting)
@@ -113,7 +104,7 @@ func (e *Env) TxUpdateUserSettings(tx *sql.Tx, username string, settings map[str
return nil
}
-func (e *Env) GetUserSettings(username string, names ...string) (map[string]Setting, error) {
+func (e *Env) GetUserSettings(tx *sql.Tx, username string, names ...string) (map[string]Setting, error) {
stmt, err := e.CacheStmt("get_user_settings", `SELECT * FROM user_settings WHERE
user_settings.username = $1 AND
(ARRAY_LENGTH($2::varchar[], 1) IS NULL OR user_settings.name = ANY ($2))`)
@@ -121,6 +112,10 @@ func (e *Env) GetUserSettings(username string, names ...string) (map[string]Sett
return nil, err
}
+ if tx != nil {
+ stmt = tx.Stmt(stmt)
+ }
+
rows, err := stmt.Query(username, pq.Array(names))
if err != nil {
return nil, err
@@ -129,21 +124,7 @@ func (e *Env) GetUserSettings(username string, names ...string) (map[string]Sett
return collectSettings(rows)
}
-func (e *Env) DeleteUserSettings(username string, names ...string) (map[string]Setting, error) {
- var deleted map[string]Setting
- tx, err := e.DB.Begin()
- if err != nil {
- return deleted, err
- }
- deleted, err = e.TxDeleteUserSettings(tx, username, names...)
- if err != nil {
- tx.Rollback()
- return deleted, err
- }
- return deleted, tx.Commit()
-}
-
-func (e *Env) TxDeleteUserSettings(tx *sql.Tx, username string, names ...string) (map[string]Setting, error) {
+func (e *Env) DeleteUserSettings(tx *sql.Tx, username string, names ...string) (map[string]Setting, error) {
var deleted map[string]Setting
stmt, err := e.CacheStmt("del_user_settings", `DELETE FROM user_settings WHERE user_settings.username = $1 AND
(ARRAY_LENGTH($2::string[], 1) IS NULL OR user_settings.name = ANY ($2)) RETURNING`)
@@ -151,7 +132,11 @@ func (e *Env) TxDeleteUserSettings(tx *sql.Tx, username string, names ...string)
return deleted, err
}
- rows, err := tx.Stmt(stmt).Query(username, names)
+ if tx != nil {
+ tx.Stmt(stmt)
+ }
+
+ rows, err := stmt.Query(username, names)
if err != nil {
return deleted, err
}