241 lines
6.0 KiB
Go
241 lines
6.0 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestUserCRUD(t *testing.T) {
|
|
t.Parallel()
|
|
s := openTestStore(t)
|
|
ctx := context.Background()
|
|
|
|
now := time.Now().UTC()
|
|
u := User{
|
|
ID: "u1",
|
|
Username: "alice",
|
|
PasswordHash: "$argon2id$...",
|
|
Role: RoleAdmin,
|
|
CreatedAt: now,
|
|
}
|
|
if err := s.CreateUser(ctx, u); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
|
|
got, err := s.GetUserByUsername(ctx, "alice")
|
|
if err != nil {
|
|
t.Fatalf("get: %v", err)
|
|
}
|
|
if got.ID != "u1" || got.Role != RoleAdmin {
|
|
t.Errorf("unexpected user: %+v", got)
|
|
}
|
|
|
|
// Username uniqueness is enforced by the schema.
|
|
if err := s.CreateUser(ctx, u); err == nil {
|
|
t.Error("duplicate username should fail")
|
|
}
|
|
|
|
if _, err := s.GetUserByUsername(ctx, "bob"); !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("missing user: want ErrNotFound, got %v", err)
|
|
}
|
|
|
|
if err := s.MarkUserLogin(ctx, "u1", now); err != nil {
|
|
t.Fatalf("mark login: %v", err)
|
|
}
|
|
got, _ = s.GetUserByUsername(ctx, "alice")
|
|
if got.LastLoginAt == nil {
|
|
t.Error("last_login_at not updated")
|
|
}
|
|
}
|
|
|
|
func TestCountUsers(t *testing.T) {
|
|
t.Parallel()
|
|
s := openTestStore(t)
|
|
ctx := context.Background()
|
|
|
|
n, _ := s.CountUsers(ctx)
|
|
if n != 0 {
|
|
t.Errorf("fresh db: want 0, got %d", n)
|
|
}
|
|
_ = s.CreateUser(ctx, User{
|
|
ID: "u1", Username: "a", PasswordHash: "x",
|
|
Role: RoleAdmin, CreatedAt: time.Now(),
|
|
})
|
|
n, _ = s.CountUsers(ctx)
|
|
if n != 1 {
|
|
t.Errorf("after insert: want 1, got %d", n)
|
|
}
|
|
}
|
|
|
|
func TestSessionLifecycle(t *testing.T) {
|
|
t.Parallel()
|
|
s := openTestStore(t)
|
|
ctx := context.Background()
|
|
|
|
// Need a user for FK.
|
|
_ = s.CreateUser(ctx, User{
|
|
ID: "u1", Username: "alice", PasswordHash: "x",
|
|
Role: RoleAdmin, CreatedAt: time.Now(),
|
|
})
|
|
|
|
now := time.Now().UTC()
|
|
sess := Session{
|
|
UserID: "u1",
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(time.Hour),
|
|
IP: "10.0.0.1",
|
|
UA: "test/1.0",
|
|
}
|
|
hash := "deadbeef" + "00000000000000000000000000000000000000000000000000000000"
|
|
if err := s.CreateSession(ctx, sess, hash); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
|
|
got, err := s.LookupSession(ctx, hash)
|
|
if err != nil {
|
|
t.Fatalf("lookup: %v", err)
|
|
}
|
|
if got.UserID != "u1" {
|
|
t.Errorf("user mismatch: %s", got.UserID)
|
|
}
|
|
|
|
// Expired sessions should not resolve.
|
|
expiredHash := "expired-hash"
|
|
expired := Session{
|
|
UserID: "u1",
|
|
CreatedAt: now.Add(-2 * time.Hour),
|
|
ExpiresAt: now.Add(-time.Hour),
|
|
}
|
|
if err := s.CreateSession(ctx, expired, expiredHash); err != nil {
|
|
t.Fatalf("create expired: %v", err)
|
|
}
|
|
if _, err := s.LookupSession(ctx, expiredHash); !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("expired session should look like ErrNotFound, got %v", err)
|
|
}
|
|
|
|
if err := s.DeleteSession(ctx, hash); err != nil {
|
|
t.Fatalf("delete: %v", err)
|
|
}
|
|
if _, err := s.LookupSession(ctx, hash); !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("deleted session: want ErrNotFound, got %v", err)
|
|
}
|
|
|
|
n, err := s.PurgeExpiredSessions(ctx)
|
|
if err != nil {
|
|
t.Fatalf("purge: %v", err)
|
|
}
|
|
if n != 1 {
|
|
t.Errorf("purge should remove the 1 expired row, got %d", n)
|
|
}
|
|
}
|
|
|
|
func TestCreateUserLowercasesUsername(t *testing.T) {
|
|
t.Parallel()
|
|
s := openTestStore(t)
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
if err := s.CreateUser(ctx, User{
|
|
ID: "u1", Username: "Alice",
|
|
PasswordHash: "x", Role: RoleAdmin, CreatedAt: now,
|
|
}); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
got, err := s.GetUserByUsername(ctx, "alice")
|
|
if err != nil {
|
|
t.Fatalf("get lower: %v", err)
|
|
}
|
|
if got.Username != "alice" {
|
|
t.Errorf("stored username: got %q want %q", got.Username, "alice")
|
|
}
|
|
got, err = s.GetUserByUsername(ctx, "ALICE")
|
|
if err != nil {
|
|
t.Fatalf("get upper: %v", err)
|
|
}
|
|
if got.ID != "u1" {
|
|
t.Errorf("upper-case lookup missed: got %+v", got)
|
|
}
|
|
if err := s.CreateUser(ctx, User{
|
|
ID: "u2", Username: "AlIcE",
|
|
PasswordHash: "x", Role: RoleAdmin, CreatedAt: now,
|
|
}); err == nil {
|
|
t.Error("duplicate (different case) should fail")
|
|
}
|
|
}
|
|
|
|
func TestGetUserByOIDCSubject(t *testing.T) {
|
|
t.Parallel()
|
|
s := openTestStore(t)
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
sub := "sub-abc-123"
|
|
|
|
if err := s.CreateUser(ctx, User{
|
|
ID: "u1", Username: "alice", PasswordHash: "",
|
|
Role: RoleAdmin, CreatedAt: now,
|
|
AuthSource: "oidc", OIDCSubject: &sub,
|
|
}); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
got, err := s.GetUserByOIDCSubject(ctx, sub)
|
|
if err != nil {
|
|
t.Fatalf("get by sub: %v", err)
|
|
}
|
|
if got.ID != "u1" || got.AuthSource != "oidc" {
|
|
t.Errorf("unexpected: %+v", got)
|
|
}
|
|
if _, err := s.GetUserByOIDCSubject(ctx, "nope"); !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("missing sub: want ErrNotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestSetUserOIDCSubject(t *testing.T) {
|
|
t.Parallel()
|
|
s := openTestStore(t)
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
if err := s.CreateUser(ctx, User{
|
|
ID: "u1", Username: "alice", PasswordHash: "x",
|
|
Role: RoleAdmin, CreatedAt: now,
|
|
}); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
sub := "sub-456"
|
|
if err := s.SetUserOIDCSubject(ctx, "u1", "oidc", sub); err != nil {
|
|
t.Fatalf("set: %v", err)
|
|
}
|
|
got, _ := s.GetUserByID(ctx, "u1")
|
|
if got.AuthSource != "oidc" || got.OIDCSubject == nil || *got.OIDCSubject != sub {
|
|
t.Errorf("after set: %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestEnrollmentTokenSingleUse(t *testing.T) {
|
|
t.Parallel()
|
|
s := openTestStore(t)
|
|
ctx := context.Background()
|
|
|
|
hash := "tok-hash"
|
|
if err := s.CreateEnrollmentToken(ctx, hash, time.Hour, "", ""); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
|
|
// Need a host for FK.
|
|
_, err := s.DB().Exec(`INSERT INTO hosts (id, name, os, arch, enrolled_at) VALUES (?,?,?,?,?)`,
|
|
"h1", "host1", "linux", "amd64", time.Now().UTC().Format(time.RFC3339Nano))
|
|
if err != nil {
|
|
t.Fatalf("insert host: %v", err)
|
|
}
|
|
|
|
if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); err != nil {
|
|
t.Fatalf("consume: %v", err)
|
|
}
|
|
// Second consume must fail — the whole point of one-time tokens.
|
|
if err := s.ConsumeEnrollmentToken(ctx, hash, "h1"); !errors.Is(err, ErrNotFound) {
|
|
t.Errorf("re-consume: want ErrNotFound, got %v", err)
|
|
}
|
|
}
|