Files

259 lines
8.8 KiB
Go

package store
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
)
// ErrFleetUpdateRunning is returned by CreateFleetUpdate if another
// fleet update is already in 'running' state. The HTTP layer surfaces
// this as a 409 with a structured error code.
var ErrFleetUpdateRunning = errors.New("store: fleet update already running")
// CreateFleetUpdate inserts the parent row and one pending child per
// hostID, in the order given (position = index). Returns
// ErrFleetUpdateRunning if a fleet update is already in flight.
func (st *Store) CreateFleetUpdate(ctx context.Context, fu FleetUpdate, hostIDs []string) error {
if fu.ID == "" || fu.StartedByUserID == "" || fu.TargetVersion == "" {
return errors.New("store: fleet update id, user_id, target_version required")
}
if fu.Status == "" {
fu.Status = "running"
}
if fu.StartedAt.IsZero() {
fu.StartedAt = time.Now().UTC()
}
tx, err := st.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("store: begin: %w", err)
}
defer func() { _ = tx.Rollback() }()
var existing string
if err := tx.QueryRowContext(ctx,
`SELECT id FROM fleet_updates WHERE status = 'running' LIMIT 1`).
Scan(&existing); err == nil {
return fmt.Errorf("%w: %s", ErrFleetUpdateRunning, existing)
} else if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("store: check active fleet update: %w", err)
}
if _, err := tx.ExecContext(ctx,
`INSERT INTO fleet_updates (id, started_at, started_by_user_id, target_version, status)
VALUES (?, ?, ?, ?, ?)`,
fu.ID, fu.StartedAt.UTC().Format(time.RFC3339Nano), fu.StartedByUserID, fu.TargetVersion, fu.Status,
); err != nil {
return fmt.Errorf("store: insert fleet_updates: %w", err)
}
for i, hid := range hostIDs {
if _, err := tx.ExecContext(ctx,
`INSERT INTO fleet_update_hosts (fleet_update_id, host_id, position, status)
VALUES (?, ?, ?, 'pending')`,
fu.ID, hid, i,
); err != nil {
return fmt.Errorf("store: insert fleet_update_hosts: %w", err)
}
}
return tx.Commit()
}
// ActiveFleetUpdate returns the currently-running fleet update or nil.
func (st *Store) ActiveFleetUpdate(ctx context.Context) (*FleetUpdate, error) {
var fu FleetUpdate
var startedAt string
var current sql.NullString
var halted sql.NullString
var completedAt sql.NullString
err := st.db.QueryRowContext(ctx,
`SELECT id, started_at, started_by_user_id, target_version, status,
current_host_id, halted_reason, completed_at
FROM fleet_updates WHERE status = 'running' LIMIT 1`).
Scan(&fu.ID, &startedAt, &fu.StartedByUserID, &fu.TargetVersion, &fu.Status,
&current, &halted, &completedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("store: active fleet update: %w", err)
}
fu.StartedAt, _ = time.Parse(time.RFC3339Nano, startedAt)
fu.CurrentHostID = current.String
fu.HaltedReason = halted.String
if completedAt.Valid {
t, _ := time.Parse(time.RFC3339Nano, completedAt.String)
fu.CompletedAt = &t
}
return &fu, nil
}
// GetFleetUpdate hydrates parent + ordered child rows. Returns
// ErrNotFound on missing id.
func (st *Store) GetFleetUpdate(ctx context.Context, id string) (*FleetUpdate, []FleetUpdateHost, error) {
var fu FleetUpdate
var startedAt string
var current sql.NullString
var halted sql.NullString
var completedAt sql.NullString
err := st.db.QueryRowContext(ctx,
`SELECT id, started_at, started_by_user_id, target_version, status,
current_host_id, halted_reason, completed_at
FROM fleet_updates WHERE id = ?`, id).
Scan(&fu.ID, &startedAt, &fu.StartedByUserID, &fu.TargetVersion, &fu.Status,
&current, &halted, &completedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
if err != nil {
return nil, nil, fmt.Errorf("store: get fleet update: %w", err)
}
fu.StartedAt, _ = time.Parse(time.RFC3339Nano, startedAt)
fu.CurrentHostID = current.String
fu.HaltedReason = halted.String
if completedAt.Valid {
t, _ := time.Parse(time.RFC3339Nano, completedAt.String)
fu.CompletedAt = &t
}
rows, err := st.db.QueryContext(ctx,
`SELECT host_id, position, status, COALESCE(job_id, ''), COALESCE(failed_reason, '')
FROM fleet_update_hosts
WHERE fleet_update_id = ?
ORDER BY position`, id)
if err != nil {
return nil, nil, fmt.Errorf("store: list fleet hosts: %w", err)
}
defer func() { _ = rows.Close() }()
out := []FleetUpdateHost{}
for rows.Next() {
fh := FleetUpdateHost{FleetUpdateID: id}
if err := rows.Scan(&fh.HostID, &fh.Position, &fh.Status, &fh.JobID, &fh.FailedReason); err != nil {
return nil, nil, fmt.Errorf("store: scan fleet host: %w", err)
}
out = append(out, fh)
}
return &fu, out, rows.Err()
}
// ListPendingFleetUpdateHosts returns rows with status='pending' for
// this fleet update, in position order. The worker calls this to
// pick the next host to dispatch.
func (st *Store) ListPendingFleetUpdateHosts(ctx context.Context, fuID string) ([]FleetUpdateHost, error) {
rows, err := st.db.QueryContext(ctx,
`SELECT host_id, position, status, COALESCE(job_id, ''), COALESCE(failed_reason, '')
FROM fleet_update_hosts
WHERE fleet_update_id = ? AND status = 'pending'
ORDER BY position`, fuID)
if err != nil {
return nil, fmt.Errorf("store: list pending fleet hosts: %w", err)
}
defer func() { _ = rows.Close() }()
out := []FleetUpdateHost{}
for rows.Next() {
fh := FleetUpdateHost{FleetUpdateID: fuID}
if err := rows.Scan(&fh.HostID, &fh.Position, &fh.Status, &fh.JobID, &fh.FailedReason); err != nil {
return nil, err
}
out = append(out, fh)
}
return out, rows.Err()
}
// SetFleetUpdateHostStatus moves one row through pending → running →
// {succeeded, failed, skipped}. failedReason and jobID may be empty
// (e.g. on succeeded). Empty values are stored as NULL so subsequent
// reads round-trip cleanly via COALESCE.
func (st *Store) SetFleetUpdateHostStatus(ctx context.Context, fuID, hostID, status, failedReason, jobID string) error {
_, err := st.db.ExecContext(ctx,
`UPDATE fleet_update_hosts
SET status = ?, failed_reason = ?, job_id = COALESCE(?, job_id)
WHERE fleet_update_id = ? AND host_id = ?`,
status, nullableString(failedReason), nullableString(jobID),
fuID, hostID,
)
if err != nil {
return fmt.Errorf("store: set fleet host status: %w", err)
}
return nil
}
// SetFleetUpdateCurrentHost stamps which host the worker is actively
// waiting on. Pass empty string to clear.
func (st *Store) SetFleetUpdateCurrentHost(ctx context.Context, fuID, hostID string) error {
_, err := st.db.ExecContext(ctx,
`UPDATE fleet_updates SET current_host_id = ? WHERE id = ?`,
nullableString(hostID), fuID,
)
if err != nil {
return fmt.Errorf("store: set fleet current host: %w", err)
}
return nil
}
// HaltFleetUpdate flips status to 'halted', stamps the reason, and
// clears current_host_id.
func (st *Store) HaltFleetUpdate(ctx context.Context, fuID, reason string, when time.Time) error {
_, err := st.db.ExecContext(ctx,
`UPDATE fleet_updates
SET status = 'halted', halted_reason = ?, current_host_id = NULL,
completed_at = ?
WHERE id = ? AND status = 'running'`,
reason, when.UTC().Format(time.RFC3339Nano), fuID,
)
if err != nil {
return fmt.Errorf("store: halt fleet update: %w", err)
}
return nil
}
// CancelFleetUpdate flips status to 'cancelled'. Caller checks that
// the row is still 'running' before calling.
func (st *Store) CancelFleetUpdate(ctx context.Context, fuID string, when time.Time) error {
_, err := st.db.ExecContext(ctx,
`UPDATE fleet_updates
SET status = 'cancelled', current_host_id = NULL, completed_at = ?
WHERE id = ? AND status = 'running'`,
when.UTC().Format(time.RFC3339Nano), fuID,
)
if err != nil {
return fmt.Errorf("store: cancel fleet update: %w", err)
}
return nil
}
// CompleteFleetUpdate flips status to 'completed' once every host has
// reached a terminal state.
func (st *Store) CompleteFleetUpdate(ctx context.Context, fuID string, when time.Time) error {
_, err := st.db.ExecContext(ctx,
`UPDATE fleet_updates
SET status = 'completed', current_host_id = NULL, completed_at = ?
WHERE id = ? AND status = 'running'`,
when.UTC().Format(time.RFC3339Nano), fuID,
)
if err != nil {
return fmt.Errorf("store: complete fleet update: %w", err)
}
return nil
}
// RunningUpdateJobForHost returns the id of any in-flight (queued or
// running) `update` job for hostID, or "" + nil if none. Used by the
// host-update HTTP handler to refuse double-dispatch and by the
// fleet worker to dedupe on retry.
func (st *Store) RunningUpdateJobForHost(ctx context.Context, hostID string) (string, error) {
var id string
err := st.db.QueryRowContext(ctx,
`SELECT id FROM jobs
WHERE host_id = ? AND kind = 'update' AND status IN ('queued','running')
ORDER BY created_at DESC LIMIT 1`, hostID).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
return "", nil
}
if err != nil {
return "", fmt.Errorf("store: running update job: %w", err)
}
return id, nil
}