From 779f5aac47c4e257a6d8d82af8958209c8c58d7d Mon Sep 17 00:00:00 2001 From: Steve Cliff Date: Sun, 3 May 2026 22:15:35 +0100 Subject: [PATCH] store: wrap UpsertHostRepoStats in a transaction (concurrency safety) --- internal/store/host_repo_stats.go | 32 +++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/internal/store/host_repo_stats.go b/internal/store/host_repo_stats.go index 0a05d04..1952f68 100644 --- a/internal/store/host_repo_stats.go +++ b/internal/store/host_repo_stats.go @@ -36,6 +36,18 @@ func (s *Store) GetHostRepoStats(ctx context.Context, hostID string) (*HostRepoS return scanHostRepoStats(row) } +// getHostRepoStatsTx is identical to GetHostRepoStats but runs on an +// existing transaction so the fetch-merge-upsert in UpsertHostRepoStats +// is fully serialized. +func getHostRepoStatsTx(ctx context.Context, tx *sql.Tx, hostID string) (*HostRepoStats, error) { + row := tx.QueryRowContext(ctx, + `SELECT host_id, total_size_bytes, raw_size_bytes, unique_files, + snapshot_count, last_check_at, last_check_status, + lock_present, last_prune_at, last_prune_freed_bytes, updated_at + FROM host_repo_stats WHERE host_id = ?`, hostID) + return scanHostRepoStats(row) +} + // scanHostRepoStats scans one row from host_repo_stats. func scanHostRepoStats(row *sql.Row) (*HostRepoStats, error) { var ( @@ -113,12 +125,17 @@ func scanHostRepoStats(row *sql.Row) (*HostRepoStats, error) { // UpsertHostRepoStats writes a partial update — only non-nil pointer // fields (and LastCheckStatus when non-empty) overwrite existing -// columns. Implemented as a row-fetch + merge + INSERT…ON CONFLICT so -// each call is atomic at the application level (sufficient for a -// single-writer server). +// columns. Wrapped in a transaction so concurrent upserts on the same +// host don't lose updates. func (s *Store) UpsertHostRepoStats(ctx context.Context, hostID string, patch HostRepoStats) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("store: begin host_repo_stats tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + // Fetch existing row; start from zero if absent. - cur, err := s.GetHostRepoStats(ctx, hostID) + cur, err := getHostRepoStatsTx(ctx, tx, hostID) if err != nil && !errors.Is(err, ErrNotFound) { return err } @@ -163,7 +180,7 @@ func (s *Store) UpsertHostRepoStats(ctx context.Context, hostID string, patch Ho lockPresentInt = 1 } - _, err = s.db.ExecContext(ctx, + if _, err = tx.ExecContext(ctx, `INSERT INTO host_repo_stats (host_id, total_size_bytes, raw_size_bytes, unique_files, snapshot_count, last_check_at, last_check_status, @@ -191,11 +208,10 @@ func (s *Store) UpsertHostRepoStats(ctx context.Context, hostID string, patch Ho nullableTime(cur.LastPruneAt), nullableInt64(cur.LastPruneFreedBytes), now, - ) - if err != nil { + ); err != nil { return fmt.Errorf("store: upsert host_repo_stats: %w", err) } - return nil + return tx.Commit() } // nullableInt64 converts *int64 to a database/sql-compatible nullable value.