package store import ( "context" "database/sql" "errors" "fmt" "time" ) // CreateEnrollmentToken persists a fresh one-time token. The caller // has already hashed the raw token; the raw form is returned to the // operator (printed in the install snippet) and never persisted. // // encRepoCreds is the AEAD-encrypted blob of {repo_url, repo_username, // repo_password} that ConsumeEnrollmentToken will promote to a // host_credentials row. Empty string = operator chose to set creds // later via PUT /api/hosts/{id}/repo-credentials; the agent will // refuse backup jobs until that lands. func (s *Store) CreateEnrollmentToken(ctx context.Context, tokenHash string, ttl time.Duration, encRepoCreds string) error { now := time.Now().UTC() var enc any = nil if encRepoCreds != "" { enc = encRepoCreds } _, err := s.db.ExecContext(ctx, `INSERT INTO enrollment_tokens (token_hash, created_at, expires_at, enc_repo_creds) VALUES (?, ?, ?, ?)`, tokenHash, now.Format(time.RFC3339Nano), now.Add(ttl).Format(time.RFC3339Nano), enc) if err != nil { return fmt.Errorf("store: create enrollment token: %w", err) } return nil } // ConsumeEnrollmentToken atomically validates a token (must exist, // not be consumed, not be expired), marks it consumed by hostID, and // — if the token carries encrypted repo creds — promotes them to a // host_credentials row in the same tx. The encrypted blob is // re-encrypted by the caller with host_id as additional data; we // don't crack it open here. // // Returns ErrNotFound on any failure. func (s *Store) ConsumeEnrollmentToken(ctx context.Context, tokenHash, hostID, encRepoCredsForHost string) error { now := time.Now().UTC().Format(time.RFC3339Nano) tx, err := s.db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("store: consume enrollment token: begin: %w", err) } defer func() { _ = tx.Rollback() }() res, err := tx.ExecContext(ctx, `UPDATE enrollment_tokens SET consumed_at = ?, consumed_host = ? WHERE token_hash = ? AND consumed_at IS NULL AND expires_at > ?`, now, hostID, tokenHash, now) if err != nil { return fmt.Errorf("store: consume enrollment token: %w", err) } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } if encRepoCredsForHost != "" { if _, err := tx.ExecContext(ctx, `INSERT INTO host_credentials (host_id, enc_repo_creds, updated_at) VALUES (?, ?, ?) ON CONFLICT(host_id) DO UPDATE SET enc_repo_creds = excluded.enc_repo_creds, updated_at = excluded.updated_at`, hostID, encRepoCredsForHost, now); err != nil { return fmt.Errorf("store: promote host credentials: %w", err) } } if err := tx.Commit(); err != nil { return fmt.Errorf("store: consume enrollment token: commit: %w", err) } return nil } // GetEnrollmentTokenCreds returns the encrypted repo-creds blob the // operator stashed when creating the token, or ("", ErrNotFound) if // the token is gone / consumed / expired / had no creds attached. // // The caller decrypts using token_hash as the AEAD additional data, // then re-encrypts using host_id as additional data before passing // to ConsumeEnrollmentToken. func (s *Store) GetEnrollmentTokenCreds(ctx context.Context, tokenHash string) (string, error) { now := time.Now().UTC().Format(time.RFC3339Nano) row := s.db.QueryRowContext(ctx, `SELECT enc_repo_creds FROM enrollment_tokens WHERE token_hash = ? AND consumed_at IS NULL AND expires_at > ?`, tokenHash, now) var enc sql.NullString if err := row.Scan(&enc); err != nil { if errors.Is(err, sql.ErrNoRows) { return "", ErrNotFound } return "", fmt.Errorf("store: get enrollment token creds: %w", err) } if !enc.Valid { return "", nil } return enc.String, nil } // PurgeExpiredEnrollmentTokens deletes long-expired token rows. Tokens // retained for ~24h after expiry so audit traces still resolve them. func (s *Store) PurgeExpiredEnrollmentTokens(ctx context.Context) (int64, error) { cutoff := time.Now().Add(-24 * time.Hour).UTC().Format(time.RFC3339Nano) res, err := s.db.ExecContext(ctx, `DELETE FROM enrollment_tokens WHERE expires_at <= ?`, cutoff) if err != nil { return 0, fmt.Errorf("store: purge enrollment tokens: %w", err) } n, _ := res.RowsAffected() return n, nil }