104 lines
3.2 KiB
Go
104 lines
3.2 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
)
|
|
|
|
// CreateSession persists a session row. The token is hashed before
|
|
// insert; the raw token is what the caller hands to the user (cookie).
|
|
func (s *Store) CreateSession(ctx context.Context, sess Session, tokenHash string) error {
|
|
_, err := s.db.ExecContext(ctx,
|
|
`INSERT INTO sessions (id, user_id, created_at, expires_at, ip, ua)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
tokenHash,
|
|
sess.UserID,
|
|
sess.CreatedAt.UTC().Format(time.RFC3339Nano),
|
|
sess.ExpiresAt.UTC().Format(time.RFC3339Nano),
|
|
sess.IP, sess.UA)
|
|
if err != nil {
|
|
return fmt.Errorf("store: create session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LookupSession resolves a token hash to a session row, returning
|
|
// ErrNotFound if the hash is unknown OR the session has expired.
|
|
// We collapse "no row" and "expired" to the same error so the caller
|
|
// can't tell them apart in error messages — that prevents enumeration
|
|
// of valid token hashes.
|
|
func (s *Store) LookupSession(ctx context.Context, tokenHash string) (*Session, error) {
|
|
row := s.db.QueryRowContext(ctx,
|
|
`SELECT id, user_id, created_at, expires_at, ip, ua
|
|
FROM sessions
|
|
WHERE id = ? AND expires_at > ?`,
|
|
tokenHash, time.Now().UTC().Format(time.RFC3339Nano))
|
|
|
|
var sess Session
|
|
var created, expires string
|
|
var ip, ua sql.NullString
|
|
if err := row.Scan(&sess.ID, &sess.UserID, &created, &expires, &ip, &ua); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, fmt.Errorf("store: lookup session: %w", err)
|
|
}
|
|
t, err := time.Parse(time.RFC3339Nano, created)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store: parse created_at: %w", err)
|
|
}
|
|
sess.CreatedAt = t
|
|
t, err = time.Parse(time.RFC3339Nano, expires)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("store: parse expires_at: %w", err)
|
|
}
|
|
sess.ExpiresAt = t
|
|
if ip.Valid {
|
|
sess.IP = ip.String
|
|
}
|
|
if ua.Valid {
|
|
sess.UA = ua.String
|
|
}
|
|
return &sess, nil
|
|
}
|
|
|
|
// DeleteSession removes a session row by token hash. Used on logout.
|
|
func (s *Store) DeleteSession(ctx context.Context, tokenHash string) error {
|
|
_, err := s.db.ExecContext(ctx, `DELETE FROM sessions WHERE id = ?`, tokenHash)
|
|
if err != nil {
|
|
return fmt.Errorf("store: delete session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PurgeExpiredSessions deletes session rows past their expires_at.
|
|
// Run periodically from a background goroutine.
|
|
func (s *Store) PurgeExpiredSessions(ctx context.Context) (int64, error) {
|
|
res, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM sessions WHERE expires_at <= ?`,
|
|
time.Now().UTC().Format(time.RFC3339Nano))
|
|
if err != nil {
|
|
return 0, fmt.Errorf("store: purge sessions: %w", err)
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
return n, nil
|
|
}
|
|
|
|
// DeleteSessionsByUserID removes every session row owned by the
|
|
// user. Returns count for caller logging. Used by:
|
|
// - admin "Force logout" button
|
|
// - admin Disable user (sessions outlive the disable flag, so we
|
|
// also clear them so the user gets bounced immediately)
|
|
func (s *Store) DeleteSessionsByUserID(ctx context.Context, userID string) (int64, error) {
|
|
res, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM sessions WHERE user_id = ?`, userID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("store: delete sessions by user: %w", err)
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
return n, nil
|
|
}
|