store: wrap UpsertHostRepoStats in a transaction (concurrency safety)

This commit is contained in:
2026-05-03 22:15:35 +01:00
parent 84a8c060b6
commit 5200e44536
+24 -8
View File
@@ -36,6 +36,18 @@ func (s *Store) GetHostRepoStats(ctx context.Context, hostID string) (*HostRepoS
return scanHostRepoStats(row) return scanHostRepoStats(row)
} }
// getHostRepoStatsTx is identical to GetHostRepoStats but runs on an
// existing transaction so the fetch-merge-upsert in UpsertHostRepoStats
// is fully serialized.
func getHostRepoStatsTx(ctx context.Context, tx *sql.Tx, hostID string) (*HostRepoStats, error) {
row := tx.QueryRowContext(ctx,
`SELECT host_id, total_size_bytes, raw_size_bytes, unique_files,
snapshot_count, last_check_at, last_check_status,
lock_present, last_prune_at, last_prune_freed_bytes, updated_at
FROM host_repo_stats WHERE host_id = ?`, hostID)
return scanHostRepoStats(row)
}
// scanHostRepoStats scans one row from host_repo_stats. // scanHostRepoStats scans one row from host_repo_stats.
func scanHostRepoStats(row *sql.Row) (*HostRepoStats, error) { func scanHostRepoStats(row *sql.Row) (*HostRepoStats, error) {
var ( var (
@@ -113,12 +125,17 @@ func scanHostRepoStats(row *sql.Row) (*HostRepoStats, error) {
// UpsertHostRepoStats writes a partial update — only non-nil pointer // UpsertHostRepoStats writes a partial update — only non-nil pointer
// fields (and LastCheckStatus when non-empty) overwrite existing // fields (and LastCheckStatus when non-empty) overwrite existing
// columns. Implemented as a row-fetch + merge + INSERT…ON CONFLICT so // columns. Wrapped in a transaction so concurrent upserts on the same
// each call is atomic at the application level (sufficient for a // host don't lose updates.
// single-writer server).
func (s *Store) UpsertHostRepoStats(ctx context.Context, hostID string, patch HostRepoStats) error { func (s *Store) UpsertHostRepoStats(ctx context.Context, hostID string, patch HostRepoStats) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("store: begin host_repo_stats tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Fetch existing row; start from zero if absent. // Fetch existing row; start from zero if absent.
cur, err := s.GetHostRepoStats(ctx, hostID) cur, err := getHostRepoStatsTx(ctx, tx, hostID)
if err != nil && !errors.Is(err, ErrNotFound) { if err != nil && !errors.Is(err, ErrNotFound) {
return err return err
} }
@@ -163,7 +180,7 @@ func (s *Store) UpsertHostRepoStats(ctx context.Context, hostID string, patch Ho
lockPresentInt = 1 lockPresentInt = 1
} }
_, err = s.db.ExecContext(ctx, if _, err = tx.ExecContext(ctx,
`INSERT INTO host_repo_stats `INSERT INTO host_repo_stats
(host_id, total_size_bytes, raw_size_bytes, unique_files, (host_id, total_size_bytes, raw_size_bytes, unique_files,
snapshot_count, last_check_at, last_check_status, snapshot_count, last_check_at, last_check_status,
@@ -191,11 +208,10 @@ func (s *Store) UpsertHostRepoStats(ctx context.Context, hostID string, patch Ho
nullableTime(cur.LastPruneAt), nullableTime(cur.LastPruneAt),
nullableInt64(cur.LastPruneFreedBytes), nullableInt64(cur.LastPruneFreedBytes),
now, now,
) ); err != nil {
if err != nil {
return fmt.Errorf("store: upsert host_repo_stats: %w", err) return fmt.Errorf("store: upsert host_repo_stats: %w", err)
} }
return nil return tx.Commit()
} }
// nullableInt64 converts *int64 to a database/sql-compatible nullable value. // nullableInt64 converts *int64 to a database/sql-compatible nullable value.