store: host_credentials becomes kind-aware (repo + admin slots)
This commit is contained in:
@@ -167,7 +167,7 @@ func (s *Server) handleAgentEnroll(w stdhttp.ResponseWriter, r *stdhttp.Request)
|
||||
// /api/hosts/{id}/repo-credentials. Failing the whole enrolment
|
||||
// here would leave a half-burned token + an orphan host.
|
||||
if encForHost != "" {
|
||||
if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, encForHost); err != nil {
|
||||
if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, store.CredKindRepo, encForHost); err != nil {
|
||||
slog.Error("enrollment: set host credentials failed",
|
||||
"host_id", hostID, "err", err)
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (s *Server) handleGetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R
|
||||
writeJSONError(w, stdhttp.StatusBadRequest, "missing_id", "")
|
||||
return
|
||||
}
|
||||
enc, err := s.deps.Store.GetHostCredentials(r.Context(), hostID)
|
||||
enc, err := s.deps.Store.GetHostCredentials(r.Context(), hostID, store.CredKindRepo)
|
||||
if err != nil {
|
||||
if errors.Is(err, store.ErrNotFound) {
|
||||
writeJSONError(w, stdhttp.StatusNotFound, "not_set", "")
|
||||
@@ -107,7 +107,7 @@ func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R
|
||||
|
||||
// Merge with the existing row, if any.
|
||||
existing := repoCredsBlob{}
|
||||
if cur, err := s.deps.Store.GetHostCredentials(r.Context(), hostID); err == nil {
|
||||
if cur, err := s.deps.Store.GetHostCredentials(r.Context(), hostID, store.CredKindRepo); err == nil {
|
||||
plain, err := s.deps.AEAD.Decrypt(cur, []byte("host:"+hostID))
|
||||
if err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "decrypt_failed", "")
|
||||
@@ -139,7 +139,7 @@ func (s *Server) handleSetHostCredentials(w stdhttp.ResponseWriter, r *stdhttp.R
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, enc); err != nil {
|
||||
if err := s.deps.Store.SetHostCredentials(r.Context(), hostID, store.CredKindRepo, enc); err != nil {
|
||||
writeJSONError(w, stdhttp.StatusInternalServerError, "internal", "")
|
||||
return
|
||||
}
|
||||
@@ -212,7 +212,7 @@ func (s *Server) onAgentHello(ctx context.Context, hostID string, conn *ws.Conn)
|
||||
// them the runner can't talk to the repo). We rely on Restic's
|
||||
// idempotent init for re-runs.
|
||||
func (s *Server) maybeAutoInit(ctx context.Context, hostID string, conn *ws.Conn) {
|
||||
if _, err := s.deps.Store.GetHostCredentials(ctx, hostID); err != nil {
|
||||
if _, err := s.deps.Store.GetHostCredentials(ctx, hostID, store.CredKindRepo); err != nil {
|
||||
// No creds bound yet — operator hasn't supplied them. The next
|
||||
// hello after creds land will pick this up.
|
||||
return
|
||||
@@ -266,7 +266,7 @@ func (s *Server) maybeAutoInit(ctx context.Context, hostID string, conn *ws.Conn
|
||||
// credentials. Silent no-op when the host has nothing on file
|
||||
// (the operator hasn't bound creds to it yet).
|
||||
func (s *Server) pushRepoCredsOnHello(ctx context.Context, hostID string, conn *ws.Conn) {
|
||||
enc, err := s.deps.Store.GetHostCredentials(ctx, hostID)
|
||||
enc, err := s.deps.Store.GetHostCredentials(ctx, hostID, store.CredKindRepo)
|
||||
if err != nil {
|
||||
if !errors.Is(err, store.ErrNotFound) {
|
||||
slog.Warn("on-hello: load host creds", "host_id", hostID, "err", err)
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.dcglab.co.uk/steve/restic-manager/internal/store"
|
||||
)
|
||||
|
||||
// TestEnrollmentTransfersRepoCreds verifies the round-trip:
|
||||
@@ -57,12 +59,12 @@ func TestEnrollmentTransfersRepoCreds(t *testing.T) {
|
||||
hostID, "host42", "linux", "amd64", "2026-01-01T00:00:00Z"); err != nil {
|
||||
t.Fatalf("insert host: %v", err)
|
||||
}
|
||||
if err := st.SetHostCredentials(ctx, hostID, encForHost); err != nil {
|
||||
if err := st.SetHostCredentials(ctx, hostID, store.CredKindRepo, encForHost); err != nil {
|
||||
t.Fatalf("set host credentials: %v", err)
|
||||
}
|
||||
|
||||
// host_credentials row should now hold the host-bound ciphertext.
|
||||
got, err := st.GetHostCredentials(ctx, hostID)
|
||||
got, err := st.GetHostCredentials(ctx, hostID, store.CredKindRepo)
|
||||
if err != nil {
|
||||
t.Fatalf("get host creds: %v", err)
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ func enrolHostForWS(t *testing.T, srv *Server, st *store.Store, name string) (ho
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt: %v", err)
|
||||
}
|
||||
if err := st.SetHostCredentials(context.Background(), hostID, enc); err != nil {
|
||||
if err := st.SetHostCredentials(context.Background(), hostID, store.CredKindRepo, enc); err != nil {
|
||||
t.Fatalf("set creds: %v", err)
|
||||
}
|
||||
return hostID, token
|
||||
|
||||
@@ -61,7 +61,7 @@ func (s *Server) loadHostRepoPage(r *stdhttp.Request, host store.Host) (*hostRep
|
||||
}
|
||||
|
||||
// Credentials (redacted).
|
||||
enc, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID)
|
||||
enc, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID, store.CredKindRepo)
|
||||
switch {
|
||||
case err == nil:
|
||||
plain, derr := s.deps.AEAD.Decrypt(enc, []byte("host:"+host.ID))
|
||||
@@ -204,7 +204,7 @@ func (s *Server) handleUIRepoCredentialsSave(w stdhttp.ResponseWriter, r *stdhtt
|
||||
|
||||
// Merge with existing blob — same semantics as the JSON PUT.
|
||||
existing := repoCredsBlob{}
|
||||
if cur, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID); err == nil {
|
||||
if cur, err := s.deps.Store.GetHostCredentials(r.Context(), host.ID, store.CredKindRepo); err == nil {
|
||||
if plain, derr := s.deps.AEAD.Decrypt(cur, []byte("host:"+host.ID)); derr == nil {
|
||||
_ = json.Unmarshal(plain, &existing)
|
||||
}
|
||||
@@ -227,7 +227,7 @@ func (s *Server) handleUIRepoCredentialsSave(w stdhttp.ResponseWriter, r *stdhtt
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := s.deps.Store.SetHostCredentials(r.Context(), host.ID, enc); err != nil {
|
||||
if err := s.deps.Store.SetHostCredentials(r.Context(), host.ID, store.CredKindRepo, enc); err != nil {
|
||||
slog.Error("ui repo creds: persist", "err", err)
|
||||
stdhttp.Error(w, "internal", stdhttp.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -8,13 +8,23 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CredentialKind identifies the role of a host_credentials row.
|
||||
type CredentialKind string
|
||||
|
||||
const (
|
||||
// CredKindRepo is the append-only credential used for every backup.
|
||||
CredKindRepo CredentialKind = "repo"
|
||||
// CredKindAdmin is the delete-capable credential used for prune.
|
||||
CredKindAdmin CredentialKind = "admin"
|
||||
)
|
||||
|
||||
// GetHostCredentials returns the AEAD-encrypted repo creds blob for
|
||||
// the host, or ("", ErrNotFound) if no credential has ever been set.
|
||||
// the host + kind, or ("", ErrNotFound) if no matching row exists.
|
||||
// The caller decrypts using host_id as AEAD additional data.
|
||||
func (s *Store) GetHostCredentials(ctx context.Context, hostID string) (string, error) {
|
||||
func (s *Store) GetHostCredentials(ctx context.Context, hostID string, kind CredentialKind) (string, error) {
|
||||
row := s.db.QueryRowContext(ctx,
|
||||
`SELECT enc_repo_creds FROM host_credentials WHERE host_id = ?`,
|
||||
hostID)
|
||||
`SELECT enc_repo_creds FROM host_credentials WHERE host_id = ? AND kind = ?`,
|
||||
hostID, string(kind))
|
||||
var enc string
|
||||
if err := row.Scan(&enc); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
@@ -25,22 +35,35 @@ func (s *Store) GetHostCredentials(ctx context.Context, hostID string) (string,
|
||||
return enc, nil
|
||||
}
|
||||
|
||||
// SetHostCredentials replaces the host's encrypted repo creds blob.
|
||||
// The caller has already encrypted using host_id as additional data.
|
||||
func (s *Store) SetHostCredentials(ctx context.Context, hostID, encRepoCreds string) error {
|
||||
// SetHostCredentials replaces the host's encrypted repo creds blob for
|
||||
// the given kind. The caller has already encrypted using host_id as
|
||||
// additional data.
|
||||
func (s *Store) SetHostCredentials(ctx context.Context, hostID string, kind CredentialKind, encRepoCreds string) error {
|
||||
if encRepoCreds == "" {
|
||||
return fmt.Errorf("store: empty enc_repo_creds")
|
||||
}
|
||||
now := time.Now().UTC().Format(time.RFC3339Nano)
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`INSERT INTO host_credentials (host_id, enc_repo_creds, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(host_id) DO UPDATE SET
|
||||
`INSERT INTO host_credentials (host_id, kind, enc_repo_creds, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(host_id, kind) DO UPDATE SET
|
||||
enc_repo_creds = excluded.enc_repo_creds,
|
||||
updated_at = excluded.updated_at`,
|
||||
hostID, encRepoCreds, now)
|
||||
hostID, string(kind), encRepoCreds, now)
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: set host credentials: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteHostCredentials removes the credential row for the given host
|
||||
// and kind. A no-op if the row does not exist.
|
||||
func (s *Store) DeleteHostCredentials(ctx context.Context, hostID string, kind CredentialKind) error {
|
||||
_, err := s.db.ExecContext(ctx,
|
||||
`DELETE FROM host_credentials WHERE host_id = ? AND kind = ?`,
|
||||
hostID, string(kind))
|
||||
if err != nil {
|
||||
return fmt.Errorf("store: delete host credentials: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// seedHost inserts a minimal host row for testing.
|
||||
func seedHost(t *testing.T, s *Store, hostID string) {
|
||||
t.Helper()
|
||||
_, err := s.DB().Exec(
|
||||
`INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`,
|
||||
hostID, hostID, "linux", "amd64", "2026-01-01T00:00:00Z")
|
||||
if err != nil {
|
||||
t.Fatalf("seed host %q: %v", hostID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostCredentialsAdminRowSeparate(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const hostID = "h-creds-test"
|
||||
seedHost(t, s, hostID)
|
||||
|
||||
const repoBlob = "enc-repo-blob"
|
||||
const adminBlob = "enc-admin-blob"
|
||||
|
||||
// Set repo creds.
|
||||
if err := s.SetHostCredentials(ctx, hostID, CredKindRepo, repoBlob); err != nil {
|
||||
t.Fatalf("set repo creds: %v", err)
|
||||
}
|
||||
// Set admin creds.
|
||||
if err := s.SetHostCredentials(ctx, hostID, CredKindAdmin, adminBlob); err != nil {
|
||||
t.Fatalf("set admin creds: %v", err)
|
||||
}
|
||||
|
||||
// Fetch each by kind and assert they differ.
|
||||
gotRepo, err := s.GetHostCredentials(ctx, hostID, CredKindRepo)
|
||||
if err != nil {
|
||||
t.Fatalf("get repo creds: %v", err)
|
||||
}
|
||||
gotAdmin, err := s.GetHostCredentials(ctx, hostID, CredKindAdmin)
|
||||
if err != nil {
|
||||
t.Fatalf("get admin creds: %v", err)
|
||||
}
|
||||
if gotRepo != repoBlob {
|
||||
t.Errorf("repo creds: got %q, want %q", gotRepo, repoBlob)
|
||||
}
|
||||
if gotAdmin != adminBlob {
|
||||
t.Errorf("admin creds: got %q, want %q", gotAdmin, adminBlob)
|
||||
}
|
||||
if gotRepo == gotAdmin {
|
||||
t.Error("repo and admin blobs must differ")
|
||||
}
|
||||
|
||||
// Delete admin; repo must be unaffected.
|
||||
if err := s.DeleteHostCredentials(ctx, hostID, CredKindAdmin); err != nil {
|
||||
t.Fatalf("delete admin creds: %v", err)
|
||||
}
|
||||
if _, err := s.GetHostCredentials(ctx, hostID, CredKindAdmin); !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("after delete, expected ErrNotFound for admin; got %v", err)
|
||||
}
|
||||
if got, err := s.GetHostCredentials(ctx, hostID, CredKindRepo); err != nil || got != repoBlob {
|
||||
t.Errorf("repo creds should survive admin delete; got %q, err %v", got, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostCredentialsNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.GetHostCredentials(ctx, "no-such-host", CredKindRepo)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostCredentialsUpsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const hostID = "h-upsert-test"
|
||||
seedHost(t, s, hostID)
|
||||
|
||||
if err := s.SetHostCredentials(ctx, hostID, CredKindRepo, "v1"); err != nil {
|
||||
t.Fatalf("set v1: %v", err)
|
||||
}
|
||||
if err := s.SetHostCredentials(ctx, hostID, CredKindRepo, "v2"); err != nil {
|
||||
t.Fatalf("set v2 (upsert): %v", err)
|
||||
}
|
||||
got, err := s.GetHostCredentials(ctx, hostID, CredKindRepo)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if got != "v2" {
|
||||
t.Errorf("expected v2, got %q", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user