feat: implement project access enforcement and management API
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
- Fix no-op RequireProjectAccess middleware to enforce project_ids
- Apply project access middleware to all project-scoped routes
- Filter GET /projects by allowed project IDs for restricted keys
- Add GET /me endpoint with key identity, scopes, and project access info
- Add PATCH /keys/{id} for partial key updates (name, scopes, project_ids, allowed_ips, expires_in)
- Add GET/POST/DELETE /projects/{id}/access for project-centric access management
- Auto-grant creating key access when using POST /project/create-and-build
- Accept grant_to_key_ids in create-and-build to grant multiple keys on project creation
- Move newProvisionerWithDeps test helper from production code to test file
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
0f25bd8dbe
commit
4f01015132
@ -88,11 +88,10 @@ type InfraConfig struct {
|
||||
GCSCredentialsPath string // Path to service account JSON (empty = ADC)
|
||||
GCSLocation string // Bucket location (default: "US")
|
||||
|
||||
// Notify provisioner (for project email delivery)
|
||||
// Notify provisioner (for per-project email delivery)
|
||||
NotifyURL string // e.g., "https://notify.orchard9.ai"
|
||||
NotifyAdminKey string // notify_admin_... admin API key
|
||||
NotifyHost string // shared host (e.g., "threesix.ai")
|
||||
NotifyFrom string // from-address (e.g., "noreply@threesix.ai")
|
||||
ResendAPIKey string // re_... Resend API key for per-project domain provisioning
|
||||
}
|
||||
|
||||
func loadConfig() Config {
|
||||
@ -155,6 +154,7 @@ func loadInfraConfig(ctx context.Context, store port.CredentialStore, cfg Config
|
||||
domain.CredKeyRegistryURL,
|
||||
domain.CredKeyNotifyURL,
|
||||
domain.CredKeyNotifyAdminKey,
|
||||
domain.CredKeyResendAPIKey,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Warn("failed to load credentials from store, using env vars", "error", err)
|
||||
@ -201,8 +201,7 @@ func loadInfraConfig(ctx context.Context, store port.CredentialStore, cfg Config
|
||||
// Notify provisioner (credential store with env fallback)
|
||||
NotifyURL: getOrFallback(domain.CredKeyNotifyURL, os.Getenv("NOTIFY_URL")),
|
||||
NotifyAdminKey: getOrFallback(domain.CredKeyNotifyAdminKey, os.Getenv("NOTIFY_ADMIN_KEY")),
|
||||
NotifyHost: envutil.GetEnv("NOTIFY_HOST", "threesix.ai"),
|
||||
NotifyFrom: envutil.GetEnv("NOTIFY_FROM", "noreply@threesix.ai"),
|
||||
ResendAPIKey: getOrFallback(domain.CredKeyResendAPIKey, os.Getenv("RESEND_API_KEY")),
|
||||
}
|
||||
|
||||
// Log which credentials were loaded from store vs env
|
||||
|
||||
@ -241,20 +241,20 @@ func main() {
|
||||
}
|
||||
defer closeProvisioner(storageProvisioner, "gcs", logger)
|
||||
|
||||
// Initialize notify provisioner (optional - for project email delivery)
|
||||
// Initialize notify provisioner (optional - for per-project email delivery)
|
||||
var notifyProvisioner port.NotifyProvisioner
|
||||
if infraCfg.NotifyURL != "" && infraCfg.NotifyAdminKey != "" {
|
||||
np := notifyadapter.NewProvisioner(notifyadapter.Config{
|
||||
BaseURL: infraCfg.NotifyURL,
|
||||
AdminKey: infraCfg.NotifyAdminKey,
|
||||
Host: infraCfg.NotifyHost,
|
||||
From: infraCfg.NotifyFrom,
|
||||
}, logger)
|
||||
ResendAPIKey: infraCfg.ResendAPIKey,
|
||||
BaseDomain: infraCfg.DefaultDomain,
|
||||
}, dnsClient, logger)
|
||||
if err := np.TestConnection(context.Background()); err != nil {
|
||||
logger.Warn("notify provisioner connection test failed, disabling", "error", err)
|
||||
} else {
|
||||
notifyProvisioner = np
|
||||
logger.Info("notify provisioner initialized", "url", infraCfg.NotifyURL, "host", infraCfg.NotifyHost)
|
||||
logger.Info("notify provisioner initialized", "url", infraCfg.NotifyURL)
|
||||
}
|
||||
}
|
||||
|
||||
@ -453,6 +453,8 @@ func main() {
|
||||
|
||||
// Initialize handlers
|
||||
projectsHandler := handlers.NewProjectsHandlerWithService(projectService)
|
||||
meHandler := handlers.NewMeHandler(authService, projectService)
|
||||
projectAccessHandler := handlers.NewProjectAccessHandler(authService)
|
||||
keysHandler := handlers.NewKeysHandler(authService)
|
||||
claudeConfigHandler := handlers.NewClaudeConfigHandlerWithService(projectService, projectRepo, k8sExecutor)
|
||||
auditHandler := handlers.NewAuditHandler(auditLogger)
|
||||
@ -574,7 +576,8 @@ func main() {
|
||||
// Initialize worker pool handlers
|
||||
workersHandler := handlers.NewWorkersHandler(workerService).WithWorkService(workService).WithWorkQueue(workQueueRepo)
|
||||
buildsHandler := handlers.NewBuildsHandler(buildService)
|
||||
createAndBuildHandler := handlers.NewCreateAndBuildHandler(projectInfraService, buildService)
|
||||
createAndBuildHandler := handlers.NewCreateAndBuildHandler(projectInfraService, buildService).
|
||||
WithAuthService(authService)
|
||||
|
||||
sdlcHandler := handlers.NewSDLCHandler(sdlcService)
|
||||
sdlcOrchestratorHandler := handlers.NewSDLCOrchestratorHandler(sdlcOrchestrator)
|
||||
@ -660,6 +663,8 @@ func main() {
|
||||
|
||||
// Register routes
|
||||
projectsHandler.Mount(app.Router())
|
||||
meHandler.Mount(app.Router())
|
||||
projectAccessHandler.Mount(app.Router())
|
||||
keysHandler.Mount(app.Router())
|
||||
claudeConfigHandler.Mount(app.Router())
|
||||
auditHandler.Mount(app.Router())
|
||||
|
||||
@ -146,6 +146,16 @@ spec:
|
||||
secretKeyRef:
|
||||
name: redis-credentials
|
||||
key: REDIS_PASSWORD
|
||||
# Citadel logging integration (environment provisioning + audit shipping)
|
||||
- name: CITADEL_URL
|
||||
value: "http://citadel-community.citadel.svc.cluster.local"
|
||||
- name: CITADEL_API_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: rdev-credentials
|
||||
key: CITADEL_API_KEY
|
||||
- name: CITADEL_PLATFORM_TENANT_ID
|
||||
value: "bf874fbf-6150-4aa9-b7bc-db531791bde1"
|
||||
# OpenTelemetry
|
||||
- name: OTEL_EXPORTER_OTLP_ENDPOINT
|
||||
value: "otel-collector.observability.svc.cluster.local:4317"
|
||||
@ -256,7 +266,7 @@ roleRef:
|
||||
name: rdev-api-deployer
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
---
|
||||
# Ingress for rdev-api
|
||||
# Ingress for rdev-api (masq-ops subdomain, DNS-01 via Cloudflare)
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
@ -283,3 +293,32 @@ spec:
|
||||
- hosts:
|
||||
- rdev.masq-ops.orchard9.ai
|
||||
secretName: rdev-api-tls
|
||||
---
|
||||
# Ingress for rdev-api (orchard9.ai vanity domain, HTTP-01 via letsencrypt-prod-http01)
|
||||
# orchard9.ai is on GoDaddy; Cloudflare token only covers threesix.ai
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: rdev-api-orchard9
|
||||
namespace: rdev
|
||||
annotations:
|
||||
cert-manager.io/cluster-issuer: letsencrypt-prod-http01
|
||||
traefik.ingress.kubernetes.io/router.entrypoints: websecure
|
||||
traefik.ingress.kubernetes.io/router.tls: "true"
|
||||
spec:
|
||||
ingressClassName: traefik
|
||||
rules:
|
||||
- host: rdev.orchard9.ai
|
||||
http:
|
||||
paths:
|
||||
- backend:
|
||||
service:
|
||||
name: rdev-api
|
||||
port:
|
||||
number: 8080
|
||||
path: /
|
||||
pathType: Prefix
|
||||
tls:
|
||||
- hosts:
|
||||
- rdev.orchard9.ai
|
||||
secretName: rdev-orchard9-tls
|
||||
|
||||
@ -122,6 +122,55 @@ func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyI
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update applies a partial update to an API key.
|
||||
func (r *APIKeyRepository) Update(ctx context.Context, id domain.APIKeyID, update port.APIKeyUpdate) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
key, ok := r.keys[id]
|
||||
if !ok || key.RevokedAt != nil {
|
||||
return domain.ErrKeyNotFound
|
||||
}
|
||||
|
||||
if update.Name != nil {
|
||||
key.Name = *update.Name
|
||||
}
|
||||
if update.Scopes != nil {
|
||||
key.Scopes = update.Scopes
|
||||
}
|
||||
if update.ProjectIDs != nil {
|
||||
key.ProjectIDs = *update.ProjectIDs
|
||||
}
|
||||
if update.AllowedIPs != nil {
|
||||
key.AllowedIPs = *update.AllowedIPs
|
||||
}
|
||||
if update.ExpiresAt != nil {
|
||||
key.ExpiresAt = *update.ExpiresAt
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByProjectID returns all active keys that have the given project ID in their project_ids.
|
||||
func (r *APIKeyRepository) ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*domain.APIKey, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var result []*domain.APIKey
|
||||
for _, key := range r.keys {
|
||||
if key.RevokedAt != nil {
|
||||
continue
|
||||
}
|
||||
for _, pid := range key.ProjectIDs {
|
||||
if pid == projectID {
|
||||
result = append(result, key)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// itoa converts an integer to a string.
|
||||
func itoa(i int) string {
|
||||
if i == 0 {
|
||||
|
||||
@ -12,6 +12,20 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// notifyAdminAPI is the interface Provisioner uses to call the notify admin API.
|
||||
// Extracted for testability.
|
||||
type notifyAdminAPI interface {
|
||||
createHost(ctx context.Context, hostSlug, strategy string) error
|
||||
deleteHost(ctx context.Context, hostSlug string) error
|
||||
createProvider(ctx context.Context, hostSlug, provider string, config map[string]string, priority, retryAttempts, retryBackoffMs int) error
|
||||
createFromAddress(ctx context.Context, hostSlug, email, displayName string) error
|
||||
createAccount(ctx context.Context, name string) (*accountResponse, error)
|
||||
createSendKey(ctx context.Context, accountID, name string) (*apiKeyResponse, error)
|
||||
grantHostAccess(ctx context.Context, hostSlug, accountID string) error
|
||||
deleteAccount(ctx context.Context, accountID string) error
|
||||
listAccounts(ctx context.Context) ([]accountResponse, error)
|
||||
}
|
||||
|
||||
// adminClient calls the notify admin API to manage accounts and keys.
|
||||
type adminClient struct {
|
||||
baseURL string
|
||||
@ -87,6 +101,51 @@ func (c *adminClient) createSendKey(ctx context.Context, accountID, name string)
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// createHost creates a new notify sending host with the given slug and sending strategy.
|
||||
func (c *adminClient) createHost(ctx context.Context, hostSlug, strategy string) error {
|
||||
payload := map[string]string{"host": hostSlug, "strategy": strategy}
|
||||
_, err := c.doRequest(ctx, http.MethodPost, "/admin/hosts", payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create host %s: %w", hostSlug, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteHost removes a notify host by its slug.
|
||||
func (c *adminClient) deleteHost(ctx context.Context, hostSlug string) error {
|
||||
_, err := c.doRequest(ctx, http.MethodDelete, "/admin/hosts/"+hostSlug, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete host %s: %w", hostSlug, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProvider adds a sending provider to an existing host.
|
||||
func (c *adminClient) createProvider(ctx context.Context, hostSlug, provider string, config map[string]string, priority, retryAttempts, retryBackoffMs int) error {
|
||||
payload := map[string]any{
|
||||
"provider": provider,
|
||||
"config": config,
|
||||
"priority": priority,
|
||||
"retry_attempts": retryAttempts,
|
||||
"retry_backoff_ms": retryBackoffMs,
|
||||
}
|
||||
_, err := c.doRequest(ctx, http.MethodPost, "/admin/hosts/"+hostSlug+"/providers", payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create provider %s on host %s: %w", provider, hostSlug, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// createFromAddress registers a from-address on a host.
|
||||
func (c *adminClient) createFromAddress(ctx context.Context, hostSlug, email, displayName string) error {
|
||||
payload := map[string]string{"email": email, "display_name": displayName}
|
||||
_, err := c.doRequest(ctx, http.MethodPost, "/admin/hosts/"+hostSlug+"/from-addresses", payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create from-address %s on host %s: %w", email, hostSlug, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// grantHostAccess grants the given account access to send from the specified host slug.
|
||||
func (c *adminClient) grantHostAccess(ctx context.Context, hostSlug, accountID string) error {
|
||||
payload := map[string]string{"account_id": accountID}
|
||||
|
||||
@ -7,15 +7,18 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// Provisioner implements port.NotifyProvisioner using the notify admin API.
|
||||
// Each project gets an isolated notify account and send key scoped to the
|
||||
// shared sending host (e.g., "threesix.ai").
|
||||
// Each project gets an isolated sending host (mail.{slug}.{baseDomain}),
|
||||
// a Resend domain with DKIM/SPF DNS records, and a dedicated send key.
|
||||
type Provisioner struct {
|
||||
client *adminClient
|
||||
host string // shared sending host slug (e.g., "threesix.ai")
|
||||
from string // from-address (e.g., "noreply@threesix.ai")
|
||||
client notifyAdminAPI
|
||||
resend resendAPI // nil when ResendAPIKey not configured
|
||||
resendAPIKey string // passed to createProvider; kept separate from resend for interface compatibility
|
||||
dns port.DNSProvider // nil when Cloudflare not configured
|
||||
baseDomain string // e.g., "threesix.ai"
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
@ -23,94 +26,244 @@ type Provisioner struct {
|
||||
type Config struct {
|
||||
BaseURL string // Required: notify service URL (e.g., "https://notify.orchard9.ai")
|
||||
AdminKey string // Required: admin API key (notify_admin_...)
|
||||
Host string // Shared host slug for all projects (e.g., "threesix.ai")
|
||||
From string // Default from-address (e.g., "noreply@threesix.ai")
|
||||
ResendAPIKey string // Optional: Resend API key for per-project domain provisioning
|
||||
BaseDomain string // Base domain for per-project hosts (default: "threesix.ai")
|
||||
}
|
||||
|
||||
// NewProvisioner creates a new notify provisioner.
|
||||
func NewProvisioner(cfg Config, logger *slog.Logger) *Provisioner {
|
||||
host := cfg.Host
|
||||
if host == "" {
|
||||
host = "threesix.ai"
|
||||
func NewProvisioner(cfg Config, dns port.DNSProvider, logger *slog.Logger) *Provisioner {
|
||||
baseDomain := cfg.BaseDomain
|
||||
if baseDomain == "" {
|
||||
baseDomain = "threesix.ai"
|
||||
}
|
||||
from := cfg.From
|
||||
if from == "" {
|
||||
from = "noreply@threesix.ai"
|
||||
}
|
||||
return &Provisioner{
|
||||
p := &Provisioner{
|
||||
client: newAdminClient(cfg.BaseURL, cfg.AdminKey),
|
||||
host: host,
|
||||
from: from,
|
||||
dns: dns,
|
||||
baseDomain: baseDomain,
|
||||
logger: logger,
|
||||
}
|
||||
if cfg.ResendAPIKey != "" {
|
||||
p.resend = newResendClient(cfg.ResendAPIKey)
|
||||
p.resendAPIKey = cfg.ResendAPIKey
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// CreateProjectNotify provisions a notify account and send key for the project.
|
||||
// CreateProjectNotify provisions a per-project notify host, Resend domain, DNS records,
|
||||
// and notify account with send key.
|
||||
//
|
||||
// Steps:
|
||||
// 1. Create account named "project-{projectID}"
|
||||
// 2. Create send API key via POST /admin/api-keys
|
||||
// 3. Grant account access to the shared host
|
||||
func (p *Provisioner) CreateProjectNotify(ctx context.Context, projectID string) (*domain.NotifyCredentials, error) {
|
||||
// 1. Create notify host mail.{slug}.{baseDomain}
|
||||
// 2. Add Resend provider to the host (skipped if ResendAPIKey not configured)
|
||||
// 3. Register from-address noreply@mail.{slug}.{baseDomain}
|
||||
// 4. Create notify account "project-{projectID}"
|
||||
// 5. Create send key for the account
|
||||
// 6. Grant the account access to the host (non-fatal)
|
||||
// 7. Create Resend domain (non-fatal — skipped if ResendAPIKey not configured)
|
||||
// 8. Add DNS records via Cloudflare (non-fatal — skipped if DNS not configured)
|
||||
// 9. Fire-and-forget async domain verification
|
||||
func (p *Provisioner) CreateProjectNotify(ctx context.Context, projectID, slug string) (*domain.NotifyCredentials, error) {
|
||||
host := "mail." + slug + "." + p.baseDomain
|
||||
from := "noreply@" + host
|
||||
accountName := "project-" + projectID
|
||||
|
||||
// 1. Create account
|
||||
// 1. Create notify host
|
||||
if err := p.client.createHost(ctx, host, "failover"); err != nil {
|
||||
return nil, fmt.Errorf("notify: create host %s for project %s: %w", host, projectID, err)
|
||||
}
|
||||
|
||||
// 2. Add Resend provider to the host (only when Resend is configured)
|
||||
if p.resend != nil {
|
||||
if err := p.client.createProvider(ctx, host, "resend", map[string]string{"api_key": p.resendAPIKey}, 1, 3, 1000); err != nil {
|
||||
p.bestEffortDeleteHost(ctx, host, projectID)
|
||||
return nil, fmt.Errorf("notify: create provider on host %s for project %s: %w", host, projectID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Register from-address
|
||||
if err := p.client.createFromAddress(ctx, host, from, slug); err != nil {
|
||||
p.bestEffortDeleteHost(ctx, host, projectID)
|
||||
return nil, fmt.Errorf("notify: create from-address %s for project %s: %w", from, projectID, err)
|
||||
}
|
||||
|
||||
// 4. Create account
|
||||
acct, err := p.client.createAccount(ctx, accountName)
|
||||
if err != nil {
|
||||
p.bestEffortDeleteHost(ctx, host, projectID)
|
||||
return nil, fmt.Errorf("notify: create account for project %s: %w", projectID, err)
|
||||
}
|
||||
|
||||
// 2. Create send key (plaintext key only returned here)
|
||||
// 5. Create send key
|
||||
key, err := p.client.createSendKey(ctx, acct.ID, accountName+"-send")
|
||||
if err != nil {
|
||||
// Best-effort cleanup
|
||||
if delErr := p.client.deleteAccount(ctx, acct.ID); delErr != nil {
|
||||
p.logger.Warn("failed to clean up notify account after key creation failure",
|
||||
"account_id", acct.ID,
|
||||
"project_id", projectID,
|
||||
"error", delErr,
|
||||
)
|
||||
}
|
||||
p.bestEffortDeleteAccount(ctx, acct.ID, projectID)
|
||||
p.bestEffortDeleteHost(ctx, host, projectID)
|
||||
return nil, fmt.Errorf("notify: create send key for project %s: %w", projectID, err)
|
||||
}
|
||||
|
||||
// 3. Grant host access
|
||||
if err := p.client.grantHostAccess(ctx, p.host, acct.ID); err != nil {
|
||||
// 6. Grant host access (non-fatal — log warn and continue)
|
||||
if err := p.client.grantHostAccess(ctx, host, acct.ID); err != nil {
|
||||
p.logger.Warn("failed to grant notify host access",
|
||||
"host", p.host,
|
||||
"host", host,
|
||||
"account_id", acct.ID,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
// 7. Create Resend domain (non-fatal — project still usable, email won't send until fixed)
|
||||
var resendDomainID string
|
||||
var dnsRecords []resendDNSRecord
|
||||
if p.resend != nil {
|
||||
var resendErr error
|
||||
resendDomainID, dnsRecords, resendErr = p.resend.createDomain(ctx, host, "us-east-1")
|
||||
if resendErr != nil {
|
||||
p.logger.Warn("failed to create resend domain — email delivery will not work until resolved",
|
||||
"host", host,
|
||||
"project_id", projectID,
|
||||
"error", resendErr,
|
||||
)
|
||||
} else {
|
||||
p.logger.Info("resend domain created", "host", host, "domain_id", resendDomainID)
|
||||
}
|
||||
}
|
||||
|
||||
// 8. Add DNS records for DKIM/SPF (non-fatal).
|
||||
// Resend returns record names relative to the registered domain; build FQDNs for Cloudflare.
|
||||
// Cloudflare's normalizeName handles FQDNs ending in the zone name correctly.
|
||||
if p.dns != nil && len(dnsRecords) > 0 {
|
||||
for _, rec := range dnsRecords {
|
||||
fqdn := rec.Name + "." + host
|
||||
dnsRec := domain.DNSRecord{
|
||||
Type: rec.Record,
|
||||
Name: fqdn,
|
||||
Content: rec.Value,
|
||||
TTL: 1,
|
||||
}
|
||||
if _, upsertErr := p.dns.UpsertRecord(ctx, dnsRec); upsertErr != nil {
|
||||
p.logger.Warn("failed to upsert notify DNS record",
|
||||
"name", fqdn,
|
||||
"record", rec.Record,
|
||||
"project_id", projectID,
|
||||
"error", upsertErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 9. Fire-and-forget async domain verification
|
||||
if p.resend != nil && resendDomainID != "" {
|
||||
go func() {
|
||||
verifyCtx := context.WithoutCancel(ctx)
|
||||
if err := p.resend.verifyDomain(verifyCtx, resendDomainID); err != nil {
|
||||
p.logger.Warn("async resend domain verification failed",
|
||||
"domain_id", resendDomainID,
|
||||
"host", host,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
p.logger.Info("notify provisioned",
|
||||
"project_id", projectID,
|
||||
"host", host,
|
||||
"resend_domain_id", resendDomainID,
|
||||
)
|
||||
|
||||
return &domain.NotifyCredentials{
|
||||
ProjectID: projectID,
|
||||
AccountID: acct.ID,
|
||||
APIKey: key.Key,
|
||||
Host: p.host,
|
||||
From: p.from,
|
||||
Host: host,
|
||||
From: from,
|
||||
ResendDomainID: resendDomainID,
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteProjectNotify removes the notify account for the project.
|
||||
func (p *Provisioner) DeleteProjectNotify(ctx context.Context, projectID string) error {
|
||||
// DeleteProjectNotify removes all notify resources for a project.
|
||||
// Failures are logged as warnings — cleanup continues regardless.
|
||||
func (p *Provisioner) DeleteProjectNotify(ctx context.Context, projectID, slug, resendDomainID string) error {
|
||||
host := "mail." + slug + "." + p.baseDomain
|
||||
|
||||
// 1. Delete notify account (cascades keys + host grants)
|
||||
acct, err := p.findAccountByProject(ctx, projectID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("notify: find account for project %s: %w", projectID, err)
|
||||
p.logger.Warn("failed to find notify account during deletion",
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
} else if acct != nil {
|
||||
if err := p.client.deleteAccount(ctx, acct.ID); err != nil {
|
||||
p.logger.Warn("failed to delete notify account",
|
||||
"account_id", acct.ID,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
if acct == nil {
|
||||
return nil // Already deleted or never provisioned
|
||||
}
|
||||
|
||||
if err := p.client.deleteAccount(ctx, acct.ID); err != nil {
|
||||
return fmt.Errorf("notify: delete account %s for project %s: %w", acct.ID, projectID, err)
|
||||
// 2. Delete notify host
|
||||
if err := p.client.deleteHost(ctx, host); err != nil {
|
||||
p.logger.Warn("failed to delete notify host",
|
||||
"host", host,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
|
||||
// 3. Delete Resend domain
|
||||
if p.resend != nil && resendDomainID != "" {
|
||||
if err := p.resend.deleteDomain(ctx, resendDomainID); err != nil {
|
||||
p.logger.Warn("failed to delete resend domain",
|
||||
"domain_id", resendDomainID,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Delete Cloudflare DNS records for DKIM/SPF.
|
||||
// Names follow Resend's standard format:
|
||||
// DKIM: resend._domainkey.{host}
|
||||
// SPF MX: send.{host}
|
||||
// SPF TXT: send.{host}
|
||||
// If Resend changes their record naming, manual cleanup may be needed.
|
||||
if p.dns != nil {
|
||||
dkimName := "resend._domainkey." + host
|
||||
if err := p.dns.DeleteRecordByName(ctx, "TXT", dkimName); err != nil {
|
||||
p.logger.Warn("failed to delete DKIM DNS record",
|
||||
"name", dkimName,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
spfSendName := "send." + host
|
||||
if err := p.dns.DeleteRecordByName(ctx, "MX", spfSendName); err != nil {
|
||||
p.logger.Warn("failed to delete SPF MX DNS record",
|
||||
"name", spfSendName,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
if err := p.dns.DeleteRecordByName(ctx, "TXT", spfSendName); err != nil {
|
||||
p.logger.Warn("failed to delete SPF TXT DNS record",
|
||||
"name", spfSendName,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.Info("notify resources deleted", "project_id", projectID, "host", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProjectNotify returns notify credentials for the project, or nil if not provisioned.
|
||||
// Note: APIKey cannot be retrieved after creation — returns empty string.
|
||||
// Note: Only AccountID and CreatedAt are populated — APIKey, Host, and From are not
|
||||
// recoverable after provisioning. Use this method solely to check whether provisioning
|
||||
// has already occurred (non-nil return = already provisioned).
|
||||
func (p *Provisioner) GetProjectNotify(ctx context.Context, projectID string) (*domain.NotifyCredentials, error) {
|
||||
acct, err := p.findAccountByProject(ctx, projectID)
|
||||
if err != nil {
|
||||
@ -123,8 +276,6 @@ func (p *Provisioner) GetProjectNotify(ctx context.Context, projectID string) (*
|
||||
return &domain.NotifyCredentials{
|
||||
ProjectID: projectID,
|
||||
AccountID: acct.ID,
|
||||
Host: p.host,
|
||||
From: p.from,
|
||||
CreatedAt: acct.CreatedAt,
|
||||
}, nil
|
||||
}
|
||||
@ -153,3 +304,25 @@ func (p *Provisioner) findAccountByProject(ctx context.Context, projectID string
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// bestEffortDeleteHost deletes the notify host, logging on failure.
|
||||
func (p *Provisioner) bestEffortDeleteHost(ctx context.Context, host, projectID string) {
|
||||
if err := p.client.deleteHost(ctx, host); err != nil {
|
||||
p.logger.Warn("failed to clean up notify host after provisioning failure",
|
||||
"host", host,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// bestEffortDeleteAccount deletes the notify account, logging on failure.
|
||||
func (p *Provisioner) bestEffortDeleteAccount(ctx context.Context, accountID, projectID string) {
|
||||
if err := p.client.deleteAccount(ctx, accountID); err != nil {
|
||||
p.logger.Warn("failed to clean up notify account after provisioning failure",
|
||||
"account_id", accountID,
|
||||
"project_id", projectID,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
511
internal/adapter/notify/provisioner_test.go
Normal file
511
internal/adapter/notify/provisioner_test.go
Normal file
@ -0,0 +1,511 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// --- mock implementations ---
|
||||
|
||||
// mockAdminClient is a controllable implementation of notifyAdminAPI.
|
||||
type mockAdminClient struct {
|
||||
// accounts simulates the account registry
|
||||
accounts []accountResponse
|
||||
|
||||
// Configurable errors per operation
|
||||
createHostErr error
|
||||
deleteHostErr error
|
||||
createProviderErr error
|
||||
createFromAddressErr error
|
||||
createAccountErr error
|
||||
createSendKeyErr error
|
||||
grantHostAccessErr error
|
||||
deleteAccountErr error
|
||||
listAccountsErr error
|
||||
|
||||
// Call counters
|
||||
createHostCalls int
|
||||
deleteHostCalls int
|
||||
createProviderCalls int
|
||||
createFromAddressCalls int
|
||||
createAccountCalls int
|
||||
createSendKeyCalls int
|
||||
grantHostAccessCalls int
|
||||
deleteAccountCalls int
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) createHost(_ context.Context, _, _ string) error {
|
||||
m.createHostCalls++
|
||||
return m.createHostErr
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) deleteHost(_ context.Context, _ string) error {
|
||||
m.deleteHostCalls++
|
||||
return m.deleteHostErr
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) createProvider(_ context.Context, _, _ string, _ map[string]string, _, _, _ int) error {
|
||||
m.createProviderCalls++
|
||||
return m.createProviderErr
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) createFromAddress(_ context.Context, _, _, _ string) error {
|
||||
m.createFromAddressCalls++
|
||||
return m.createFromAddressErr
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) createAccount(_ context.Context, name string) (*accountResponse, error) {
|
||||
m.createAccountCalls++
|
||||
if m.createAccountErr != nil {
|
||||
return nil, m.createAccountErr
|
||||
}
|
||||
acct := &accountResponse{ID: "acct-" + name, Name: name, CreatedAt: time.Now()}
|
||||
m.accounts = append(m.accounts, *acct)
|
||||
return acct, nil
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) createSendKey(_ context.Context, accountID, name string) (*apiKeyResponse, error) {
|
||||
m.createSendKeyCalls++
|
||||
if m.createSendKeyErr != nil {
|
||||
return nil, m.createSendKeyErr
|
||||
}
|
||||
return &apiKeyResponse{
|
||||
ID: 1,
|
||||
Key: "notify_send_test_key",
|
||||
KeyType: "send",
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) grantHostAccess(_ context.Context, _, _ string) error {
|
||||
m.grantHostAccessCalls++
|
||||
return m.grantHostAccessErr
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) deleteAccount(_ context.Context, _ string) error {
|
||||
m.deleteAccountCalls++
|
||||
return m.deleteAccountErr
|
||||
}
|
||||
|
||||
func (m *mockAdminClient) listAccounts(_ context.Context) ([]accountResponse, error) {
|
||||
if m.listAccountsErr != nil {
|
||||
return nil, m.listAccountsErr
|
||||
}
|
||||
return m.accounts, nil
|
||||
}
|
||||
|
||||
// mockResendClient is a controllable implementation of resendAPI.
|
||||
type mockResendClient struct {
|
||||
createDomainErr error
|
||||
verifyDomainErr error
|
||||
deleteDomainErr error
|
||||
|
||||
createDomainCalls int
|
||||
verifyDomainCalls int
|
||||
deleteDomainCalls int
|
||||
|
||||
domainID string
|
||||
dnsRecords []resendDNSRecord
|
||||
}
|
||||
|
||||
func (m *mockResendClient) createDomain(_ context.Context, _, _ string) (string, []resendDNSRecord, error) {
|
||||
m.createDomainCalls++
|
||||
if m.createDomainErr != nil {
|
||||
return "", nil, m.createDomainErr
|
||||
}
|
||||
id := m.domainID
|
||||
if id == "" {
|
||||
id = "resend-domain-id-123"
|
||||
}
|
||||
return id, m.dnsRecords, nil
|
||||
}
|
||||
|
||||
func (m *mockResendClient) verifyDomain(_ context.Context, _ string) error {
|
||||
m.verifyDomainCalls++
|
||||
return m.verifyDomainErr
|
||||
}
|
||||
|
||||
func (m *mockResendClient) deleteDomain(_ context.Context, _ string) error {
|
||||
m.deleteDomainCalls++
|
||||
return m.deleteDomainErr
|
||||
}
|
||||
|
||||
// mockDNS is a controllable implementation of port.DNSProvider.
|
||||
type mockDNS struct {
|
||||
upsertErr error
|
||||
deleteByNameErr error
|
||||
upsertCalls []domain.DNSRecord
|
||||
deleteByNameCalls []struct{ recordType, name string }
|
||||
}
|
||||
|
||||
func (m *mockDNS) UpsertRecord(_ context.Context, record domain.DNSRecord) (*domain.DNSRecord, error) {
|
||||
m.upsertCalls = append(m.upsertCalls, record)
|
||||
if m.upsertErr != nil {
|
||||
return nil, m.upsertErr
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (m *mockDNS) DeleteRecordByName(_ context.Context, recordType, name string) error {
|
||||
m.deleteByNameCalls = append(m.deleteByNameCalls, struct{ recordType, name string }{recordType, name})
|
||||
return m.deleteByNameErr
|
||||
}
|
||||
|
||||
func (m *mockDNS) CreateRecord(_ context.Context, r domain.DNSRecord) (*domain.DNSRecord, error) {
|
||||
return &r, nil
|
||||
}
|
||||
func (m *mockDNS) UpdateRecord(_ context.Context, _ string, r domain.DNSRecord) (*domain.DNSRecord, error) {
|
||||
return &r, nil
|
||||
}
|
||||
func (m *mockDNS) DeleteRecord(_ context.Context, _ string) error { return nil }
|
||||
func (m *mockDNS) GetRecord(_ context.Context, _ string) (*domain.DNSRecord, error) { return nil, nil }
|
||||
func (m *mockDNS) ListRecords(_ context.Context, _ string) ([]*domain.DNSRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockDNS) FindRecord(_ context.Context, _, _ string) (*domain.DNSRecord, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
func testLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
}
|
||||
|
||||
// newProvisionerWithDeps creates a Provisioner with injected dependencies for testing.
|
||||
func newProvisionerWithDeps(client notifyAdminAPI, resend resendAPI, resendAPIKey string, dns port.DNSProvider, baseDomain string, logger *slog.Logger) *Provisioner {
|
||||
if baseDomain == "" {
|
||||
baseDomain = "threesix.ai"
|
||||
}
|
||||
return &Provisioner{
|
||||
client: client,
|
||||
resend: resend,
|
||||
resendAPIKey: resendAPIKey,
|
||||
dns: dns,
|
||||
baseDomain: baseDomain,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func newTestProvisioner(admin *mockAdminClient, resend *mockResendClient, dns *mockDNS) *Provisioner {
|
||||
var r resendAPI
|
||||
if resend != nil {
|
||||
r = resend
|
||||
}
|
||||
var d interface {
|
||||
UpsertRecord(context.Context, domain.DNSRecord) (*domain.DNSRecord, error)
|
||||
DeleteRecordByName(context.Context, string, string) error
|
||||
CreateRecord(context.Context, domain.DNSRecord) (*domain.DNSRecord, error)
|
||||
UpdateRecord(context.Context, string, domain.DNSRecord) (*domain.DNSRecord, error)
|
||||
DeleteRecord(context.Context, string) error
|
||||
GetRecord(context.Context, string) (*domain.DNSRecord, error)
|
||||
ListRecords(context.Context, string) ([]*domain.DNSRecord, error)
|
||||
FindRecord(context.Context, string, string) (*domain.DNSRecord, error)
|
||||
}
|
||||
if dns != nil {
|
||||
d = dns
|
||||
}
|
||||
return newProvisionerWithDeps(admin, r, "re_test_key", d, "test.example", testLogger())
|
||||
}
|
||||
|
||||
// --- tests ---
|
||||
|
||||
func TestCreateProjectNotify_Success(t *testing.T) {
|
||||
admin := &mockAdminClient{}
|
||||
resend := &mockResendClient{
|
||||
dnsRecords: []resendDNSRecord{
|
||||
{Record: "TXT", Name: "resend._domainkey", Value: "v=DKIM1; p=..."},
|
||||
{Record: "MX", Name: "send", Value: "feedback-smtp.us-east-1.amazonses.com", Priority: 10},
|
||||
},
|
||||
}
|
||||
dns := &mockDNS{}
|
||||
p := newTestProvisioner(admin, resend, dns)
|
||||
|
||||
creds, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if creds.Host != "mail.happy-fox.test.example" {
|
||||
t.Errorf("expected host mail.happy-fox.test.example, got %s", creds.Host)
|
||||
}
|
||||
if creds.From != "noreply@mail.happy-fox.test.example" {
|
||||
t.Errorf("expected from noreply@mail.happy-fox.test.example, got %s", creds.From)
|
||||
}
|
||||
if creds.APIKey != "notify_send_test_key" {
|
||||
t.Errorf("expected send key, got %s", creds.APIKey)
|
||||
}
|
||||
if creds.ResendDomainID != "resend-domain-id-123" {
|
||||
t.Errorf("expected resend domain id, got %s", creds.ResendDomainID)
|
||||
}
|
||||
if creds.ProjectID != "proj-123" {
|
||||
t.Errorf("expected project id proj-123, got %s", creds.ProjectID)
|
||||
}
|
||||
|
||||
// Verify all steps executed
|
||||
if admin.createHostCalls != 1 {
|
||||
t.Errorf("expected 1 createHost call, got %d", admin.createHostCalls)
|
||||
}
|
||||
if admin.createProviderCalls != 1 {
|
||||
t.Errorf("expected 1 createProvider call, got %d", admin.createProviderCalls)
|
||||
}
|
||||
if admin.createFromAddressCalls != 1 {
|
||||
t.Errorf("expected 1 createFromAddress call, got %d", admin.createFromAddressCalls)
|
||||
}
|
||||
if admin.createAccountCalls != 1 {
|
||||
t.Errorf("expected 1 createAccount call, got %d", admin.createAccountCalls)
|
||||
}
|
||||
if admin.createSendKeyCalls != 1 {
|
||||
t.Errorf("expected 1 createSendKey call, got %d", admin.createSendKeyCalls)
|
||||
}
|
||||
if resend.createDomainCalls != 1 {
|
||||
t.Errorf("expected 1 createDomain call, got %d", resend.createDomainCalls)
|
||||
}
|
||||
|
||||
// Verify 2 DNS records were upserted
|
||||
if len(dns.upsertCalls) != 2 {
|
||||
t.Errorf("expected 2 DNS upserts, got %d", len(dns.upsertCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_RollsBackOnProviderFailure(t *testing.T) {
|
||||
admin := &mockAdminClient{
|
||||
createProviderErr: errors.New("provider setup failed"),
|
||||
}
|
||||
p := newTestProvisioner(admin, &mockResendClient{}, nil)
|
||||
|
||||
_, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if admin.deleteHostCalls != 1 {
|
||||
t.Errorf("expected host rollback, got %d deleteHost calls", admin.deleteHostCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_RollsBackOnFromAddressFailure(t *testing.T) {
|
||||
admin := &mockAdminClient{
|
||||
createFromAddressErr: errors.New("from address failed"),
|
||||
}
|
||||
p := newTestProvisioner(admin, &mockResendClient{}, nil)
|
||||
|
||||
_, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if admin.deleteHostCalls != 1 {
|
||||
t.Errorf("expected host rollback, got %d deleteHost calls", admin.deleteHostCalls)
|
||||
}
|
||||
if admin.deleteAccountCalls != 0 {
|
||||
t.Errorf("account not yet created, should not delete account")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_RollsBackOnAccountFailure(t *testing.T) {
|
||||
admin := &mockAdminClient{
|
||||
createAccountErr: errors.New("account creation failed"),
|
||||
}
|
||||
p := newTestProvisioner(admin, &mockResendClient{}, nil)
|
||||
|
||||
_, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if admin.deleteHostCalls != 1 {
|
||||
t.Errorf("expected host rollback, got %d deleteHost calls", admin.deleteHostCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_RollsBackOnSendKeyFailure(t *testing.T) {
|
||||
admin := &mockAdminClient{
|
||||
createSendKeyErr: errors.New("send key creation failed"),
|
||||
}
|
||||
p := newTestProvisioner(admin, &mockResendClient{}, nil)
|
||||
|
||||
_, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if admin.deleteAccountCalls != 1 {
|
||||
t.Errorf("expected account rollback, got %d deleteAccount calls", admin.deleteAccountCalls)
|
||||
}
|
||||
if admin.deleteHostCalls != 1 {
|
||||
t.Errorf("expected host rollback, got %d deleteHost calls", admin.deleteHostCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_ResendFailureIsNonFatal(t *testing.T) {
|
||||
admin := &mockAdminClient{}
|
||||
resend := &mockResendClient{
|
||||
createDomainErr: errors.New("resend API down"),
|
||||
}
|
||||
p := newTestProvisioner(admin, resend, nil)
|
||||
|
||||
creds, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err != nil {
|
||||
t.Fatalf("resend failure should be non-fatal, got error: %v", err)
|
||||
}
|
||||
if creds.ResendDomainID != "" {
|
||||
t.Errorf("expected empty resend domain id on failure, got %s", creds.ResendDomainID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_WithoutResend_SkipsProviderAndDomain(t *testing.T) {
|
||||
admin := &mockAdminClient{}
|
||||
p := newProvisionerWithDeps(admin, nil, "", nil, "test.example", testLogger())
|
||||
|
||||
creds, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error without resend, got %v", err)
|
||||
}
|
||||
|
||||
if admin.createProviderCalls != 0 {
|
||||
t.Errorf("createProvider should not be called without ResendAPIKey, got %d calls", admin.createProviderCalls)
|
||||
}
|
||||
if creds.ResendDomainID != "" {
|
||||
t.Errorf("expected no resend domain id without ResendAPIKey, got %s", creds.ResendDomainID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_DNSFailureIsNonFatal(t *testing.T) {
|
||||
admin := &mockAdminClient{}
|
||||
resend := &mockResendClient{
|
||||
dnsRecords: []resendDNSRecord{{Record: "TXT", Name: "resend._domainkey", Value: "v=DKIM1"}},
|
||||
}
|
||||
dns := &mockDNS{upsertErr: errors.New("cloudflare down")}
|
||||
p := newTestProvisioner(admin, resend, dns)
|
||||
|
||||
creds, err := p.CreateProjectNotify(context.Background(), "proj-123", "happy-fox")
|
||||
if err != nil {
|
||||
t.Fatalf("DNS failure should be non-fatal, got error: %v", err)
|
||||
}
|
||||
// Project still usable; DNS will need manual fix
|
||||
if creds.Host != "mail.happy-fox.test.example" {
|
||||
t.Errorf("creds should still be returned on DNS failure, got host %s", creds.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProjectNotify_Success(t *testing.T) {
|
||||
admin := &mockAdminClient{
|
||||
accounts: []accountResponse{
|
||||
{ID: "acct-001", Name: "project-proj-123"},
|
||||
},
|
||||
}
|
||||
resend := &mockResendClient{}
|
||||
dns := &mockDNS{}
|
||||
p := newTestProvisioner(admin, resend, dns)
|
||||
|
||||
err := p.DeleteProjectNotify(context.Background(), "proj-123", "happy-fox", "resend-domain-id-123")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if admin.deleteAccountCalls != 1 {
|
||||
t.Errorf("expected 1 deleteAccount call, got %d", admin.deleteAccountCalls)
|
||||
}
|
||||
if admin.deleteHostCalls != 1 {
|
||||
t.Errorf("expected 1 deleteHost call, got %d", admin.deleteHostCalls)
|
||||
}
|
||||
if resend.deleteDomainCalls != 1 {
|
||||
t.Errorf("expected 1 deleteDomain call, got %d", resend.deleteDomainCalls)
|
||||
}
|
||||
// 3 DNS records: DKIM TXT, SPF MX, SPF TXT
|
||||
if len(dns.deleteByNameCalls) != 3 {
|
||||
t.Errorf("expected 3 DNS deleteByName calls, got %d", len(dns.deleteByNameCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProjectNotify_NoResendDomainID_SkipsDomainDeletion(t *testing.T) {
|
||||
admin := &mockAdminClient{
|
||||
accounts: []accountResponse{
|
||||
{ID: "acct-001", Name: "project-proj-123"},
|
||||
},
|
||||
}
|
||||
resend := &mockResendClient{}
|
||||
p := newTestProvisioner(admin, resend, nil)
|
||||
|
||||
err := p.DeleteProjectNotify(context.Background(), "proj-123", "happy-fox", "")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if resend.deleteDomainCalls != 0 {
|
||||
t.Errorf("should skip domain deletion when resendDomainID is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProjectNotify_AccountNotFound_ContinuesCleanup(t *testing.T) {
|
||||
// Account doesn't exist (never provisioned or already deleted)
|
||||
admin := &mockAdminClient{accounts: []accountResponse{}}
|
||||
resend := &mockResendClient{}
|
||||
p := newTestProvisioner(admin, resend, nil)
|
||||
|
||||
err := p.DeleteProjectNotify(context.Background(), "proj-123", "happy-fox", "resend-domain-id-123")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error when account not found, got %v", err)
|
||||
}
|
||||
// Should still attempt host and Resend domain deletion
|
||||
if admin.deleteHostCalls != 1 {
|
||||
t.Errorf("expected 1 deleteHost call, got %d", admin.deleteHostCalls)
|
||||
}
|
||||
if resend.deleteDomainCalls != 1 {
|
||||
t.Errorf("expected 1 deleteDomain call, got %d", resend.deleteDomainCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProjectNotify_NotProvisioned(t *testing.T) {
|
||||
admin := &mockAdminClient{}
|
||||
p := newTestProvisioner(admin, nil, nil)
|
||||
|
||||
creds, err := p.GetProjectNotify(context.Background(), "proj-123")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if creds != nil {
|
||||
t.Errorf("expected nil for unprovisioned project, got %+v", creds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProjectNotify_AlreadyProvisioned(t *testing.T) {
|
||||
admin := &mockAdminClient{
|
||||
accounts: []accountResponse{
|
||||
{ID: "acct-001", Name: "project-proj-123", CreatedAt: time.Now()},
|
||||
},
|
||||
}
|
||||
p := newTestProvisioner(admin, nil, nil)
|
||||
|
||||
creds, err := p.GetProjectNotify(context.Background(), "proj-123")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if creds == nil {
|
||||
t.Fatal("expected non-nil credentials for provisioned project")
|
||||
}
|
||||
if creds.AccountID != "acct-001" {
|
||||
t.Errorf("expected account id acct-001, got %s", creds.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProjectNotify_HostUsesBaseDomain(t *testing.T) {
|
||||
admin := &mockAdminClient{}
|
||||
p := newProvisionerWithDeps(admin, nil, "", nil, "staging.example.com", testLogger())
|
||||
|
||||
creds, err := p.CreateProjectNotify(context.Background(), "proj-123", "some-slug")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if creds.Host != "mail.some-slug.staging.example.com" {
|
||||
t.Errorf("expected host to use baseDomain, got %s", creds.Host)
|
||||
}
|
||||
}
|
||||
128
internal/adapter/notify/resend_client.go
Normal file
128
internal/adapter/notify/resend_client.go
Normal file
@ -0,0 +1,128 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const resendBaseURL = "https://api.resend.com"
|
||||
|
||||
// resendAPI is the interface Provisioner uses to call the Resend API.
|
||||
// Extracted for testability.
|
||||
type resendAPI interface {
|
||||
createDomain(ctx context.Context, name, region string) (domainID string, records []resendDNSRecord, err error)
|
||||
verifyDomain(ctx context.Context, domainID string) error
|
||||
deleteDomain(ctx context.Context, domainID string) error
|
||||
}
|
||||
|
||||
// resendClient calls the Resend API for domain management.
|
||||
type resendClient struct {
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// resendDNSRecord is a DNS record returned by Resend after domain creation.
|
||||
// The "record" JSON field contains the DNS record type (e.g., "TXT", "MX").
|
||||
type resendDNSRecord struct {
|
||||
Record string `json:"record"` // DNS record type: "TXT", "MX", "CNAME"
|
||||
Name string `json:"name"` // relative name (e.g., "resend._domainkey")
|
||||
Value string `json:"value"` // record content
|
||||
Priority int `json:"priority,omitempty"`
|
||||
}
|
||||
|
||||
// resendCreateDomainResponse is the shape returned by POST /domains.
|
||||
type resendCreateDomainResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Records []resendDNSRecord `json:"records"`
|
||||
}
|
||||
|
||||
func newResendClient(apiKey string) *resendClient {
|
||||
return &resendClient{
|
||||
apiKey: apiKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// createDomain creates a new Resend domain and returns the domain ID and DNS records to set.
|
||||
func (r *resendClient) createDomain(ctx context.Context, name, region string) (domainID string, records []resendDNSRecord, err error) {
|
||||
payload := map[string]string{"name": name, "region": region}
|
||||
respBody, err := r.doRequest(ctx, http.MethodPost, "/domains", payload)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("create resend domain %s: %w", name, err)
|
||||
}
|
||||
|
||||
var resp resendCreateDomainResponse
|
||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||
return "", nil, fmt.Errorf("unmarshal resend domain response: %w", err)
|
||||
}
|
||||
return resp.ID, resp.Records, nil
|
||||
}
|
||||
|
||||
// verifyDomain triggers domain verification on Resend.
|
||||
func (r *resendClient) verifyDomain(ctx context.Context, domainID string) error {
|
||||
_, err := r.doRequest(ctx, http.MethodPost, "/domains/"+domainID+"/verify", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify resend domain %s: %w", domainID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteDomain removes a Resend domain by ID.
|
||||
func (r *resendClient) deleteDomain(ctx context.Context, domainID string) error {
|
||||
_, err := r.doRequest(ctx, http.MethodDelete, "/domains/"+domainID, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete resend domain %s: %w", domainID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// doRequest executes an HTTP request against the Resend API.
|
||||
func (r *resendClient) doRequest(ctx context.Context, method, path string, bodyData any) ([]byte, error) {
|
||||
var reqBody io.Reader
|
||||
if bodyData != nil {
|
||||
jsonBody, err := json.Marshal(bodyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewReader(jsonBody)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, resendBaseURL+path, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+r.apiKey)
|
||||
if bodyData != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("http do: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("resend API error (HTTP %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
@ -203,6 +203,118 @@ func (r *APIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.APIKeyI
|
||||
return err
|
||||
}
|
||||
|
||||
// Update applies a partial update to an API key.
|
||||
func (r *APIKeyRepository) Update(ctx context.Context, id domain.APIKeyID, update port.APIKeyUpdate) error {
|
||||
// Build SET clauses dynamically based on non-nil fields
|
||||
setClauses := []string{}
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
|
||||
if update.Name != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("name = $%d", argIdx))
|
||||
args = append(args, *update.Name)
|
||||
argIdx++
|
||||
}
|
||||
if update.Scopes != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("scopes = $%d", argIdx))
|
||||
args = append(args, pq.Array(scopesToStrings(update.Scopes)))
|
||||
argIdx++
|
||||
}
|
||||
if update.ProjectIDs != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("project_ids = $%d", argIdx))
|
||||
args = append(args, pq.Array(projectIDsToStrings(*update.ProjectIDs)))
|
||||
argIdx++
|
||||
}
|
||||
if update.AllowedIPs != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("allowed_ips = $%d", argIdx))
|
||||
args = append(args, pq.Array(*update.AllowedIPs))
|
||||
argIdx++
|
||||
}
|
||||
if update.ExpiresAt != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("expires_at = $%d", argIdx))
|
||||
args = append(args, *update.ExpiresAt)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(setClauses) == 0 {
|
||||
return nil // nothing to update
|
||||
}
|
||||
|
||||
args = append(args, string(id))
|
||||
query := fmt.Sprintf("UPDATE api_keys SET %s WHERE id = $%d AND revoked_at IS NULL",
|
||||
joinStrings(setClauses, ", "), argIdx)
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update key: %w", err)
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByProjectID returns all active keys that have the given project ID in their project_ids.
|
||||
func (r *APIKeyRepository) ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*domain.APIKey, error) {
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, name, key_prefix, scopes, project_ids, allowed_ips, created_at, expires_at, last_used_at, revoked_at, created_by
|
||||
FROM api_keys
|
||||
WHERE $1 = ANY(project_ids) AND revoked_at IS NULL
|
||||
ORDER BY created_at DESC
|
||||
`, string(projectID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query keys by project: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var keys []*domain.APIKey
|
||||
for rows.Next() {
|
||||
var (
|
||||
key domain.APIKey
|
||||
id string
|
||||
scopeStrings []string
|
||||
projectIDs []string
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&id,
|
||||
&key.Name,
|
||||
&key.KeyPrefix,
|
||||
pq.Array(&scopeStrings),
|
||||
pq.Array(&projectIDs),
|
||||
pq.Array(&key.AllowedIPs),
|
||||
&key.CreatedAt,
|
||||
&key.ExpiresAt,
|
||||
&key.LastUsedAt,
|
||||
&key.RevokedAt,
|
||||
&key.CreatedBy,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan key: %w", err)
|
||||
}
|
||||
key.ID = domain.APIKeyID(id)
|
||||
key.Scopes = scopesFromStrings(scopeStrings)
|
||||
key.ProjectIDs = projectIDsFromStrings(projectIDs)
|
||||
keys = append(keys, &key)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// joinStrings joins string slices with a separator (avoids importing strings in this file).
|
||||
func joinStrings(ss []string, sep string) string {
|
||||
result := ""
|
||||
for i, s := range ss {
|
||||
if i > 0 {
|
||||
result += sep
|
||||
}
|
||||
result += s
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper functions for scope conversion
|
||||
func scopesToStrings(scopes []domain.Scope) []string {
|
||||
ss := make([]string, len(scopes))
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { Routes, Route, useLocation, useNavigate } from 'react-router-dom';
|
||||
import { useState, useEffect } from 'react';
|
||||
import { Routes, Route, useLocation, useNavigate, useSearchParams, Link } from 'react-router-dom';
|
||||
import { AuthProvider, useAuth, ProtectedRoute } from '@{{PROJECT_NAME}}/auth';
|
||||
import { DashboardShell, Sidebar, Header, type NavItem } from '@{{PROJECT_NAME}}/layout';
|
||||
import {
|
||||
@ -17,8 +18,14 @@ import {
|
||||
MessageSquare,
|
||||
Sparkles,
|
||||
Loader2,
|
||||
AlertCircle,
|
||||
} from '@{{PROJECT_NAME}}/ui';
|
||||
import { LoginPage } from './pages/LoginPage';
|
||||
import { RegisterPage } from './pages/RegisterPage';
|
||||
import { ForgotPasswordPage } from './pages/ForgotPasswordPage';
|
||||
import { ResetPasswordPage } from './pages/ResetPasswordPage';
|
||||
import { VerifyEmailPage } from './pages/VerifyEmailPage';
|
||||
import { SessionsPage } from './pages/SessionsPage';
|
||||
import { ChatPage } from './pages/ChatPage';
|
||||
import { GeneratePage } from './pages/GeneratePage';
|
||||
import { MediaPage } from './pages/MediaPage';
|
||||
@ -41,6 +48,8 @@ const pageTitles: Record<string, string> = {
|
||||
'/analytics': 'Analytics',
|
||||
'/users': 'Users',
|
||||
'/settings': 'Settings',
|
||||
'/settings/sessions': 'Sessions',
|
||||
'/settings/verify-email': 'Verify Email',
|
||||
};
|
||||
|
||||
function DashboardPage() {
|
||||
@ -287,6 +296,64 @@ function LoadingScreen() {
|
||||
);
|
||||
}
|
||||
|
||||
function MagicLinkCallbackPage() {
|
||||
const [searchParams] = useSearchParams();
|
||||
const navigate = useNavigate();
|
||||
const { loginWithMagicLink } = useAuth();
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [verifying, setVerifying] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
const token = searchParams.get('token');
|
||||
const email = searchParams.get('email');
|
||||
|
||||
if (!token || !email) {
|
||||
setError('Invalid magic link. Missing token or email.');
|
||||
setVerifying(false);
|
||||
return;
|
||||
}
|
||||
|
||||
loginWithMagicLink({ email, token })
|
||||
.then(() => {
|
||||
navigate('/', { replace: true });
|
||||
})
|
||||
.catch((err) => {
|
||||
setError(err instanceof Error ? err.message : 'Magic link verification failed.');
|
||||
setVerifying(false);
|
||||
});
|
||||
}, [searchParams, loginWithMagicLink, navigate]);
|
||||
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-[var(--surface-100)]">
|
||||
<Card className="w-full max-w-md">
|
||||
<CardHeader className="text-center">
|
||||
<CardTitle>{verifying ? 'Verifying Magic Link' : 'Verification Failed'}</CardTitle>
|
||||
<CardDescription>
|
||||
{verifying
|
||||
? 'Please wait while we verify your magic link...'
|
||||
: 'We could not verify your magic link.'}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="flex flex-col items-center gap-4">
|
||||
{verifying ? (
|
||||
<Loader2 className="h-8 w-8 animate-spin text-[var(--accent)]" />
|
||||
) : (
|
||||
<>
|
||||
<div className="flex items-center gap-2 text-[var(--text-error, #ef4444)]">
|
||||
<AlertCircle className="h-5 w-5" />
|
||||
<p className="text-sm">{error}</p>
|
||||
</div>
|
||||
<Link to="/login">
|
||||
<Button variant="outline">Back to Login</Button>
|
||||
</Link>
|
||||
</>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function AppLayout() {
|
||||
const location = useLocation();
|
||||
const navigate = useNavigate();
|
||||
@ -331,6 +398,8 @@ function AppLayout() {
|
||||
<Route path="/users" element={<UsersPage />} />
|
||||
<Route path="/analytics" element={<AnalyticsPage />} />
|
||||
<Route path="/settings" element={<SettingsPage />} />
|
||||
<Route path="/settings/sessions" element={<SessionsPage />} />
|
||||
<Route path="/settings/verify-email" element={<VerifyEmailPage />} />
|
||||
</Routes>
|
||||
</DashboardShell>
|
||||
);
|
||||
@ -343,6 +412,10 @@ function AppRoutes() {
|
||||
return (
|
||||
<Routes>
|
||||
<Route path="/login" element={<LoginPage />} />
|
||||
<Route path="/register" element={<RegisterPage />} />
|
||||
<Route path="/forgot-password" element={<ForgotPasswordPage />} />
|
||||
<Route path="/reset-password" element={<ResetPasswordPage />} />
|
||||
<Route path="/auth/magic-link/callback" element={<MagicLinkCallbackPage />} />
|
||||
<Route
|
||||
path="/*"
|
||||
element={
|
||||
@ -367,7 +440,7 @@ function App() {
|
||||
const apiBaseUrl = import.meta.env.VITE_API_URL || '';
|
||||
|
||||
return (
|
||||
<AuthProvider loginUrl={`${apiBaseUrl}/api/{{SERVICE_NAME}}/auth/login`}>
|
||||
<AuthProvider authBaseUrl={`${apiBaseUrl}/api/{{SERVICE_NAME}}`}>
|
||||
<AppRoutes />
|
||||
</AuthProvider>
|
||||
);
|
||||
|
||||
@ -0,0 +1,110 @@
|
||||
import { useState } from 'react';
|
||||
import { Link } from 'react-router-dom';
|
||||
import {
|
||||
Button,
|
||||
Card,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
CardDescription,
|
||||
CardContent,
|
||||
CardFooter,
|
||||
FormField,
|
||||
Alert,
|
||||
AlertDescription,
|
||||
Loader2,
|
||||
} from '@{{PROJECT_NAME}}/ui';
|
||||
|
||||
export function ForgotPasswordPage() {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [sent, setSent] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const apiPrefix = import.meta.env.VITE_API_URL || '';
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
setError(null);
|
||||
setIsLoading(true);
|
||||
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const email = formData.get('email') as string;
|
||||
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/auth/forgot-password`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ email }),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const body = await res.json().catch(() => ({}));
|
||||
throw new Error(body.error?.message || body.message || 'Request failed');
|
||||
}
|
||||
|
||||
setSent(true);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'An error occurred');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-[var(--surface-100)] p-4">
|
||||
<Card className="w-full max-w-md">
|
||||
<CardHeader className="text-center">
|
||||
<CardTitle className="text-2xl">Reset your password</CardTitle>
|
||||
<CardDescription>
|
||||
{sent
|
||||
? 'Check your email for a reset link'
|
||||
: "Enter your email and we'll send you a reset link"}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
{!sent ? (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<CardContent className="space-y-4">
|
||||
{error && (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
<FormField
|
||||
label="Email"
|
||||
name="email"
|
||||
type="email"
|
||||
placeholder="you@example.com"
|
||||
required
|
||||
autoComplete="email"
|
||||
autoFocus
|
||||
/>
|
||||
</CardContent>
|
||||
|
||||
<CardFooter className="flex flex-col gap-4">
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Send reset link
|
||||
</Button>
|
||||
<Link to="/login" className="text-sm text-center text-[var(--accent)] hover:underline">
|
||||
Back to sign in
|
||||
</Link>
|
||||
</CardFooter>
|
||||
</form>
|
||||
) : (
|
||||
<CardContent className="space-y-4 text-center">
|
||||
<p className="text-sm text-[var(--text-muted)]">
|
||||
If an account exists with that email, you will receive a password reset link.
|
||||
</p>
|
||||
<p className="text-xs text-[var(--text-muted)]">
|
||||
In dev mode, check the server console for the reset token.
|
||||
</p>
|
||||
<Link to="/login" className="text-sm text-[var(--accent)] hover:underline">
|
||||
Back to sign in
|
||||
</Link>
|
||||
</CardContent>
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
import { useState } from 'react';
|
||||
import { useNavigate, useLocation } from 'react-router-dom';
|
||||
import { useNavigate, useLocation, Link } from 'react-router-dom';
|
||||
import { useAuth } from '@{{PROJECT_NAME}}/auth';
|
||||
import {
|
||||
Button,
|
||||
@ -17,17 +17,22 @@ import {
|
||||
} from '@{{PROJECT_NAME}}/ui';
|
||||
import { isApiClientError } from '@{{PROJECT_NAME}}/api-client';
|
||||
|
||||
type LoginTab = 'password' | 'otp' | 'magic-link';
|
||||
|
||||
export function LoginPage() {
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
const { login, isLoading } = useAuth();
|
||||
const { login, sendOTP, loginWithOTP, sendMagicLink, isLoading } = useAuth();
|
||||
const { setErrors, clearErrors, getError } = useFormErrors();
|
||||
const [generalError, setGeneralError] = useState<string | null>(null);
|
||||
const [activeTab, setActiveTab] = useState<LoginTab>('password');
|
||||
const [otpSent, setOtpSent] = useState(false);
|
||||
const [otpEmail, setOtpEmail] = useState('');
|
||||
const [magicLinkSent, setMagicLinkSent] = useState(false);
|
||||
|
||||
// Get the redirect path from location state, default to dashboard
|
||||
const from = (location.state as { from?: string })?.from || '/';
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
const handlePasswordLogin = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
clearErrors();
|
||||
setGeneralError(null);
|
||||
@ -46,12 +51,70 @@ export function LoginPage() {
|
||||
} else {
|
||||
setGeneralError(error.message);
|
||||
}
|
||||
} else if (error instanceof Error) {
|
||||
setGeneralError(error.message);
|
||||
} else {
|
||||
setGeneralError('An unexpected error occurred. Please try again.');
|
||||
setGeneralError('An unexpected error occurred.');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleSendOTP = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
clearErrors();
|
||||
setGeneralError(null);
|
||||
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const email = formData.get('email') as string;
|
||||
|
||||
try {
|
||||
await sendOTP(email);
|
||||
setOtpEmail(email);
|
||||
setOtpSent(true);
|
||||
} catch (error) {
|
||||
setGeneralError(error instanceof Error ? error.message : 'Failed to send code');
|
||||
}
|
||||
};
|
||||
|
||||
const handleVerifyOTP = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
clearErrors();
|
||||
setGeneralError(null);
|
||||
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const code = formData.get('code') as string;
|
||||
|
||||
try {
|
||||
await loginWithOTP({ email: otpEmail, code });
|
||||
navigate(from, { replace: true });
|
||||
} catch (error) {
|
||||
setGeneralError(error instanceof Error ? error.message : 'Invalid code');
|
||||
}
|
||||
};
|
||||
|
||||
const handleSendMagicLink = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
clearErrors();
|
||||
setGeneralError(null);
|
||||
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const email = formData.get('email') as string;
|
||||
|
||||
try {
|
||||
await sendMagicLink(email);
|
||||
setMagicLinkSent(true);
|
||||
} catch (error) {
|
||||
setGeneralError(error instanceof Error ? error.message : 'Failed to send link');
|
||||
}
|
||||
};
|
||||
|
||||
const tabClass = (tab: LoginTab) =>
|
||||
`flex-1 py-2 text-sm font-medium text-center rounded-md transition-colors ${
|
||||
activeTab === tab
|
||||
? 'bg-[var(--surface-100)] text-[var(--text-primary)] shadow-sm'
|
||||
: 'text-[var(--text-muted)] hover:text-[var(--text-secondary)]'
|
||||
}`;
|
||||
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-[var(--surface-100)] p-4">
|
||||
<Card className="w-full max-w-md">
|
||||
@ -60,14 +123,29 @@ export function LoginPage() {
|
||||
<CardDescription>Sign in to your {{PROJECT_NAME}} account</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
<form onSubmit={handleSubmit}>
|
||||
<CardContent className="space-y-4">
|
||||
{/* Tab switcher */}
|
||||
<div className="flex gap-1 p-1 rounded-lg bg-[var(--surface-200)]">
|
||||
<button type="button" className={tabClass('password')} onClick={() => { setActiveTab('password'); clearErrors(); setGeneralError(null); }}>
|
||||
Password
|
||||
</button>
|
||||
<button type="button" className={tabClass('otp')} onClick={() => { setActiveTab('otp'); clearErrors(); setGeneralError(null); setOtpSent(false); }}>
|
||||
OTP
|
||||
</button>
|
||||
<button type="button" className={tabClass('magic-link')} onClick={() => { setActiveTab('magic-link'); clearErrors(); setGeneralError(null); setMagicLinkSent(false); }}>
|
||||
Magic Link
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{generalError && (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{generalError}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Password tab */}
|
||||
{activeTab === 'password' && (
|
||||
<form onSubmit={handlePasswordLogin} className="space-y-4">
|
||||
<FormField
|
||||
label="Email"
|
||||
name="email"
|
||||
@ -78,7 +156,6 @@ export function LoginPage() {
|
||||
autoComplete="email"
|
||||
autoFocus
|
||||
/>
|
||||
|
||||
<FormField
|
||||
label="Password"
|
||||
name="password"
|
||||
@ -88,21 +165,105 @@ export function LoginPage() {
|
||||
required
|
||||
autoComplete="current-password"
|
||||
/>
|
||||
</CardContent>
|
||||
|
||||
<CardFooter className="flex flex-col gap-4">
|
||||
<div className="text-right">
|
||||
<Link to="/forgot-password" className="text-sm text-[var(--accent)] hover:underline">
|
||||
Forgot password?
|
||||
</Link>
|
||||
</div>
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Sign in
|
||||
</Button>
|
||||
</form>
|
||||
)}
|
||||
|
||||
{/* OTP tab */}
|
||||
{activeTab === 'otp' && !otpSent && (
|
||||
<form onSubmit={handleSendOTP} className="space-y-4">
|
||||
<FormField
|
||||
label="Email"
|
||||
name="email"
|
||||
type="email"
|
||||
placeholder="you@example.com"
|
||||
required
|
||||
autoComplete="email"
|
||||
autoFocus
|
||||
/>
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Send code
|
||||
</Button>
|
||||
</form>
|
||||
)}
|
||||
|
||||
{activeTab === 'otp' && otpSent && (
|
||||
<form onSubmit={handleVerifyOTP} className="space-y-4">
|
||||
<p className="text-sm text-[var(--text-muted)]">
|
||||
A 6-digit code was sent to <strong>{otpEmail}</strong>
|
||||
</p>
|
||||
<FormField
|
||||
label="Code"
|
||||
name="code"
|
||||
type="text"
|
||||
placeholder="000000"
|
||||
required
|
||||
autoComplete="one-time-code"
|
||||
autoFocus
|
||||
/>
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Verify
|
||||
</Button>
|
||||
<button
|
||||
type="button"
|
||||
className="w-full text-sm text-[var(--text-muted)] hover:text-[var(--text-secondary)]"
|
||||
onClick={() => setOtpSent(false)}
|
||||
>
|
||||
Use a different email
|
||||
</button>
|
||||
</form>
|
||||
)}
|
||||
|
||||
{/* Magic Link tab */}
|
||||
{activeTab === 'magic-link' && !magicLinkSent && (
|
||||
<form onSubmit={handleSendMagicLink} className="space-y-4">
|
||||
<FormField
|
||||
label="Email"
|
||||
name="email"
|
||||
type="email"
|
||||
placeholder="you@example.com"
|
||||
required
|
||||
autoComplete="email"
|
||||
autoFocus
|
||||
/>
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Send magic link
|
||||
</Button>
|
||||
</form>
|
||||
)}
|
||||
|
||||
{activeTab === 'magic-link' && magicLinkSent && (
|
||||
<div className="text-center space-y-3 py-4">
|
||||
<p className="text-sm text-[var(--text-primary)]">Check your email</p>
|
||||
<p className="text-sm text-[var(--text-muted)]">
|
||||
We sent a sign-in link to your email. Click it to continue.
|
||||
</p>
|
||||
<p className="text-xs text-[var(--text-muted)]">
|
||||
In dev mode, check the server console for the link token.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
|
||||
<CardFooter className="flex flex-col gap-3">
|
||||
<p className="text-sm text-center text-[var(--text-muted)]">
|
||||
Demo accounts: test@example.com / password123
|
||||
<br />
|
||||
or admin@example.com / admin123
|
||||
Don't have an account?{' '}
|
||||
<Link to="/register" className="text-[var(--accent)] hover:underline">
|
||||
Sign up
|
||||
</Link>
|
||||
</p>
|
||||
</CardFooter>
|
||||
</form>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -61,15 +61,15 @@ export function MediaPage() {
|
||||
fetchMedia();
|
||||
}, [fetchMedia]);
|
||||
|
||||
const handleDelete = useCallback(async (path: string) => {
|
||||
const handleDelete = useCallback(async (id: string) => {
|
||||
setDeleteError(null);
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/media/${path}`, {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/media/${id}`, {
|
||||
method: 'DELETE',
|
||||
headers: { ...authHeaders },
|
||||
});
|
||||
if (!res.ok) throw new Error(`Delete failed: ${res.status}`);
|
||||
setItems((prev) => prev.filter((item) => item.path !== path));
|
||||
setItems((prev) => prev.filter((item) => item.id !== id));
|
||||
} catch (err) {
|
||||
setDeleteError(err instanceof Error ? err.message : 'Delete failed');
|
||||
}
|
||||
|
||||
@ -0,0 +1,128 @@
|
||||
import { useState } from 'react';
|
||||
import { useNavigate, Link } from 'react-router-dom';
|
||||
import { useAuth } from '@{{PROJECT_NAME}}/auth';
|
||||
import {
|
||||
Button,
|
||||
Card,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
CardDescription,
|
||||
CardContent,
|
||||
CardFooter,
|
||||
FormField,
|
||||
useFormErrors,
|
||||
Alert,
|
||||
AlertDescription,
|
||||
Loader2,
|
||||
} from '@{{PROJECT_NAME}}/ui';
|
||||
|
||||
export function RegisterPage() {
|
||||
const navigate = useNavigate();
|
||||
const { register, isLoading } = useAuth();
|
||||
const { setErrors, clearErrors, getError } = useFormErrors();
|
||||
const [generalError, setGeneralError] = useState<string | null>(null);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
clearErrors();
|
||||
setGeneralError(null);
|
||||
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const email = formData.get('email') as string;
|
||||
const password = formData.get('password') as string;
|
||||
const confirmPassword = formData.get('confirmPassword') as string;
|
||||
const name = formData.get('name') as string;
|
||||
|
||||
if (password !== confirmPassword) {
|
||||
setErrors({ confirmPassword: 'Passwords do not match' });
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await register({ email, password, name });
|
||||
navigate('/', { replace: true });
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
setGeneralError(error.message);
|
||||
} else {
|
||||
setGeneralError('An unexpected error occurred.');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-[var(--surface-100)] p-4">
|
||||
<Card className="w-full max-w-md">
|
||||
<CardHeader className="text-center">
|
||||
<CardTitle className="text-2xl">Create an account</CardTitle>
|
||||
<CardDescription>Get started with {{PROJECT_NAME}}</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
<form onSubmit={handleSubmit}>
|
||||
<CardContent className="space-y-4">
|
||||
{generalError && (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{generalError}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
<FormField
|
||||
label="Name"
|
||||
name="name"
|
||||
type="text"
|
||||
placeholder="Your name"
|
||||
error={getError('name')}
|
||||
autoComplete="name"
|
||||
autoFocus
|
||||
/>
|
||||
|
||||
<FormField
|
||||
label="Email"
|
||||
name="email"
|
||||
type="email"
|
||||
placeholder="you@example.com"
|
||||
error={getError('email')}
|
||||
required
|
||||
autoComplete="email"
|
||||
/>
|
||||
|
||||
<FormField
|
||||
label="Password"
|
||||
name="password"
|
||||
type="password"
|
||||
placeholder="At least 8 characters"
|
||||
error={getError('password')}
|
||||
required
|
||||
autoComplete="new-password"
|
||||
description="Must contain uppercase, lowercase, and a number"
|
||||
/>
|
||||
|
||||
<FormField
|
||||
label="Confirm Password"
|
||||
name="confirmPassword"
|
||||
type="password"
|
||||
placeholder="Repeat your password"
|
||||
error={getError('confirmPassword')}
|
||||
required
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
</CardContent>
|
||||
|
||||
<CardFooter className="flex flex-col gap-4">
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Create account
|
||||
</Button>
|
||||
|
||||
<p className="text-sm text-center text-[var(--text-muted)]">
|
||||
Already have an account?{' '}
|
||||
<Link to="/login" className="text-[var(--accent)] hover:underline">
|
||||
Sign in
|
||||
</Link>
|
||||
</p>
|
||||
</CardFooter>
|
||||
</form>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,135 @@
|
||||
import { useState } from 'react';
|
||||
import { useNavigate, useSearchParams, Link } from 'react-router-dom';
|
||||
import {
|
||||
Button,
|
||||
Card,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
CardDescription,
|
||||
CardContent,
|
||||
CardFooter,
|
||||
FormField,
|
||||
useFormErrors,
|
||||
Alert,
|
||||
AlertDescription,
|
||||
Loader2,
|
||||
} from '@{{PROJECT_NAME}}/ui';
|
||||
|
||||
export function ResetPasswordPage() {
|
||||
const navigate = useNavigate();
|
||||
const [searchParams] = useSearchParams();
|
||||
const { setErrors, clearErrors, getError } = useFormErrors();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [generalError, setGeneralError] = useState<string | null>(null);
|
||||
|
||||
const token = searchParams.get('token') || '';
|
||||
const email = searchParams.get('email') || '';
|
||||
const apiPrefix = import.meta.env.VITE_API_URL || '';
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
clearErrors();
|
||||
setGeneralError(null);
|
||||
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const newPassword = formData.get('newPassword') as string;
|
||||
const confirmPassword = formData.get('confirmPassword') as string;
|
||||
|
||||
if (newPassword !== confirmPassword) {
|
||||
setErrors({ confirmPassword: 'Passwords do not match' });
|
||||
return;
|
||||
}
|
||||
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/auth/reset-password`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ email, token, newPassword }),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const body = await res.json().catch(() => ({}));
|
||||
throw new Error(body.error?.message || body.message || 'Reset failed');
|
||||
}
|
||||
|
||||
navigate('/login', { state: { message: 'Password reset successfully. Please sign in.' } });
|
||||
} catch (err) {
|
||||
setGeneralError(err instanceof Error ? err.message : 'An error occurred');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (!token || !email) {
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-[var(--surface-100)] p-4">
|
||||
<Card className="w-full max-w-md">
|
||||
<CardHeader className="text-center">
|
||||
<CardTitle className="text-2xl">Invalid reset link</CardTitle>
|
||||
<CardDescription>This password reset link is missing required parameters.</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="text-center">
|
||||
<Link to="/forgot-password" className="text-sm text-[var(--accent)] hover:underline">
|
||||
Request a new reset link
|
||||
</Link>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen flex items-center justify-center bg-[var(--surface-100)] p-4">
|
||||
<Card className="w-full max-w-md">
|
||||
<CardHeader className="text-center">
|
||||
<CardTitle className="text-2xl">Set new password</CardTitle>
|
||||
<CardDescription>Enter your new password below</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
<form onSubmit={handleSubmit}>
|
||||
<CardContent className="space-y-4">
|
||||
{generalError && (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{generalError}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
<FormField
|
||||
label="New Password"
|
||||
name="newPassword"
|
||||
type="password"
|
||||
placeholder="At least 8 characters"
|
||||
error={getError('newPassword')}
|
||||
required
|
||||
autoComplete="new-password"
|
||||
autoFocus
|
||||
description="Must contain uppercase, lowercase, and a number"
|
||||
/>
|
||||
|
||||
<FormField
|
||||
label="Confirm Password"
|
||||
name="confirmPassword"
|
||||
type="password"
|
||||
placeholder="Repeat your new password"
|
||||
error={getError('confirmPassword')}
|
||||
required
|
||||
autoComplete="new-password"
|
||||
/>
|
||||
</CardContent>
|
||||
|
||||
<CardFooter className="flex flex-col gap-4">
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Reset password
|
||||
</Button>
|
||||
<Link to="/login" className="text-sm text-center text-[var(--accent)] hover:underline">
|
||||
Back to sign in
|
||||
</Link>
|
||||
</CardFooter>
|
||||
</form>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,178 @@
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { useAuth, type Session } from '@{{PROJECT_NAME}}/auth';
|
||||
import {
|
||||
Button,
|
||||
Card,
|
||||
CardContent,
|
||||
Badge,
|
||||
Alert,
|
||||
AlertDescription,
|
||||
Loader2,
|
||||
Trash2,
|
||||
} from '@{{PROJECT_NAME}}/ui';
|
||||
|
||||
export function SessionsPage() {
|
||||
const { getToken } = useAuth();
|
||||
const [sessions, setSessions] = useState<Session[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [revokingId, setRevokingId] = useState<string | null>(null);
|
||||
|
||||
const apiPrefix = import.meta.env.VITE_API_URL || '';
|
||||
|
||||
const authHeaders = useCallback(() => {
|
||||
const token = getToken();
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
...(token ? { 'Authorization': `Bearer ${token}` } : {}),
|
||||
};
|
||||
}, [getToken]);
|
||||
|
||||
const loadSessions = useCallback(async () => {
|
||||
setError(null);
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/auth/sessions`, {
|
||||
headers: authHeaders(),
|
||||
});
|
||||
if (!res.ok) throw new Error('Failed to load sessions');
|
||||
const data = await res.json();
|
||||
setSessions(data.data || data || []);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to load sessions');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [apiPrefix, authHeaders]);
|
||||
|
||||
useEffect(() => {
|
||||
loadSessions();
|
||||
}, [loadSessions]);
|
||||
|
||||
const revokeSession = async (sessionId: string) => {
|
||||
setRevokingId(sessionId);
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/auth/sessions/${sessionId}`, {
|
||||
method: 'DELETE',
|
||||
headers: authHeaders(),
|
||||
});
|
||||
if (!res.ok) throw new Error('Failed to revoke session');
|
||||
setSessions(prev => prev.filter(s => s.id !== sessionId));
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to revoke session');
|
||||
} finally {
|
||||
setRevokingId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const revokeAll = async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/auth/sessions`, {
|
||||
method: 'DELETE',
|
||||
headers: authHeaders(),
|
||||
});
|
||||
if (!res.ok) throw new Error('Failed to revoke sessions');
|
||||
await loadSessions();
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'Failed to revoke sessions');
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const formatDate = (dateStr: string) => {
|
||||
const date = new Date(dateStr);
|
||||
const now = new Date();
|
||||
const diffMs = now.getTime() - date.getTime();
|
||||
const diffMin = Math.floor(diffMs / 60000);
|
||||
const diffHr = Math.floor(diffMs / 3600000);
|
||||
const diffDay = Math.floor(diffMs / 86400000);
|
||||
|
||||
if (diffMin < 1) return 'Just now';
|
||||
if (diffMin < 60) return `${diffMin}m ago`;
|
||||
if (diffHr < 24) return `${diffHr}h ago`;
|
||||
if (diffDay < 7) return `${diffDay}d ago`;
|
||||
return date.toLocaleDateString();
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex justify-center py-12">
|
||||
<Loader2 className="h-8 w-8 animate-spin text-[var(--text-muted)]" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h2 className="text-lg font-semibold text-[var(--text-primary)]">Active Sessions</h2>
|
||||
<p className="text-sm text-[var(--text-muted)]">
|
||||
Manage your active login sessions across devices.
|
||||
</p>
|
||||
</div>
|
||||
{sessions.length > 1 && (
|
||||
<Button variant="outline" onClick={revokeAll}>
|
||||
Revoke all other sessions
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
<div className="space-y-3">
|
||||
{sessions.length === 0 ? (
|
||||
<Card>
|
||||
<CardContent className="py-8 text-center text-[var(--text-muted)]">
|
||||
No active sessions found.
|
||||
</CardContent>
|
||||
</Card>
|
||||
) : (
|
||||
sessions.map((session) => (
|
||||
<Card key={session.id}>
|
||||
<CardContent className="flex items-center justify-between py-4">
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium text-sm text-[var(--text-primary)]">
|
||||
{session.deviceLabel || 'Unknown device'}
|
||||
</span>
|
||||
{session.isCurrent && (
|
||||
<Badge variant="success">Current</Badge>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-3 mt-1">
|
||||
<span className="text-xs text-[var(--text-muted)]">
|
||||
{session.ipAddress || 'Unknown IP'}
|
||||
</span>
|
||||
<span className="text-xs text-[var(--text-muted)]">
|
||||
Last active: {formatDate(session.lastActiveAt)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{!session.isCurrent && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => revokeSession(session.id)}
|
||||
disabled={revokingId === session.id}
|
||||
>
|
||||
{revokingId === session.id ? (
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<Trash2 className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,161 @@
|
||||
import { useState } from 'react';
|
||||
import { useAuth } from '@{{PROJECT_NAME}}/auth';
|
||||
import {
|
||||
Button,
|
||||
Card,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
CardDescription,
|
||||
CardContent,
|
||||
FormField,
|
||||
Alert,
|
||||
AlertDescription,
|
||||
Loader2,
|
||||
Check,
|
||||
} from '@{{PROJECT_NAME}}/ui';
|
||||
|
||||
export function VerifyEmailPage() {
|
||||
const { user, getToken } = useAuth();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [codeSent, setCodeSent] = useState(false);
|
||||
const [verified, setVerified] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const apiPrefix = import.meta.env.VITE_API_URL || '';
|
||||
|
||||
const authHeaders = () => {
|
||||
const token = getToken();
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
...(token ? { 'Authorization': `Bearer ${token}` } : {}),
|
||||
};
|
||||
};
|
||||
|
||||
const handleSendCode = async () => {
|
||||
setError(null);
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/auth/verify-email/send`, {
|
||||
method: 'POST',
|
||||
headers: authHeaders(),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const body = await res.json().catch(() => ({}));
|
||||
throw new Error(body.error?.message || body.message || 'Failed to send code');
|
||||
}
|
||||
|
||||
setCodeSent(true);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'An error occurred');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleVerify = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
setError(null);
|
||||
setIsLoading(true);
|
||||
|
||||
const formData = new FormData(e.currentTarget);
|
||||
const code = formData.get('code') as string;
|
||||
|
||||
try {
|
||||
const res = await fetch(`${apiPrefix}/api/{{SERVICE_NAME}}/auth/verify-email`, {
|
||||
method: 'POST',
|
||||
headers: authHeaders(),
|
||||
body: JSON.stringify({ code }),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const body = await res.json().catch(() => ({}));
|
||||
throw new Error(body.error?.message || body.message || 'Verification failed');
|
||||
}
|
||||
|
||||
setVerified(true);
|
||||
} catch (err) {
|
||||
setError(err instanceof Error ? err.message : 'An error occurred');
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (verified) {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<Card className="max-w-md mx-auto">
|
||||
<CardHeader className="text-center">
|
||||
<div className="flex justify-center mb-2">
|
||||
<Check className="h-12 w-12 text-green-500" />
|
||||
</div>
|
||||
<CardTitle>Email verified</CardTitle>
|
||||
<CardDescription>Your email address has been verified successfully.</CardDescription>
|
||||
</CardHeader>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<Card className="max-w-md mx-auto">
|
||||
<CardHeader className="text-center">
|
||||
<CardTitle>Verify your email</CardTitle>
|
||||
<CardDescription>
|
||||
{codeSent
|
||||
? `Enter the 6-digit code sent to ${user?.email}`
|
||||
: 'Verify your email address to access all features'}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
|
||||
<CardContent className="space-y-4">
|
||||
{error && (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{!codeSent ? (
|
||||
<div className="text-center">
|
||||
<p className="text-sm text-[var(--text-muted)] mb-4">
|
||||
We'll send a verification code to <strong>{user?.email}</strong>
|
||||
</p>
|
||||
<Button onClick={handleSendCode} disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Send verification code
|
||||
</Button>
|
||||
<p className="text-xs text-[var(--text-muted)] mt-2">
|
||||
In dev mode, check the server console for the code.
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<form onSubmit={handleVerify} className="space-y-4">
|
||||
<FormField
|
||||
label="Verification Code"
|
||||
name="code"
|
||||
type="text"
|
||||
placeholder="000000"
|
||||
required
|
||||
autoComplete="one-time-code"
|
||||
autoFocus
|
||||
/>
|
||||
<Button type="submit" className="w-full" disabled={isLoading}>
|
||||
{isLoading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
Verify email
|
||||
</Button>
|
||||
<button
|
||||
type="button"
|
||||
className="w-full text-sm text-[var(--text-muted)] hover:text-[var(--text-secondary)]"
|
||||
onClick={handleSendCode}
|
||||
>
|
||||
Resend code
|
||||
</button>
|
||||
</form>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -15,7 +15,15 @@ LOG_FORMAT=text
|
||||
|
||||
# Auth (set AUTH_ENABLED=true to require JWT for protected routes)
|
||||
AUTH_ENABLED=false
|
||||
JWT_SECRET=dev-secret-change-in-production
|
||||
JWT_SECRET=dev-secret-change-in-production # Required — server refuses to start with empty secret
|
||||
REGISTRATION_ENABLED=true
|
||||
|
||||
# Email delivery (notify service)
|
||||
# When NOTIFY_URL is empty, auth codes are logged to stdout (dev mode).
|
||||
# NOTIFY_URL=https://notify.threesix.ai
|
||||
# NOTIFY_API_KEY=notify_send_xxx
|
||||
# NOTIFY_HOST=myapp.threesix.ai
|
||||
# NOTIFY_FROM=noreply@myapp.threesix.ai
|
||||
|
||||
# Database (if needed)
|
||||
# Local dev: PostgreSQL via docker-compose. Production: CockroachDB (platform-provisioned).
|
||||
|
||||
@ -3,6 +3,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
@ -18,17 +19,24 @@ import (
|
||||
"{{GO_MODULE}}/pkg/mediagen"
|
||||
mediagenAdapters "{{GO_MODULE}}/pkg/mediagen/adapters"
|
||||
"{{GO_MODULE}}/pkg/generation"
|
||||
"{{GO_MODULE}}/pkg/notify"
|
||||
"{{GO_MODULE}}/pkg/queue"
|
||||
"{{GO_MODULE}}/pkg/realtime"
|
||||
"{{GO_MODULE}}/pkg/storage"
|
||||
"{{GO_MODULE}}/pkg/textgen"
|
||||
textgenAdapters "{{GO_MODULE}}/pkg/textgen/adapters"
|
||||
emailadapter "{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/adapter/email"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/adapter/memory"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/adapter/postgres"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/api"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/config"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/service"
|
||||
)
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
func main() {
|
||||
// Parse flags
|
||||
exportOpenAPI := flag.Bool("export-openapi", false, "Export OpenAPI spec to stdout and exit")
|
||||
@ -52,17 +60,18 @@ func main() {
|
||||
// Create logger
|
||||
logger := logging.Default()
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create SSE hub for async event delivery (generation progress, chat, etc.)
|
||||
sseHub := realtime.NewSSEHub(logger.Logger)
|
||||
|
||||
// Initialize storage backend (before queue, since standalone queue handlers use it).
|
||||
// GCS_BUCKET set = production (GCS). Otherwise = dev (in-memory).
|
||||
port := fmt.Sprintf("%d", {{PORT}})
|
||||
listenPort := fmt.Sprintf("%d", {{PORT}})
|
||||
var mediaStore storage.Store
|
||||
if bucket := os.Getenv("GCS_BUCKET"); bucket != "" {
|
||||
gcsStore, err := storage.NewGCSStore(bucket, os.Getenv("GCS_SERVICE_ACCOUNT_JSON"), logger.Logger)
|
||||
gcsStore, err := storage.NewGCSStore(ctx, bucket, os.Getenv("GCS_SERVICE_ACCOUNT_JSON"), logger.Logger)
|
||||
if err != nil {
|
||||
logger.Error("failed to create GCS store", "error", err)
|
||||
os.Exit(1)
|
||||
@ -71,29 +80,97 @@ func main() {
|
||||
mediaStore = gcsStore
|
||||
logger.Info("storage initialized (GCS)", "bucket", bucket)
|
||||
} else {
|
||||
memStore := storage.NewMemoryStore("http://localhost:" + port + "/storage")
|
||||
memStore := storage.NewMemoryStore("http://localhost:" + listenPort + "/storage")
|
||||
mediaStore = memStore
|
||||
logger.Info("storage initialized (in-memory dev mode)")
|
||||
}
|
||||
|
||||
// Select queue backend based on DATABASE_URL availability.
|
||||
// With DATABASE_URL: DB queue + separate worker process (production)
|
||||
// Without DATABASE_URL: in-memory queue + in-process handlers (development)
|
||||
// Select backend based on DATABASE_URL availability.
|
||||
// With DATABASE_URL: Postgres repos + DB queue (production)
|
||||
// Without DATABASE_URL: in-memory repos + in-process AI (development)
|
||||
exampleRepo := memory.NewExampleRepository()
|
||||
var userRepo port.UserRepository
|
||||
var sessionRepo port.SessionRepository
|
||||
var authCodeRepo port.AuthCodeRepository
|
||||
var mediaRepo port.MediaRepository
|
||||
var jobQueue queue.Producer
|
||||
var jobReader queue.JobReader
|
||||
|
||||
if cfg.Database.URL != "" {
|
||||
jobQueue = setupDBQueue(ctx, cfg, sseHub, logger)
|
||||
// Connect to database (shared pool for queue + auth repos).
|
||||
dbPool, err := database.Connect(ctx, cfg.Database.URL, database.Options{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: cfg.Database.ConnMaxLifetime,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to connect to database", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Info("connected to database")
|
||||
|
||||
// Verify the database connection is actually alive before proceeding.
|
||||
if err := dbPool.DB.PingContext(ctx); err != nil {
|
||||
logger.Error("database health check failed", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Info("database health check passed")
|
||||
|
||||
// Run auth migrations.
|
||||
if err := database.RunMigrations(ctx, dbPool, migrationsFS, "migrations"); err != nil {
|
||||
logger.Error("failed to run auth migrations", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Info("auth migrations complete")
|
||||
|
||||
// Postgres-backed repositories.
|
||||
userRepo = postgres.NewUserRepository(dbPool.DB)
|
||||
sessionRepo = postgres.NewSessionRepository(dbPool.DB)
|
||||
authCodeRepo = postgres.NewAuthCodeRepository(dbPool.DB)
|
||||
mediaRepo = postgres.NewMediaObjectRepository(dbPool.DB)
|
||||
|
||||
// DB-backed queue.
|
||||
jobQueue, jobReader = setupDBQueue(ctx, cfg, dbPool, sseHub, logger)
|
||||
} else {
|
||||
logger.Info("DATABASE_URL not set — running in standalone mode (in-memory queue + in-process AI)")
|
||||
jobQueue = setupStandaloneQueue(ctx, mediaStore, sseHub, logger)
|
||||
userRepo = memory.NewUserRepository()
|
||||
sessionRepo = memory.NewSessionRepository()
|
||||
authCodeRepo = memory.NewAuthCodeRepository()
|
||||
mediaRepo = memory.NewMediaRepository()
|
||||
jobQueue, jobReader = setupStandaloneQueue(ctx, mediaStore, sseHub, logger)
|
||||
}
|
||||
|
||||
// Create adapters (repositories)
|
||||
exampleRepo := memory.NewExampleRepository()
|
||||
userRepo := memory.NewUserRepository()
|
||||
// Validate required config.
|
||||
if cfg.JWTSecret == "" {
|
||||
logger.Error("JWT_SECRET must be set (even in development)")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create email sender — notify service in production (NOTIFY_URL set), log-only for dev.
|
||||
var emailSender port.EmailSender
|
||||
if cfg.NotifyURL != "" {
|
||||
notifyClient, err := notify.NewClient(notify.Config{
|
||||
URL: cfg.NotifyURL,
|
||||
APIKey: cfg.NotifyAPIKey,
|
||||
Logger: logger.Logger,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to create notify client", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
emailSender = emailadapter.NewNotifySender(notifyClient, cfg.NotifyHost, cfg.NotifyFrom, logger)
|
||||
logger.Info("email sender initialized (notify)", "url", cfg.NotifyURL, "host", cfg.NotifyHost)
|
||||
} else {
|
||||
emailSender = emailadapter.NewLogSender(logger)
|
||||
logger.Info("email sender initialized (log-only dev mode)")
|
||||
}
|
||||
|
||||
// Create services (business logic)
|
||||
exampleService := service.NewExampleService(exampleRepo, logger)
|
||||
authService := service.NewAuthService(userRepo, cfg.JWTSecret, logger)
|
||||
authService := service.NewAuthService(
|
||||
userRepo, sessionRepo, authCodeRepo, emailSender,
|
||||
cfg.JWTSecret, cfg.RegistrationEnabled, logger,
|
||||
)
|
||||
|
||||
// Create application
|
||||
application := app.New("{{COMPONENT_NAME}}", app.WithDefaultPort({{PORT}}))
|
||||
@ -108,29 +185,22 @@ func main() {
|
||||
ExampleService: exampleService,
|
||||
AuthService: authService,
|
||||
Queue: jobQueue,
|
||||
JobReader: jobReader,
|
||||
SSEHub: sseHub,
|
||||
Store: mediaStore,
|
||||
MediaRepo: mediaRepo,
|
||||
})
|
||||
|
||||
// Start background cleanup of expired sessions and auth codes.
|
||||
go runCleanup(ctx, sessionRepo, authCodeRepo, logger)
|
||||
|
||||
// Start server
|
||||
application.Run()
|
||||
}
|
||||
|
||||
// setupDBQueue initializes the production queue backend with database + optional Redis.
|
||||
func setupDBQueue(ctx context.Context, cfg *config.Config, sseHub *realtime.SSEHub, logger *logging.Logger) queue.Producer {
|
||||
pool, err := database.Connect(ctx, cfg.Database.URL, database.Options{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: cfg.Database.ConnMaxLifetime,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to connect to database", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Note: pool is not deferred here since it's needed for the lifetime of the process.
|
||||
// The OS reclaims resources on exit.
|
||||
logger.Info("connected to database")
|
||||
|
||||
// setupDBQueue initializes the production queue backend using the shared database pool + optional Redis.
|
||||
// Returns both Producer (for enqueue) and JobReader (for status polling).
|
||||
func setupDBQueue(ctx context.Context, cfg *config.Config, pool *database.Pool, sseHub *realtime.SSEHub, logger *logging.Logger) (queue.Producer, queue.JobReader) {
|
||||
if err := queue.RunMigrations(ctx, pool); err != nil {
|
||||
logger.Error("failed to run queue migrations", "error", err)
|
||||
os.Exit(1)
|
||||
@ -162,12 +232,13 @@ func setupDBQueue(ctx context.Context, cfg *config.Config, sseHub *realtime.SSEH
|
||||
logger.Warn("REDIS_URL not set — SSE events from worker will not be delivered")
|
||||
}
|
||||
|
||||
return jobQueue
|
||||
return jobQueue, jobQueue
|
||||
}
|
||||
|
||||
// setupStandaloneQueue initializes an in-memory queue with in-process AI handlers.
|
||||
// This mode requires no database or Redis — everything runs in a single process.
|
||||
func setupStandaloneQueue(ctx context.Context, store storage.Store, sseHub *realtime.SSEHub, logger *logging.Logger) queue.Producer {
|
||||
// Returns both Producer (for enqueue) and JobReader (for status polling).
|
||||
func setupStandaloneQueue(ctx context.Context, store storage.Store, sseHub *realtime.SSEHub, logger *logging.Logger) (queue.Producer, queue.JobReader) {
|
||||
memQueue := queue.NewMemoryQueue(logger.Logger)
|
||||
|
||||
// LocalPublisher delivers events directly to the SSE hub (no Redis needed).
|
||||
@ -187,7 +258,7 @@ func setupStandaloneQueue(ctx context.Context, store storage.Store, sseHub *real
|
||||
memQueue.RegisterHandler("ai_chat_response", generation.ChatResponseHandler(textgenManager, pub, logger))
|
||||
}
|
||||
|
||||
return memQueue
|
||||
return memQueue, memQueue
|
||||
}
|
||||
|
||||
// initMediagen creates a mediagen manager from available AI provider credentials.
|
||||
@ -290,3 +361,31 @@ func initTextgen(ctx context.Context, logger *logging.Logger) *textgen.Manager {
|
||||
logger.Info("textgen manager initialized")
|
||||
return mgr
|
||||
}
|
||||
|
||||
// runCleanup periodically removes expired sessions and auth codes.
|
||||
// Runs every hour. Stops when ctx is cancelled.
|
||||
func runCleanup(ctx context.Context, sessions port.SessionRepository, codes port.AuthCodeRepository, logger *logging.Logger) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
sessCount, err := sessions.DeleteExpired(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to cleanup expired sessions", "error", err)
|
||||
} else if sessCount > 0 {
|
||||
logger.Info("cleaned up expired sessions", "count", sessCount)
|
||||
}
|
||||
|
||||
codeCount, err := codes.DeleteExpired(ctx)
|
||||
if err != nil {
|
||||
logger.Warn("failed to cleanup expired auth codes", "error", err)
|
||||
} else if codeCount > 0 {
|
||||
logger.Info("cleaned up expired auth codes", "count", codeCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,79 @@
|
||||
-- 001_create_users.sql
|
||||
-- Auth tables for user management, sessions, and authentication codes.
|
||||
-- Compatible with both PostgreSQL (local dev) and CockroachDB (production).
|
||||
|
||||
-- Core user identity. Email is the primary identifier for humans.
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
email_verified BOOL NOT NULL DEFAULT FALSE,
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
avatar_url TEXT NOT NULL DEFAULT '',
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
last_login_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Password credentials. Separate table because OAuth-only users have no password.
|
||||
CREATE TABLE IF NOT EXISTS user_passwords (
|
||||
user_id TEXT PRIMARY KEY REFERENCES users(id) ON DELETE CASCADE,
|
||||
password_hash TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- OAuth provider connections (Google, GitHub, Apple, etc.).
|
||||
CREATE TABLE IF NOT EXISTS oauth_connections (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
provider TEXT NOT NULL,
|
||||
provider_user_id TEXT NOT NULL,
|
||||
provider_email TEXT NOT NULL DEFAULT '',
|
||||
access_token TEXT NOT NULL DEFAULT '',
|
||||
refresh_token TEXT NOT NULL DEFAULT '',
|
||||
token_expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (provider, provider_user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_connections_user_id ON oauth_connections (user_id);
|
||||
|
||||
-- Verification codes for OTP login, magic links, password reset, and email verification.
|
||||
CREATE TABLE IF NOT EXISTS auth_codes (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT REFERENCES users(id) ON DELETE CASCADE,
|
||||
email TEXT NOT NULL,
|
||||
code TEXT NOT NULL,
|
||||
purpose TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
used_at TIMESTAMPTZ,
|
||||
ip_address TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_codes_email_purpose ON auth_codes (email, purpose, expires_at)
|
||||
WHERE used_at IS NULL;
|
||||
|
||||
-- Sessions track where and when users are logged in.
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
ip_address TEXT NOT NULL DEFAULT '',
|
||||
user_agent TEXT NOT NULL DEFAULT '',
|
||||
device_label TEXT NOT NULL DEFAULT '',
|
||||
last_active_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
revoked_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions (user_id)
|
||||
WHERE revoked_at IS NULL;
|
||||
|
||||
-- User roles. Separate table so users can have multiple roles.
|
||||
CREATE TABLE IF NOT EXISTS user_roles (
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL,
|
||||
granted_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (user_id, role)
|
||||
);
|
||||
@ -0,0 +1,9 @@
|
||||
-- 002_add_indexes.sql
|
||||
-- Additional indexes for query performance.
|
||||
-- Compatible with both PostgreSQL (local dev) and CockroachDB (production).
|
||||
|
||||
-- Speed up OTP/magic-link/reset token lookup in FindValid queries.
|
||||
-- The existing partial index on (email, purpose, expires_at) doesn't cover the code column,
|
||||
-- so queries filtering on code must scan all matching rows.
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_codes_code ON auth_codes (code)
|
||||
WHERE used_at IS NULL;
|
||||
@ -0,0 +1,22 @@
|
||||
-- 003_create_media_objects.sql
|
||||
-- Media metadata table for tracking uploads, generation provenance, and soft deletes.
|
||||
-- Compatible with both PostgreSQL (local dev) and CockroachDB (production).
|
||||
|
||||
CREATE TABLE IF NOT EXISTS media_objects (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL UNIQUE,
|
||||
filename TEXT NOT NULL DEFAULT '',
|
||||
content_type TEXT NOT NULL DEFAULT '',
|
||||
size BIGINT NOT NULL DEFAULT 0,
|
||||
generation_job_id TEXT NOT NULL DEFAULT '',
|
||||
deleted_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_media_objects_user_id ON media_objects (user_id, created_at DESC)
|
||||
WHERE deleted_at IS NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_media_objects_generation_job ON media_objects (generation_job_id)
|
||||
WHERE generation_job_id != '';
|
||||
@ -0,0 +1,33 @@
|
||||
package email
|
||||
|
||||
import "fmt"
|
||||
|
||||
func subjectForPurpose(purpose string) string {
|
||||
switch purpose {
|
||||
case "login_otp":
|
||||
return "Your login code"
|
||||
case "magic_link":
|
||||
return "Your sign-in link"
|
||||
case "password_reset":
|
||||
return "Reset your password"
|
||||
case "email_verify":
|
||||
return "Verify your email"
|
||||
default:
|
||||
return "Your authentication code"
|
||||
}
|
||||
}
|
||||
|
||||
func bodyForPurpose(purpose, code string) string {
|
||||
switch purpose {
|
||||
case "login_otp":
|
||||
return fmt.Sprintf("Your login code is: %s\n\nThis code expires in 10 minutes.", code)
|
||||
case "magic_link":
|
||||
return fmt.Sprintf("Click this link to sign in:\n\n%s\n\nThis link expires in 15 minutes.", code)
|
||||
case "password_reset":
|
||||
return fmt.Sprintf("Use this code to reset your password:\n\n%s\n\nThis code expires in 1 hour.", code)
|
||||
case "email_verify":
|
||||
return fmt.Sprintf("Your verification code is: %s\n\nThis code expires in 24 hours.", code)
|
||||
default:
|
||||
return fmt.Sprintf("Your code is: %s", code)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,32 @@
|
||||
// Package email provides email sending adapters for authentication flows.
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"{{GO_MODULE}}/pkg/logging"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.EmailSender = (*LogSender)(nil)
|
||||
|
||||
// LogSender logs emails to the console instead of sending them.
|
||||
// Useful for development and testing when no notify service is configured.
|
||||
type LogSender struct {
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewLogSender creates a new log-based email sender.
|
||||
func NewLogSender(logger *logging.Logger) *LogSender {
|
||||
return &LogSender{logger: logger.WithComponent("EmailSender")}
|
||||
}
|
||||
|
||||
func (s *LogSender) SendAuthCode(_ context.Context, email, code, purpose string) error {
|
||||
s.logger.Warn("DEV MODE — email not sent, code logged",
|
||||
"to", email,
|
||||
"purpose", purpose,
|
||||
"code", code,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,57 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"{{GO_MODULE}}/pkg/logging"
|
||||
"{{GO_MODULE}}/pkg/notify"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.EmailSender = (*NotifySender)(nil)
|
||||
|
||||
// NotifySender sends emails via the orchard9 notify service.
|
||||
type NotifySender struct {
|
||||
client *notify.Client
|
||||
host string
|
||||
from string
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewNotifySender creates a new notify-backed email sender.
|
||||
func NewNotifySender(client *notify.Client, host, from string, logger *logging.Logger) *NotifySender {
|
||||
return &NotifySender{
|
||||
client: client,
|
||||
host: host,
|
||||
from: from,
|
||||
logger: logger.WithComponent("EmailSender"),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NotifySender) SendAuthCode(ctx context.Context, toEmail, code, purpose string) error {
|
||||
resp, err := s.client.SendEmail(ctx, ¬ify.SendRequest{
|
||||
To: toEmail,
|
||||
From: s.from,
|
||||
Content: notify.Content{
|
||||
Subject: subjectForPurpose(purpose),
|
||||
Text: bodyForPurpose(purpose, code),
|
||||
},
|
||||
Meta: notify.Meta{
|
||||
Host: s.host,
|
||||
Category: "critical",
|
||||
Tags: []string{"auth", purpose},
|
||||
},
|
||||
Options: notify.Options{
|
||||
IdempotencyKey: fmt.Sprintf("auth:%s:%s:%s", toEmail, purpose, code),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
s.logger.Error("failed to send email via notify", "to", toEmail, "purpose", purpose, "error", err)
|
||||
return fmt.Errorf("send email: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("email queued via notify", "to", toEmail, "purpose", purpose, "message_id", resp.MessageID)
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,76 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.AuthCodeRepository = (*AuthCodeRepository)(nil)
|
||||
|
||||
// AuthCodeRepository is an in-memory auth code store for standalone development.
|
||||
type AuthCodeRepository struct {
|
||||
mu sync.RWMutex
|
||||
codes map[string]*domain.AuthCode
|
||||
}
|
||||
|
||||
// NewAuthCodeRepository creates a new in-memory auth code repository.
|
||||
func NewAuthCodeRepository() *AuthCodeRepository {
|
||||
return &AuthCodeRepository{
|
||||
codes: make(map[string]*domain.AuthCode),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) Create(_ context.Context, code *domain.AuthCode) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
cp := *code
|
||||
r.codes[code.ID] = &cp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) FindValid(_ context.Context, email string, code string, purpose domain.AuthCodePurpose) (*domain.AuthCode, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, c := range r.codes {
|
||||
if c.Email == email && c.Code == code && c.Purpose == purpose && c.IsValid() {
|
||||
cp := *c
|
||||
return &cp, nil
|
||||
}
|
||||
}
|
||||
return nil, domain.ErrInvalidAuthCode
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) MarkUsed(_ context.Context, id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
c, ok := r.codes[id]
|
||||
if !ok {
|
||||
return domain.ErrInvalidAuthCode
|
||||
}
|
||||
now := time.Now()
|
||||
c.UsedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) DeleteExpired(_ context.Context) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
deleted := 0
|
||||
for id, c := range r.codes {
|
||||
if now.After(c.ExpiresAt) {
|
||||
delete(r.codes, id)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
@ -0,0 +1,135 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.MediaRepository = (*MediaRepository)(nil)
|
||||
|
||||
// MediaRepository is an in-memory media metadata store for standalone development.
|
||||
type MediaRepository struct {
|
||||
mu sync.RWMutex
|
||||
objects map[domain.MediaObjectID]*domain.MediaObject
|
||||
byPath map[string]domain.MediaObjectID
|
||||
}
|
||||
|
||||
// NewMediaRepository creates a new in-memory media repository.
|
||||
func NewMediaRepository() *MediaRepository {
|
||||
return &MediaRepository{
|
||||
objects: make(map[domain.MediaObjectID]*domain.MediaObject),
|
||||
byPath: make(map[string]domain.MediaObjectID),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MediaRepository) copyObject(obj *domain.MediaObject) *domain.MediaObject {
|
||||
cp := *obj
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (r *MediaRepository) Create(_ context.Context, obj *domain.MediaObject) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.objects[obj.ID] = r.copyObject(obj)
|
||||
r.byPath[obj.Path] = obj.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MediaRepository) Get(_ context.Context, id domain.MediaObjectID) (*domain.MediaObject, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
obj, ok := r.objects[id]
|
||||
if !ok || obj.DeletedAt != nil {
|
||||
return nil, domain.ErrNotFound
|
||||
}
|
||||
return r.copyObject(obj), nil
|
||||
}
|
||||
|
||||
func (r *MediaRepository) ListByUser(_ context.Context, userID domain.UserID, opts port.ListMediaOptions) ([]domain.MediaObject, int, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var all []domain.MediaObject
|
||||
for _, obj := range r.objects {
|
||||
if obj.UserID != userID || obj.DeletedAt != nil {
|
||||
continue
|
||||
}
|
||||
if opts.ContentTypePrefix != "" && !strings.HasPrefix(obj.ContentType, opts.ContentTypePrefix) {
|
||||
continue
|
||||
}
|
||||
all = append(all, *r.copyObject(obj))
|
||||
}
|
||||
|
||||
// Sort by created_at DESC
|
||||
sort.Slice(all, func(i, j int) bool {
|
||||
return all[i].CreatedAt.After(all[j].CreatedAt)
|
||||
})
|
||||
|
||||
total := len(all)
|
||||
|
||||
// Apply pagination
|
||||
limit := opts.Limit
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
offset := opts.Offset
|
||||
if offset > len(all) {
|
||||
offset = len(all)
|
||||
}
|
||||
end := offset + limit
|
||||
if end > len(all) {
|
||||
end = len(all)
|
||||
}
|
||||
|
||||
return all[offset:end], total, nil
|
||||
}
|
||||
|
||||
func (r *MediaRepository) SoftDelete(_ context.Context, id domain.MediaObjectID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
obj, ok := r.objects[id]
|
||||
if !ok {
|
||||
return domain.ErrNotFound
|
||||
}
|
||||
now := time.Now()
|
||||
obj.DeletedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MediaRepository) HardDelete(_ context.Context, id domain.MediaObjectID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
obj, ok := r.objects[id]
|
||||
if !ok {
|
||||
return domain.ErrNotFound
|
||||
}
|
||||
delete(r.byPath, obj.Path)
|
||||
delete(r.objects, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MediaRepository) GetByPath(_ context.Context, path string) (*domain.MediaObject, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
id, ok := r.byPath[path]
|
||||
if !ok {
|
||||
return nil, domain.ErrNotFound
|
||||
}
|
||||
obj, ok := r.objects[id]
|
||||
if !ok || obj.DeletedAt != nil {
|
||||
return nil, domain.ErrNotFound
|
||||
}
|
||||
return r.copyObject(obj), nil
|
||||
}
|
||||
@ -0,0 +1,120 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.SessionRepository = (*SessionRepository)(nil)
|
||||
|
||||
// SessionRepository is an in-memory session store for standalone development.
|
||||
type SessionRepository struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[domain.SessionID]*domain.Session
|
||||
}
|
||||
|
||||
// NewSessionRepository creates a new in-memory session repository.
|
||||
func NewSessionRepository() *SessionRepository {
|
||||
return &SessionRepository{
|
||||
sessions: make(map[domain.SessionID]*domain.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *SessionRepository) copySession(s *domain.Session) *domain.Session {
|
||||
cp := *s
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (r *SessionRepository) Create(_ context.Context, session *domain.Session) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.sessions[session.ID] = r.copySession(session)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) Get(_ context.Context, id domain.SessionID) (*domain.Session, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
s, ok := r.sessions[id]
|
||||
if !ok {
|
||||
return nil, domain.ErrSessionNotFound
|
||||
}
|
||||
return r.copySession(s), nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) ListByUser(_ context.Context, userID domain.UserID) ([]domain.Session, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
var result []domain.Session
|
||||
for _, s := range r.sessions {
|
||||
if s.UserID == userID && s.RevokedAt == nil && s.ExpiresAt.After(now) {
|
||||
result = append(result, *r.copySession(s))
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) UpdateLastActive(_ context.Context, id domain.SessionID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
s, ok := r.sessions[id]
|
||||
if !ok {
|
||||
return domain.ErrSessionNotFound
|
||||
}
|
||||
s.LastActiveAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) Revoke(_ context.Context, id domain.SessionID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
s, ok := r.sessions[id]
|
||||
if !ok {
|
||||
return domain.ErrSessionNotFound
|
||||
}
|
||||
now := time.Now()
|
||||
s.RevokedAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) RevokeAllForUser(_ context.Context, userID domain.UserID, exceptID *domain.SessionID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for _, s := range r.sessions {
|
||||
if s.UserID == userID && s.RevokedAt == nil {
|
||||
if exceptID != nil && s.ID == *exceptID {
|
||||
continue
|
||||
}
|
||||
s.RevokedAt = &now
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) DeleteExpired(_ context.Context) (int, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
deleted := 0
|
||||
for id, s := range r.sessions {
|
||||
if now.After(s.ExpiresAt) {
|
||||
delete(r.sessions, id)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
@ -3,90 +3,233 @@ package memory
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/pkg/auth"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// userEntry stores a user with their password for demo purposes.
|
||||
type userEntry struct {
|
||||
user *auth.User
|
||||
password string
|
||||
}
|
||||
// Compile-time interface check.
|
||||
var _ port.UserRepository = (*UserRepository)(nil)
|
||||
|
||||
// UserRepository is an in-memory user store for demo/testing purposes.
|
||||
// Pre-populated with demo users.
|
||||
// UserRepository is an in-memory user store with bcrypt password hashing.
|
||||
// Pre-populated with demo users for standalone development.
|
||||
type UserRepository struct {
|
||||
mu sync.RWMutex
|
||||
users map[string]*userEntry // keyed by email
|
||||
users map[domain.UserID]*domain.User
|
||||
passwords map[domain.UserID]string // bcrypt hashes
|
||||
roles map[domain.UserID][]string // role lists
|
||||
byEmail map[string]domain.UserID // email → user ID index
|
||||
}
|
||||
|
||||
// NewUserRepository creates a new in-memory user repository with demo users.
|
||||
// NewUserRepository creates a new in-memory user repository seeded with demo users.
|
||||
func NewUserRepository() *UserRepository {
|
||||
repo := &UserRepository{
|
||||
users: make(map[string]*userEntry),
|
||||
users: make(map[domain.UserID]*domain.User),
|
||||
passwords: make(map[domain.UserID]string),
|
||||
roles: make(map[domain.UserID][]string),
|
||||
byEmail: make(map[string]domain.UserID),
|
||||
}
|
||||
|
||||
// Add demo users
|
||||
repo.users["test@example.com"] = &userEntry{
|
||||
user: &auth.User{
|
||||
ID: "usr_test_001",
|
||||
Email: "test@example.com",
|
||||
Roles: []string{"user"},
|
||||
Metadata: map[string]any{
|
||||
"name": "Test User",
|
||||
},
|
||||
},
|
||||
password: "password123",
|
||||
}
|
||||
|
||||
repo.users["admin@example.com"] = &userEntry{
|
||||
user: &auth.User{
|
||||
ID: "usr_admin_001",
|
||||
Email: "admin@example.com",
|
||||
Roles: []string{"admin", "user"},
|
||||
Metadata: map[string]any{
|
||||
"name": "Admin User",
|
||||
},
|
||||
},
|
||||
password: "admin123",
|
||||
}
|
||||
// Seed demo users with bcrypt-hashed passwords.
|
||||
// Passwords meet complexity requirements (min 8 chars, uppercase, lowercase, digit).
|
||||
repo.seedUser("usr_test_001", "test@example.com", "Test User", "Password123", []string{"user"})
|
||||
repo.seedUser("usr_admin_001", "admin@example.com", "Admin User", "Admin1234", []string{"admin", "user"})
|
||||
|
||||
return repo
|
||||
}
|
||||
|
||||
// FindByEmail returns a user by email address.
|
||||
func (r *UserRepository) FindByEmail(ctx context.Context, email string) (*auth.User, error) {
|
||||
func (r *UserRepository) seedUser(id, email, name, password string, userRoles []string) {
|
||||
uid := domain.UserID(id)
|
||||
now := time.Now()
|
||||
|
||||
hash, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
panic("failed to hash seed password: " + err.Error())
|
||||
}
|
||||
|
||||
r.users[uid] = &domain.User{
|
||||
ID: uid,
|
||||
Email: email,
|
||||
EmailVerified: true,
|
||||
Name: name,
|
||||
Status: domain.UserStatusActive,
|
||||
Roles: userRoles,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
r.passwords[uid] = hash
|
||||
r.roles[uid] = userRoles
|
||||
r.byEmail[email] = uid
|
||||
}
|
||||
|
||||
func (r *UserRepository) copyUser(u *domain.User) *domain.User {
|
||||
cp := *u
|
||||
cp.Roles = make([]string, len(u.Roles))
|
||||
copy(cp.Roles, u.Roles)
|
||||
return &cp
|
||||
}
|
||||
|
||||
func (r *UserRepository) Create(_ context.Context, user *domain.User) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.byEmail[user.Email]; exists {
|
||||
return domain.ErrDuplicateEmail
|
||||
}
|
||||
|
||||
r.users[user.ID] = r.copyUser(user)
|
||||
r.byEmail[user.Email] = user.ID
|
||||
r.roles[user.ID] = make([]string, len(user.Roles))
|
||||
copy(r.roles[user.ID], user.Roles)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) Get(_ context.Context, id domain.UserID) (*domain.User, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
entry, ok := r.users[email]
|
||||
u, ok := r.users[id]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
return entry.user, nil
|
||||
return r.copyUser(u), nil
|
||||
}
|
||||
|
||||
// FindByID returns a user by ID.
|
||||
func (r *UserRepository) FindByID(ctx context.Context, id string) (*auth.User, error) {
|
||||
func (r *UserRepository) GetByEmail(_ context.Context, email string) (*domain.User, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, entry := range r.users {
|
||||
if entry.user.ID == id {
|
||||
return entry.user, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ValidatePassword checks if the password matches for a user.
|
||||
func (r *UserRepository) ValidatePassword(ctx context.Context, user *auth.User, password string) bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
entry, ok := r.users[user.Email]
|
||||
uid, ok := r.byEmail[email]
|
||||
if !ok {
|
||||
return false
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
return entry.password == password
|
||||
return r.copyUser(r.users[uid]), nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) Update(_ context.Context, user *domain.User) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
existing, ok := r.users[user.ID]
|
||||
if !ok {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
|
||||
// If email changed, update the index.
|
||||
if existing.Email != user.Email {
|
||||
if _, taken := r.byEmail[user.Email]; taken {
|
||||
return domain.ErrDuplicateEmail
|
||||
}
|
||||
delete(r.byEmail, existing.Email)
|
||||
r.byEmail[user.Email] = user.ID
|
||||
}
|
||||
|
||||
user.UpdatedAt = time.Now()
|
||||
r.users[user.ID] = r.copyUser(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) UpdateLastLogin(_ context.Context, id domain.UserID) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
u, ok := r.users[id]
|
||||
if !ok {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
now := time.Now()
|
||||
u.LastLoginAt = &now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) ExistsByEmail(_ context.Context, email string) (bool, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, ok := r.byEmail[email]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) SetPassword(_ context.Context, userID domain.UserID, hash string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, ok := r.users[userID]; !ok {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
r.passwords[userID] = hash
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetPasswordHash(_ context.Context, userID domain.UserID) (string, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
hash := r.passwords[userID]
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) HasPassword(_ context.Context, userID domain.UserID) (bool, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
_, ok := r.passwords[userID]
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) AddRole(_ context.Context, userID domain.UserID, role string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
u, ok := r.users[userID]
|
||||
if !ok {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
|
||||
for _, existing := range r.roles[userID] {
|
||||
if existing == role {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
r.roles[userID] = append(r.roles[userID], role)
|
||||
u.Roles = make([]string, len(r.roles[userID]))
|
||||
copy(u.Roles, r.roles[userID])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) RemoveRole(_ context.Context, userID domain.UserID, role string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
u, ok := r.users[userID]
|
||||
if !ok {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
|
||||
filtered := make([]string, 0, len(r.roles[userID]))
|
||||
for _, existing := range r.roles[userID] {
|
||||
if existing != role {
|
||||
filtered = append(filtered, existing)
|
||||
}
|
||||
}
|
||||
r.roles[userID] = filtered
|
||||
u.Roles = make([]string, len(filtered))
|
||||
copy(u.Roles, filtered)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetRoles(_ context.Context, userID domain.UserID) ([]string, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if _, ok := r.users[userID]; !ok {
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
|
||||
roles := r.roles[userID]
|
||||
result := make([]string, len(roles))
|
||||
copy(result, roles)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@ -0,0 +1,120 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.AuthCodeRepository = (*AuthCodeRepository)(nil)
|
||||
|
||||
// authCodeRow maps to the auth_codes table.
|
||||
type authCodeRow struct {
|
||||
ID string `db:"id"`
|
||||
UserID *string `db:"user_id"`
|
||||
Email string `db:"email"`
|
||||
Code string `db:"code"`
|
||||
Purpose string `db:"purpose"`
|
||||
ExpiresAt time.Time `db:"expires_at"`
|
||||
UsedAt *time.Time `db:"used_at"`
|
||||
IPAddress string `db:"ip_address"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
func (r *authCodeRow) toDomain() *domain.AuthCode {
|
||||
ac := &domain.AuthCode{
|
||||
ID: r.ID,
|
||||
Email: r.Email,
|
||||
Code: r.Code,
|
||||
Purpose: domain.AuthCodePurpose(r.Purpose),
|
||||
ExpiresAt: r.ExpiresAt,
|
||||
UsedAt: r.UsedAt,
|
||||
IPAddress: r.IPAddress,
|
||||
CreatedAt: r.CreatedAt,
|
||||
}
|
||||
if r.UserID != nil {
|
||||
uid := domain.UserID(*r.UserID)
|
||||
ac.UserID = &uid
|
||||
}
|
||||
return ac
|
||||
}
|
||||
|
||||
// AuthCodeRepository implements port.AuthCodeRepository with PostgreSQL/CockroachDB.
|
||||
type AuthCodeRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewAuthCodeRepository creates a new Postgres-backed auth code repository.
|
||||
func NewAuthCodeRepository(db *sqlx.DB) *AuthCodeRepository {
|
||||
return &AuthCodeRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) Create(ctx context.Context, code *domain.AuthCode) error {
|
||||
var userID *string
|
||||
if code.UserID != nil {
|
||||
s := string(*code.UserID)
|
||||
userID = &s
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO auth_codes (id, user_id, email, code, purpose, expires_at, ip_address, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`, code.ID, userID, code.Email, code.Code, string(code.Purpose),
|
||||
code.ExpiresAt, code.IPAddress, code.CreatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert auth code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) FindValid(ctx context.Context, email string, code string, purpose domain.AuthCodePurpose) (*domain.AuthCode, error) {
|
||||
var row authCodeRow
|
||||
err := r.db.GetContext(ctx, &row, `
|
||||
SELECT id, user_id, email, code, purpose, expires_at, used_at, ip_address, created_at
|
||||
FROM auth_codes
|
||||
WHERE email = $1 AND code = $2 AND purpose = $3
|
||||
AND used_at IS NULL AND expires_at > NOW()
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
`, email, code, string(purpose))
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrInvalidAuthCode
|
||||
}
|
||||
return nil, fmt.Errorf("find valid auth code: %w", err)
|
||||
}
|
||||
return row.toDomain(), nil
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) MarkUsed(ctx context.Context, id string) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE auth_codes SET used_at = NOW() WHERE id = $1
|
||||
`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mark auth code used: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *AuthCodeRepository) DeleteExpired(ctx context.Context) (int, error) {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM auth_codes WHERE expires_at < NOW()
|
||||
`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("delete expired auth codes: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("delete expired rows affected: %w", err)
|
||||
}
|
||||
return int(rows), nil
|
||||
}
|
||||
@ -0,0 +1,184 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.MediaRepository = (*MediaObjectRepository)(nil)
|
||||
|
||||
// mediaObjectRow maps to the media_objects table.
|
||||
type mediaObjectRow struct {
|
||||
ID string `db:"id"`
|
||||
UserID string `db:"user_id"`
|
||||
Path string `db:"path"`
|
||||
Filename string `db:"filename"`
|
||||
ContentType string `db:"content_type"`
|
||||
Size int64 `db:"size"`
|
||||
GenerationJobID string `db:"generation_job_id"`
|
||||
DeletedAt *time.Time `db:"deleted_at"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func (r *mediaObjectRow) toDomain() *domain.MediaObject {
|
||||
return &domain.MediaObject{
|
||||
ID: domain.MediaObjectID(r.ID),
|
||||
UserID: domain.UserID(r.UserID),
|
||||
Path: r.Path,
|
||||
Filename: r.Filename,
|
||||
ContentType: r.ContentType,
|
||||
Size: r.Size,
|
||||
GenerationJobID: r.GenerationJobID,
|
||||
DeletedAt: r.DeletedAt,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// MediaObjectRepository implements port.MediaRepository with PostgreSQL/CockroachDB.
|
||||
type MediaObjectRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewMediaObjectRepository creates a new Postgres-backed media repository.
|
||||
func NewMediaObjectRepository(db *sqlx.DB) *MediaObjectRepository {
|
||||
return &MediaObjectRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *MediaObjectRepository) Create(ctx context.Context, obj *domain.MediaObject) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO media_objects (id, user_id, path, filename, content_type, size, generation_job_id, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`, string(obj.ID), string(obj.UserID), obj.Path, obj.Filename, obj.ContentType,
|
||||
obj.Size, obj.GenerationJobID, obj.CreatedAt, obj.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert media object: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MediaObjectRepository) Get(ctx context.Context, id domain.MediaObjectID) (*domain.MediaObject, error) {
|
||||
var row mediaObjectRow
|
||||
err := r.db.GetContext(ctx, &row, `
|
||||
SELECT id, user_id, path, filename, content_type, size, generation_job_id, deleted_at, created_at, updated_at
|
||||
FROM media_objects WHERE id = $1 AND deleted_at IS NULL
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get media object: %w", err)
|
||||
}
|
||||
return row.toDomain(), nil
|
||||
}
|
||||
|
||||
func (r *MediaObjectRepository) ListByUser(ctx context.Context, userID domain.UserID, opts port.ListMediaOptions) ([]domain.MediaObject, int, error) {
|
||||
limit := opts.Limit
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
// Count total matching records
|
||||
countQuery := `SELECT COUNT(*) FROM media_objects WHERE user_id = $1 AND deleted_at IS NULL`
|
||||
args := []any{string(userID)}
|
||||
argIdx := 2
|
||||
|
||||
if opts.ContentTypePrefix != "" {
|
||||
countQuery += fmt.Sprintf(` AND content_type LIKE $%d`, argIdx)
|
||||
args = append(args, opts.ContentTypePrefix+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
var total int
|
||||
if err := r.db.GetContext(ctx, &total, countQuery, args...); err != nil {
|
||||
return nil, 0, fmt.Errorf("count media objects: %w", err)
|
||||
}
|
||||
|
||||
// Fetch paginated results
|
||||
query := `
|
||||
SELECT id, user_id, path, filename, content_type, size, generation_job_id, deleted_at, created_at, updated_at
|
||||
FROM media_objects
|
||||
WHERE user_id = $1 AND deleted_at IS NULL`
|
||||
|
||||
fetchArgs := []any{string(userID)}
|
||||
fetchIdx := 2
|
||||
|
||||
if opts.ContentTypePrefix != "" {
|
||||
query += fmt.Sprintf(` AND content_type LIKE $%d`, fetchIdx)
|
||||
fetchArgs = append(fetchArgs, opts.ContentTypePrefix+"%")
|
||||
fetchIdx++
|
||||
}
|
||||
|
||||
query += ` ORDER BY created_at DESC`
|
||||
query += fmt.Sprintf(` LIMIT $%d OFFSET $%d`, fetchIdx, fetchIdx+1)
|
||||
fetchArgs = append(fetchArgs, limit, opts.Offset)
|
||||
|
||||
var rows []mediaObjectRow
|
||||
if err := r.db.SelectContext(ctx, &rows, query, fetchArgs...); err != nil {
|
||||
return nil, 0, fmt.Errorf("list media objects: %w", err)
|
||||
}
|
||||
|
||||
objects := make([]domain.MediaObject, len(rows))
|
||||
for i := range rows {
|
||||
objects[i] = *rows[i].toDomain()
|
||||
}
|
||||
return objects, total, nil
|
||||
}
|
||||
|
||||
func (r *MediaObjectRepository) SoftDelete(ctx context.Context, id domain.MediaObjectID) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE media_objects SET deleted_at = NOW(), updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("soft delete media object: %w", err)
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("soft delete rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MediaObjectRepository) HardDelete(ctx context.Context, id domain.MediaObjectID) error {
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM media_objects WHERE id = $1`, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("hard delete media object: %w", err)
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("hard delete rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MediaObjectRepository) GetByPath(ctx context.Context, path string) (*domain.MediaObject, error) {
|
||||
var row mediaObjectRow
|
||||
err := r.db.GetContext(ctx, &row, `
|
||||
SELECT id, user_id, path, filename, content_type, size, generation_job_id, deleted_at, created_at, updated_at
|
||||
FROM media_objects WHERE path = $1 AND deleted_at IS NULL
|
||||
`, path)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get media object by path: %w", err)
|
||||
}
|
||||
return row.toDomain(), nil
|
||||
}
|
||||
@ -0,0 +1,162 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.SessionRepository = (*SessionRepository)(nil)
|
||||
|
||||
// sessionRow maps to the sessions table.
|
||||
type sessionRow struct {
|
||||
ID string `db:"id"`
|
||||
UserID string `db:"user_id"`
|
||||
IPAddress string `db:"ip_address"`
|
||||
UserAgent string `db:"user_agent"`
|
||||
DeviceLabel string `db:"device_label"`
|
||||
LastActiveAt time.Time `db:"last_active_at"`
|
||||
ExpiresAt time.Time `db:"expires_at"`
|
||||
RevokedAt *time.Time `db:"revoked_at"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
func (r *sessionRow) toDomain() *domain.Session {
|
||||
return &domain.Session{
|
||||
ID: domain.SessionID(r.ID),
|
||||
UserID: domain.UserID(r.UserID),
|
||||
IPAddress: r.IPAddress,
|
||||
UserAgent: r.UserAgent,
|
||||
DeviceLabel: r.DeviceLabel,
|
||||
LastActiveAt: r.LastActiveAt,
|
||||
ExpiresAt: r.ExpiresAt,
|
||||
RevokedAt: r.RevokedAt,
|
||||
CreatedAt: r.CreatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// SessionRepository implements port.SessionRepository with PostgreSQL/CockroachDB.
|
||||
type SessionRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewSessionRepository creates a new Postgres-backed session repository.
|
||||
func NewSessionRepository(db *sqlx.DB) *SessionRepository {
|
||||
return &SessionRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *SessionRepository) Create(ctx context.Context, session *domain.Session) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO sessions (id, user_id, ip_address, user_agent, device_label, last_active_at, expires_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`, string(session.ID), string(session.UserID), session.IPAddress, session.UserAgent,
|
||||
session.DeviceLabel, session.LastActiveAt, session.ExpiresAt, session.CreatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) Get(ctx context.Context, id domain.SessionID) (*domain.Session, error) {
|
||||
var row sessionRow
|
||||
err := r.db.GetContext(ctx, &row, `
|
||||
SELECT id, user_id, ip_address, user_agent, device_label, last_active_at, expires_at, revoked_at, created_at
|
||||
FROM sessions WHERE id = $1
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrSessionNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get session: %w", err)
|
||||
}
|
||||
return row.toDomain(), nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) ListByUser(ctx context.Context, userID domain.UserID) ([]domain.Session, error) {
|
||||
var rows []sessionRow
|
||||
err := r.db.SelectContext(ctx, &rows, `
|
||||
SELECT id, user_id, ip_address, user_agent, device_label, last_active_at, expires_at, revoked_at, created_at
|
||||
FROM sessions
|
||||
WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW()
|
||||
ORDER BY last_active_at DESC
|
||||
`, string(userID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list sessions: %w", err)
|
||||
}
|
||||
|
||||
sessions := make([]domain.Session, len(rows))
|
||||
for i := range rows {
|
||||
sessions[i] = *rows[i].toDomain()
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) UpdateLastActive(ctx context.Context, id domain.SessionID) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET last_active_at = NOW() WHERE id = $1
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("update last active: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) Revoke(ctx context.Context, id domain.SessionID) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET revoked_at = NOW() WHERE id = $1 AND revoked_at IS NULL
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke session rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrSessionNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) RevokeAllForUser(ctx context.Context, userID domain.UserID, exceptID *domain.SessionID) error {
|
||||
if exceptID != nil {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET revoked_at = NOW()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL AND id != $2
|
||||
`, string(userID), string(*exceptID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke all sessions except: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET revoked_at = NOW()
|
||||
WHERE user_id = $1 AND revoked_at IS NULL
|
||||
`, string(userID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("revoke all sessions: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SessionRepository) DeleteExpired(ctx context.Context) (int, error) {
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM sessions WHERE expires_at < NOW()`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("delete expired sessions: %w", err)
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("delete expired sessions rows: %w", err)
|
||||
}
|
||||
return int(rows), nil
|
||||
}
|
||||
@ -0,0 +1,260 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Compile-time interface check.
|
||||
var _ port.UserRepository = (*UserRepository)(nil)
|
||||
|
||||
// userRow maps to the users table.
|
||||
type userRow struct {
|
||||
ID string `db:"id"`
|
||||
Email string `db:"email"`
|
||||
EmailVerified bool `db:"email_verified"`
|
||||
Name string `db:"name"`
|
||||
AvatarURL string `db:"avatar_url"`
|
||||
Status string `db:"status"`
|
||||
LastLoginAt *time.Time `db:"last_login_at"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
func (r *userRow) toDomain(roles []string) *domain.User {
|
||||
return &domain.User{
|
||||
ID: domain.UserID(r.ID),
|
||||
Email: r.Email,
|
||||
EmailVerified: r.EmailVerified,
|
||||
Name: r.Name,
|
||||
AvatarURL: r.AvatarURL,
|
||||
Status: domain.UserStatus(r.Status),
|
||||
Roles: roles,
|
||||
LastLoginAt: r.LastLoginAt,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// UserRepository implements port.UserRepository with PostgreSQL/CockroachDB.
|
||||
type UserRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewUserRepository creates a new Postgres-backed user repository.
|
||||
func NewUserRepository(db *sqlx.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO users (id, email, email_verified, name, avatar_url, status, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`, string(user.ID), user.Email, user.EmailVerified, user.Name, user.AvatarURL,
|
||||
string(user.Status), user.CreatedAt, user.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return domain.ErrDuplicateEmail
|
||||
}
|
||||
return fmt.Errorf("insert user: %w", err)
|
||||
}
|
||||
|
||||
// Insert roles
|
||||
for _, role := range user.Roles {
|
||||
if _, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO user_roles (user_id, role) VALUES ($1, $2)
|
||||
ON CONFLICT (user_id, role) DO NOTHING
|
||||
`, string(user.ID), role); err != nil {
|
||||
return fmt.Errorf("insert role: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) Get(ctx context.Context, id domain.UserID) (*domain.User, error) {
|
||||
var row userRow
|
||||
err := r.db.GetContext(ctx, &row, `
|
||||
SELECT id, email, email_verified, name, avatar_url, status, last_login_at, created_at, updated_at
|
||||
FROM users WHERE id = $1
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
roles, err := r.GetRoles(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return row.toDomain(roles), nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
var row userRow
|
||||
err := r.db.GetContext(ctx, &row, `
|
||||
SELECT id, email, email_verified, name, avatar_url, status, last_login_at, created_at, updated_at
|
||||
FROM users WHERE email = $1
|
||||
`, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, domain.ErrUserNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("get user by email: %w", err)
|
||||
}
|
||||
|
||||
roles, err := r.GetRoles(ctx, domain.UserID(row.ID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return row.toDomain(roles), nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) Update(ctx context.Context, user *domain.User) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE users
|
||||
SET email = $2, email_verified = $3, name = $4, avatar_url = $5,
|
||||
status = $6, updated_at = $7
|
||||
WHERE id = $1
|
||||
`, string(user.ID), user.Email, user.EmailVerified, user.Name,
|
||||
user.AvatarURL, string(user.Status), time.Now())
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return domain.ErrDuplicateEmail
|
||||
}
|
||||
return fmt.Errorf("update user: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("update user rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id domain.UserID) error {
|
||||
result, err := r.db.ExecContext(ctx, `
|
||||
UPDATE users SET last_login_at = NOW() WHERE id = $1
|
||||
`, string(id))
|
||||
if err != nil {
|
||||
return fmt.Errorf("update last login: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("update last login rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return domain.ErrUserNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.GetContext(ctx, &exists, `SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)`, email)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("exists by email: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) SetPassword(ctx context.Context, userID domain.UserID, hash string) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO user_passwords (user_id, password_hash, updated_at)
|
||||
VALUES ($1, $2, NOW())
|
||||
ON CONFLICT (user_id) DO UPDATE SET password_hash = $2, updated_at = NOW()
|
||||
`, string(userID), hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set password: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetPasswordHash(ctx context.Context, userID domain.UserID) (string, error) {
|
||||
var hash string
|
||||
err := r.db.GetContext(ctx, &hash, `
|
||||
SELECT password_hash FROM user_passwords WHERE user_id = $1
|
||||
`, string(userID))
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("get password hash: %w", err)
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) HasPassword(ctx context.Context, userID domain.UserID) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.GetContext(ctx, &exists, `
|
||||
SELECT EXISTS(SELECT 1 FROM user_passwords WHERE user_id = $1)
|
||||
`, string(userID))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("has password: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) AddRole(ctx context.Context, userID domain.UserID, role string) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO user_roles (user_id, role) VALUES ($1, $2)
|
||||
ON CONFLICT (user_id, role) DO NOTHING
|
||||
`, string(userID), role)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) RemoveRole(ctx context.Context, userID domain.UserID, role string) error {
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
DELETE FROM user_roles WHERE user_id = $1 AND role = $2
|
||||
`, string(userID), role)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepository) GetRoles(ctx context.Context, userID domain.UserID) ([]string, error) {
|
||||
var roles []string
|
||||
err := r.db.SelectContext(ctx, &roles, `
|
||||
SELECT role FROM user_roles WHERE user_id = $1 ORDER BY role
|
||||
`, string(userID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get roles: %w", err)
|
||||
}
|
||||
if roles == nil {
|
||||
roles = []string{}
|
||||
}
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// isUniqueViolation checks if a database error is a unique constraint violation.
|
||||
// Works with both PostgreSQL (23505) and CockroachDB.
|
||||
func isUniqueViolation(err error) bool {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) {
|
||||
return pqErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
||||
@ -2,13 +2,16 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"{{GO_MODULE}}/pkg/app"
|
||||
"{{GO_MODULE}}/pkg/auth"
|
||||
"{{GO_MODULE}}/pkg/httperror"
|
||||
"{{GO_MODULE}}/pkg/httpresponse"
|
||||
"{{GO_MODULE}}/pkg/logging"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/service"
|
||||
)
|
||||
|
||||
@ -18,7 +21,7 @@ type Auth struct {
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewAuth creates a new Auth handler with injected dependencies.
|
||||
// NewAuth creates a new Auth handler.
|
||||
func NewAuth(svc *service.AuthService, logger *logging.Logger) *Auth {
|
||||
return &Auth{
|
||||
svc: svc,
|
||||
@ -26,13 +29,22 @@ func NewAuth(svc *service.AuthService, logger *logging.Logger) *Auth {
|
||||
}
|
||||
}
|
||||
|
||||
// LoginRequest is the request body for login.
|
||||
// --- Request / Response types ---
|
||||
|
||||
// LoginRequest is the request body for password login.
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Password string `json:"password" validate:"required,min=1"`
|
||||
}
|
||||
|
||||
// LoginResponse is the response for successful login.
|
||||
// RegisterRequest is the request body for registration.
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Password string `json:"password" validate:"required,min=8"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// LoginResponse is the response for successful login or registration.
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
User UserResponse `json:"user"`
|
||||
@ -43,26 +55,51 @@ type UserResponse struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name,omitempty"`
|
||||
AvatarURL string `json:"avatarUrl,omitempty"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
}
|
||||
|
||||
// toUserResponse converts an auth.User to UserResponse.
|
||||
func toUserResponse(u *auth.User) UserResponse {
|
||||
name := ""
|
||||
if u.Metadata != nil {
|
||||
if n, ok := u.Metadata["name"].(string); ok {
|
||||
name = n
|
||||
// UpdateProfileRequest is the request body for updating the user profile.
|
||||
type UpdateProfileRequest struct {
|
||||
Name string `json:"name"`
|
||||
AvatarURL string `json:"avatarUrl"`
|
||||
}
|
||||
|
||||
// ChangePasswordRequest is the request body for changing password.
|
||||
type ChangePasswordRequest struct {
|
||||
CurrentPassword string `json:"currentPassword" validate:"required"`
|
||||
NewPassword string `json:"newPassword" validate:"required,min=8"`
|
||||
}
|
||||
|
||||
// RefreshRequest is the request body for refreshing an access token.
|
||||
type RefreshRequest struct {
|
||||
Token string `json:"token" validate:"required"`
|
||||
}
|
||||
|
||||
// toUserResponse converts a domain.User to UserResponse.
|
||||
func toUserResponse(u *domain.User) UserResponse {
|
||||
return UserResponse{
|
||||
ID: u.ID,
|
||||
ID: string(u.ID),
|
||||
Email: u.Email,
|
||||
Name: name,
|
||||
Name: u.Name,
|
||||
AvatarURL: u.AvatarURL,
|
||||
EmailVerified: u.EmailVerified,
|
||||
Roles: u.Roles,
|
||||
}
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT token.
|
||||
// toLoginResponse creates a LoginResponse from service output.
|
||||
func toLoginResponse(out *service.LoginOutput) LoginResponse {
|
||||
return LoginResponse{
|
||||
Token: out.Token,
|
||||
User: toUserResponse(out.User),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// Login authenticates a user with email and password.
|
||||
//
|
||||
// POST /api/{service}/auth/login
|
||||
func (h *Auth) Login(w http.ResponseWriter, r *http.Request) error {
|
||||
@ -71,21 +108,30 @@ func (h *Auth) Login(w http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
|
||||
output, err := h.svc.Login(r.Context(), service.LoginInput{
|
||||
Email: req.Email,
|
||||
Password: req.Password,
|
||||
})
|
||||
output, err := h.svc.LoginWithPassword(r.Context(), req.Email, req.Password, clientIP(r), r.UserAgent())
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrInvalidCredentials) {
|
||||
return httperror.Unauthorized("invalid email or password")
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, toLoginResponse(output))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register creates a new user account.
|
||||
//
|
||||
// POST /api/{service}/auth/register
|
||||
func (h *Auth) Register(w http.ResponseWriter, r *http.Request) error {
|
||||
var req RegisterRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, LoginResponse{
|
||||
Token: output.Token,
|
||||
User: toUserResponse(output.User),
|
||||
})
|
||||
output, err := h.svc.Register(r.Context(), req.Email, req.Password, req.Name, clientIP(r), r.UserAgent())
|
||||
if err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.Created(w, r, toLoginResponse(output))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -98,30 +144,188 @@ func (h *Auth) Me(w http.ResponseWriter, r *http.Request) error {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
// Optionally refresh user data from repository
|
||||
freshUser, err := h.svc.GetCurrentUser(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrUserNotFound) {
|
||||
return httperror.Unauthorized("user not found")
|
||||
}
|
||||
return err
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, toUserResponse(freshUser))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logout handles user logout.
|
||||
// This is a stateless operation since we use JWTs.
|
||||
// UpdateMe updates the current user's profile.
|
||||
//
|
||||
// POST /api/{service}/auth/logout
|
||||
func (h *Auth) Logout(w http.ResponseWriter, r *http.Request) error {
|
||||
// With JWT-based auth, logout is handled client-side by discarding the token.
|
||||
// This endpoint exists for API completeness and could be extended to:
|
||||
// - Add the token to a blacklist
|
||||
// - Clear server-side sessions if using hybrid auth
|
||||
// - Log the logout event
|
||||
// PUT /api/{service}/auth/me
|
||||
func (h *Auth) UpdateMe(w http.ResponseWriter, r *http.Request) error {
|
||||
user, err := auth.GetUserOrError(r.Context())
|
||||
if err != nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
var req UpdateProfileRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := h.svc.UpdateProfile(r.Context(), user.ID, req.Name, req.AvatarURL)
|
||||
if err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, toUserResponse(updated))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChangePassword changes the current user's password.
|
||||
//
|
||||
// POST /api/{service}/auth/change-password
|
||||
func (h *Auth) ChangePassword(w http.ResponseWriter, r *http.Request) error {
|
||||
user, err := auth.GetUserOrError(r.Context())
|
||||
if err != nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
var req ChangePasswordRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.svc.ChangePassword(r.Context(), user.ID, req.CurrentPassword, req.NewPassword); err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.NoContent(w)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logout revokes the current session.
|
||||
//
|
||||
// POST /api/{service}/auth/logout
|
||||
func (h *Auth) Logout(w http.ResponseWriter, r *http.Request) error {
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil {
|
||||
httpresponse.NoContent(w)
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionID := ""
|
||||
if user.Metadata != nil {
|
||||
if sid, ok := user.Metadata["sid"].(string); ok {
|
||||
sessionID = sid
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.svc.Logout(r.Context(), sessionID); err != nil {
|
||||
h.logger.Warn("logout session revoke failed", "error", err)
|
||||
}
|
||||
|
||||
httpresponse.NoContent(w)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshToken issues a new access token for an active session.
|
||||
//
|
||||
// POST /api/{service}/auth/refresh
|
||||
func (h *Auth) RefreshToken(w http.ResponseWriter, r *http.Request) error {
|
||||
// The caller sends their current (possibly near-expiry) token.
|
||||
// We parse it to get user ID and session ID, then issue a new one.
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
sessionID := ""
|
||||
if user.Metadata != nil {
|
||||
if sid, ok := user.Metadata["sid"].(string); ok {
|
||||
sessionID = sid
|
||||
}
|
||||
}
|
||||
if sessionID == "" {
|
||||
return httperror.Unauthorized("no session")
|
||||
}
|
||||
|
||||
output, err := h.svc.RefreshToken(r.Context(), sessionID, user.ID)
|
||||
if err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, toLoginResponse(output))
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
// mapAuthError translates domain errors to HTTP errors.
|
||||
func mapAuthError(err error) error {
|
||||
switch {
|
||||
case errors.Is(err, domain.ErrInvalidCredentials):
|
||||
return httperror.Unauthorized("invalid email or password")
|
||||
case errors.Is(err, domain.ErrUserNotFound):
|
||||
return httperror.Unauthorized("invalid email or password")
|
||||
case errors.Is(err, domain.ErrUserSuspended):
|
||||
return httperror.Forbidden("account is suspended")
|
||||
case errors.Is(err, domain.ErrDuplicateEmail):
|
||||
return httperror.Conflict("email already registered")
|
||||
case errors.Is(err, domain.ErrWeakPassword):
|
||||
return httperror.BadRequest(err.Error())
|
||||
case errors.Is(err, domain.ErrRegistrationDisabled):
|
||||
return httperror.Forbidden("registration is currently disabled")
|
||||
case errors.Is(err, domain.ErrNameTooLong), errors.Is(err, domain.ErrEmailTooLong):
|
||||
return httperror.BadRequest(err.Error())
|
||||
case errors.Is(err, domain.ErrInvalidAvatarURL):
|
||||
return httperror.BadRequest("avatar URL must use http or https")
|
||||
case errors.Is(err, domain.ErrSessionNotFound):
|
||||
return httperror.NotFound("session not found")
|
||||
case errors.Is(err, domain.ErrSessionRevoked):
|
||||
return httperror.Unauthorized("session has been revoked")
|
||||
case errors.Is(err, domain.ErrInvalidAuthCode):
|
||||
return httperror.Unauthorized("invalid or expired code")
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clientIP extracts the client IP from the request.
|
||||
// It prefers RemoteAddr (set by the Go HTTP server from the TCP connection) and
|
||||
// only uses X-Forwarded-For/X-Real-Ip when the direct connection is from a
|
||||
// private/loopback address, indicating a trusted reverse proxy.
|
||||
func clientIP(r *http.Request) string {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
host = r.RemoteAddr
|
||||
}
|
||||
|
||||
// Only trust proxy headers when the connection is from a private network.
|
||||
if isPrivateIP(host) {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.SplitN(xff, ",", 2)
|
||||
ip := strings.TrimSpace(parts[0])
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
// isPrivateIP returns true if the address is loopback or RFC 1918 private.
|
||||
func isPrivateIP(addr string) bool {
|
||||
ip := net.ParseIP(addr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
return ip.IsLoopback() || ip.IsPrivate()
|
||||
}
|
||||
|
||||
// sessionID extracts the session ID from the authenticated user's metadata.
|
||||
func sessionID(user *auth.User) string {
|
||||
if user == nil || user.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
sid, _ := user.Metadata["sid"].(string)
|
||||
return sid
|
||||
}
|
||||
|
||||
@ -0,0 +1,288 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"{{GO_MODULE}}/pkg/app"
|
||||
"{{GO_MODULE}}/pkg/auth"
|
||||
"{{GO_MODULE}}/pkg/httperror"
|
||||
"{{GO_MODULE}}/pkg/httpresponse"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
)
|
||||
|
||||
// --- Request types for auth flows ---
|
||||
|
||||
// EmailRequest is used by OTP send, magic link, and forgot password.
|
||||
type EmailRequest struct {
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
}
|
||||
|
||||
// OTPVerifyRequest verifies a one-time password.
|
||||
type OTPVerifyRequest struct {
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Code string `json:"code" validate:"required,len=6"`
|
||||
}
|
||||
|
||||
// MagicLinkVerifyRequest verifies a magic link token.
|
||||
type MagicLinkVerifyRequest struct {
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Token string `json:"token" validate:"required"`
|
||||
}
|
||||
|
||||
// ResetPasswordRequest sets a new password using a reset token.
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
Token string `json:"token" validate:"required"`
|
||||
NewPassword string `json:"newPassword" validate:"required,min=8"`
|
||||
}
|
||||
|
||||
// VerifyEmailRequest verifies an email with a code.
|
||||
type VerifyEmailRequest struct {
|
||||
Code string `json:"code" validate:"required,len=6"`
|
||||
}
|
||||
|
||||
// SessionResponse is a single session in the list.
|
||||
type SessionResponse struct {
|
||||
ID string `json:"id"`
|
||||
IPAddress string `json:"ipAddress"`
|
||||
DeviceLabel string `json:"deviceLabel"`
|
||||
LastActiveAt string `json:"lastActiveAt"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
IsCurrent bool `json:"isCurrent"`
|
||||
}
|
||||
|
||||
// --- OTP handlers ---
|
||||
|
||||
// SendOTP sends a one-time password to the user's email.
|
||||
//
|
||||
// POST /api/{service}/auth/otp/send
|
||||
func (h *Auth) SendOTP(w http.ResponseWriter, r *http.Request) error {
|
||||
var req EmailRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.svc.SendOTP(r.Context(), req.Email, clientIP(r)); err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]string{"message": "If an account exists, a code has been sent"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyOTP verifies a one-time password and returns a login token.
|
||||
//
|
||||
// POST /api/{service}/auth/otp/verify
|
||||
func (h *Auth) VerifyOTP(w http.ResponseWriter, r *http.Request) error {
|
||||
var req OTPVerifyRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output, err := h.svc.VerifyOTP(r.Context(), req.Email, req.Code, clientIP(r), r.UserAgent())
|
||||
if err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, toLoginResponse(output))
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Magic Link handlers ---
|
||||
|
||||
// SendMagicLink sends a magic link to the user's email.
|
||||
//
|
||||
// POST /api/{service}/auth/magic-link
|
||||
func (h *Auth) SendMagicLink(w http.ResponseWriter, r *http.Request) error {
|
||||
var req EmailRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.svc.SendMagicLink(r.Context(), req.Email, clientIP(r)); err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]string{"message": "If an account exists, a link has been sent"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyMagicLink verifies a magic link token and returns a login token.
|
||||
//
|
||||
// POST /api/{service}/auth/magic-link/verify
|
||||
func (h *Auth) VerifyMagicLink(w http.ResponseWriter, r *http.Request) error {
|
||||
var req MagicLinkVerifyRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output, err := h.svc.VerifyMagicLink(r.Context(), req.Email, req.Token, clientIP(r), r.UserAgent())
|
||||
if err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, toLoginResponse(output))
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Forgot / Reset Password handlers ---
|
||||
|
||||
// ForgotPassword sends a password reset token.
|
||||
//
|
||||
// POST /api/{service}/auth/forgot-password
|
||||
func (h *Auth) ForgotPassword(w http.ResponseWriter, r *http.Request) error {
|
||||
var req EmailRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.svc.ForgotPassword(r.Context(), req.Email, clientIP(r)); err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]string{"message": "If an account exists, a reset link has been sent"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetPassword sets a new password using a reset token.
|
||||
//
|
||||
// POST /api/{service}/auth/reset-password
|
||||
func (h *Auth) ResetPassword(w http.ResponseWriter, r *http.Request) error {
|
||||
var req ResetPasswordRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.svc.ResetPassword(r.Context(), req.Email, req.Token, req.NewPassword); err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]string{"message": "Password has been reset. Please sign in."})
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Email Verification handlers ---
|
||||
|
||||
// SendVerifyEmail sends a verification code to the current user's email.
|
||||
//
|
||||
// POST /api/{service}/auth/verify-email/send
|
||||
func (h *Auth) SendVerifyEmail(w http.ResponseWriter, r *http.Request) error {
|
||||
user, err := auth.GetUserOrError(r.Context())
|
||||
if err != nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
if err := h.svc.SendVerifyEmail(r.Context(), user.ID); err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]string{"message": "Verification code sent"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyEmail verifies the current user's email with a code.
|
||||
//
|
||||
// POST /api/{service}/auth/verify-email
|
||||
func (h *Auth) VerifyEmail(w http.ResponseWriter, r *http.Request) error {
|
||||
user, err := auth.GetUserOrError(r.Context())
|
||||
if err != nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
var req VerifyEmailRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.svc.VerifyEmail(r.Context(), user.ID, req.Code); err != nil {
|
||||
return mapAuthError(err)
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]string{"message": "Email verified"})
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Session Management handlers ---
|
||||
|
||||
// ListSessions returns all active sessions for the current user.
|
||||
//
|
||||
// GET /api/{service}/auth/sessions
|
||||
func (h *Auth) ListSessions(w http.ResponseWriter, r *http.Request) error {
|
||||
user, err := auth.GetUserOrError(r.Context())
|
||||
if err != nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
currentSID := sessionID(user)
|
||||
|
||||
sessions, err := h.svc.ListSessions(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := make([]SessionResponse, 0, len(sessions))
|
||||
for _, s := range sessions {
|
||||
result = append(result, SessionResponse{
|
||||
ID: string(s.ID),
|
||||
IPAddress: s.IPAddress,
|
||||
DeviceLabel: s.DeviceLabel,
|
||||
LastActiveAt: s.LastActiveAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
CreatedAt: s.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
|
||||
IsCurrent: string(s.ID) == currentSID,
|
||||
})
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, result)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeSession revokes a specific session.
|
||||
//
|
||||
// DELETE /api/{service}/auth/sessions/{id}
|
||||
func (h *Auth) RevokeSession(w http.ResponseWriter, r *http.Request) error {
|
||||
user, err := auth.GetUserOrError(r.Context())
|
||||
if err != nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
sid := chi.URLParam(r, "id")
|
||||
if sid == "" {
|
||||
return httperror.BadRequest("session id required")
|
||||
}
|
||||
|
||||
if err := h.svc.RevokeSession(r.Context(), user.ID, sid); err != nil {
|
||||
if errors.Is(err, domain.ErrSessionNotFound) {
|
||||
return httperror.NotFound("session not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
httpresponse.NoContent(w)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllSessions revokes all sessions except the current one.
|
||||
//
|
||||
// DELETE /api/{service}/auth/sessions
|
||||
func (h *Auth) RevokeAllSessions(w http.ResponseWriter, r *http.Request) error {
|
||||
user, err := auth.GetUserOrError(r.Context())
|
||||
if err != nil {
|
||||
return httperror.Unauthorized("not authenticated")
|
||||
}
|
||||
|
||||
currentSID := sessionID(user)
|
||||
var except *string
|
||||
if currentSID != "" {
|
||||
except = ¤tSID
|
||||
}
|
||||
|
||||
if err := h.svc.LogoutAll(r.Context(), user.ID, except); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpresponse.NoContent(w)
|
||||
return nil
|
||||
}
|
||||
@ -10,21 +10,26 @@ import (
|
||||
"{{GO_MODULE}}/pkg/logging"
|
||||
"{{GO_MODULE}}/pkg/queue"
|
||||
"{{GO_MODULE}}/pkg/realtime"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
// Generate handles HTTP requests for AI generation endpoints.
|
||||
// All generation is async: validate request, enqueue job, return 202 with job ID.
|
||||
// The worker processes jobs and sends results via Redis → SSE.
|
||||
// Job status can be polled via GET /generate/jobs/{id} as a fallback to SSE.
|
||||
type Generate struct {
|
||||
queue queue.Producer
|
||||
jobReader queue.JobReader
|
||||
sseHub *realtime.SSEHub
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewGenerate creates a new Generate handler with injected dependencies.
|
||||
func NewGenerate(q queue.Producer, hub *realtime.SSEHub, logger *logging.Logger) *Generate {
|
||||
func NewGenerate(q queue.Producer, jr queue.JobReader, hub *realtime.SSEHub, logger *logging.Logger) *Generate {
|
||||
return &Generate{
|
||||
queue: q,
|
||||
jobReader: jr,
|
||||
sseHub: hub,
|
||||
logger: logger.WithComponent("GenerateHandler"),
|
||||
}
|
||||
@ -177,6 +182,47 @@ func (h *Generate) GenerateText(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Job status (poll fallback for SSE)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// GetJobStatus returns the current status of a generation job.
|
||||
// This is a poll-based fallback for clients that can't use SSE.
|
||||
func (h *Generate) GetJobStatus(w http.ResponseWriter, r *http.Request) error {
|
||||
jobID := chi.URLParam(r, "id")
|
||||
if jobID == "" {
|
||||
return httperror.BadRequest("job ID is required")
|
||||
}
|
||||
|
||||
job, err := h.jobReader.GetJob(r.Context(), jobID)
|
||||
if err != nil {
|
||||
if err == queue.ErrJobNotFound {
|
||||
return httperror.NotFound("job not found")
|
||||
}
|
||||
h.logger.Error("failed to get job status", "error", err, "job_id", jobID)
|
||||
return httperror.Internal("failed to get job status")
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"id": job.ID,
|
||||
"type": job.Type,
|
||||
"status": string(job.Status),
|
||||
"createdAt": job.CreatedAt,
|
||||
}
|
||||
if job.StartedAt != nil {
|
||||
resp["startedAt"] = job.StartedAt
|
||||
}
|
||||
if job.CompletedAt != nil {
|
||||
resp["completedAt"] = job.CompletedAt
|
||||
}
|
||||
if job.Error != "" {
|
||||
resp["error"] = job.Error
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, resp)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SSE Events endpoint
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/pkg/app"
|
||||
"{{GO_MODULE}}/pkg/auth"
|
||||
@ -11,20 +14,43 @@ import (
|
||||
"{{GO_MODULE}}/pkg/httpresponse"
|
||||
"{{GO_MODULE}}/pkg/logging"
|
||||
"{{GO_MODULE}}/pkg/storage"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// maxUploadSize is the maximum allowed file size for uploads (500MB).
|
||||
const maxUploadSize = 500 << 20
|
||||
|
||||
// allowedMediaTypes is the allowlist of MIME types permitted for upload.
|
||||
var allowedMediaTypes = map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/png": true,
|
||||
"image/gif": true,
|
||||
"image/webp": true,
|
||||
"image/svg+xml": true,
|
||||
"video/mp4": true,
|
||||
"video/webm": true,
|
||||
"video/quicktime": true,
|
||||
"audio/mpeg": true,
|
||||
"audio/wav": true,
|
||||
"audio/ogg": true,
|
||||
"audio/webm": true,
|
||||
"application/pdf": true,
|
||||
}
|
||||
|
||||
// Media handles media upload and library operations.
|
||||
type Media struct {
|
||||
store storage.Store
|
||||
repo port.MediaRepository
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewMedia creates a new media handler.
|
||||
func NewMedia(store storage.Store, logger *logging.Logger) *Media {
|
||||
return &Media{store: store, logger: logger.WithComponent("MediaHandler")}
|
||||
func NewMedia(store storage.Store, repo port.MediaRepository, logger *logging.Logger) *Media {
|
||||
return &Media{store: store, repo: repo, logger: logger.WithComponent("MediaHandler")}
|
||||
}
|
||||
|
||||
// Routes returns the media subrouter.
|
||||
@ -33,31 +59,63 @@ func (h *Media) Routes() http.Handler {
|
||||
r.Post("/upload/init", app.Wrap(h.InitUpload))
|
||||
r.Post("/upload/complete", app.Wrap(h.CompleteUpload))
|
||||
r.Get("/", app.Wrap(h.List))
|
||||
r.Delete("/*", app.Wrap(h.Delete))
|
||||
r.Get("/{id}", app.Wrap(h.GetOne))
|
||||
r.Get("/{id}/url", app.Wrap(h.RefreshURL))
|
||||
r.Delete("/{id}", app.Wrap(h.Delete))
|
||||
return r
|
||||
}
|
||||
|
||||
// sanitizeFilename removes path separators and dangerous characters from filenames.
|
||||
func sanitizeFilename(name string) string {
|
||||
// Remove any directory components
|
||||
name = filepath.Base(name)
|
||||
// Replace any remaining path separators (e.g., from URL encoding)
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
name = strings.ReplaceAll(name, "\\", "_")
|
||||
name = strings.ReplaceAll(name, "..", "_")
|
||||
// Remove null bytes
|
||||
name = strings.ReplaceAll(name, "\x00", "")
|
||||
if name == "" || name == "." {
|
||||
name = "unnamed"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// initUploadRequest is the request body for POST /media/upload/init.
|
||||
type initUploadRequest struct {
|
||||
Filename string `json:"filename" validate:"required"`
|
||||
ContentType string `json:"contentType" validate:"required"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// InitUpload returns a presigned URL for direct client-to-storage upload.
|
||||
// The metadata record is created in CompleteUpload after the file is actually stored.
|
||||
func (h *Media) InitUpload(w http.ResponseWriter, r *http.Request) error {
|
||||
var req initUploadRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user := auth.GetUser(r.Context())
|
||||
userID := "anonymous"
|
||||
if user != nil {
|
||||
userID = user.ID
|
||||
// Validate MIME type against allowlist
|
||||
if !allowedMediaTypes[req.ContentType] {
|
||||
return httperror.BadRequest("unsupported file type: " + req.ContentType)
|
||||
}
|
||||
|
||||
// Validate file size if provided
|
||||
if req.Size > maxUploadSize {
|
||||
return httperror.BadRequest(fmt.Sprintf("file too large: %d bytes (max %d)", req.Size, maxUploadSize))
|
||||
}
|
||||
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil {
|
||||
return httperror.Unauthorized("authentication required")
|
||||
}
|
||||
|
||||
// Sanitize filename to prevent path traversal
|
||||
safeName := sanitizeFilename(req.Filename)
|
||||
|
||||
// Build object path: media/{userID}/{uuid}/{filename}
|
||||
objectPath := fmt.Sprintf("media/%s/%s/%s", userID, uuid.New().String(), req.Filename)
|
||||
objectPath := fmt.Sprintf("media/%s/%s/%s", user.ID, uuid.New().String(), safeName)
|
||||
|
||||
presigned, err := h.store.UploadPresigned(r.Context(), objectPath, req.ContentType)
|
||||
if err != nil {
|
||||
@ -68,6 +126,7 @@ func (h *Media) InitUpload(w http.ResponseWriter, r *http.Request) error {
|
||||
httpresponse.OK(w, r, map[string]any{
|
||||
"uploadURL": presigned.URL,
|
||||
"objectPath": objectPath,
|
||||
"filename": safeName,
|
||||
"headers": presigned.Headers,
|
||||
"method": presigned.Method,
|
||||
"expires": presigned.Expires,
|
||||
@ -78,84 +137,236 @@ func (h *Media) InitUpload(w http.ResponseWriter, r *http.Request) error {
|
||||
// completeUploadRequest is the request body for POST /media/upload/complete.
|
||||
type completeUploadRequest struct {
|
||||
ObjectPath string `json:"objectPath" validate:"required"`
|
||||
Filename string `json:"filename"`
|
||||
ContentType string `json:"contentType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// CompleteUpload confirms an upload is done and returns the final URL.
|
||||
// CompleteUpload confirms an upload is done, creates the metadata record, and returns the final URL.
|
||||
func (h *Media) CompleteUpload(w http.ResponseWriter, r *http.Request) error {
|
||||
var req completeUploadRequest
|
||||
if err := app.BindAndValidate(r, &req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify the object path belongs to the authenticated user
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil {
|
||||
return httperror.Unauthorized("authentication required")
|
||||
}
|
||||
expectedPrefix := fmt.Sprintf("media/%s/", user.ID)
|
||||
if !strings.HasPrefix(req.ObjectPath, expectedPrefix) {
|
||||
return httperror.Forbidden("cannot complete upload for another user's media")
|
||||
}
|
||||
|
||||
url, err := h.store.GetURL(r.Context(), req.ObjectPath)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get object URL", "error", err, "path", req.ObjectPath)
|
||||
return httperror.Internal("failed to confirm upload")
|
||||
}
|
||||
|
||||
// Create the metadata record now that the file is in storage.
|
||||
now := time.Now()
|
||||
filename := sanitizeFilename(req.Filename)
|
||||
if filename == "unnamed" {
|
||||
// Extract filename from the object path (last segment)
|
||||
parts := strings.Split(req.ObjectPath, "/")
|
||||
if len(parts) > 0 {
|
||||
filename = parts[len(parts)-1]
|
||||
}
|
||||
}
|
||||
|
||||
mediaObj := &domain.MediaObject{
|
||||
ID: domain.MediaObjectID("med_" + uuid.New().String()),
|
||||
UserID: domain.UserID(user.ID),
|
||||
Path: req.ObjectPath,
|
||||
Filename: filename,
|
||||
ContentType: req.ContentType,
|
||||
Size: req.Size,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if err := h.repo.Create(r.Context(), mediaObj); err != nil {
|
||||
h.logger.Error("failed to create media record", "error", err)
|
||||
return httperror.Internal("failed to create upload record")
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]any{
|
||||
"id": string(mediaObj.ID),
|
||||
"url": url,
|
||||
"path": req.ObjectPath,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// List returns the user's media objects.
|
||||
// List returns the user's media objects with pagination.
|
||||
func (h *Media) List(w http.ResponseWriter, r *http.Request) error {
|
||||
user := auth.GetUser(r.Context())
|
||||
userID := "anonymous"
|
||||
if user != nil {
|
||||
userID = user.ID
|
||||
if user == nil {
|
||||
return httperror.Unauthorized("authentication required")
|
||||
}
|
||||
|
||||
prefix := fmt.Sprintf("media/%s/", userID)
|
||||
|
||||
// Allow filtering by sub-prefix (e.g., ?prefix=images)
|
||||
if subPrefix := r.URL.Query().Get("prefix"); subPrefix != "" {
|
||||
prefix = fmt.Sprintf("media/%s/%s", userID, subPrefix)
|
||||
opts := port.ListMediaOptions{
|
||||
ContentTypePrefix: r.URL.Query().Get("type"),
|
||||
Limit: intQueryParam(r, "limit", 50),
|
||||
Offset: intQueryParam(r, "offset", 0),
|
||||
}
|
||||
|
||||
objects, err := h.store.List(r.Context(), prefix)
|
||||
objects, total, err := h.repo.ListByUser(r.Context(), domain.UserID(user.ID), opts)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to list media", "error", err)
|
||||
return httperror.Internal("failed to list media")
|
||||
}
|
||||
|
||||
if objects == nil {
|
||||
objects = []storage.MediaObject{}
|
||||
// Enrich each object with a fresh signed URL
|
||||
type mediaItem struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
URL string `json:"url"`
|
||||
Filename string `json:"filename"`
|
||||
ContentType string `json:"contentType"`
|
||||
Size int64 `json:"size"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
items := make([]mediaItem, 0, len(objects))
|
||||
for _, obj := range objects {
|
||||
url, urlErr := h.store.GetURL(r.Context(), obj.Path)
|
||||
if urlErr != nil {
|
||||
h.logger.Warn("failed to get URL for media object", "path", obj.Path, "error", urlErr)
|
||||
continue
|
||||
}
|
||||
items = append(items, mediaItem{
|
||||
ID: string(obj.ID),
|
||||
Path: obj.Path,
|
||||
URL: url,
|
||||
Filename: obj.Filename,
|
||||
ContentType: obj.ContentType,
|
||||
Size: obj.Size,
|
||||
CreatedAt: obj.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]any{
|
||||
"items": objects,
|
||||
"count": len(objects),
|
||||
"items": items,
|
||||
"total": total,
|
||||
"count": len(items),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a media object.
|
||||
// Users can only delete objects under their own media/{userID}/ prefix.
|
||||
func (h *Media) Delete(w http.ResponseWriter, r *http.Request) error {
|
||||
// Extract path from URL (everything after /media/)
|
||||
path := strings.TrimPrefix(r.URL.Path, "/")
|
||||
if path == "" {
|
||||
return httperror.BadRequest("path is required")
|
||||
// GetOne returns a single media object with a fresh URL.
|
||||
func (h *Media) GetOne(w http.ResponseWriter, r *http.Request) error {
|
||||
id := chi.URLParam(r, "id")
|
||||
if id == "" {
|
||||
return httperror.BadRequest("media ID is required")
|
||||
}
|
||||
|
||||
// Verify the path belongs to the authenticated user
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil {
|
||||
return httperror.Unauthorized("authentication required")
|
||||
obj, err := h.repo.Get(r.Context(), domain.MediaObjectID(id))
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotFound) {
|
||||
return httperror.NotFound("media object not found")
|
||||
}
|
||||
expectedPrefix := fmt.Sprintf("media/%s/", user.ID)
|
||||
if !strings.HasPrefix(path, expectedPrefix) {
|
||||
return httperror.Internal("failed to get media object")
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil || domain.UserID(user.ID) != obj.UserID {
|
||||
return httperror.Forbidden("access denied")
|
||||
}
|
||||
|
||||
url, err := h.store.GetURL(r.Context(), obj.Path)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to get URL", "error", err, "path", obj.Path)
|
||||
return httperror.Internal("failed to get media URL")
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]any{
|
||||
"id": string(obj.ID),
|
||||
"path": obj.Path,
|
||||
"url": url,
|
||||
"filename": obj.Filename,
|
||||
"contentType": obj.ContentType,
|
||||
"size": obj.Size,
|
||||
"createdAt": obj.CreatedAt,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshURL returns a fresh signed URL for a media object.
|
||||
func (h *Media) RefreshURL(w http.ResponseWriter, r *http.Request) error {
|
||||
id := chi.URLParam(r, "id")
|
||||
if id == "" {
|
||||
return httperror.BadRequest("media ID is required")
|
||||
}
|
||||
|
||||
obj, err := h.repo.Get(r.Context(), domain.MediaObjectID(id))
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotFound) {
|
||||
return httperror.NotFound("media object not found")
|
||||
}
|
||||
return httperror.Internal("failed to get media object")
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil || domain.UserID(user.ID) != obj.UserID {
|
||||
return httperror.Forbidden("access denied")
|
||||
}
|
||||
|
||||
url, err := h.store.GetURL(r.Context(), obj.Path)
|
||||
if err != nil {
|
||||
h.logger.Error("failed to refresh URL", "error", err, "path", obj.Path)
|
||||
return httperror.Internal("failed to refresh media URL")
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]any{
|
||||
"id": string(obj.ID),
|
||||
"url": url,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete soft-deletes a media object.
|
||||
func (h *Media) Delete(w http.ResponseWriter, r *http.Request) error {
|
||||
id := chi.URLParam(r, "id")
|
||||
if id == "" {
|
||||
return httperror.BadRequest("media ID is required")
|
||||
}
|
||||
|
||||
obj, err := h.repo.Get(r.Context(), domain.MediaObjectID(id))
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrNotFound) {
|
||||
return httperror.NotFound("media object not found")
|
||||
}
|
||||
return httperror.Internal("failed to get media object")
|
||||
}
|
||||
|
||||
// Verify ownership
|
||||
user := auth.GetUser(r.Context())
|
||||
if user == nil || domain.UserID(user.ID) != obj.UserID {
|
||||
return httperror.Forbidden("cannot delete another user's media")
|
||||
}
|
||||
|
||||
if err := h.store.Delete(r.Context(), path); err != nil {
|
||||
h.logger.Error("failed to delete media", "error", err, "path", path)
|
||||
if err := h.repo.SoftDelete(r.Context(), domain.MediaObjectID(id)); err != nil {
|
||||
h.logger.Error("failed to delete media", "error", err, "id", id)
|
||||
return httperror.Internal("failed to delete media")
|
||||
}
|
||||
|
||||
httpresponse.OK(w, r, map[string]any{"deleted": path})
|
||||
httpresponse.OK(w, r, map[string]any{"deleted": id})
|
||||
return nil
|
||||
}
|
||||
|
||||
// intQueryParam parses an integer query parameter with a default value.
|
||||
func intQueryParam(r *http.Request, key string, defaultVal int) int {
|
||||
val := r.URL.Query().Get(key)
|
||||
if val == "" {
|
||||
return defaultVal
|
||||
}
|
||||
var n int
|
||||
if _, err := fmt.Sscanf(val, "%d", &n); err != nil || n < 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
@ -2,13 +2,17 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/pkg/app"
|
||||
"{{GO_MODULE}}/pkg/auth"
|
||||
"{{GO_MODULE}}/pkg/middleware"
|
||||
"{{GO_MODULE}}/pkg/queue"
|
||||
"{{GO_MODULE}}/pkg/realtime"
|
||||
"{{GO_MODULE}}/pkg/storage"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/api/handlers"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/config"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/service"
|
||||
)
|
||||
|
||||
@ -26,9 +30,9 @@ func RegisterRoutes(application *app.App, deps *Dependencies) {
|
||||
healthHandler := handlers.NewHealth(logger)
|
||||
exampleHandler := handlers.NewExample(deps.ExampleService, logger)
|
||||
authHandler := handlers.NewAuth(deps.AuthService, logger)
|
||||
generateHandler := handlers.NewGenerate(deps.Queue, deps.SSEHub, logger)
|
||||
generateHandler := handlers.NewGenerate(deps.Queue, deps.JobReader, deps.SSEHub, logger)
|
||||
chatHandler := handlers.NewChat(deps.Queue, deps.SSEHub, logger)
|
||||
mediaHandler := handlers.NewMedia(deps.Store, logger)
|
||||
mediaHandler := handlers.NewMedia(deps.Store, deps.MediaRepo, logger)
|
||||
|
||||
// Build and mount OpenAPI spec
|
||||
spec := NewServiceSpec()
|
||||
@ -45,17 +49,56 @@ func RegisterRoutes(application *app.App, deps *Dependencies) {
|
||||
application.Route("/api/{{COMPONENT_NAME}}", func(r app.Router) {
|
||||
r.Get("/health", healthHandler.Check)
|
||||
|
||||
// ----- Auth routes -----
|
||||
// Public auth routes
|
||||
r.Post("/auth/login", app.Wrap(authHandler.Login))
|
||||
r.Post("/auth/logout", app.Wrap(authHandler.Logout))
|
||||
// ----- Public auth routes (rate-limited) -----
|
||||
// Auth attempts: 20/min per IP (login, register, verify, reset).
|
||||
authAttemptLimit := middleware.RateLimit(middleware.RateLimitConfig{Requests: 20, Window: time.Minute})
|
||||
// Code sends: 5/min per IP (prevents email bombing via OTP/magic-link/forgot-password).
|
||||
codeSendLimit := middleware.RateLimit(middleware.RateLimitConfig{Requests: 5, Window: time.Minute})
|
||||
|
||||
// Protected auth routes
|
||||
r.Group(func(r app.Router) {
|
||||
r.Use(authAttemptLimit)
|
||||
r.Post("/auth/login", app.Wrap(authHandler.Login))
|
||||
r.Post("/auth/register", app.Wrap(authHandler.Register))
|
||||
r.Post("/auth/otp/verify", app.Wrap(authHandler.VerifyOTP))
|
||||
r.Post("/auth/magic-link/verify", app.Wrap(authHandler.VerifyMagicLink))
|
||||
r.Post("/auth/reset-password", app.Wrap(authHandler.ResetPassword))
|
||||
})
|
||||
r.Group(func(r app.Router) {
|
||||
r.Use(codeSendLimit)
|
||||
r.Post("/auth/otp/send", app.Wrap(authHandler.SendOTP))
|
||||
r.Post("/auth/magic-link", app.Wrap(authHandler.SendMagicLink))
|
||||
r.Post("/auth/forgot-password", app.Wrap(authHandler.ForgotPassword))
|
||||
})
|
||||
|
||||
// Refresh accepts expired tokens (still validates signature).
|
||||
// The service layer checks session validity to prevent abuse.
|
||||
r.Group(func(r app.Router) {
|
||||
r.Use(auth.Middleware(auth.MiddlewareConfig{
|
||||
Validator: jwtValidator,
|
||||
AllowExpired: true,
|
||||
}))
|
||||
r.Post("/auth/refresh", app.Wrap(authHandler.RefreshToken))
|
||||
})
|
||||
|
||||
// Session checker for revocation enforcement.
|
||||
sessionChecker := deps.AuthService.CheckSession
|
||||
|
||||
// ----- Protected auth routes -----
|
||||
r.Group(func(r app.Router) {
|
||||
r.Use(auth.Middleware(auth.MiddlewareConfig{
|
||||
Validator: jwtValidator,
|
||||
}))
|
||||
r.Use(auth.SessionCheck(sessionChecker))
|
||||
|
||||
r.Get("/auth/me", app.Wrap(authHandler.Me))
|
||||
r.Put("/auth/me", app.Wrap(authHandler.UpdateMe))
|
||||
r.Post("/auth/change-password", app.Wrap(authHandler.ChangePassword))
|
||||
r.Post("/auth/logout", app.Wrap(authHandler.Logout))
|
||||
r.Post("/auth/verify-email/send", app.Wrap(authHandler.SendVerifyEmail))
|
||||
r.Post("/auth/verify-email", app.Wrap(authHandler.VerifyEmail))
|
||||
r.Get("/auth/sessions", app.Wrap(authHandler.ListSessions))
|
||||
r.Delete("/auth/sessions", app.Wrap(authHandler.RevokeAllSessions))
|
||||
r.Delete("/auth/sessions/{id}", app.Wrap(authHandler.RevokeSession))
|
||||
})
|
||||
|
||||
// ----- SSE Events -----
|
||||
@ -87,6 +130,7 @@ func RegisterRoutes(application *app.App, deps *Dependencies) {
|
||||
r.Use(auth.Middleware(auth.MiddlewareConfig{
|
||||
Validator: jwtValidator,
|
||||
}))
|
||||
r.Use(auth.SessionCheck(sessionChecker))
|
||||
|
||||
// Chat messaging
|
||||
r.Post("/chat/messages", app.Wrap(chatHandler.SendMessage))
|
||||
@ -95,6 +139,7 @@ func RegisterRoutes(application *app.App, deps *Dependencies) {
|
||||
r.Post("/generate/image", app.Wrap(generateHandler.GenerateImage))
|
||||
r.Post("/generate/video", app.Wrap(generateHandler.GenerateVideo))
|
||||
r.Post("/generate/text", app.Wrap(generateHandler.GenerateText))
|
||||
r.Get("/generate/jobs/{id}", app.Wrap(generateHandler.GetJobStatus))
|
||||
|
||||
// Media library (upload, list, delete)
|
||||
r.Mount("/media", mediaHandler.Routes())
|
||||
@ -107,6 +152,8 @@ type Dependencies struct {
|
||||
ExampleService *service.ExampleService
|
||||
AuthService *service.AuthService
|
||||
Queue queue.Producer
|
||||
JobReader queue.JobReader
|
||||
SSEHub *realtime.SSEHub
|
||||
Store storage.Store
|
||||
MediaRepo port.MediaRepository
|
||||
}
|
||||
|
||||
@ -18,13 +18,26 @@ type Config struct {
|
||||
// Auth
|
||||
AuthEnabled bool
|
||||
JWTSecret string
|
||||
RegistrationEnabled bool
|
||||
|
||||
// Redis for cross-process SSE event delivery
|
||||
RedisURL string
|
||||
|
||||
// Notify service for email delivery (OTP, magic links, password reset, etc.)
|
||||
// When NotifyURL is empty, emails are logged to stdout (dev mode).
|
||||
NotifyURL string
|
||||
NotifyAPIKey string
|
||||
NotifyHost string
|
||||
NotifyFrom string
|
||||
}
|
||||
|
||||
// Load reads configuration from environment variables.
|
||||
func Load() *Config {
|
||||
regEnabled := true
|
||||
if v := os.Getenv("REGISTRATION_ENABLED"); v != "" {
|
||||
regEnabled = strings.EqualFold(v, "true")
|
||||
}
|
||||
|
||||
return &Config{
|
||||
AppConfig: config.ReadAppConfig(),
|
||||
Server: config.ReadServerConfig(),
|
||||
@ -33,6 +46,19 @@ func Load() *Config {
|
||||
|
||||
AuthEnabled: strings.EqualFold(os.Getenv("AUTH_ENABLED"), "true"),
|
||||
JWTSecret: os.Getenv("JWT_SECRET"),
|
||||
RegistrationEnabled: regEnabled,
|
||||
RedisURL: os.Getenv("REDIS_URL"),
|
||||
|
||||
NotifyURL: os.Getenv("NOTIFY_URL"),
|
||||
NotifyAPIKey: os.Getenv("NOTIFY_API_KEY"),
|
||||
NotifyHost: os.Getenv("NOTIFY_HOST"),
|
||||
NotifyFrom: getEnvDefault("NOTIFY_FROM", "noreply@{{PROJECT_NAME}}.com"),
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvDefault(key, defaultVal string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
@ -0,0 +1,32 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// AuthCodePurpose identifies what an auth code is used for.
|
||||
type AuthCodePurpose string
|
||||
|
||||
const (
|
||||
PurposeLoginOTP AuthCodePurpose = "login_otp"
|
||||
PurposeMagicLink AuthCodePurpose = "magic_link"
|
||||
PurposePasswordReset AuthCodePurpose = "password_reset"
|
||||
PurposeEmailVerify AuthCodePurpose = "email_verify"
|
||||
)
|
||||
|
||||
// AuthCode is a single-use, time-limited code for authentication flows.
|
||||
// Used by OTP login, magic links, password reset, and email verification.
|
||||
type AuthCode struct {
|
||||
ID string
|
||||
UserID *UserID // Nullable for magic link signup
|
||||
Email string
|
||||
Code string
|
||||
Purpose AuthCodePurpose
|
||||
ExpiresAt time.Time
|
||||
UsedAt *time.Time
|
||||
IPAddress string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// IsValid returns true if the code has not been used and has not expired.
|
||||
func (c *AuthCode) IsValid() bool {
|
||||
return c.UsedAt == nil && time.Now().Before(c.ExpiresAt)
|
||||
}
|
||||
@ -18,4 +18,19 @@ var (
|
||||
|
||||
// ErrInvalidExampleName indicates the example name is invalid.
|
||||
ErrInvalidExampleName = errors.New("invalid example name")
|
||||
|
||||
// Auth errors
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrDuplicateEmail = errors.New("email already registered")
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
ErrSessionRevoked = errors.New("session has been revoked")
|
||||
ErrInvalidAuthCode = errors.New("invalid or expired code")
|
||||
ErrExpiredAuthCode = errors.New("code has expired")
|
||||
ErrWeakPassword = errors.New("password does not meet requirements")
|
||||
ErrUserSuspended = errors.New("account is suspended")
|
||||
ErrRegistrationDisabled = errors.New("registration is disabled")
|
||||
ErrNameTooLong = errors.New("name exceeds maximum length")
|
||||
ErrEmailTooLong = errors.New("email exceeds maximum length")
|
||||
ErrInvalidAvatarURL = errors.New("avatar URL must use http or https")
|
||||
)
|
||||
|
||||
@ -0,0 +1,27 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// MediaObjectID is a typed media object identifier with prefix "med_".
|
||||
type MediaObjectID string
|
||||
|
||||
// MediaObject tracks a stored media file with ownership and metadata.
|
||||
// The actual file is stored in GCS (production) or MemoryStore (dev).
|
||||
// This record enables querying, soft deletes, and provenance tracking.
|
||||
type MediaObject struct {
|
||||
ID MediaObjectID
|
||||
UserID UserID
|
||||
Path string // Storage path (e.g., "media/usr_123/uuid/photo.png")
|
||||
Filename string // Original filename
|
||||
ContentType string
|
||||
Size int64
|
||||
GenerationJobID string // Non-empty if created by AI generation
|
||||
DeletedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// IsDeleted returns true if the media object has been soft-deleted.
|
||||
func (m *MediaObject) IsDeleted() bool {
|
||||
return m.DeletedAt != nil
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// SessionID is a typed session identifier with prefix "ses_".
|
||||
type SessionID string
|
||||
|
||||
// Session tracks a user login with device and location information.
|
||||
// The session ID is embedded in the JWT token for revocation support.
|
||||
type Session struct {
|
||||
ID SessionID
|
||||
UserID UserID
|
||||
IPAddress string
|
||||
UserAgent string
|
||||
DeviceLabel string
|
||||
LastActiveAt time.Time
|
||||
ExpiresAt time.Time
|
||||
RevokedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// IsActive returns true if the session has not been revoked and has not expired.
|
||||
func (s *Session) IsActive() bool {
|
||||
return s.RevokedAt == nil && time.Now().Before(s.ExpiresAt)
|
||||
}
|
||||
@ -0,0 +1,52 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
// UserID is a typed user identifier with prefix "usr_".
|
||||
type UserID string
|
||||
|
||||
// UserStatus represents the account state.
|
||||
type UserStatus string
|
||||
|
||||
const (
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusSuspended UserStatus = "suspended"
|
||||
UserStatusDeactivated UserStatus = "deactivated"
|
||||
)
|
||||
|
||||
// User is the full domain model for a registered user.
|
||||
// This is the database-backed identity, separate from auth.User which is the
|
||||
// lightweight JWT-derived identity carried in request context.
|
||||
type User struct {
|
||||
ID UserID
|
||||
Email string
|
||||
EmailVerified bool
|
||||
Name string
|
||||
AvatarURL string
|
||||
Status UserStatus
|
||||
Roles []string
|
||||
LastLoginAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// Validation constants for user fields.
|
||||
const (
|
||||
MaxNameLen = 100
|
||||
MaxEmailLen = 254 // RFC 5321
|
||||
)
|
||||
|
||||
// NewUser creates a new user with default values.
|
||||
func NewUser(id UserID, email, name string) *User {
|
||||
now := time.Now()
|
||||
return &User{
|
||||
ID: id,
|
||||
Email: email,
|
||||
EmailVerified: false,
|
||||
Name: name,
|
||||
Status: UserStatusActive,
|
||||
Roles: []string{"user"},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,24 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
)
|
||||
|
||||
// AuthCodeRepository defines the interface for auth code persistence.
|
||||
type AuthCodeRepository interface {
|
||||
// Create persists a new auth code.
|
||||
Create(ctx context.Context, code *domain.AuthCode) error
|
||||
|
||||
// FindValid returns an unused, non-expired code matching the criteria.
|
||||
// Returns domain.ErrInvalidAuthCode if no valid code exists.
|
||||
FindValid(ctx context.Context, email string, code string, purpose domain.AuthCodePurpose) (*domain.AuthCode, error)
|
||||
|
||||
// MarkUsed sets the used_at timestamp on a code, making it single-use.
|
||||
MarkUsed(ctx context.Context, id string) error
|
||||
|
||||
// DeleteExpired removes codes that have passed their expiry time.
|
||||
// Returns the number of codes deleted.
|
||||
DeleteExpired(ctx context.Context) (int, error)
|
||||
}
|
||||
@ -0,0 +1,11 @@
|
||||
package port
|
||||
|
||||
import "context"
|
||||
|
||||
// EmailSender sends emails for authentication flows (OTP, magic link, password reset, etc.).
|
||||
type EmailSender interface {
|
||||
// SendAuthCode sends an authentication code to the given email.
|
||||
// purpose identifies the flow (e.g. "login_otp", "magic_link", "password_reset", "email_verify").
|
||||
// code is the token or OTP to include in the email.
|
||||
SendAuthCode(ctx context.Context, email, code, purpose string) error
|
||||
}
|
||||
@ -0,0 +1,38 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
)
|
||||
|
||||
// MediaRepository defines the interface for media metadata persistence.
|
||||
type MediaRepository interface {
|
||||
// Create persists a new media object record.
|
||||
Create(ctx context.Context, obj *domain.MediaObject) error
|
||||
|
||||
// Get returns a media object by ID. Returns domain.ErrNotFound if not found or soft-deleted.
|
||||
Get(ctx context.Context, id domain.MediaObjectID) (*domain.MediaObject, error)
|
||||
|
||||
// ListByUser returns non-deleted media objects for a user, ordered by created_at DESC.
|
||||
ListByUser(ctx context.Context, userID domain.UserID, opts ListMediaOptions) ([]domain.MediaObject, int, error)
|
||||
|
||||
// SoftDelete marks a media object as deleted without removing it.
|
||||
SoftDelete(ctx context.Context, id domain.MediaObjectID) error
|
||||
|
||||
// HardDelete permanently removes a media object record.
|
||||
HardDelete(ctx context.Context, id domain.MediaObjectID) error
|
||||
|
||||
// GetByPath returns a media object by its storage path. Returns domain.ErrNotFound if not found.
|
||||
GetByPath(ctx context.Context, path string) (*domain.MediaObject, error)
|
||||
}
|
||||
|
||||
// ListMediaOptions controls filtering and pagination for media queries.
|
||||
type ListMediaOptions struct {
|
||||
// ContentTypePrefix filters by MIME type prefix (e.g., "image/", "video/").
|
||||
ContentTypePrefix string
|
||||
// Limit is the maximum number of results (0 = default 50).
|
||||
Limit int
|
||||
// Offset is the pagination offset.
|
||||
Offset int
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
)
|
||||
|
||||
// SessionRepository defines the interface for session persistence.
|
||||
type SessionRepository interface {
|
||||
// Create persists a new session record.
|
||||
Create(ctx context.Context, session *domain.Session) error
|
||||
|
||||
// Get returns a session by ID. Returns domain.ErrSessionNotFound if not found.
|
||||
Get(ctx context.Context, id domain.SessionID) (*domain.Session, error)
|
||||
|
||||
// ListByUser returns all active (non-revoked) sessions for a user.
|
||||
ListByUser(ctx context.Context, userID domain.UserID) ([]domain.Session, error)
|
||||
|
||||
// UpdateLastActive updates the last_active_at timestamp for a session.
|
||||
UpdateLastActive(ctx context.Context, id domain.SessionID) error
|
||||
|
||||
// Revoke marks a session as revoked by setting revoked_at.
|
||||
Revoke(ctx context.Context, id domain.SessionID) error
|
||||
|
||||
// RevokeAllForUser revokes all sessions for a user.
|
||||
// If exceptID is non-nil, that session is kept active.
|
||||
RevokeAllForUser(ctx context.Context, userID domain.UserID, exceptID *domain.SessionID) error
|
||||
|
||||
// DeleteExpired removes sessions that have passed their expiry time.
|
||||
// Returns the number of sessions deleted.
|
||||
DeleteExpired(ctx context.Context) (int, error)
|
||||
}
|
||||
@ -3,21 +3,49 @@ package port
|
||||
import (
|
||||
"context"
|
||||
|
||||
"{{GO_MODULE}}/pkg/auth"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
)
|
||||
|
||||
// UserRepository defines the interface for user lookup operations.
|
||||
// Used by AuthService for authentication.
|
||||
// UserRepository defines the interface for user persistence.
|
||||
type UserRepository interface {
|
||||
// FindByEmail returns a user by email address.
|
||||
// Returns nil if not found (no error).
|
||||
FindByEmail(ctx context.Context, email string) (*auth.User, error)
|
||||
// Create persists a new user.
|
||||
Create(ctx context.Context, user *domain.User) error
|
||||
|
||||
// FindByID returns a user by ID.
|
||||
// Returns nil if not found (no error).
|
||||
FindByID(ctx context.Context, id string) (*auth.User, error)
|
||||
// Get returns a user by ID. Returns domain.ErrUserNotFound if not found.
|
||||
Get(ctx context.Context, id domain.UserID) (*domain.User, error)
|
||||
|
||||
// ValidatePassword checks if the password matches for a user.
|
||||
// Returns true if valid, false otherwise.
|
||||
ValidatePassword(ctx context.Context, user *auth.User, password string) bool
|
||||
// GetByEmail returns a user by email. Returns domain.ErrUserNotFound if not found.
|
||||
GetByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
|
||||
// Update persists changes to an existing user.
|
||||
Update(ctx context.Context, user *domain.User) error
|
||||
|
||||
// UpdateLastLogin sets the last_login_at timestamp.
|
||||
UpdateLastLogin(ctx context.Context, id domain.UserID) error
|
||||
|
||||
// ExistsByEmail returns true if a user with the given email exists.
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
|
||||
// Password operations (separate from user CRUD because OAuth-only users have no password)
|
||||
|
||||
// SetPassword stores a bcrypt hash for a user. Creates or replaces existing.
|
||||
SetPassword(ctx context.Context, userID domain.UserID, hash string) error
|
||||
|
||||
// GetPasswordHash returns the bcrypt hash for a user.
|
||||
// Returns empty string and nil error if user has no password set.
|
||||
GetPasswordHash(ctx context.Context, userID domain.UserID) (string, error)
|
||||
|
||||
// HasPassword returns true if the user has a password set.
|
||||
HasPassword(ctx context.Context, userID domain.UserID) (bool, error)
|
||||
|
||||
// Role operations
|
||||
|
||||
// AddRole grants a role to a user. No-op if already granted.
|
||||
AddRole(ctx context.Context, userID domain.UserID, role string) error
|
||||
|
||||
// RemoveRole revokes a role from a user. No-op if not granted.
|
||||
RemoveRole(ctx context.Context, userID domain.UserID, role string) error
|
||||
|
||||
// GetRoles returns all roles for a user.
|
||||
GetRoles(ctx context.Context, userID domain.UserID) ([]string, error)
|
||||
}
|
||||
|
||||
@ -2,96 +2,603 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/pkg/auth"
|
||||
"{{GO_MODULE}}/pkg/logging"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/domain"
|
||||
"{{GO_MODULE}}/services/{{COMPONENT_NAME}}/internal/port"
|
||||
)
|
||||
|
||||
// Auth errors.
|
||||
var (
|
||||
ErrInvalidCredentials = errors.New("invalid email or password")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
const (
|
||||
// TokenLifetime is the access token duration (short-lived, requires refresh).
|
||||
TokenLifetime = 15 * time.Minute
|
||||
// SessionLifetime is how long a session stays valid before requiring re-login.
|
||||
SessionLifetime = 30 * 24 * time.Hour // 30 days
|
||||
// OTPExpiry is how long a one-time password is valid.
|
||||
OTPExpiry = 10 * time.Minute
|
||||
// MagicLinkExpiry is how long a magic link token is valid.
|
||||
MagicLinkExpiry = 15 * time.Minute
|
||||
// PasswordResetExpiry is how long a password reset token is valid.
|
||||
PasswordResetExpiry = 1 * time.Hour
|
||||
// EmailVerifyExpiry is how long an email verification code is valid.
|
||||
EmailVerifyExpiry = 24 * time.Hour
|
||||
)
|
||||
|
||||
// AuthService handles authentication logic.
|
||||
// AuthService handles all authentication and identity flows.
|
||||
type AuthService struct {
|
||||
userRepo port.UserRepository
|
||||
users port.UserRepository
|
||||
sessions port.SessionRepository
|
||||
codes port.AuthCodeRepository
|
||||
email port.EmailSender
|
||||
jwtSecret []byte
|
||||
issuer string
|
||||
registrationEnabled bool
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewAuthService creates a new auth service.
|
||||
func NewAuthService(userRepo port.UserRepository, jwtSecret string, logger *logging.Logger) *AuthService {
|
||||
func NewAuthService(
|
||||
users port.UserRepository,
|
||||
sessions port.SessionRepository,
|
||||
codes port.AuthCodeRepository,
|
||||
email port.EmailSender,
|
||||
jwtSecret string,
|
||||
registrationEnabled bool,
|
||||
logger *logging.Logger,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
users: users,
|
||||
sessions: sessions,
|
||||
codes: codes,
|
||||
email: email,
|
||||
jwtSecret: []byte(jwtSecret),
|
||||
issuer: "{{PROJECT_NAME}}",
|
||||
registrationEnabled: registrationEnabled,
|
||||
logger: logger.WithService("AuthService"),
|
||||
}
|
||||
}
|
||||
|
||||
// LoginInput contains the data needed to log in.
|
||||
type LoginInput struct {
|
||||
Email string
|
||||
Password string
|
||||
}
|
||||
|
||||
// LoginOutput contains the login result.
|
||||
// LoginOutput is the result of a successful login or registration.
|
||||
type LoginOutput struct {
|
||||
Token string
|
||||
User *auth.User
|
||||
User *domain.User
|
||||
}
|
||||
|
||||
// Login authenticates a user and returns a JWT token.
|
||||
func (s *AuthService) Login(ctx context.Context, input LoginInput) (*LoginOutput, error) {
|
||||
// Find user by email
|
||||
user, err := s.userRepo.FindByEmail(ctx, input.Email)
|
||||
// Register creates a new user account with email and password.
|
||||
func (s *AuthService) Register(ctx context.Context, email, password, name, ip, userAgent string) (*LoginOutput, error) {
|
||||
if !s.registrationEnabled {
|
||||
return nil, domain.ErrRegistrationDisabled
|
||||
}
|
||||
|
||||
if err := auth.ValidatePasswordStrength(password); err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", domain.ErrWeakPassword, err)
|
||||
}
|
||||
|
||||
name = strings.TrimSpace(name)
|
||||
if len(name) > domain.MaxNameLen {
|
||||
return nil, domain.ErrNameTooLong
|
||||
}
|
||||
if len(email) > domain.MaxEmailLen {
|
||||
return nil, domain.ErrEmailTooLong
|
||||
}
|
||||
|
||||
exists, err := s.users.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
s.logger.Warn("login attempt for unknown email", "email", input.Email)
|
||||
return nil, ErrInvalidCredentials
|
||||
if exists {
|
||||
return nil, domain.ErrDuplicateEmail
|
||||
}
|
||||
|
||||
// Validate password
|
||||
if !s.userRepo.ValidatePassword(ctx, user, input.Password) {
|
||||
s.logger.Warn("invalid password attempt", "email", input.Email)
|
||||
return nil, ErrInvalidCredentials
|
||||
hash, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hashing password: %w", err)
|
||||
}
|
||||
|
||||
// Generate JWT token
|
||||
token, err := auth.GenerateTokenWithIssuer(
|
||||
s.jwtSecret,
|
||||
user,
|
||||
24*time.Hour, // 24 hour expiration
|
||||
s.issuer,
|
||||
s.issuer, // audience = issuer for simplicity
|
||||
)
|
||||
userID := domain.UserID("usr_" + generateID())
|
||||
user := domain.NewUser(userID, email, name)
|
||||
|
||||
if err := s.users.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.users.SetPassword(ctx, userID, hash); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.logger.Info("user registered", "user_id", string(userID), "email", email)
|
||||
|
||||
return s.createSession(ctx, user, ip, userAgent)
|
||||
}
|
||||
|
||||
// LoginWithPassword authenticates a user with email and password.
|
||||
func (s *AuthService) LoginWithPassword(ctx context.Context, email, password, ip, userAgent string) (*LoginOutput, error) {
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
return nil, domain.ErrInvalidCredentials
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.Status == domain.UserStatusSuspended {
|
||||
return nil, domain.ErrUserSuspended
|
||||
}
|
||||
|
||||
hash, err := s.users.GetPasswordHash(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hash == "" || !auth.CheckPassword(password, hash) {
|
||||
s.logger.Warn("invalid password attempt", "email", email)
|
||||
return nil, domain.ErrInvalidCredentials
|
||||
}
|
||||
|
||||
_ = s.users.UpdateLastLogin(ctx, user.ID)
|
||||
s.logger.Info("user logged in", "user_id", string(user.ID), "email", email)
|
||||
|
||||
return s.createSession(ctx, user, ip, userAgent)
|
||||
}
|
||||
|
||||
// RefreshToken issues a new access token if the session is still active.
|
||||
func (s *AuthService) RefreshToken(ctx context.Context, sessionID string, userID string) (*LoginOutput, error) {
|
||||
sid := domain.SessionID(sessionID)
|
||||
session, err := s.sessions.Get(ctx, sid)
|
||||
if err != nil {
|
||||
return nil, domain.ErrSessionNotFound
|
||||
}
|
||||
if !session.IsActive() {
|
||||
return nil, domain.ErrSessionRevoked
|
||||
}
|
||||
|
||||
user, err := s.users.Get(ctx, domain.UserID(userID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user.Status == domain.UserStatusSuspended {
|
||||
return nil, domain.ErrUserSuspended
|
||||
}
|
||||
|
||||
_ = s.sessions.UpdateLastActive(ctx, sid)
|
||||
|
||||
token, err := s.generateToken(user, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.logger.Info("user logged in", "user_id", user.ID, "email", user.Email)
|
||||
|
||||
return &LoginOutput{
|
||||
Token: token,
|
||||
User: user,
|
||||
}, nil
|
||||
return &LoginOutput{Token: token, User: user}, nil
|
||||
}
|
||||
|
||||
// GetCurrentUser returns the user for the given ID.
|
||||
func (s *AuthService) GetCurrentUser(ctx context.Context, userID string) (*auth.User, error) {
|
||||
user, err := s.userRepo.FindByID(ctx, userID)
|
||||
// Logout revokes the current session.
|
||||
func (s *AuthService) Logout(ctx context.Context, sessionID string) error {
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
return s.sessions.Revoke(ctx, domain.SessionID(sessionID))
|
||||
}
|
||||
|
||||
// LogoutAll revokes all sessions for a user, optionally keeping one.
|
||||
func (s *AuthService) LogoutAll(ctx context.Context, userID string, exceptSessionID *string) error {
|
||||
var except *domain.SessionID
|
||||
if exceptSessionID != nil {
|
||||
sid := domain.SessionID(*exceptSessionID)
|
||||
except = &sid
|
||||
}
|
||||
return s.sessions.RevokeAllForUser(ctx, domain.UserID(userID), except)
|
||||
}
|
||||
|
||||
// CheckSession returns whether a session is active (not revoked, not expired).
|
||||
// Used as auth.SessionChecker for the SessionCheck middleware.
|
||||
func (s *AuthService) CheckSession(ctx context.Context, sessionID string) (bool, error) {
|
||||
session, err := s.sessions.Get(ctx, domain.SessionID(sessionID))
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
return session.IsActive(), nil
|
||||
}
|
||||
|
||||
// ListSessions returns all active sessions for a user.
|
||||
func (s *AuthService) ListSessions(ctx context.Context, userID string) ([]domain.Session, error) {
|
||||
return s.sessions.ListByUser(ctx, domain.UserID(userID))
|
||||
}
|
||||
|
||||
// RevokeSession revokes a specific session for a user.
|
||||
func (s *AuthService) RevokeSession(ctx context.Context, userID, sessionID string) error {
|
||||
session, err := s.sessions.Get(ctx, domain.SessionID(sessionID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if session.UserID != domain.UserID(userID) {
|
||||
return domain.ErrSessionNotFound
|
||||
}
|
||||
return s.sessions.Revoke(ctx, domain.SessionID(sessionID))
|
||||
}
|
||||
|
||||
// GetCurrentUser returns the full user for the given ID.
|
||||
func (s *AuthService) GetCurrentUser(ctx context.Context, userID string) (*domain.User, error) {
|
||||
return s.users.Get(ctx, domain.UserID(userID))
|
||||
}
|
||||
|
||||
// UpdateProfile updates a user's name and avatar.
|
||||
func (s *AuthService) UpdateProfile(ctx context.Context, userID, name, avatarURL string) (*domain.User, error) {
|
||||
user, err := s.users.Get(ctx, domain.UserID(userID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, ErrUserNotFound
|
||||
|
||||
if name != "" {
|
||||
name = strings.TrimSpace(name)
|
||||
if len(name) > domain.MaxNameLen {
|
||||
return nil, domain.ErrNameTooLong
|
||||
}
|
||||
user.Name = name
|
||||
}
|
||||
if avatarURL != "" {
|
||||
if err := validateAvatarURL(avatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user.AvatarURL = avatarURL
|
||||
}
|
||||
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ChangePassword changes a user's password after verifying the current one.
|
||||
func (s *AuthService) ChangePassword(ctx context.Context, userID, currentPassword, newPassword string) error {
|
||||
uid := domain.UserID(userID)
|
||||
|
||||
hash, err := s.users.GetPasswordHash(ctx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hash == "" || !auth.CheckPassword(currentPassword, hash) {
|
||||
return domain.ErrInvalidCredentials
|
||||
}
|
||||
|
||||
if err := auth.ValidatePasswordStrength(newPassword); err != nil {
|
||||
return fmt.Errorf("%w: %w", domain.ErrWeakPassword, err)
|
||||
}
|
||||
|
||||
newHash, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hashing password: %w", err)
|
||||
}
|
||||
|
||||
return s.users.SetPassword(ctx, uid, newHash)
|
||||
}
|
||||
|
||||
// SendOTP generates and logs a one-time password for the given email.
|
||||
// In production, this would send an email. In dev mode, the code is logged to stdout.
|
||||
func (s *AuthService) SendOTP(ctx context.Context, email, ip string) error {
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
// Don't reveal whether email exists
|
||||
s.logger.Info("OTP requested for unknown email", "email", email)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code := generateOTP()
|
||||
uid := user.ID
|
||||
authCode := &domain.AuthCode{
|
||||
ID: "acd_" + generateID(),
|
||||
UserID: &uid,
|
||||
Email: email,
|
||||
Code: code,
|
||||
Purpose: domain.PurposeLoginOTP,
|
||||
ExpiresAt: time.Now().Add(OTPExpiry),
|
||||
IPAddress: ip,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.codes.Create(ctx, authCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("auth code created", "purpose", "login_otp", "email", email, "code_id", authCode.ID)
|
||||
if err := s.email.SendAuthCode(ctx, email, code, string(domain.PurposeLoginOTP)); err != nil {
|
||||
s.logger.Error("failed to send OTP email", "email", email, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyOTP verifies a one-time password and returns a login token.
|
||||
func (s *AuthService) VerifyOTP(ctx context.Context, email, code, ip, userAgent string) (*LoginOutput, error) {
|
||||
authCode, err := s.codes.FindValid(ctx, email, code, domain.PurposeLoginOTP)
|
||||
if err != nil {
|
||||
return nil, domain.ErrInvalidAuthCode
|
||||
}
|
||||
|
||||
if err := s.codes.MarkUsed(ctx, authCode.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = s.users.UpdateLastLogin(ctx, user.ID)
|
||||
s.logger.Info("user logged in via OTP", "user_id", string(user.ID), "email", email)
|
||||
|
||||
return s.createSession(ctx, user, ip, userAgent)
|
||||
}
|
||||
|
||||
// SendMagicLink generates and logs a magic link token.
|
||||
func (s *AuthService) SendMagicLink(ctx context.Context, email, ip string) error {
|
||||
// Magic links can work for existing users.
|
||||
// Don't reveal whether email exists — but propagate infrastructure errors.
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if err != nil && !errors.Is(err, domain.ErrUserNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
token := generateHexToken()
|
||||
var uid *domain.UserID
|
||||
if user != nil {
|
||||
uid = &user.ID
|
||||
}
|
||||
|
||||
authCode := &domain.AuthCode{
|
||||
ID: "acd_" + generateID(),
|
||||
UserID: uid,
|
||||
Email: email,
|
||||
Code: token,
|
||||
Purpose: domain.PurposeMagicLink,
|
||||
ExpiresAt: time.Now().Add(MagicLinkExpiry),
|
||||
IPAddress: ip,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.codes.Create(ctx, authCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("auth code created", "purpose", "magic_link", "email", email, "code_id", authCode.ID)
|
||||
if err := s.email.SendAuthCode(ctx, email, token, string(domain.PurposeMagicLink)); err != nil {
|
||||
s.logger.Error("failed to send magic link email", "email", email, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyMagicLink verifies a magic link token and returns a login token.
|
||||
func (s *AuthService) VerifyMagicLink(ctx context.Context, email, token, ip, userAgent string) (*LoginOutput, error) {
|
||||
authCode, err := s.codes.FindValid(ctx, email, token, domain.PurposeMagicLink)
|
||||
if err != nil {
|
||||
return nil, domain.ErrInvalidAuthCode
|
||||
}
|
||||
|
||||
if err := s.codes.MarkUsed(ctx, authCode.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = s.users.UpdateLastLogin(ctx, user.ID)
|
||||
s.logger.Info("user logged in via magic link", "user_id", string(user.ID), "email", email)
|
||||
|
||||
return s.createSession(ctx, user, ip, userAgent)
|
||||
}
|
||||
|
||||
// ForgotPassword generates a password reset token.
|
||||
func (s *AuthService) ForgotPassword(ctx context.Context, email, ip string) error {
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, domain.ErrUserNotFound) {
|
||||
// Don't reveal whether email exists
|
||||
s.logger.Info("password reset requested for unknown email", "email", email)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
token := generateHexToken()
|
||||
uid := user.ID
|
||||
authCode := &domain.AuthCode{
|
||||
ID: "acd_" + generateID(),
|
||||
UserID: &uid,
|
||||
Email: email,
|
||||
Code: token,
|
||||
Purpose: domain.PurposePasswordReset,
|
||||
ExpiresAt: time.Now().Add(PasswordResetExpiry),
|
||||
IPAddress: ip,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.codes.Create(ctx, authCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("auth code created", "purpose", "password_reset", "email", email, "code_id", authCode.ID)
|
||||
if err := s.email.SendAuthCode(ctx, email, token, string(domain.PurposePasswordReset)); err != nil {
|
||||
s.logger.Error("failed to send password reset email", "email", email, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetPassword sets a new password using a reset token and revokes all sessions.
|
||||
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
|
||||
authCode, err := s.codes.FindValid(ctx, email, token, domain.PurposePasswordReset)
|
||||
if err != nil {
|
||||
return domain.ErrInvalidAuthCode
|
||||
}
|
||||
|
||||
if err := auth.ValidatePasswordStrength(newPassword); err != nil {
|
||||
return fmt.Errorf("%w: %w", domain.ErrWeakPassword, err)
|
||||
}
|
||||
|
||||
user, err := s.users.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hashing password: %w", err)
|
||||
}
|
||||
|
||||
if err := s.users.SetPassword(ctx, user.ID, hash); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.codes.MarkUsed(ctx, authCode.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Revoke all sessions — user must re-login with new password.
|
||||
_ = s.sessions.RevokeAllForUser(ctx, user.ID, nil)
|
||||
s.logger.Info("password reset completed", "user_id", string(user.ID), "email", email)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendVerifyEmail generates an email verification code.
|
||||
func (s *AuthService) SendVerifyEmail(ctx context.Context, userID string) error {
|
||||
user, err := s.users.Get(ctx, domain.UserID(userID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.EmailVerified {
|
||||
return nil
|
||||
}
|
||||
|
||||
code := generateOTP()
|
||||
uid := user.ID
|
||||
authCode := &domain.AuthCode{
|
||||
ID: "acd_" + generateID(),
|
||||
UserID: &uid,
|
||||
Email: user.Email,
|
||||
Code: code,
|
||||
Purpose: domain.PurposeEmailVerify,
|
||||
ExpiresAt: time.Now().Add(EmailVerifyExpiry),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.codes.Create(ctx, authCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("auth code created", "purpose", "email_verify", "email", user.Email, "code_id", authCode.ID)
|
||||
if err := s.email.SendAuthCode(ctx, user.Email, code, string(domain.PurposeEmailVerify)); err != nil {
|
||||
s.logger.Error("failed to send email verification", "email", user.Email, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyEmail marks the user's email as verified.
|
||||
func (s *AuthService) VerifyEmail(ctx context.Context, userID, code string) error {
|
||||
user, err := s.users.Get(ctx, domain.UserID(userID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
authCode, err := s.codes.FindValid(ctx, user.Email, code, domain.PurposeEmailVerify)
|
||||
if err != nil {
|
||||
return domain.ErrInvalidAuthCode
|
||||
}
|
||||
|
||||
if err := s.codes.MarkUsed(ctx, authCode.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.EmailVerified = true
|
||||
if err := s.users.Update(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("email verified", "user_id", userID, "email", user.Email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSession creates a session record and generates a JWT.
|
||||
func (s *AuthService) createSession(ctx context.Context, user *domain.User, ip, userAgent string) (*LoginOutput, error) {
|
||||
sessionID := "ses_" + generateID()
|
||||
now := time.Now()
|
||||
|
||||
session := &domain.Session{
|
||||
ID: domain.SessionID(sessionID),
|
||||
UserID: user.ID,
|
||||
IPAddress: ip,
|
||||
UserAgent: userAgent,
|
||||
DeviceLabel: auth.ParseDeviceLabel(userAgent),
|
||||
LastActiveAt: now,
|
||||
ExpiresAt: now.Add(SessionLifetime),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if err := s.sessions.Create(ctx, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := s.generateToken(user, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LoginOutput{Token: token, User: user}, nil
|
||||
}
|
||||
|
||||
// generateToken creates a JWT for the user with the given session ID.
|
||||
func (s *AuthService) generateToken(user *domain.User, sessionID string) (string, error) {
|
||||
authUser := &auth.User{
|
||||
ID: string(user.ID),
|
||||
Email: user.Email,
|
||||
Roles: user.Roles,
|
||||
}
|
||||
return auth.GenerateTokenWithSession(
|
||||
s.jwtSecret, authUser, TokenLifetime, s.issuer, s.issuer, sessionID,
|
||||
)
|
||||
}
|
||||
|
||||
// generateID returns a random hex string suitable for entity IDs.
|
||||
func generateID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic("crypto/rand failed: " + err.Error())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// generateOTP returns a 6-digit numeric one-time password.
|
||||
func generateOTP() string {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(1000000))
|
||||
if err != nil {
|
||||
panic("crypto/rand failed: " + err.Error())
|
||||
}
|
||||
return fmt.Sprintf("%06d", n.Int64())
|
||||
}
|
||||
|
||||
// validateAvatarURL checks that the URL uses http or https.
|
||||
func validateAvatarURL(rawURL string) error {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return domain.ErrInvalidAvatarURL
|
||||
}
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return domain.ErrInvalidAvatarURL
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateHexToken returns a 32-character hex token for magic links and resets.
|
||||
func generateHexToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic("crypto/rand failed: " + err.Error())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
@ -197,7 +197,7 @@ func main() {
|
||||
// GCS_BUCKET is injected by the platform; if absent, store is nil (media not persisted).
|
||||
var mediaStore storage.Store
|
||||
if bucket := os.Getenv("GCS_BUCKET"); bucket != "" {
|
||||
gcsStore, err := storage.NewGCSStore(bucket, os.Getenv("GCS_SERVICE_ACCOUNT_JSON"), logger.Logger)
|
||||
gcsStore, err := storage.NewGCSStore(ctx, bucket, os.Getenv("GCS_SERVICE_ACCOUNT_JSON"), logger.Logger)
|
||||
if err != nil {
|
||||
logger.Warn("failed to create GCS store, generated media will not be persisted", "error", err)
|
||||
} else {
|
||||
|
||||
@ -41,6 +41,21 @@ You design database schemas and optimize queries for {{PROJECT_NAME}}. Every ser
|
||||
- Composite indexes: most selective column first
|
||||
- Name format: `idx_{table}_{columns}`
|
||||
|
||||
## Auth Tables (built-in)
|
||||
|
||||
These tables are auto-created by `001_create_users.sql`:
|
||||
|
||||
| Table | Purpose | Key Columns |
|
||||
|-------|---------|-------------|
|
||||
| `users` | Core identity | `id TEXT PK`, `email UNIQUE`, `email_verified`, `status` |
|
||||
| `user_passwords` | Bcrypt hashes | `user_id TEXT PK FK`, `password_hash` |
|
||||
| `sessions` | Login tracking | `user_id FK`, `ip_address`, `device_label`, `revoked_at` |
|
||||
| `auth_codes` | OTP/magic/reset | `email`, `code`, `purpose`, `expires_at`, `used_at` |
|
||||
| `user_roles` | Role assignments | `(user_id, role) PK` |
|
||||
| `oauth_connections` | OAuth providers | `(provider, provider_user_id) UNIQUE` |
|
||||
|
||||
Key indexes: `idx_auth_codes_email_purpose` (partial, WHERE used_at IS NULL), `idx_sessions_user_id` (partial, WHERE revoked_at IS NULL).
|
||||
|
||||
## Migration Rules
|
||||
|
||||
- NEVER modify committed migrations
|
||||
|
||||
@ -10,26 +10,45 @@ You enforce security best practices across {{PROJECT_NAME}}. Authentication is c
|
||||
|
||||
## Authentication
|
||||
|
||||
### JWT Pattern
|
||||
- Tokens issued by auth service
|
||||
- Other services validate tokens via middleware
|
||||
- Short-lived access tokens + longer refresh tokens
|
||||
- Never store tokens in localStorage (use httpOnly cookies)
|
||||
### JWT Token Lifecycle
|
||||
- **Access tokens:** 15 minutes, signed with `JWT_SECRET`
|
||||
- **Session ID:** Embedded as `sid` claim for revocation support
|
||||
- **Refresh:** POST `/auth/refresh` issues new token, same session
|
||||
- **Revocation:** Revoking a session invalidates all tokens for that session
|
||||
|
||||
### Middleware
|
||||
### Password Security
|
||||
- **Hashing:** bcrypt cost 12 (`pkg/auth/password.go`)
|
||||
- **Strength:** Min 8 chars, max 72 (bcrypt limit), requires uppercase + lowercase + digit
|
||||
- **Storage:** Separate `user_passwords` table (OAuth-only users have no password row)
|
||||
|
||||
### Auth Codes
|
||||
- **OTP login:** 6-digit numeric, 10-minute expiry
|
||||
- **Magic links:** 32-char hex token, 15-minute expiry
|
||||
- **Password reset:** 32-char hex token, 1-hour expiry
|
||||
- **Email verification:** 6-digit numeric, 24-hour expiry
|
||||
- All codes are single-use (marked with `used_at` timestamp)
|
||||
- In dev mode (`NOTIFY_URL` unset): codes logged to stdout
|
||||
- In production: emails sent via the notify service, which handles provider routing, retries, and suppression
|
||||
|
||||
### Session Management
|
||||
- Sessions track IP address, user agent, device label
|
||||
- 30-day session lifetime
|
||||
- Revokable individually or all-at-once (except current)
|
||||
- `SessionCheck` middleware (opt-in) validates session on every request
|
||||
|
||||
### Middleware Stack
|
||||
```go
|
||||
func AuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := extractToken(r)
|
||||
claims, err := validateToken(token)
|
||||
if err != nil {
|
||||
httpresponse.Unauthorized(w, "invalid token")
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), userKey, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
// Auth middleware validates JWT and sets user in context
|
||||
r.Use(auth.Middleware(auth.MiddlewareConfig{
|
||||
Validator: jwtValidator,
|
||||
}))
|
||||
|
||||
// Optional: enforce session revocation
|
||||
r.Use(auth.SessionCheck(checker))
|
||||
|
||||
// Require specific roles
|
||||
r.Use(auth.RequireRole("admin"))
|
||||
```
|
||||
```
|
||||
|
||||
## Input Validation
|
||||
@ -54,7 +73,7 @@ func AuthMiddleware(next http.Handler) http.Handler {
|
||||
|------|-----------|
|
||||
| SQL Injection | Parameterized queries only |
|
||||
| XSS | Sanitize input, escape output |
|
||||
| CSRF | CSRF tokens for state-changing requests |
|
||||
| CSRF | Not applicable — all auth uses Bearer tokens in Authorization header, not cookies |
|
||||
| Auth Bypass | Middleware on every protected route |
|
||||
| Secret Exposure | .env in .gitignore, no hardcoding |
|
||||
| Mass Assignment | Explicit field mapping (no bind-all) |
|
||||
|
||||
@ -0,0 +1,146 @@
|
||||
# Authentication & User Management
|
||||
|
||||
Complete auth system with registration, login, sessions, and verification flows.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Frontend (AuthProvider) → HTTP → Auth Handlers → AuthService → Repositories (Memory | Postgres)
|
||||
```
|
||||
|
||||
- **pkg/auth/** — JWT validation, middleware, password hashing, session checking (shared)
|
||||
- **service/internal/domain/** — User, Session, AuthCode domain models
|
||||
- **service/internal/port/** — Repository interfaces (UserRepository, SessionRepository, AuthCodeRepository)
|
||||
- **service/internal/adapter/memory/** — In-memory implementations for standalone dev
|
||||
- **service/internal/adapter/postgres/** — PostgreSQL/CockroachDB implementations for production
|
||||
- **service/internal/service/auth.go** — Business logic (AuthService)
|
||||
- **service/internal/api/handlers/auth.go** — Core HTTP handlers (login, register, profile)
|
||||
- **service/internal/api/handlers/auth_flows.go** — Flow handlers (OTP, magic link, sessions, reset)
|
||||
|
||||
## Standalone Mode (No DATABASE_URL)
|
||||
|
||||
When `DATABASE_URL` is not set, the service runs with in-memory adapters:
|
||||
- Two demo users seeded: `test@example.com` / `Password123`, `admin@example.com` / `Admin1234`
|
||||
- Auth codes (OTP, magic links, reset tokens) logged to stdout (no notify/email needed)
|
||||
- Sessions stored in memory (lost on restart)
|
||||
- No external dependencies required
|
||||
|
||||
## Token Lifecycle
|
||||
|
||||
- **Access token:** 15 minutes, JWT with embedded session ID (`sid` claim)
|
||||
- **Refresh:** POST `/auth/refresh` with valid token returns new token (same session)
|
||||
- **Session:** 30-day lifetime, tracked in sessions table
|
||||
- **Revocation:** Revoking a session invalidates all tokens for that session
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `JWT_SECRET` | `""` | Secret for signing JWT tokens |
|
||||
| `REGISTRATION_ENABLED` | `true` | Allow new user registration |
|
||||
| `DATABASE_URL` | `""` | If set, use Postgres repos; otherwise in-memory |
|
||||
| `NOTIFY_URL` | `""` | Notify service URL. If set, emails sent via notify; otherwise logged to stdout |
|
||||
| `NOTIFY_API_KEY` | `""` | Per-project notify send key (`notify_send_xxx`) |
|
||||
| `NOTIFY_HOST` | `""` | Sending domain (e.g. `myapp.threesix.ai`) |
|
||||
| `NOTIFY_FROM` | `noreply@{project}.com` | Registered sender address |
|
||||
|
||||
## Auth Flows
|
||||
|
||||
### Password Login
|
||||
```
|
||||
POST /auth/login { email, password } → { token, user }
|
||||
```
|
||||
|
||||
### Registration
|
||||
```
|
||||
POST /auth/register { email, password, name } → { token, user }
|
||||
```
|
||||
|
||||
### OTP Login
|
||||
```
|
||||
POST /auth/otp/send { email } → 200 (code logged to stdout in dev)
|
||||
POST /auth/otp/verify { email, code } → { token, user }
|
||||
```
|
||||
|
||||
### Magic Link
|
||||
```
|
||||
POST /auth/magic-link { email } → 200 (token logged to stdout in dev)
|
||||
POST /auth/magic-link/verify { email, token } → { token, user }
|
||||
```
|
||||
|
||||
### Password Reset
|
||||
```
|
||||
POST /auth/forgot-password { email } → 200 (token logged to stdout in dev)
|
||||
POST /auth/reset-password { email, token, newPassword } → 200
|
||||
```
|
||||
|
||||
### Email Verification (requires auth)
|
||||
```
|
||||
POST /auth/verify-email/send → 200 (code logged to stdout in dev)
|
||||
POST /auth/verify-email { code } → 200
|
||||
```
|
||||
|
||||
### Session Management (requires auth)
|
||||
```
|
||||
GET /auth/sessions → [{ id, deviceLabel, ipAddress, lastActiveAt, isCurrent }]
|
||||
DELETE /auth/sessions/{id} → 204
|
||||
DELETE /auth/sessions → 204 (revoke all except current)
|
||||
```
|
||||
|
||||
### Profile (requires auth)
|
||||
```
|
||||
GET /auth/me → { user }
|
||||
PUT /auth/me { name, avatarUrl } → { user }
|
||||
POST /auth/change-password { currentPassword, newPassword } → 200
|
||||
POST /auth/logout → 204
|
||||
```
|
||||
|
||||
## Frontend Integration
|
||||
|
||||
The `@{{PROJECT_NAME}}/auth` package provides `AuthProvider` and `useAuth()` hook:
|
||||
|
||||
```tsx
|
||||
// In App.tsx
|
||||
<AuthProvider authBaseUrl={`${apiBaseUrl}/api/service-name`}>
|
||||
<App />
|
||||
</AuthProvider>
|
||||
|
||||
// In components
|
||||
const { user, login, register, logout, sendOTP, loginWithOTP } = useAuth();
|
||||
```
|
||||
|
||||
Auto-refresh schedules token renewal at 80% of token lifetime.
|
||||
|
||||
## Adding Session Revocation Middleware
|
||||
|
||||
To enforce session revocation on every request (opt-in):
|
||||
|
||||
```go
|
||||
import "{{GO_MODULE}}/pkg/auth"
|
||||
|
||||
checker := func(ctx context.Context, sid string) (bool, error) {
|
||||
session, err := sessionRepo.Get(ctx, domain.SessionID(sid))
|
||||
if err != nil { return false, nil }
|
||||
return session.IsActive(), nil
|
||||
}
|
||||
|
||||
r.Use(auth.SessionCheck(checker))
|
||||
```
|
||||
|
||||
## Password Requirements
|
||||
|
||||
- Minimum 8 characters, maximum 72 (bcrypt limit)
|
||||
- Must contain uppercase, lowercase, and digit
|
||||
- Hashed with bcrypt cost 12
|
||||
|
||||
## Database Tables
|
||||
|
||||
When `DATABASE_URL` is set, these tables are auto-created:
|
||||
- `users` — Core identity (email, name, status)
|
||||
- `user_passwords` — Bcrypt hashes (separate for OAuth-only users)
|
||||
- `sessions` — Login sessions with IP/device tracking
|
||||
- `auth_codes` — OTP, magic link, reset, and verification codes
|
||||
- `user_roles` — Many-to-many user roles
|
||||
- `oauth_connections` — Schema placeholder for future OAuth provider links (table exists but no handlers/adapters yet)
|
||||
|
||||
> **Note:** The `oauth_connections` table is created by the migration but has no corresponding handlers, service methods, or adapters. It's a schema placeholder — implementing OAuth requires building the full handler → service → adapter chain. See the composable monorepo templates guide for adding new auth providers.
|
||||
@ -10,6 +10,7 @@
|
||||
| **Build a feature** | [feature-development.md](.claude/guides/feature-development.md) |
|
||||
| **Backend API patterns** | [backend/api-patterns.md](.claude/guides/backend/api-patterns.md) |
|
||||
| **Frontend design system** | [frontend/design-system.md](.claude/guides/frontend/design-system.md) |
|
||||
| **Auth & user management** | [auth.md](.claude/guides/auth.md) |
|
||||
| **Event channels** | [events.md](.claude/guides/events.md) |
|
||||
| **Media pipeline** | [media.md](.claude/guides/media.md) |
|
||||
| **Deploy** | [ops/deploying.md](.claude/guides/ops/deploying.md) |
|
||||
@ -46,6 +47,9 @@
|
||||
- **Media generation:** Same pattern - POST queues job, returns ID, results via SSE. Video takes 2-5 min; never block HTTP. Text generation streams `ai_chat_chunk` events token-by-token.
|
||||
- **Media storage:** Backend returns complete URLs. Never construct storage paths in frontend. Variants (thumbnail, optimized) auto-generated.
|
||||
- **No fake progress:** Never simulate progress with timers. Real progress comes from real events.
|
||||
- **Auth tokens:** 15-minute access tokens with embedded session ID (`sid`). Refresh via POST `/auth/refresh`. Session revocation invalidates all tokens for that session.
|
||||
- **Passwords:** Bcrypt cost 12, min 8 chars, max 72. Hashing lives in `pkg/auth/password.go`. Never store plaintext.
|
||||
- **Auth codes:** OTP/magic link/reset codes are single-use and time-limited. In dev mode (`NOTIFY_URL` unset), codes are logged to stdout. In production, emails go through the notify service (`NOTIFY_URL`/`NOTIFY_API_KEY`/`NOTIFY_HOST`/`NOTIFY_FROM`).
|
||||
|
||||
## Architecture
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import * as React from 'react';
|
||||
import { createContext, useContext, useCallback, useMemo, useEffect, useState } from 'react';
|
||||
import type { User, AuthState, LoginCredentials } from './types';
|
||||
import { createContext, useContext, useCallback, useMemo, useEffect, useState, useRef } from 'react';
|
||||
import type { User, AuthState, LoginCredentials, RegisterCredentials, OTPVerifyCredentials, MagicLinkVerifyCredentials } from './types';
|
||||
|
||||
const TOKEN_STORAGE_KEY = 'auth_token';
|
||||
const USER_STORAGE_KEY = 'auth_user';
|
||||
@ -9,18 +9,30 @@ const USER_STORAGE_KEY = 'auth_user';
|
||||
* Authentication context value.
|
||||
*/
|
||||
export interface AuthContextValue extends AuthState {
|
||||
/** Log in with credentials */
|
||||
/** Log in with email and password */
|
||||
login: (credentials: LoginCredentials) => Promise<void>;
|
||||
/** Log in with a token directly */
|
||||
loginWithToken: (token: string, user?: User) => void;
|
||||
/** Register a new account */
|
||||
register: (credentials: RegisterCredentials) => Promise<void>;
|
||||
/** Log out the current user */
|
||||
logout: () => void;
|
||||
logout: () => Promise<void>;
|
||||
/** Get the current access token */
|
||||
getToken: () => string | null;
|
||||
/** Check if user has a specific role */
|
||||
hasRole: (role: string) => boolean;
|
||||
/** Check if user has a specific scope */
|
||||
hasScope: (scope: string) => boolean;
|
||||
/** Send an OTP code to an email */
|
||||
sendOTP: (email: string) => Promise<void>;
|
||||
/** Log in with an OTP code */
|
||||
loginWithOTP: (credentials: OTPVerifyCredentials) => Promise<void>;
|
||||
/** Send a magic link to an email */
|
||||
sendMagicLink: (email: string) => Promise<void>;
|
||||
/** Log in with a magic link token */
|
||||
loginWithMagicLink: (credentials: MagicLinkVerifyCredentials) => Promise<void>;
|
||||
/** Refresh the access token */
|
||||
refreshToken: () => Promise<void>;
|
||||
}
|
||||
|
||||
const AuthContext = createContext<AuthContextValue | null>(null);
|
||||
@ -30,10 +42,16 @@ const AuthContext = createContext<AuthContextValue | null>(null);
|
||||
*/
|
||||
export interface AuthProviderProps {
|
||||
children: React.ReactNode;
|
||||
/** API endpoint for login */
|
||||
/** API base URL for auth endpoints (e.g. "/api/my-service") */
|
||||
authBaseUrl?: string;
|
||||
/** API endpoint for login (defaults to authBaseUrl + "/auth/login") */
|
||||
loginUrl?: string;
|
||||
/** API endpoint for logout */
|
||||
logoutUrl?: string;
|
||||
/** API endpoint for registration */
|
||||
registerUrl?: string;
|
||||
/** API endpoint for token refresh */
|
||||
refreshUrl?: string;
|
||||
/** Custom login handler */
|
||||
onLogin?: (credentials: LoginCredentials) => Promise<{ token: string; user: User }>;
|
||||
/** Custom logout handler */
|
||||
@ -44,32 +62,28 @@ export interface AuthProviderProps {
|
||||
|
||||
/**
|
||||
* AuthProvider manages authentication state and provides auth methods.
|
||||
*
|
||||
* @example
|
||||
* // Basic usage
|
||||
* <AuthProvider loginUrl="/api/auth/login">
|
||||
* <App />
|
||||
* </AuthProvider>
|
||||
*
|
||||
* @example
|
||||
* // With custom handlers
|
||||
* <AuthProvider
|
||||
* onLogin={async (creds) => {
|
||||
* const res = await myAuthService.login(creds);
|
||||
* return { token: res.token, user: res.user };
|
||||
* }}
|
||||
* >
|
||||
* <App />
|
||||
* </AuthProvider>
|
||||
*/
|
||||
export function AuthProvider({
|
||||
children,
|
||||
loginUrl = '/api/auth/login',
|
||||
logoutUrl = '/api/auth/logout',
|
||||
authBaseUrl,
|
||||
loginUrl,
|
||||
logoutUrl,
|
||||
registerUrl,
|
||||
refreshUrl,
|
||||
onLogin,
|
||||
onLogout,
|
||||
storage = 'localStorage',
|
||||
}: AuthProviderProps) {
|
||||
// Derive URLs from authBaseUrl if individual URLs not provided
|
||||
const resolvedLoginUrl = loginUrl || (authBaseUrl ? `${authBaseUrl}/auth/login` : '/api/auth/login');
|
||||
const resolvedLogoutUrl = logoutUrl || (authBaseUrl ? `${authBaseUrl}/auth/logout` : '/api/auth/logout');
|
||||
const resolvedRegisterUrl = registerUrl || (authBaseUrl ? `${authBaseUrl}/auth/register` : '/api/auth/register');
|
||||
const resolvedRefreshUrl = refreshUrl || (authBaseUrl ? `${authBaseUrl}/auth/refresh` : '/api/auth/refresh');
|
||||
const resolvedOtpSendUrl = authBaseUrl ? `${authBaseUrl}/auth/otp/send` : '/api/auth/otp/send';
|
||||
const resolvedOtpVerifyUrl = authBaseUrl ? `${authBaseUrl}/auth/otp/verify` : '/api/auth/otp/verify';
|
||||
const resolvedMagicLinkUrl = authBaseUrl ? `${authBaseUrl}/auth/magic-link` : '/api/auth/magic-link';
|
||||
const resolvedMagicLinkVerifyUrl = authBaseUrl ? `${authBaseUrl}/auth/magic-link/verify` : '/api/auth/magic-link/verify';
|
||||
|
||||
const [state, setState] = useState<AuthState>({
|
||||
user: null,
|
||||
isLoading: true,
|
||||
@ -77,12 +91,86 @@ export function AuthProvider({
|
||||
error: null,
|
||||
});
|
||||
|
||||
const refreshTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
// Get storage implementation
|
||||
const getStorage = useCallback(() => {
|
||||
if (storage === 'none') return null;
|
||||
return storage === 'sessionStorage' ? sessionStorage : localStorage;
|
||||
}, [storage]);
|
||||
|
||||
// Store token and user
|
||||
const persistAuth = useCallback((token: string, user: User) => {
|
||||
const store = getStorage();
|
||||
if (store) {
|
||||
store.setItem(TOKEN_STORAGE_KEY, token);
|
||||
store.setItem(USER_STORAGE_KEY, JSON.stringify(user));
|
||||
}
|
||||
}, [getStorage]);
|
||||
|
||||
// Clear stored auth
|
||||
const clearAuth = useCallback(() => {
|
||||
const store = getStorage();
|
||||
if (store) {
|
||||
store.removeItem(TOKEN_STORAGE_KEY);
|
||||
store.removeItem(USER_STORAGE_KEY);
|
||||
}
|
||||
}, [getStorage]);
|
||||
|
||||
// Schedule token refresh (at 80% of token lifetime)
|
||||
const scheduleRefresh = useCallback((token: string) => {
|
||||
if (refreshTimerRef.current) {
|
||||
clearTimeout(refreshTimerRef.current);
|
||||
}
|
||||
try {
|
||||
const payload = JSON.parse(atob(token.split('.')[1]));
|
||||
const exp = payload.exp * 1000;
|
||||
const iat = payload.iat * 1000;
|
||||
const lifetime = exp - iat;
|
||||
const refreshAt = iat + lifetime * 0.8;
|
||||
const delay = refreshAt - Date.now();
|
||||
|
||||
if (delay > 0) {
|
||||
refreshTimerRef.current = setTimeout(async () => {
|
||||
try {
|
||||
const store = getStorage();
|
||||
const currentToken = store?.getItem(TOKEN_STORAGE_KEY);
|
||||
if (!currentToken) return;
|
||||
|
||||
const response = await fetch(resolvedRefreshUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${currentToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
const newToken = data.data?.token || data.token;
|
||||
const newUser = data.data?.user || data.user;
|
||||
persistAuth(newToken, newUser);
|
||||
setState(s => ({ ...s, user: newUser }));
|
||||
scheduleRefresh(newToken);
|
||||
} else if (response.status === 401) {
|
||||
// Session revoked or token invalid — force logout.
|
||||
clearAuth();
|
||||
setState({ user: null, isLoading: false, isAuthenticated: false, error: null });
|
||||
}
|
||||
} catch {
|
||||
// Refresh failed (network error) — clear auth to prevent silent expiry.
|
||||
clearAuth();
|
||||
setState({ user: null, isLoading: false, isAuthenticated: false, error: null });
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
} catch {
|
||||
// Corrupted token in storage — clear it and force logout.
|
||||
clearAuth();
|
||||
setState({ user: null, isLoading: false, isAuthenticated: false, error: null });
|
||||
}
|
||||
}, [getStorage, persistAuth, clearAuth, resolvedRefreshUrl]);
|
||||
|
||||
// Initialize auth state from storage
|
||||
useEffect(() => {
|
||||
const store = getStorage();
|
||||
@ -97,14 +185,70 @@ export function AuthProvider({
|
||||
if (token && userJson) {
|
||||
try {
|
||||
const user = JSON.parse(userJson) as User;
|
||||
|
||||
// Check if the stored token is already expired
|
||||
let tokenExpired = false;
|
||||
let deeplyExpired = false;
|
||||
try {
|
||||
const payload = JSON.parse(atob(token.split('.')[1]));
|
||||
const exp = payload.exp * 1000;
|
||||
tokenExpired = exp < Date.now();
|
||||
// Session is deeply expired if token expired over 30 days ago
|
||||
deeplyExpired = exp + 30 * 24 * 60 * 60 * 1000 < Date.now();
|
||||
} catch {
|
||||
// Corrupted token — treat as expired
|
||||
tokenExpired = true;
|
||||
deeplyExpired = true;
|
||||
}
|
||||
|
||||
if (deeplyExpired) {
|
||||
// Session expired beyond recovery — clear auth
|
||||
store.removeItem(TOKEN_STORAGE_KEY);
|
||||
store.removeItem(USER_STORAGE_KEY);
|
||||
setState({ user: null, isLoading: false, isAuthenticated: false, error: null });
|
||||
} else if (tokenExpired) {
|
||||
// Token expired but session may still be valid — attempt refresh
|
||||
setState((s) => ({ ...s, isLoading: true }));
|
||||
fetch(resolvedRefreshUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${token}`,
|
||||
},
|
||||
})
|
||||
.then((response) => {
|
||||
if (response.ok) {
|
||||
return response.json().then((data) => {
|
||||
const newToken = data.data?.token || data.token;
|
||||
const newUser = data.data?.user || data.user;
|
||||
store.setItem(TOKEN_STORAGE_KEY, newToken);
|
||||
store.setItem(USER_STORAGE_KEY, JSON.stringify(newUser));
|
||||
setState({ user: newUser, isLoading: false, isAuthenticated: true, error: null });
|
||||
scheduleRefresh(newToken);
|
||||
});
|
||||
}
|
||||
// Refresh failed (401, etc.) — clear auth
|
||||
store.removeItem(TOKEN_STORAGE_KEY);
|
||||
store.removeItem(USER_STORAGE_KEY);
|
||||
setState({ user: null, isLoading: false, isAuthenticated: false, error: null });
|
||||
})
|
||||
.catch(() => {
|
||||
// Network error during refresh — clear auth
|
||||
store.removeItem(TOKEN_STORAGE_KEY);
|
||||
store.removeItem(USER_STORAGE_KEY);
|
||||
setState({ user: null, isLoading: false, isAuthenticated: false, error: null });
|
||||
});
|
||||
} else {
|
||||
// Token is still valid — restore authenticated state
|
||||
setState({
|
||||
user,
|
||||
isLoading: false,
|
||||
isAuthenticated: true,
|
||||
error: null,
|
||||
});
|
||||
scheduleRefresh(token);
|
||||
}
|
||||
} catch {
|
||||
// Invalid stored data, clear it
|
||||
store.removeItem(TOKEN_STORAGE_KEY);
|
||||
store.removeItem(USER_STORAGE_KEY);
|
||||
setState((s) => ({ ...s, isLoading: false }));
|
||||
@ -112,7 +256,31 @@ export function AuthProvider({
|
||||
} else {
|
||||
setState((s) => ({ ...s, isLoading: false }));
|
||||
}
|
||||
}, [getStorage]);
|
||||
}, [getStorage, scheduleRefresh, resolvedRefreshUrl]);
|
||||
|
||||
// Cleanup timer
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (refreshTimerRef.current) {
|
||||
clearTimeout(refreshTimerRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Helper to handle auth response (login, register, OTP verify, magic link verify)
|
||||
const handleAuthResponse = useCallback(
|
||||
(token: string, user: User) => {
|
||||
persistAuth(token, user);
|
||||
setState({
|
||||
user,
|
||||
isLoading: false,
|
||||
isAuthenticated: true,
|
||||
error: null,
|
||||
});
|
||||
scheduleRefresh(token);
|
||||
},
|
||||
[persistAuth, scheduleRefresh]
|
||||
);
|
||||
|
||||
// Login with credentials
|
||||
const login = useCallback(
|
||||
@ -124,13 +292,11 @@ export function AuthProvider({
|
||||
let user: User;
|
||||
|
||||
if (onLogin) {
|
||||
// Use custom login handler
|
||||
const result = await onLogin(credentials);
|
||||
token = result.token;
|
||||
user = result.user;
|
||||
} else {
|
||||
// Use default API login
|
||||
const response = await fetch(loginUrl, {
|
||||
const response = await fetch(resolvedLoginUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(credentials),
|
||||
@ -147,19 +313,7 @@ export function AuthProvider({
|
||||
user = data.data?.user || data.user;
|
||||
}
|
||||
|
||||
// Store token and user
|
||||
const store = getStorage();
|
||||
if (store) {
|
||||
store.setItem(TOKEN_STORAGE_KEY, token);
|
||||
store.setItem(USER_STORAGE_KEY, JSON.stringify(user));
|
||||
}
|
||||
|
||||
setState({
|
||||
user,
|
||||
isLoading: false,
|
||||
isAuthenticated: true,
|
||||
error: null,
|
||||
});
|
||||
handleAuthResponse(token, user);
|
||||
} catch (error) {
|
||||
setState({
|
||||
user: null,
|
||||
@ -170,7 +324,7 @@ export function AuthProvider({
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[loginUrl, onLogin, getStorage]
|
||||
[resolvedLoginUrl, onLogin, handleAuthResponse]
|
||||
);
|
||||
|
||||
// Login with token directly
|
||||
@ -190,25 +344,68 @@ export function AuthProvider({
|
||||
isAuthenticated: true,
|
||||
error: null,
|
||||
});
|
||||
scheduleRefresh(token);
|
||||
},
|
||||
[getStorage]
|
||||
[getStorage, scheduleRefresh]
|
||||
);
|
||||
|
||||
// Register
|
||||
const register = useCallback(
|
||||
async (credentials: RegisterCredentials) => {
|
||||
setState((s) => ({ ...s, isLoading: true, error: null }));
|
||||
|
||||
try {
|
||||
const response = await fetch(resolvedRegisterUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(credentials),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errBody = await response.json().catch(() => ({}));
|
||||
const errMsg = errBody.error?.message || errBody.message || 'Registration failed';
|
||||
throw new Error(errMsg);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const token = data.data?.token || data.token;
|
||||
const user = data.data?.user || data.user;
|
||||
|
||||
handleAuthResponse(token, user);
|
||||
} catch (error) {
|
||||
setState({
|
||||
user: null,
|
||||
isLoading: false,
|
||||
isAuthenticated: false,
|
||||
error: error instanceof Error ? error : new Error('Registration failed'),
|
||||
});
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[resolvedRegisterUrl, handleAuthResponse]
|
||||
);
|
||||
|
||||
// Logout
|
||||
const logout = useCallback(async () => {
|
||||
if (refreshTimerRef.current) {
|
||||
clearTimeout(refreshTimerRef.current);
|
||||
}
|
||||
|
||||
try {
|
||||
if (onLogout) {
|
||||
await onLogout();
|
||||
} else if (logoutUrl) {
|
||||
await fetch(logoutUrl, { method: 'POST' }).catch(() => {});
|
||||
} else {
|
||||
const store = getStorage();
|
||||
const token = store?.getItem(TOKEN_STORAGE_KEY);
|
||||
if (token) {
|
||||
await fetch(resolvedLogoutUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Authorization': `Bearer ${token}` },
|
||||
}).catch(() => {});
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
const store = getStorage();
|
||||
if (store) {
|
||||
store.removeItem(TOKEN_STORAGE_KEY);
|
||||
store.removeItem(USER_STORAGE_KEY);
|
||||
}
|
||||
|
||||
clearAuth();
|
||||
setState({
|
||||
user: null,
|
||||
isLoading: false,
|
||||
@ -216,7 +413,134 @@ export function AuthProvider({
|
||||
error: null,
|
||||
});
|
||||
}
|
||||
}, [logoutUrl, onLogout, getStorage]);
|
||||
}, [resolvedLogoutUrl, onLogout, getStorage, clearAuth]);
|
||||
|
||||
// Send OTP
|
||||
const sendOTP = useCallback(
|
||||
async (email: string) => {
|
||||
const response = await fetch(resolvedOtpSendUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ email }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errBody = await response.json().catch(() => ({}));
|
||||
throw new Error(errBody.error?.message || errBody.message || 'Failed to send code');
|
||||
}
|
||||
},
|
||||
[resolvedOtpSendUrl]
|
||||
);
|
||||
|
||||
// Login with OTP
|
||||
const loginWithOTP = useCallback(
|
||||
async (credentials: OTPVerifyCredentials) => {
|
||||
setState((s) => ({ ...s, isLoading: true, error: null }));
|
||||
|
||||
try {
|
||||
const response = await fetch(resolvedOtpVerifyUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(credentials),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errBody = await response.json().catch(() => ({}));
|
||||
throw new Error(errBody.error?.message || errBody.message || 'Invalid code');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const token = data.data?.token || data.token;
|
||||
const user = data.data?.user || data.user;
|
||||
|
||||
handleAuthResponse(token, user);
|
||||
} catch (error) {
|
||||
setState((s) => ({
|
||||
...s,
|
||||
isLoading: false,
|
||||
error: error instanceof Error ? error : new Error('OTP verification failed'),
|
||||
}));
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[resolvedOtpVerifyUrl, handleAuthResponse]
|
||||
);
|
||||
|
||||
// Send magic link
|
||||
const sendMagicLink = useCallback(
|
||||
async (email: string) => {
|
||||
const response = await fetch(resolvedMagicLinkUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ email }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errBody = await response.json().catch(() => ({}));
|
||||
throw new Error(errBody.error?.message || errBody.message || 'Failed to send link');
|
||||
}
|
||||
},
|
||||
[resolvedMagicLinkUrl]
|
||||
);
|
||||
|
||||
// Login with magic link token
|
||||
const loginWithMagicLink = useCallback(
|
||||
async (credentials: MagicLinkVerifyCredentials) => {
|
||||
setState((s) => ({ ...s, isLoading: true, error: null }));
|
||||
|
||||
try {
|
||||
const response = await fetch(resolvedMagicLinkVerifyUrl, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(credentials),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errBody = await response.json().catch(() => ({}));
|
||||
throw new Error(errBody.error?.message || errBody.message || 'Invalid link');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const token = data.data?.token || data.token;
|
||||
const user = data.data?.user || data.user;
|
||||
|
||||
handleAuthResponse(token, user);
|
||||
} catch (error) {
|
||||
setState((s) => ({
|
||||
...s,
|
||||
isLoading: false,
|
||||
error: error instanceof Error ? error : new Error('Magic link verification failed'),
|
||||
}));
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[resolvedMagicLinkVerifyUrl, handleAuthResponse]
|
||||
);
|
||||
|
||||
// Refresh token
|
||||
const refreshTokenFn = useCallback(async () => {
|
||||
const store = getStorage();
|
||||
const currentToken = store?.getItem(TOKEN_STORAGE_KEY);
|
||||
if (!currentToken) return;
|
||||
|
||||
const response = await fetch(resolvedRefreshUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${currentToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Token refresh failed');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const newToken = data.data?.token || data.token;
|
||||
const newUser = data.data?.user || data.user;
|
||||
|
||||
persistAuth(newToken, newUser);
|
||||
setState(s => ({ ...s, user: newUser }));
|
||||
scheduleRefresh(newToken);
|
||||
}, [getStorage, resolvedRefreshUrl, persistAuth, scheduleRefresh]);
|
||||
|
||||
// Get token
|
||||
const getToken = useCallback(() => {
|
||||
@ -245,12 +569,18 @@ export function AuthProvider({
|
||||
...state,
|
||||
login,
|
||||
loginWithToken,
|
||||
register,
|
||||
logout,
|
||||
getToken,
|
||||
hasRole,
|
||||
hasScope,
|
||||
sendOTP,
|
||||
loginWithOTP,
|
||||
sendMagicLink,
|
||||
loginWithMagicLink,
|
||||
refreshToken: refreshTokenFn,
|
||||
}),
|
||||
[state, login, loginWithToken, logout, getToken, hasRole, hasScope]
|
||||
[state, login, loginWithToken, register, logout, getToken, hasRole, hasScope, sendOTP, loginWithOTP, sendMagicLink, loginWithMagicLink, refreshTokenFn]
|
||||
);
|
||||
|
||||
return <AuthContext.Provider value={value}>{children}</AuthContext.Provider>;
|
||||
|
||||
@ -1,3 +1,11 @@
|
||||
export { AuthProvider, useAuth, type AuthContextValue } from './AuthProvider';
|
||||
export { ProtectedRoute } from './ProtectedRoute';
|
||||
export type { User, AuthState, LoginCredentials } from './types';
|
||||
export type {
|
||||
User,
|
||||
AuthState,
|
||||
LoginCredentials,
|
||||
RegisterCredentials,
|
||||
OTPVerifyCredentials,
|
||||
MagicLinkVerifyCredentials,
|
||||
Session,
|
||||
} from './types';
|
||||
|
||||
@ -5,6 +5,8 @@ export interface User {
|
||||
id: string;
|
||||
email?: string;
|
||||
name?: string;
|
||||
avatarUrl?: string;
|
||||
emailVerified?: boolean;
|
||||
roles?: string[];
|
||||
scopes?: string[];
|
||||
metadata?: Record<string, unknown>;
|
||||
@ -32,12 +34,48 @@ export interface LoginCredentials {
|
||||
password: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Registration credentials.
|
||||
*/
|
||||
export interface RegisterCredentials {
|
||||
email: string;
|
||||
password: string;
|
||||
name?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* OTP verification credentials.
|
||||
*/
|
||||
export interface OTPVerifyCredentials {
|
||||
email: string;
|
||||
code: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Magic link verification credentials.
|
||||
*/
|
||||
export interface MagicLinkVerifyCredentials {
|
||||
email: string;
|
||||
token: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A login session with device and location information.
|
||||
*/
|
||||
export interface Session {
|
||||
id: string;
|
||||
ipAddress: string;
|
||||
deviceLabel: string;
|
||||
lastActiveAt: string;
|
||||
createdAt: string;
|
||||
isCurrent: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Token response from the authentication API.
|
||||
* Matches the backend LoginResponse shape: { token, user }.
|
||||
*/
|
||||
export interface TokenResponse {
|
||||
access_token: string;
|
||||
refresh_token?: string;
|
||||
token_type: string;
|
||||
expires_in: number;
|
||||
token: string;
|
||||
user: User;
|
||||
}
|
||||
|
||||
@ -4,8 +4,10 @@ import { useState } from 'react';
|
||||
import { Trash2, Image, Video, ExternalLink } from 'lucide-react';
|
||||
|
||||
export interface MediaItem {
|
||||
id: string;
|
||||
path: string;
|
||||
url: string;
|
||||
filename: string;
|
||||
contentType: string;
|
||||
size: number;
|
||||
createdAt: string;
|
||||
@ -14,8 +16,8 @@ export interface MediaItem {
|
||||
export interface MediaLibraryProps {
|
||||
/** Media items to display */
|
||||
items: MediaItem[];
|
||||
/** Called when a media item is deleted */
|
||||
onDelete?: (path: string) => void;
|
||||
/** Called when a media item is deleted (by ID) */
|
||||
onDelete?: (id: string) => void;
|
||||
/** Called when a media item is selected */
|
||||
onSelect?: (item: MediaItem) => void;
|
||||
/** Whether delete operations are in progress */
|
||||
@ -45,7 +47,7 @@ export function MediaLibrary({
|
||||
isDeleting = false,
|
||||
emptyMessage = 'No media files yet. Upload or generate some!',
|
||||
}: MediaLibraryProps) {
|
||||
const [selectedPath, setSelectedPath] = useState<string | null>(null);
|
||||
const [selectedId, setSelectedId] = useState<string | null>(null);
|
||||
|
||||
if (items.length === 0) {
|
||||
return (
|
||||
@ -60,15 +62,15 @@ export function MediaLibrary({
|
||||
<div className="grid grid-cols-2 md:grid-cols-3 lg:grid-cols-4 gap-4">
|
||||
{items.map((item) => (
|
||||
<div
|
||||
key={item.path}
|
||||
key={item.id}
|
||||
className={`
|
||||
group relative rounded-lg overflow-hidden border transition-colors cursor-pointer
|
||||
${selectedPath === item.path
|
||||
${selectedId === item.id
|
||||
? 'border-[var(--accent)] ring-2 ring-[var(--accent)]/20'
|
||||
: 'border-[var(--border-muted)] hover:border-[var(--border-default)]'}
|
||||
`}
|
||||
onClick={() => {
|
||||
setSelectedPath(item.path);
|
||||
setSelectedId(item.id);
|
||||
onSelect?.(item);
|
||||
}}
|
||||
>
|
||||
@ -77,7 +79,7 @@ export function MediaLibrary({
|
||||
{isImage(item.contentType) ? (
|
||||
<img
|
||||
src={item.url}
|
||||
alt={item.path.split('/').pop() || ''}
|
||||
alt={item.filename || item.path.split('/').pop() || ''}
|
||||
className="w-full h-full object-cover"
|
||||
loading="lazy"
|
||||
/>
|
||||
@ -100,7 +102,7 @@ export function MediaLibrary({
|
||||
{/* Info bar */}
|
||||
<div className="p-2 bg-[var(--surface-100)]">
|
||||
<p className="text-xs text-[var(--text-primary)] truncate">
|
||||
{item.path.split('/').pop()}
|
||||
{item.filename || item.path.split('/').pop()}
|
||||
</p>
|
||||
<p className="text-xs text-[var(--text-muted)]">
|
||||
{formatSize(item.size)}
|
||||
@ -121,7 +123,7 @@ export function MediaLibrary({
|
||||
</a>
|
||||
{onDelete && (
|
||||
<button
|
||||
onClick={(e) => { e.stopPropagation(); onDelete(item.path); }}
|
||||
onClick={(e) => { e.stopPropagation(); onDelete(item.id); }}
|
||||
disabled={isDeleting}
|
||||
className="p-1.5 rounded bg-black/60 text-white hover:bg-red-600"
|
||||
title="Delete"
|
||||
|
||||
@ -41,6 +41,8 @@ type JWTClaims struct {
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
// Scopes are the permitted scopes
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
// SessionID links the token to a specific login session
|
||||
SessionID string `json:"sid,omitempty"`
|
||||
}
|
||||
|
||||
// JWTValidator validates JWT tokens.
|
||||
@ -119,6 +121,7 @@ func (v *JWTValidator) Validate(ctx context.Context, tokenString string) (*User,
|
||||
Email: claims.Email,
|
||||
Roles: claims.Roles,
|
||||
Scopes: claims.Scopes,
|
||||
Metadata: make(map[string]any),
|
||||
}
|
||||
|
||||
// Fallback to subject if no user ID
|
||||
@ -126,6 +129,11 @@ func (v *JWTValidator) Validate(ctx context.Context, tokenString string) (*User,
|
||||
user.ID = claims.Subject
|
||||
}
|
||||
|
||||
// Store session ID in metadata for downstream checks
|
||||
if claims.SessionID != "" {
|
||||
user.Metadata["sid"] = claims.SessionID
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@ -155,8 +163,80 @@ func GenerateToken(secret []byte, user *User, expiresIn time.Duration) (string,
|
||||
return token.SignedString(secret)
|
||||
}
|
||||
|
||||
// ValidateAllowExpired is like Validate but tolerates expired tokens.
|
||||
// It still validates the signature, issuer, and audience — only the expiry check is relaxed.
|
||||
// Used for token refresh: the caller presents an expired access token to prove identity,
|
||||
// and the service layer checks session validity separately.
|
||||
func (v *JWTValidator) ValidateAllowExpired(ctx context.Context, tokenString string) (*User, error) {
|
||||
// Use a generous leeway to accept expired tokens for refresh purposes.
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (any, error) {
|
||||
switch token.Method.(type) {
|
||||
case *jwt.SigningMethodHMAC:
|
||||
if v.secret == nil {
|
||||
return nil, fmt.Errorf("%w: HMAC secret not configured", ErrInvalidToken)
|
||||
}
|
||||
return v.secret, nil
|
||||
case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA:
|
||||
if v.publicKey == nil {
|
||||
return nil, fmt.Errorf("%w: public key not configured", ErrInvalidToken)
|
||||
}
|
||||
return v.publicKey, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unexpected signing method %v", ErrInvalidToken, token.Header["alg"])
|
||||
}
|
||||
}, jwt.WithLeeway(30*24*time.Hour)) // 30 days — matches session lifetime
|
||||
|
||||
if err != nil {
|
||||
// Even with generous leeway, signature failures still fail.
|
||||
return nil, fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, ErrInvalidClaims
|
||||
}
|
||||
|
||||
if v.issuer != "" && claims.Issuer != v.issuer {
|
||||
return nil, fmt.Errorf("%w: invalid issuer", ErrInvalidClaims)
|
||||
}
|
||||
if v.audience != "" {
|
||||
found := false
|
||||
for _, aud := range claims.Audience {
|
||||
if aud == v.audience {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("%w: invalid audience", ErrInvalidClaims)
|
||||
}
|
||||
}
|
||||
|
||||
user := &User{
|
||||
ID: claims.UserID,
|
||||
Email: claims.Email,
|
||||
Roles: claims.Roles,
|
||||
Scopes: claims.Scopes,
|
||||
Metadata: make(map[string]any),
|
||||
}
|
||||
if user.ID == "" {
|
||||
user.ID = claims.Subject
|
||||
}
|
||||
if claims.SessionID != "" {
|
||||
user.Metadata["sid"] = claims.SessionID
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GenerateTokenWithIssuer creates a new JWT token with issuer and audience claims.
|
||||
func GenerateTokenWithIssuer(secret []byte, user *User, expiresIn time.Duration, issuer, audience string) (string, error) {
|
||||
return GenerateTokenWithSession(secret, user, expiresIn, issuer, audience, "")
|
||||
}
|
||||
|
||||
// GenerateTokenWithSession creates a JWT token with a session ID embedded.
|
||||
// The session ID links the token to a login session for revocation support.
|
||||
func GenerateTokenWithSession(secret []byte, user *User, expiresIn time.Duration, issuer, audience, sessionID string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := JWTClaims{
|
||||
@ -172,6 +252,7 @@ func GenerateTokenWithIssuer(secret []byte, user *User, expiresIn time.Duration,
|
||||
Email: user.Email,
|
||||
Roles: user.Roles,
|
||||
Scopes: user.Scopes,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
@ -2,6 +2,7 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -19,6 +20,10 @@ type MiddlewareConfig struct {
|
||||
// Optional returns 401 only when a token is provided but invalid.
|
||||
// If no token is provided, the request continues without authentication.
|
||||
Optional bool
|
||||
// AllowExpired accepts expired tokens (still validates signature).
|
||||
// Use for token refresh endpoints where the caller presents an expired
|
||||
// access token to prove identity, and session validity is checked separately.
|
||||
AllowExpired bool
|
||||
// SkipPaths are paths that skip authentication entirely
|
||||
SkipPaths []string
|
||||
}
|
||||
@ -71,6 +76,14 @@ func Middleware(cfg MiddlewareConfig) func(http.Handler) http.Handler {
|
||||
|
||||
// Validate token
|
||||
user, err := cfg.Validator.Validate(r.Context(), token)
|
||||
if err != nil {
|
||||
// If AllowExpired is set and the token is expired (but signature valid),
|
||||
// re-validate with relaxed expiry for refresh flows.
|
||||
if cfg.AllowExpired && errors.Is(err, ErrExpiredToken) {
|
||||
if jwtVal, ok := cfg.Validator.(*JWTValidator); ok {
|
||||
user, err = jwtVal.ValidateAllowExpired(r.Context(), token)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if cfg.Optional {
|
||||
// Token invalid/expired but auth is optional — continue without user context.
|
||||
@ -81,6 +94,7 @@ func Middleware(cfg MiddlewareConfig) func(http.Handler) http.Handler {
|
||||
httpresponse.Unauthorized(w, r, "invalid credentials")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Store user and token in context
|
||||
ctx := SetUser(r.Context(), user)
|
||||
@ -237,3 +251,58 @@ func RequireScopeErr(ctx context.Context, scope string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionChecker is a function that checks whether a session is still active.
|
||||
// Returns true if the session is active, false if revoked/expired.
|
||||
// Implementations should query the session store.
|
||||
type SessionChecker func(ctx context.Context, sessionID string) (bool, error)
|
||||
|
||||
// SessionCheck middleware validates that the JWT's embedded session is still active.
|
||||
// It extracts the "sid" from the authenticated user's Metadata and calls the checker.
|
||||
// If the session has been revoked, the request is rejected with 401.
|
||||
//
|
||||
// This middleware must be applied AFTER auth.Middleware (which sets the user in context).
|
||||
// It is opt-in — services that don't need session revocation can skip it.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// checker := func(ctx context.Context, sid string) (bool, error) {
|
||||
// session, err := sessionRepo.Get(ctx, domain.SessionID(sid))
|
||||
// if err != nil { return false, nil }
|
||||
// return session.IsActive(), nil
|
||||
// }
|
||||
// r.Use(auth.SessionCheck(checker))
|
||||
func SessionCheck(checker SessionChecker) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := GetUser(r.Context())
|
||||
if user == nil {
|
||||
// No user in context — let downstream middleware handle it.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract session ID from JWT metadata.
|
||||
sid, _ := user.Metadata["sid"].(string)
|
||||
if sid == "" {
|
||||
// Token has no session ID (e.g., old token before sessions were added).
|
||||
// Allow through — backward compatible.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
active, err := checker(r.Context(), sid)
|
||||
if err != nil {
|
||||
// Session check failed — fail open is dangerous, fail closed.
|
||||
httpresponse.Unauthorized(w, r, "session validation failed")
|
||||
return
|
||||
}
|
||||
if !active {
|
||||
httpresponse.Unauthorized(w, r, "session has been revoked")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
// BcryptCost is the bcrypt hashing cost. 12 balances security and performance.
|
||||
BcryptCost = 12
|
||||
|
||||
// MinPasswordLength is the minimum allowed password length.
|
||||
MinPasswordLength = 8
|
||||
// MaxPasswordLength is the maximum allowed password length (bcrypt limit).
|
||||
MaxPasswordLength = 72
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPasswordTooShort is returned when a password is below the minimum length.
|
||||
ErrPasswordTooShort = errors.New("password must be at least 8 characters")
|
||||
// ErrPasswordTooLong is returned when a password exceeds the bcrypt limit.
|
||||
ErrPasswordTooLong = errors.New("password must be at most 72 characters")
|
||||
// ErrPasswordWeak is returned when a password lacks required character types.
|
||||
ErrPasswordWeak = errors.New("password must contain at least one uppercase letter, one lowercase letter, and one number")
|
||||
)
|
||||
|
||||
// HashPassword hashes a plaintext password using bcrypt.
|
||||
func HashPassword(password string) (string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// CheckPassword compares a plaintext password against a bcrypt hash.
|
||||
// Returns true if the password matches.
|
||||
func CheckPassword(password, hash string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
|
||||
}
|
||||
|
||||
// ValidatePasswordStrength checks that a password meets minimum complexity requirements.
|
||||
// Returns nil if the password is acceptable.
|
||||
func ValidatePasswordStrength(password string) error {
|
||||
if len(password) < MinPasswordLength {
|
||||
return ErrPasswordTooShort
|
||||
}
|
||||
if len(password) > MaxPasswordLength {
|
||||
return ErrPasswordTooLong
|
||||
}
|
||||
|
||||
var hasUpper, hasLower, hasDigit bool
|
||||
for _, r := range password {
|
||||
switch {
|
||||
case unicode.IsUpper(r):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(r):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(r):
|
||||
hasDigit = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasDigit {
|
||||
return ErrPasswordWeak
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,68 @@
|
||||
package auth
|
||||
|
||||
import "strings"
|
||||
|
||||
// ParseDeviceLabel extracts a human-readable device label from a user agent string.
|
||||
// Returns something like "Chrome on macOS", "Safari on iPhone", "Firefox on Windows".
|
||||
func ParseDeviceLabel(userAgent string) string {
|
||||
if userAgent == "" {
|
||||
return "Unknown device"
|
||||
}
|
||||
|
||||
browser := parseBrowser(userAgent)
|
||||
os := parseOS(userAgent)
|
||||
|
||||
if browser == "" && os == "" {
|
||||
return "Unknown device"
|
||||
}
|
||||
if browser == "" {
|
||||
return os
|
||||
}
|
||||
if os == "" {
|
||||
return browser
|
||||
}
|
||||
return browser + " on " + os
|
||||
}
|
||||
|
||||
func parseBrowser(ua string) string {
|
||||
// Order matters — check more specific before generic.
|
||||
switch {
|
||||
case strings.Contains(ua, "Edg/") || strings.Contains(ua, "Edge/"):
|
||||
return "Edge"
|
||||
case strings.Contains(ua, "OPR/") || strings.Contains(ua, "Opera"):
|
||||
return "Opera"
|
||||
case strings.Contains(ua, "Brave"):
|
||||
return "Brave"
|
||||
case strings.Contains(ua, "Vivaldi"):
|
||||
return "Vivaldi"
|
||||
case strings.Contains(ua, "Chrome/") && !strings.Contains(ua, "Chromium"):
|
||||
return "Chrome"
|
||||
case strings.Contains(ua, "Firefox/"):
|
||||
return "Firefox"
|
||||
case strings.Contains(ua, "Safari/") && !strings.Contains(ua, "Chrome"):
|
||||
return "Safari"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func parseOS(ua string) string {
|
||||
switch {
|
||||
case strings.Contains(ua, "iPhone"):
|
||||
return "iPhone"
|
||||
case strings.Contains(ua, "iPad"):
|
||||
return "iPad"
|
||||
case strings.Contains(ua, "Android"):
|
||||
return "Android"
|
||||
case strings.Contains(ua, "Mac OS X") || strings.Contains(ua, "Macintosh"):
|
||||
return "macOS"
|
||||
case strings.Contains(ua, "Windows"):
|
||||
return "Windows"
|
||||
case strings.Contains(ua, "Linux"):
|
||||
return "Linux"
|
||||
case strings.Contains(ua, "CrOS"):
|
||||
return "ChromeOS"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@ -350,7 +350,9 @@ func downloadURL(ctx context.Context, url string) ([]byte, error) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("download: status %d", resp.StatusCode)
|
||||
}
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
// Limit body to 500MB to prevent OOM from unexpected large responses.
|
||||
const maxBodySize = 500 << 20
|
||||
data, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@ require (
|
||||
github.com/redis/go-redis/v9 v9.7.0
|
||||
github.com/spf13/viper v1.19.0
|
||||
google.golang.org/api v0.192.0
|
||||
golang.org/x/crypto v0.21.0
|
||||
google.golang.org/genai v1.46.0
|
||||
)
|
||||
|
||||
@ -38,7 +39,6 @@ require (
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
go.uber.org/atomic v1.9.0 // indirect
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/crypto v0.21.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/net v0.23.0 // indirect
|
||||
golang.org/x/sys v0.18.0 // indirect
|
||||
|
||||
@ -0,0 +1,132 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimitConfig configures per-IP rate limiting.
|
||||
type RateLimitConfig struct {
|
||||
// Requests is the maximum number of requests allowed per window.
|
||||
Requests int
|
||||
// Window is the time window for the rate limit.
|
||||
Window time.Duration
|
||||
}
|
||||
|
||||
// ipEntry tracks request timestamps for a single IP.
|
||||
type ipEntry struct {
|
||||
timestamps []time.Time
|
||||
}
|
||||
|
||||
// rateLimiter implements a sliding window rate limiter.
|
||||
type rateLimiter struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*ipEntry
|
||||
config RateLimitConfig
|
||||
}
|
||||
|
||||
// RateLimit returns middleware that limits requests per IP using a sliding window.
|
||||
// When the limit is exceeded, it responds with 429 Too Many Requests.
|
||||
func RateLimit(cfg RateLimitConfig) func(http.Handler) http.Handler {
|
||||
rl := &rateLimiter{
|
||||
entries: make(map[string]*ipEntry),
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// Periodically evict stale entries to prevent unbounded growth.
|
||||
go rl.cleanup()
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := rateLimitClientIP(r)
|
||||
if !rl.allow(ip) {
|
||||
http.Error(w, `{"error":{"message":"too many requests, please try again later"}}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// allow checks if the IP is within its rate limit and records the request.
|
||||
func (rl *rateLimiter) allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-rl.config.Window)
|
||||
|
||||
entry, ok := rl.entries[ip]
|
||||
if !ok {
|
||||
entry = &ipEntry{}
|
||||
rl.entries[ip] = entry
|
||||
}
|
||||
|
||||
// Remove timestamps outside the window.
|
||||
valid := entry.timestamps[:0]
|
||||
for _, t := range entry.timestamps {
|
||||
if t.After(cutoff) {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
entry.timestamps = valid
|
||||
|
||||
if len(entry.timestamps) >= rl.config.Requests {
|
||||
return false
|
||||
}
|
||||
|
||||
entry.timestamps = append(entry.timestamps, now)
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanup removes stale entries every 5 minutes.
|
||||
func (rl *rateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-rl.config.Window)
|
||||
for ip, entry := range rl.entries {
|
||||
valid := entry.timestamps[:0]
|
||||
for _, t := range entry.timestamps {
|
||||
if t.After(cutoff) {
|
||||
valid = append(valid, t)
|
||||
}
|
||||
}
|
||||
if len(valid) == 0 {
|
||||
delete(rl.entries, ip)
|
||||
} else {
|
||||
entry.timestamps = valid
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// rateLimitClientIP extracts the client IP, trusting proxy headers only from private IPs.
|
||||
func rateLimitClientIP(r *http.Request) string {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
host = r.RemoteAddr
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && (ip.IsLoopback() || ip.IsPrivate()) {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.SplitN(xff, ",", 2)
|
||||
if fwd := strings.TrimSpace(parts[0]); fwd != "" {
|
||||
return fwd
|
||||
}
|
||||
}
|
||||
if xri := r.Header.Get("X-Real-Ip"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
@ -0,0 +1,184 @@
|
||||
// Package notify provides a Go client for the orchard9 notify service.
|
||||
//
|
||||
// The notify service handles email delivery with provider routing, retries,
|
||||
// delivery tracking, and suppression handling.
|
||||
//
|
||||
// Basic usage:
|
||||
//
|
||||
// client, err := notify.NewClient(notify.Config{
|
||||
// URL: os.Getenv("NOTIFY_URL"),
|
||||
// APIKey: os.Getenv("NOTIFY_API_KEY"),
|
||||
// })
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// resp, err := client.SendEmail(ctx, ¬ify.SendRequest{
|
||||
// To: "user@example.com",
|
||||
// From: "noreply@myapp.threesix.ai",
|
||||
// Content: notify.Content{Subject: "Hello", Text: "World"},
|
||||
// Meta: notify.Meta{Host: "myapp.threesix.ai"},
|
||||
// })
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"{{GO_MODULE}}/pkg/httpclient"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 30 * time.Second
|
||||
defaultMaxRetries = 3
|
||||
)
|
||||
|
||||
// Config holds configuration for the notify client.
|
||||
type Config struct {
|
||||
URL string // Required: base URL (e.g. "https://notify.threesix.ai")
|
||||
APIKey string // Required: send API key (notify_send_xxx)
|
||||
Timeout time.Duration // Optional: defaults to 30s
|
||||
MaxRetries int // Optional: defaults to 3
|
||||
Logger *slog.Logger // Optional: defaults to slog.Default()
|
||||
}
|
||||
|
||||
// Client is the notify API client.
|
||||
type Client struct {
|
||||
httpClient *httpclient.Client
|
||||
baseURL string
|
||||
apiKey string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewClient creates a new notify API client.
|
||||
func NewClient(config Config) (*Client, error) {
|
||||
if config.URL == "" {
|
||||
return nil, fmt.Errorf("%w: URL is required", ErrInvalidConfig)
|
||||
}
|
||||
if config.APIKey == "" {
|
||||
return nil, fmt.Errorf("%w: API key is required", ErrInvalidConfig)
|
||||
}
|
||||
|
||||
if _, err := url.Parse(config.URL); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid URL: %v", ErrInvalidConfig, err)
|
||||
}
|
||||
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = defaultTimeout
|
||||
}
|
||||
if config.MaxRetries == 0 {
|
||||
config.MaxRetries = defaultMaxRetries
|
||||
}
|
||||
if config.Logger == nil {
|
||||
config.Logger = slog.Default()
|
||||
}
|
||||
|
||||
return &Client{
|
||||
httpClient: httpclient.New(httpclient.Config{
|
||||
Timeout: config.Timeout,
|
||||
MaxRetries: config.MaxRetries,
|
||||
Logger: config.Logger,
|
||||
}),
|
||||
baseURL: config.URL,
|
||||
apiKey: config.APIKey,
|
||||
logger: config.Logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail sends an email through the notify service.
|
||||
// Returns the message ID and status on success (202 Accepted).
|
||||
func (c *Client) SendEmail(ctx context.Context, req *SendRequest) (*SendResponse, error) {
|
||||
respBody, err := c.doRequest(ctx, http.MethodPost, "/email", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp SendResponse
|
||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetMessage retrieves the full status of a sent message.
|
||||
func (c *Client) GetMessage(ctx context.Context, id string) (*Message, error) {
|
||||
respBody, err := c.doRequest(ctx, http.MethodGet, "/messages/"+id, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal(respBody, &msg); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// doRequest is a helper for making HTTP requests to the notify API.
|
||||
func (c *Client) doRequest(ctx context.Context, method, path string, bodyData interface{}) ([]byte, error) {
|
||||
var reqBody io.Reader
|
||||
if bodyData != nil {
|
||||
jsonBody, err := json.Marshal(bodyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewReader(jsonBody)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
if bodyData != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response body: %w", err)
|
||||
}
|
||||
|
||||
// 2xx = success
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
var errResp ErrorResponse
|
||||
baseErr := classifyError(resp.StatusCode, "")
|
||||
|
||||
if err := json.Unmarshal(respBody, &errResp); err != nil {
|
||||
return nil, &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Message: string(respBody),
|
||||
err: baseErr,
|
||||
}
|
||||
}
|
||||
|
||||
// Re-classify with the error code from the response.
|
||||
if errResp.Code != "" {
|
||||
baseErr = classifyError(resp.StatusCode, errResp.Code)
|
||||
}
|
||||
|
||||
return nil, &APIError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Message: errResp.Error,
|
||||
Code: errResp.Code,
|
||||
err: baseErr,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,69 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Sentinel errors for programmatic handling.
|
||||
var (
|
||||
ErrInvalidConfig = errors.New("invalid configuration")
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
ErrForbidden = errors.New("forbidden")
|
||||
ErrSuppressed = errors.New("recipient suppressed")
|
||||
ErrHostNotFound = errors.New("host not found")
|
||||
ErrFromNotFound = errors.New("from address not registered")
|
||||
ErrValidation = errors.New("validation error")
|
||||
ErrServerError = errors.New("server error")
|
||||
ErrRateLimit = errors.New("rate limit exceeded")
|
||||
)
|
||||
|
||||
// APIError represents an error returned by the notify API.
|
||||
type APIError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
Code string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
if e.Code != "" {
|
||||
return fmt.Sprintf("notify api error (status %d): [%s] %s", e.StatusCode, e.Code, e.Message)
|
||||
}
|
||||
return fmt.Sprintf("notify api error (status %d): %s", e.StatusCode, e.Message)
|
||||
}
|
||||
|
||||
func (e *APIError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// classifyError maps an HTTP status and optional error code to a sentinel error.
|
||||
func classifyError(statusCode int, code string) error {
|
||||
// Check specific error codes first.
|
||||
switch code {
|
||||
case "suppressed":
|
||||
return ErrSuppressed
|
||||
case "host_not_found":
|
||||
return ErrHostNotFound
|
||||
case "from_not_found":
|
||||
return ErrFromNotFound
|
||||
}
|
||||
|
||||
// Fall back to HTTP status.
|
||||
switch {
|
||||
case statusCode == 401:
|
||||
return ErrUnauthorized
|
||||
case statusCode == 403:
|
||||
return ErrForbidden
|
||||
case statusCode == 422:
|
||||
return ErrValidation
|
||||
case statusCode == 429:
|
||||
return ErrRateLimit
|
||||
case statusCode >= 400 && statusCode < 500:
|
||||
return ErrValidation
|
||||
case statusCode >= 500:
|
||||
return ErrServerError
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,57 @@
|
||||
package notify
|
||||
|
||||
// SendRequest is the payload for POST /email.
|
||||
type SendRequest struct {
|
||||
To string `json:"to"`
|
||||
From string `json:"from"`
|
||||
ReplyTo string `json:"reply_to,omitempty"`
|
||||
CC string `json:"cc,omitempty"`
|
||||
BCC string `json:"bcc,omitempty"`
|
||||
Content Content `json:"content"`
|
||||
Meta Meta `json:"meta"`
|
||||
Options Options `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
// Content holds the email subject and body.
|
||||
type Content struct {
|
||||
Subject string `json:"subject"`
|
||||
HTML string `json:"html,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// Meta contains routing metadata for the notify service.
|
||||
type Meta struct {
|
||||
Host string `json:"host"`
|
||||
Category string `json:"category,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// Options contains delivery options.
|
||||
type Options struct {
|
||||
IdempotencyKey string `json:"idempotency_key,omitempty"`
|
||||
}
|
||||
|
||||
// SendResponse is returned by POST /email (202 Accepted).
|
||||
type SendResponse struct {
|
||||
MessageID string `json:"message_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// Message is the full status of a sent message (GET /messages/{id}).
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
To string `json:"to"`
|
||||
From string `json:"from"`
|
||||
Subject string `json:"subject"`
|
||||
Status string `json:"status"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
SentAt string `json:"sent_at,omitempty"`
|
||||
}
|
||||
|
||||
// ErrorResponse is the error envelope from the notify API.
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Details map[string]string `json:"details,omitempty"`
|
||||
}
|
||||
@ -12,9 +12,10 @@ import (
|
||||
|
||||
// MemoryQueue is an in-memory job queue that dispatches jobs to registered handlers
|
||||
// in goroutines. Use this for local development when no database is available.
|
||||
// Implements Producer so the service handlers can enqueue jobs without caring about the backend.
|
||||
// Implements Producer and JobReader so the service handlers can enqueue and query jobs.
|
||||
type MemoryQueue struct {
|
||||
handlers map[string]Handler
|
||||
jobs map[string]*Job
|
||||
mu sync.RWMutex
|
||||
logger *slog.Logger
|
||||
}
|
||||
@ -26,6 +27,7 @@ func NewMemoryQueue(logger *slog.Logger) *MemoryQueue {
|
||||
}
|
||||
return &MemoryQueue{
|
||||
handlers: make(map[string]Handler),
|
||||
jobs: make(map[string]*Job),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@ -55,6 +57,7 @@ func (q *MemoryQueue) EnqueueWithOptions(_ context.Context, job Job) (string, er
|
||||
if job.CreatedAt.IsZero() {
|
||||
job.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
job.Status = StatusPending
|
||||
|
||||
q.mu.RLock()
|
||||
handler, ok := q.handlers[job.Type]
|
||||
@ -64,17 +67,35 @@ func (q *MemoryQueue) EnqueueWithOptions(_ context.Context, job Job) (string, er
|
||||
return "", fmt.Errorf("no handler registered for job type %q", job.Type)
|
||||
}
|
||||
|
||||
// Store job for status tracking.
|
||||
q.mu.Lock()
|
||||
q.jobs[job.ID] = &job
|
||||
q.mu.Unlock()
|
||||
|
||||
q.logger.Info("dispatching in-memory job", "job_id", job.ID, "job_type", job.Type)
|
||||
|
||||
// Process in background goroutine (mirrors worker behavior).
|
||||
go func() {
|
||||
q.mu.Lock()
|
||||
job.Status = StatusRunning
|
||||
now := time.Now().UTC()
|
||||
job.StartedAt = &now
|
||||
q.mu.Unlock()
|
||||
|
||||
if err := handler(context.Background(), &job); err != nil {
|
||||
q.mu.Lock()
|
||||
job.Status = StatusFailed
|
||||
job.Error = err.Error()
|
||||
completed := time.Now().UTC()
|
||||
job.CompletedAt = &completed
|
||||
q.mu.Unlock()
|
||||
q.logger.Error("in-memory job failed", "job_id", job.ID, "job_type", job.Type, "error", err)
|
||||
} else {
|
||||
q.mu.Lock()
|
||||
job.Status = StatusCompleted
|
||||
completed := time.Now().UTC()
|
||||
job.CompletedAt = &completed
|
||||
q.mu.Unlock()
|
||||
q.logger.Info("in-memory job completed", "job_id", job.ID, "job_type", job.Type)
|
||||
}
|
||||
}()
|
||||
@ -82,5 +103,21 @@ func (q *MemoryQueue) EnqueueWithOptions(_ context.Context, job Job) (string, er
|
||||
return job.ID, nil
|
||||
}
|
||||
|
||||
// Compile-time check that MemoryQueue implements Producer.
|
||||
var _ Producer = (*MemoryQueue)(nil)
|
||||
// GetJob returns a job by ID. Returns ErrJobNotFound if the job doesn't exist.
|
||||
func (q *MemoryQueue) GetJob(_ context.Context, jobID string) (*Job, error) {
|
||||
q.mu.RLock()
|
||||
defer q.mu.RUnlock()
|
||||
|
||||
job, ok := q.jobs[jobID]
|
||||
if !ok {
|
||||
return nil, ErrJobNotFound
|
||||
}
|
||||
cp := *job
|
||||
return &cp, nil
|
||||
}
|
||||
|
||||
// Compile-time checks.
|
||||
var (
|
||||
_ Producer = (*MemoryQueue)(nil)
|
||||
_ JobReader = (*MemoryQueue)(nil)
|
||||
)
|
||||
|
||||
@ -32,8 +32,11 @@ type DBQueue struct {
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// Ensure DBQueue implements Queue at compile time.
|
||||
var _ Queue = (*DBQueue)(nil)
|
||||
// Ensure DBQueue implements Queue and JobReader at compile time.
|
||||
var (
|
||||
_ Queue = (*DBQueue)(nil)
|
||||
_ JobReader = (*DBQueue)(nil)
|
||||
)
|
||||
|
||||
// NewQueue creates a queue backed by a SQL database (PostgreSQL or CockroachDB).
|
||||
func NewQueue(db *sqlx.DB, logger *logging.Logger) *DBQueue {
|
||||
|
||||
@ -100,6 +100,13 @@ type Queue interface {
|
||||
Consumer
|
||||
}
|
||||
|
||||
// JobReader provides read-only access to job status.
|
||||
// Used by handlers to expose job status via API without requiring full queue access.
|
||||
type JobReader interface {
|
||||
// GetJob returns a job by ID. Returns ErrJobNotFound if the job doesn't exist.
|
||||
GetJob(ctx context.Context, jobID string) (*Job, error)
|
||||
}
|
||||
|
||||
// Handler processes a single job.
|
||||
// Return nil for success, error for failure (triggers retry if attempts remain).
|
||||
type Handler func(ctx context.Context, job *Job) error
|
||||
|
||||
@ -20,13 +20,13 @@ type GCSStore struct {
|
||||
|
||||
// NewGCSStore creates a GCS-backed store.
|
||||
// credentialsJSON may be empty to use Application Default Credentials.
|
||||
func NewGCSStore(bucket string, credentialsJSON string, logger *slog.Logger) (*GCSStore, error) {
|
||||
func NewGCSStore(ctx context.Context, bucket string, credentialsJSON string, logger *slog.Logger) (*GCSStore, error) {
|
||||
var opts []option.ClientOption
|
||||
if credentialsJSON != "" {
|
||||
opts = append(opts, option.WithCredentialsJSON([]byte(credentialsJSON)))
|
||||
}
|
||||
|
||||
client, err := gcsstorage.NewClient(context.Background(), opts...)
|
||||
client, err := gcsstorage.NewClient(ctx, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: failed to create GCS client: %w", err)
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -88,7 +89,9 @@ func (s *MemoryStore) List(_ context.Context, prefix string) ([]MediaObject, err
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ServeHTTP serves stored objects and accepts PUT uploads (for dev presigned URL flow).
|
||||
// ServeHTTP serves stored objects and accepts PUT uploads for the dev presigned URL flow.
|
||||
// This handler is dev-only — in production, clients upload directly to GCS via presigned URLs.
|
||||
// PUT requests are limited to 100MB to prevent accidental OOM in development.
|
||||
// Mount at /storage/ in the application router.
|
||||
func (s *MemoryStore) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip /storage/ prefix to get the object path.
|
||||
@ -112,8 +115,16 @@ func (s *MemoryStore) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write(obj.data)
|
||||
|
||||
case http.MethodPut:
|
||||
// Limit upload size to 100MB (dev mode only — production uses GCS presigned URLs with their own limits).
|
||||
const maxUploadSize = 100 << 20
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize)
|
||||
data, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) {
|
||||
http.Error(w, "file too large (max 100MB)", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
http.Error(w, "read body failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@ -6,6 +6,8 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
@ -173,10 +175,11 @@ func RequireProjectAccess(projectIDParam string) func(http.Handler) http.Handler
|
||||
return
|
||||
}
|
||||
|
||||
// Get project ID from URL
|
||||
// Using chi's URLParam would require importing chi here
|
||||
// Instead, we'll extract from path in the handler
|
||||
// This middleware just validates the key has project restrictions
|
||||
projectID := domain.ProjectID(chi.URLParam(r, projectIDParam))
|
||||
if !apiKey.HasProjectAccess(projectID) {
|
||||
api.WriteForbidden(w, r, "Access denied to this project")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
)
|
||||
|
||||
@ -113,3 +114,13 @@ func (s *Service) Get(ctx context.Context, id string) (*APIKey, error) {
|
||||
func (s *Service) Revoke(ctx context.Context, id string) error {
|
||||
return s.svc.Revoke(ctx, domain.APIKeyID(id))
|
||||
}
|
||||
|
||||
// Update applies a partial update to an API key.
|
||||
func (s *Service) Update(ctx context.Context, id string, update port.APIKeyUpdate) error {
|
||||
return s.svc.Update(ctx, domain.APIKeyID(id), update)
|
||||
}
|
||||
|
||||
// ListByProjectID returns all active keys that have the given project ID in their project_ids.
|
||||
func (s *Service) ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*APIKey, error) {
|
||||
return s.svc.ListByProjectID(ctx, projectID)
|
||||
}
|
||||
|
||||
@ -73,4 +73,8 @@ const (
|
||||
CredKeyNotifyAPIKey = "NOTIFY_API_KEY"
|
||||
CredKeyNotifyHost = "NOTIFY_HOST"
|
||||
CredKeyNotifyFrom = "NOTIFY_FROM"
|
||||
CredKeyNotifyResendDomainID = "NOTIFY_RESEND_DOMAIN_ID"
|
||||
|
||||
// Resend (email provider for per-project domain provisioning)
|
||||
CredKeyResendAPIKey = "RESEND_API_KEY"
|
||||
)
|
||||
|
||||
@ -14,12 +14,15 @@ type NotifyCredentials struct {
|
||||
// APIKey is the notify send key (notify_send_...) for sending emails.
|
||||
APIKey string
|
||||
|
||||
// Host is the shared sending host (e.g., "threesix.ai").
|
||||
// Host is the per-project sending host (e.g., "mail.{slug}.threesix.ai").
|
||||
Host string
|
||||
|
||||
// From is the from-address for outgoing email (e.g., "noreply@threesix.ai").
|
||||
// From is the from-address for outgoing email (e.g., "noreply@mail.{slug}.threesix.ai").
|
||||
From string
|
||||
|
||||
// ResendDomainID is the Resend domain UUID (used for deletion).
|
||||
ResendDomainID string
|
||||
|
||||
// CreatedAt is when the credentials were provisioned.
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
@ -31,12 +31,12 @@ func NewBuildsHandler(buildService *service.BuildService) *BuildsHandler {
|
||||
// Mount registers the build routes.
|
||||
func (h *BuildsHandler) Mount(r api.Router) {
|
||||
// Project-scoped build endpoints
|
||||
r.With(auth.RequireScope(auth.ScopeBuildWrite, auth.ScopeAdmin)).
|
||||
r.With(auth.RequireScope(auth.ScopeBuildWrite, auth.ScopeAdmin), auth.RequireProjectAccess("id")).
|
||||
Post("/projects/{id}/builds", h.StartBuild)
|
||||
r.With(auth.RequireScope(auth.ScopeBuildRead, auth.ScopeAdmin)).
|
||||
r.With(auth.RequireScope(auth.ScopeBuildRead, auth.ScopeAdmin), auth.RequireProjectAccess("id")).
|
||||
Get("/projects/{id}/builds", h.ListBuilds)
|
||||
|
||||
// Build detail by task ID
|
||||
// Build detail by task ID (no project ID in URL, no project access check needed)
|
||||
r.With(auth.RequireScope(auth.ScopeBuildRead, auth.ScopeAdmin)).
|
||||
Get("/builds/{taskId}", h.GetBuild)
|
||||
}
|
||||
|
||||
@ -28,6 +28,8 @@ func NewCheckoutHandler(checkoutService *service.CheckoutService) *CheckoutHandl
|
||||
// Mount registers the checkout routes.
|
||||
func (h *CheckoutHandler) Mount(r api.Router) {
|
||||
r.Route("/projects/{id}/checkout", func(r chi.Router) {
|
||||
r.Use(auth.RequireProjectAccess("id"))
|
||||
|
||||
// Branch listing (read access)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsRead, auth.ScopeAdmin)).
|
||||
Get("/branches", h.ListBranches)
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/logging"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/internal/validate"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
@ -18,6 +19,7 @@ import (
|
||||
type CreateAndBuildHandler struct {
|
||||
infraService *service.ProjectInfraService
|
||||
buildService *service.BuildService
|
||||
authService *auth.Service
|
||||
}
|
||||
|
||||
// NewCreateAndBuildHandler creates a new create-and-build handler.
|
||||
@ -31,6 +33,12 @@ func NewCreateAndBuildHandler(
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthService sets an auth service for auto-granting project access to the creating key.
|
||||
func (h *CreateAndBuildHandler) WithAuthService(authService *auth.Service) *CreateAndBuildHandler {
|
||||
h.authService = authService
|
||||
return h
|
||||
}
|
||||
|
||||
// Mount registers the create-and-build route.
|
||||
func (h *CreateAndBuildHandler) Mount(r api.Router) {
|
||||
// Requires both project execute (create) and build write (start build)
|
||||
@ -52,6 +60,9 @@ type CreateAndBuildRequest struct {
|
||||
AutoCommit bool `json:"auto_commit"`
|
||||
AutoPush bool `json:"auto_push"`
|
||||
CallbackURL string `json:"callback_url,omitempty"`
|
||||
|
||||
// Access control: additional key IDs to grant access to the new project
|
||||
GrantToKeyIDs []string `json:"grant_to_key_ids,omitempty"`
|
||||
}
|
||||
|
||||
// CreateAndBuildResponse is the response for POST /project/create-and-build.
|
||||
@ -127,6 +138,42 @@ func (h *CreateAndBuildHandler) CreateAndBuild(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-grant: if creating key is restricted (non-admin with explicit project_ids), add the new project
|
||||
if h.authService != nil {
|
||||
log := logging.FromContext(ctx).WithHandler("CreateAndBuild")
|
||||
if apiKey := auth.GetAPIKey(ctx); apiKey != nil &&
|
||||
!apiKey.HasScope(domain.ScopeAdmin) && apiKey.ProjectIDs != nil {
|
||||
newIDs := append(apiKey.ProjectIDs, domain.ProjectID(projectResult.ProjectID))
|
||||
if err := h.authService.Update(ctx, string(apiKey.ID), port.APIKeyUpdate{ProjectIDs: &newIDs}); err != nil {
|
||||
log.Warn("failed to auto-grant creating key access to new project",
|
||||
logging.FieldError, err.Error(),
|
||||
logging.FieldProjectID, projectResult.ProjectID,
|
||||
)
|
||||
// non-fatal: project still usable, admin can grant access manually
|
||||
}
|
||||
}
|
||||
|
||||
// Grant to additional key IDs specified in request
|
||||
for _, keyID := range req.GrantToKeyIDs {
|
||||
key, err := h.authService.Get(ctx, keyID)
|
||||
if err != nil || key == nil || !key.IsActive() {
|
||||
log.Warn("failed to grant access: key not found or inactive", "key_id", keyID)
|
||||
continue
|
||||
}
|
||||
// Unrestricted or admin keys already have access
|
||||
if key.ProjectIDs == nil || key.HasScope(domain.ScopeAdmin) {
|
||||
continue
|
||||
}
|
||||
newIDs := append(key.ProjectIDs, domain.ProjectID(projectResult.ProjectID))
|
||||
if err := h.authService.Update(ctx, keyID, port.APIKeyUpdate{ProjectIDs: &newIDs}); err != nil {
|
||||
log.Warn("failed to grant access to key",
|
||||
"key_id", keyID,
|
||||
logging.FieldError, err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Enqueue the build task
|
||||
spec := domain.BuildSpec{
|
||||
Prompt: req.Prompt,
|
||||
|
||||
@ -4,9 +4,12 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/validate"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
@ -28,6 +31,7 @@ func (h *KeysHandler) Mount(r api.Router) {
|
||||
r.With(auth.RequireScope(auth.ScopeKeysRead, auth.ScopeAdmin)).Get("/", h.List)
|
||||
r.With(auth.RequireScope(auth.ScopeKeysWrite, auth.ScopeAdmin)).Post("/", h.Create)
|
||||
r.With(auth.RequireScope(auth.ScopeKeysRead, auth.ScopeAdmin)).Get("/{id}", h.Get)
|
||||
r.With(auth.RequireScope(auth.ScopeKeysWrite, auth.ScopeAdmin)).Patch("/{id}", h.Update)
|
||||
r.With(auth.RequireScope(auth.ScopeKeysWrite, auth.ScopeAdmin)).Delete("/{id}", h.Revoke)
|
||||
})
|
||||
}
|
||||
@ -239,3 +243,97 @@ func (h *KeysHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
"id": id,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateKeyRequest is the JSON body for PATCH /keys/{id}.
|
||||
// A null JSON value for project_ids or allowed_ips sets them to unrestricted.
|
||||
type UpdateKeyRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ProjectIDs *[]string `json:"project_ids"` // null = unrestricted; array = restrict to these projects
|
||||
AllowedIPs *[]string `json:"allowed_ips"` // null = no restriction; array = restrict to these IPs
|
||||
ExpiresIn *string `json:"expires_in,omitempty"` // "30d", "60d", "90d", "1y", "never"
|
||||
}
|
||||
|
||||
// Update modifies a mutable API key fields.
|
||||
// PATCH /keys/{id}
|
||||
func (h *KeysHandler) Update(w http.ResponseWriter, r *http.Request) {
|
||||
id := chi.URLParam(r, "id")
|
||||
|
||||
var req UpdateKeyRequest
|
||||
if err := api.DecodeJSON(r, &req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
update := port.APIKeyUpdate{}
|
||||
|
||||
if req.Name != nil {
|
||||
if *req.Name == "" {
|
||||
api.WriteBadRequest(w, r, "name cannot be empty")
|
||||
return
|
||||
}
|
||||
update.Name = req.Name
|
||||
}
|
||||
|
||||
if req.Scopes != nil {
|
||||
scopes := auth.ScopesFromStrings(req.Scopes)
|
||||
if !auth.ValidateScopes(scopes) {
|
||||
api.WriteBadRequest(w, r, "invalid scope(s)")
|
||||
return
|
||||
}
|
||||
update.Scopes = scopes
|
||||
}
|
||||
|
||||
if req.ProjectIDs != nil {
|
||||
pids := make([]domain.ProjectID, len(*req.ProjectIDs))
|
||||
for i, s := range *req.ProjectIDs {
|
||||
pids[i] = domain.ProjectID(s)
|
||||
}
|
||||
update.ProjectIDs = &pids
|
||||
}
|
||||
|
||||
if req.AllowedIPs != nil {
|
||||
for _, cidr := range *req.AllowedIPs {
|
||||
if err := validateCIDROrIP(cidr); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid allowed_ips: "+cidr+" is not a valid CIDR or IP address")
|
||||
return
|
||||
}
|
||||
}
|
||||
update.AllowedIPs = req.AllowedIPs
|
||||
}
|
||||
|
||||
if req.ExpiresIn != nil {
|
||||
expiresIn, err := auth.ParseExpiration(*req.ExpiresIn)
|
||||
if err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
if expiresIn == 0 {
|
||||
// "never" — remove expiry
|
||||
var nilTime *time.Time
|
||||
update.ExpiresAt = &nilTime
|
||||
} else {
|
||||
t := time.Now().Add(expiresIn)
|
||||
tPtr := &t
|
||||
update.ExpiresAt = &tPtr
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.authService.Update(r.Context(), id, update); err != nil {
|
||||
if errors.Is(err, auth.ErrKeyNotFound) {
|
||||
api.WriteNotFound(w, r, "Key not found")
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "Failed to update key")
|
||||
return
|
||||
}
|
||||
|
||||
// Return updated key
|
||||
key, err := h.authService.Get(r.Context(), id)
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "Failed to fetch updated key")
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, apiKeyToResponse(key))
|
||||
}
|
||||
|
||||
99
internal/handlers/me.go
Normal file
99
internal/handlers/me.go
Normal file
@ -0,0 +1,99 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/service"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// MeHandler handles the /me endpoint.
|
||||
type MeHandler struct {
|
||||
authService *auth.Service
|
||||
projectService *service.ProjectService
|
||||
}
|
||||
|
||||
// NewMeHandler creates a new me handler.
|
||||
func NewMeHandler(authService *auth.Service, projectService *service.ProjectService) *MeHandler {
|
||||
return &MeHandler{
|
||||
authService: authService,
|
||||
projectService: projectService,
|
||||
}
|
||||
}
|
||||
|
||||
// Mount registers the /me route.
|
||||
func (h *MeHandler) Mount(r api.Router) {
|
||||
r.Get("/me", h.Get)
|
||||
}
|
||||
|
||||
// MeResponse is the JSON response for GET /me.
|
||||
type MeResponse struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
KeyPrefix string `json:"key_prefix"`
|
||||
Scopes []string `json:"scopes"`
|
||||
ProjectAccess string `json:"project_access"` // "unrestricted" | "restricted"
|
||||
Projects []ProjectSummary `json:"projects,omitempty"`
|
||||
AllowedIPs []string `json:"allowed_ips,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExpiresAt *string `json:"expires_at,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
// ProjectSummary is a lightweight project view for embedding in /me.
|
||||
type ProjectSummary struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// Get returns the current key's identity, scopes, and project access.
|
||||
// GET /me
|
||||
func (h *MeHandler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), TimeoutFastLookup)
|
||||
defer cancel()
|
||||
|
||||
apiKey := auth.GetAPIKey(ctx)
|
||||
if apiKey == nil {
|
||||
api.WriteUnauthorized(w, r, "Not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
resp := MeResponse{
|
||||
ID: string(apiKey.ID),
|
||||
Name: apiKey.Name,
|
||||
KeyPrefix: apiKey.KeyPrefix,
|
||||
Scopes: auth.ScopesToStrings(apiKey.Scopes),
|
||||
ProjectAccess: "unrestricted",
|
||||
AllowedIPs: apiKey.AllowedIPs,
|
||||
CreatedAt: apiKey.CreatedAt.Format(time.RFC3339),
|
||||
Active: apiKey.IsActive(),
|
||||
}
|
||||
|
||||
// Populate projects list when key is restricted (non-admin with explicit project_ids)
|
||||
if apiKey.ProjectIDs != nil && !apiKey.HasScope(domain.ScopeAdmin) {
|
||||
resp.ProjectAccess = "restricted"
|
||||
if h.projectService != nil {
|
||||
projects, _ := h.projectService.List(ctx, apiKey.ProjectIDs)
|
||||
resp.Projects = make([]ProjectSummary, len(projects))
|
||||
for i, p := range projects {
|
||||
resp.Projects[i] = ProjectSummary{
|
||||
ID: string(p.ID),
|
||||
Name: p.Name,
|
||||
Status: string(p.Status),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiKey.ExpiresAt != nil {
|
||||
s := apiKey.ExpiresAt.Format(time.RFC3339)
|
||||
resp.ExpiresAt = &s
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, resp)
|
||||
}
|
||||
205
internal/handlers/project_access.go
Normal file
205
internal/handlers/project_access.go
Normal file
@ -0,0 +1,205 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/orchard9/rdev/internal/auth"
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
"github.com/orchard9/rdev/internal/validate"
|
||||
"github.com/orchard9/rdev/pkg/api"
|
||||
)
|
||||
|
||||
// ProjectAccessHandler handles project-centric key access management endpoints.
|
||||
type ProjectAccessHandler struct {
|
||||
authService *auth.Service
|
||||
}
|
||||
|
||||
// NewProjectAccessHandler creates a new project access handler.
|
||||
func NewProjectAccessHandler(authService *auth.Service) *ProjectAccessHandler {
|
||||
return &ProjectAccessHandler{authService: authService}
|
||||
}
|
||||
|
||||
// Mount registers the project access routes.
|
||||
func (h *ProjectAccessHandler) Mount(r api.Router) {
|
||||
r.Route("/projects/{id}/access", func(r chi.Router) {
|
||||
r.With(auth.RequireScope(auth.ScopeAdmin)).Get("/", h.List)
|
||||
r.With(auth.RequireScope(auth.ScopeAdmin)).Post("/", h.Grant)
|
||||
r.With(auth.RequireScope(auth.ScopeAdmin)).Delete("/{keyId}", h.Revoke)
|
||||
})
|
||||
}
|
||||
|
||||
// ProjectAccessResponse is the response for GET /projects/{id}/access.
|
||||
type ProjectAccessResponse struct {
|
||||
ProjectID string `json:"project_id"`
|
||||
Keys []KeyResponse `json:"keys"`
|
||||
UnrestrictedCount int `json:"unrestricted_keys"`
|
||||
}
|
||||
|
||||
// GrantAccessRequest is the JSON body for POST /projects/{id}/access.
|
||||
type GrantAccessRequest struct {
|
||||
KeyID string `json:"key_id"`
|
||||
}
|
||||
|
||||
// List returns all keys with access to a project.
|
||||
// GET /projects/{id}/access
|
||||
func (h *ProjectAccessHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), TimeoutFastLookup)
|
||||
defer cancel()
|
||||
|
||||
// Keys explicitly granted this project
|
||||
keys, err := h.authService.ListByProjectID(ctx, domain.ProjectID(projectID))
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "failed to list project access")
|
||||
return
|
||||
}
|
||||
|
||||
// Count unrestricted keys (nil project_ids) from all keys
|
||||
allKeys, err := h.authService.List(ctx)
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "failed to list keys")
|
||||
return
|
||||
}
|
||||
unrestrictedCount := 0
|
||||
for _, k := range allKeys {
|
||||
if k.ProjectIDs == nil && k.IsActive() {
|
||||
unrestrictedCount++
|
||||
}
|
||||
}
|
||||
|
||||
keyResponses := make([]KeyResponse, len(keys))
|
||||
for i, k := range keys {
|
||||
keyResponses[i] = apiKeyToResponse(k)
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, ProjectAccessResponse{
|
||||
ProjectID: projectID,
|
||||
Keys: keyResponses,
|
||||
UnrestrictedCount: unrestrictedCount,
|
||||
})
|
||||
}
|
||||
|
||||
// Grant adds a project to a key's project_ids list.
|
||||
// POST /projects/{id}/access
|
||||
func (h *ProjectAccessHandler) Grant(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
|
||||
var req GrantAccessRequest
|
||||
if err := api.DecodeJSON(r, &req); err != nil {
|
||||
api.WriteBadRequest(w, r, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
v := validate.New()
|
||||
v.Required(req.KeyID, "key_id")
|
||||
if err := v.Error(); err != nil {
|
||||
api.WriteBadRequest(w, r, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), TimeoutStandard)
|
||||
defer cancel()
|
||||
|
||||
key, err := h.authService.Get(ctx, req.KeyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrKeyNotFound) {
|
||||
api.WriteNotFound(w, r, "key not found")
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "failed to get key")
|
||||
return
|
||||
}
|
||||
|
||||
if !key.IsActive() {
|
||||
api.WriteBadRequest(w, r, "key is not active")
|
||||
return
|
||||
}
|
||||
|
||||
// Unrestricted keys already have access
|
||||
if key.ProjectIDs == nil {
|
||||
api.WriteBadRequest(w, r, "key already has unrestricted access to all projects")
|
||||
return
|
||||
}
|
||||
|
||||
// Admin-scoped keys already have full access
|
||||
if key.HasScope(domain.ScopeAdmin) {
|
||||
api.WriteBadRequest(w, r, "key with admin scope already has full access")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if already granted
|
||||
pid := domain.ProjectID(projectID)
|
||||
for _, existing := range key.ProjectIDs {
|
||||
if existing == pid {
|
||||
api.WriteSuccess(w, r, map[string]string{
|
||||
"status": "already_granted",
|
||||
"project_id": projectID,
|
||||
"key_id": req.KeyID,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Append and update
|
||||
newIDs := append(key.ProjectIDs, pid)
|
||||
if err := h.authService.Update(ctx, req.KeyID, port.APIKeyUpdate{ProjectIDs: &newIDs}); err != nil {
|
||||
api.WriteInternalError(w, r, "failed to grant access")
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]string{
|
||||
"status": "granted",
|
||||
"project_id": projectID,
|
||||
"key_id": req.KeyID,
|
||||
})
|
||||
}
|
||||
|
||||
// Revoke removes a project from a key's project_ids list.
|
||||
// DELETE /projects/{id}/access/{keyId}
|
||||
func (h *ProjectAccessHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
projectID := chi.URLParam(r, "id")
|
||||
keyID := chi.URLParam(r, "keyId")
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), TimeoutStandard)
|
||||
defer cancel()
|
||||
|
||||
key, err := h.authService.Get(ctx, keyID)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrKeyNotFound) {
|
||||
api.WriteNotFound(w, r, "key not found")
|
||||
return
|
||||
}
|
||||
api.WriteInternalError(w, r, "failed to get key")
|
||||
return
|
||||
}
|
||||
|
||||
if key.ProjectIDs == nil {
|
||||
api.WriteBadRequest(w, r, "key has unrestricted access; use PATCH /keys/{id} to restrict it first")
|
||||
return
|
||||
}
|
||||
|
||||
// Filter out the project ID
|
||||
pid := domain.ProjectID(projectID)
|
||||
newIDs := make([]domain.ProjectID, 0, len(key.ProjectIDs))
|
||||
for _, existing := range key.ProjectIDs {
|
||||
if existing != pid {
|
||||
newIDs = append(newIDs, existing)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.authService.Update(ctx, keyID, port.APIKeyUpdate{ProjectIDs: &newIDs}); err != nil {
|
||||
api.WriteInternalError(w, r, "failed to revoke access")
|
||||
return
|
||||
}
|
||||
|
||||
api.WriteSuccess(w, r, map[string]string{
|
||||
"status": "revoked",
|
||||
"project_id": projectID,
|
||||
"key_id": keyID,
|
||||
})
|
||||
}
|
||||
@ -50,15 +50,15 @@ func NewProjectsHandlerWithService(projectService *service.ProjectService) *Proj
|
||||
// Mount registers the projects routes.
|
||||
func (h *ProjectsHandler) Mount(r api.Router) {
|
||||
r.Route("/projects", func(r chi.Router) {
|
||||
// Read operations
|
||||
// Read operations — List does not use RequireProjectAccess (filtered in service layer)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsRead, auth.ScopeAdmin)).Get("/", h.List)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsRead, auth.ScopeAdmin)).Get("/{id}", h.Get)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsRead, auth.ScopeAdmin), auth.RequireProjectAccess("id")).Get("/{id}", h.Get)
|
||||
|
||||
// Execute operations
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin)).Post("/{id}/claude", h.RunClaude)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin)).Post("/{id}/shell", h.RunShell)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin)).Post("/{id}/git", h.RunGit)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin)).Get("/{id}/events", h.Events)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin), auth.RequireProjectAccess("id")).Post("/{id}/claude", h.RunClaude)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin), auth.RequireProjectAccess("id")).Post("/{id}/shell", h.RunShell)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin), auth.RequireProjectAccess("id")).Post("/{id}/git", h.RunGit)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsExecute, auth.ScopeAdmin), auth.RequireProjectAccess("id")).Get("/{id}/events", h.Events)
|
||||
})
|
||||
}
|
||||
|
||||
@ -108,7 +108,7 @@ func getClientIP(r *http.Request) string {
|
||||
return addr
|
||||
}
|
||||
|
||||
// List returns all available projects.
|
||||
// List returns available projects, filtered to the key's allowed projects.
|
||||
// GET /projects
|
||||
func (h *ProjectsHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), TimeoutFastLookup)
|
||||
@ -116,7 +116,12 @@ func (h *ProjectsHandler) List(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Use new service if available
|
||||
if h.projectService != nil {
|
||||
projects, err := h.projectService.List(ctx)
|
||||
// Determine allowed project IDs from the API key (nil = unrestricted)
|
||||
var allowedIDs []domain.ProjectID
|
||||
if apiKey := auth.GetAPIKey(r.Context()); apiKey != nil && !apiKey.HasScope(domain.ScopeAdmin) {
|
||||
allowedIDs = apiKey.ProjectIDs // nil means unrestricted for this key
|
||||
}
|
||||
projects, err := h.projectService.List(ctx, allowedIDs)
|
||||
if err != nil {
|
||||
api.WriteInternalError(w, r, "failed to list projects")
|
||||
return
|
||||
|
||||
@ -28,6 +28,8 @@ func NewSDLCHandler(sdlcService *service.SDLCService) *SDLCHandler {
|
||||
// Mount registers all SDLC routes under /projects/{id}/sdlc/.
|
||||
func (h *SDLCHandler) Mount(r api.Router) {
|
||||
r.Route("/projects/{id}/sdlc", func(r chi.Router) {
|
||||
r.Use(auth.RequireProjectAccess("id"))
|
||||
|
||||
// State (read)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsRead, auth.ScopeAdmin)).Get("/state", h.GetState)
|
||||
r.With(auth.RequireScope(auth.ScopeProjectsRead, auth.ScopeAdmin)).Get("/next", h.GetNext)
|
||||
|
||||
@ -39,6 +39,8 @@ func NewSessionsHandler(
|
||||
// Mount registers the session routes.
|
||||
func (h *SessionsHandler) Mount(r api.Router) {
|
||||
r.Route("/projects/{id}/sessions", func(r chi.Router) {
|
||||
r.Use(auth.RequireProjectAccess("id"))
|
||||
|
||||
// List sessions (read access)
|
||||
r.With(auth.RequireScope(auth.ScopeSessionsRead, auth.ScopeProjectsRead, auth.ScopeAdmin)).
|
||||
Get("/", h.List)
|
||||
|
||||
@ -40,7 +40,7 @@ func (h *VerifyHandler) Mount(r api.Router) {
|
||||
r.With(auth.RequireScope(auth.ScopeVerifyRead, auth.ScopeAdmin)).Get("/{taskId}/stream", h.Stream)
|
||||
r.With(auth.RequireScope(auth.ScopeVerifyWrite, auth.ScopeAdmin)).Delete("/{taskId}", h.Cancel)
|
||||
})
|
||||
r.With(auth.RequireScope(auth.ScopeVerifyRead, auth.ScopeAdmin)).Get("/projects/{id}/verify", h.ListByProject)
|
||||
r.With(auth.RequireScope(auth.ScopeVerifyRead, auth.ScopeAdmin), auth.RequireProjectAccess("id")).Get("/projects/{id}/verify", h.ListByProject)
|
||||
}
|
||||
|
||||
// SubmitVerifyRequest is the request body for POST /verify.
|
||||
|
||||
@ -2,10 +2,21 @@ package port
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// APIKeyUpdate contains mutable fields for updating an API key.
|
||||
// A nil pointer means "don't change" that field.
|
||||
type APIKeyUpdate struct {
|
||||
Name *string
|
||||
Scopes []domain.Scope // nil = don't change; non-nil = replace
|
||||
ProjectIDs *[]domain.ProjectID // nil ptr = don't change; ptr to nil slice = unrestricted
|
||||
AllowedIPs *[]string // nil ptr = don't change; ptr to nil slice = no restriction
|
||||
ExpiresAt **time.Time // nil ptr = don't change; ptr to nil ptr = remove expiry
|
||||
}
|
||||
|
||||
// APIKeyRepository defines operations for managing API keys.
|
||||
type APIKeyRepository interface {
|
||||
// Create stores a new API key.
|
||||
@ -25,4 +36,10 @@ type APIKeyRepository interface {
|
||||
|
||||
// UpdateLastUsed updates the last used timestamp for a key.
|
||||
UpdateLastUsed(ctx context.Context, id domain.APIKeyID) error
|
||||
|
||||
// Update applies a partial update to an API key.
|
||||
Update(ctx context.Context, id domain.APIKeyID, update APIKeyUpdate) error
|
||||
|
||||
// ListByProjectID returns all active keys that have the given project ID in their project_ids.
|
||||
ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*domain.APIKey, error)
|
||||
}
|
||||
|
||||
@ -6,14 +6,17 @@ import (
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
)
|
||||
|
||||
// NotifyProvisioner manages per-project email delivery accounts on the notify service.
|
||||
// NotifyProvisioner manages per-project email delivery on the notify service.
|
||||
// Each project gets its own isolated sending host (mail.{slug}.threesix.ai),
|
||||
// Resend domain with DKIM/SPF, and a dedicated notify account with send key.
|
||||
type NotifyProvisioner interface {
|
||||
// CreateProjectNotify creates a notify account and send key for a project.
|
||||
// Grants the account access to the shared host and returns credentials.
|
||||
CreateProjectNotify(ctx context.Context, projectID string) (*domain.NotifyCredentials, error)
|
||||
// CreateProjectNotify provisions a notify host, Resend domain, DNS records,
|
||||
// and account with send key for the project.
|
||||
CreateProjectNotify(ctx context.Context, projectID, slug string) (*domain.NotifyCredentials, error)
|
||||
|
||||
// DeleteProjectNotify removes the notify account for a project.
|
||||
DeleteProjectNotify(ctx context.Context, projectID string) error
|
||||
// DeleteProjectNotify removes all notify resources for a project:
|
||||
// the notify account, the per-project host, the Resend domain, and DNS records.
|
||||
DeleteProjectNotify(ctx context.Context, projectID, slug, resendDomainID string) error
|
||||
|
||||
// GetProjectNotify returns notify credentials for a project, or nil if not provisioned.
|
||||
GetProjectNotify(ctx context.Context, projectID string) (*domain.NotifyCredentials, error)
|
||||
|
||||
@ -148,6 +148,26 @@ func (m *mockAPIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.API
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAPIKeyRepository) Update(ctx context.Context, id domain.APIKeyID, update port.APIKeyUpdate) error {
|
||||
if _, ok := m.keys[id]; !ok {
|
||||
return domain.ErrKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAPIKeyRepository) ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*domain.APIKey, error) {
|
||||
var result []*domain.APIKey
|
||||
for _, k := range m.keys {
|
||||
for _, pid := range k.ProjectIDs {
|
||||
if pid == projectID {
|
||||
result = append(result, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type mockStreamPublisher struct {
|
||||
subscribers map[string][]chan port.StreamEvent
|
||||
}
|
||||
|
||||
@ -119,6 +119,16 @@ func (s *APIKeyService) UpdateLastUsed(ctx context.Context, id domain.APIKeyID)
|
||||
return s.repo.UpdateLastUsed(ctx, id)
|
||||
}
|
||||
|
||||
// Update applies a partial update to an API key.
|
||||
func (s *APIKeyService) Update(ctx context.Context, id domain.APIKeyID, update port.APIKeyUpdate) error {
|
||||
return s.repo.Update(ctx, id, update)
|
||||
}
|
||||
|
||||
// ListByProjectID returns all active keys that have the given project ID in their project_ids.
|
||||
func (s *APIKeyService) ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*domain.APIKey, error) {
|
||||
return s.repo.ListByProjectID(ctx, projectID)
|
||||
}
|
||||
|
||||
// Validate checks a raw API key and returns the associated APIKey if valid.
|
||||
// It checks for admin key, looks up by hash, and verifies the key is active.
|
||||
// On success it asynchronously updates the last-used timestamp.
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/orchard9/rdev/internal/domain"
|
||||
"github.com/orchard9/rdev/internal/port"
|
||||
)
|
||||
|
||||
// MockAPIKeyRepository implements port.APIKeyRepository for testing.
|
||||
@ -83,6 +84,23 @@ func (m *MockAPIKeyRepository) UpdateLastUsed(ctx context.Context, id domain.API
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAPIKeyRepository) Update(ctx context.Context, id domain.APIKeyID, update port.APIKeyUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockAPIKeyRepository) ListByProjectID(ctx context.Context, projectID domain.ProjectID) ([]*domain.APIKey, error) {
|
||||
var result []*domain.APIKey
|
||||
for _, k := range m.keys {
|
||||
for _, pid := range k.ProjectIDs {
|
||||
if pid == projectID {
|
||||
result = append(result, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_Create(t *testing.T) {
|
||||
repo := NewMockAPIKeyRepository()
|
||||
svc := NewAPIKeyService(repo, "admin-secret")
|
||||
|
||||
@ -182,12 +182,16 @@ func (s *ComponentService) fetchProjectCredentials(ctx context.Context, projectI
|
||||
"REDIS_PREFIX",
|
||||
domain.CredKeyGCSBucket,
|
||||
domain.CredKeyGCSServiceAccountJSON,
|
||||
domain.CredKeyNotifyAPIKey,
|
||||
domain.CredKeyNotifyHost,
|
||||
domain.CredKeyNotifyFrom,
|
||||
}
|
||||
|
||||
// Global credentials (stored without project prefix, shared across all projects)
|
||||
globalKeys := []string{
|
||||
domain.CredKeyLaozhangAPIKey,
|
||||
domain.CredKeyGeminiAPIKey,
|
||||
domain.CredKeyNotifyURL,
|
||||
}
|
||||
|
||||
secrets := make(map[string]string)
|
||||
|
||||
@ -473,7 +473,7 @@ func (s *ProjectInfraService) provisionResources(ctx context.Context, result *Cr
|
||||
if existing != nil {
|
||||
log.Info("notify already provisioned, skipping", logging.FieldProjectID, projectID)
|
||||
} else {
|
||||
notifyCreds, err := s.notifyProvisioner.CreateProjectNotify(ctx, projectID)
|
||||
notifyCreds, err := s.notifyProvisioner.CreateProjectNotify(ctx, projectID, result.Slug)
|
||||
if err != nil {
|
||||
log.Error("failed to provision notify", logging.FieldProjectID, projectID, logging.FieldError, err)
|
||||
result.NextSteps = append(result.NextSteps, "Notify provisioning failed - contact admin")
|
||||
@ -491,10 +491,14 @@ func (s *ProjectInfraService) provisionResources(ctx context.Context, result *Cr
|
||||
storeErr = err
|
||||
log.Error("failed to store NOTIFY_FROM", logging.FieldProjectID, projectID, logging.FieldError, err)
|
||||
}
|
||||
if err := s.storeCredential(ctx, projectID, domain.CredentialCategoryNotify, domain.CredKeyNotifyResendDomainID, notifyCreds.ResendDomainID); err != nil {
|
||||
storeErr = err
|
||||
log.Error("failed to store NOTIFY_RESEND_DOMAIN_ID", logging.FieldProjectID, projectID, logging.FieldError, err)
|
||||
}
|
||||
|
||||
if storeErr != nil {
|
||||
log.Warn("rolling back notify due to credential storage failure", logging.FieldProjectID, projectID)
|
||||
if rollbackErr := s.notifyProvisioner.DeleteProjectNotify(ctx, projectID); rollbackErr != nil {
|
||||
if rollbackErr := s.notifyProvisioner.DeleteProjectNotify(ctx, projectID, result.Slug, notifyCreds.ResendDomainID); rollbackErr != nil {
|
||||
log.Error("failed to rollback notify account", logging.FieldProjectID, projectID, logging.FieldError, rollbackErr)
|
||||
result.NextSteps = append(result.NextSteps, "Notify created but credentials not stored - manual cleanup required")
|
||||
} else {
|
||||
@ -884,9 +888,15 @@ func (s *ProjectInfraService) DeleteProject(ctx context.Context, projectID strin
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Delete provisioned notify account
|
||||
// 5. Delete provisioned notify account (look up slug + resendDomainID from credential store)
|
||||
if s.notifyProvisioner != nil {
|
||||
if err := s.notifyProvisioner.DeleteProjectNotify(ctx, projectID); err != nil {
|
||||
notifySlug := status.Slug
|
||||
var resendDomainID string
|
||||
if s.credentialStore != nil {
|
||||
cred, _ := s.credentialStore.Get(ctx, projectID+":"+domain.CredKeyNotifyResendDomainID)
|
||||
resendDomainID = cred
|
||||
}
|
||||
if err := s.notifyProvisioner.DeleteProjectNotify(ctx, projectID, notifySlug, resendDomainID); err != nil {
|
||||
log.Warn("failed to delete project notify account", logging.FieldError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,14 +73,37 @@ type AuditContext struct {
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
// List returns all available projects with refreshed status.
|
||||
func (s *ProjectService) List(ctx context.Context) ([]domain.Project, error) {
|
||||
// List returns available projects with refreshed status.
|
||||
// allowedProjectIDs restricts results to specific projects; nil means unrestricted.
|
||||
func (s *ProjectService) List(ctx context.Context, allowedProjectIDs []domain.ProjectID) ([]domain.Project, error) {
|
||||
log := logging.FromContext(ctx).WithService("ProjectService")
|
||||
// Refresh status from Kubernetes
|
||||
if err := s.projects.RefreshStatus(ctx); err != nil {
|
||||
log.Warn("failed to refresh project status", logging.FieldError, err)
|
||||
}
|
||||
return s.projects.List(ctx)
|
||||
|
||||
projects, err := s.projects.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// nil = unrestricted (admin or no project_ids restriction)
|
||||
if allowedProjectIDs == nil {
|
||||
return projects, nil
|
||||
}
|
||||
|
||||
// Filter to only allowed projects
|
||||
allowed := make(map[domain.ProjectID]bool, len(allowedProjectIDs))
|
||||
for _, id := range allowedProjectIDs {
|
||||
allowed[id] = true
|
||||
}
|
||||
filtered := projects[:0]
|
||||
for _, p := range projects {
|
||||
if allowed[p.ID] {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// Get returns a specific project by ID.
|
||||
|
||||
@ -167,7 +167,7 @@ func TestProjectService_List(t *testing.T) {
|
||||
|
||||
svc := NewProjectService(repo, nil, nil)
|
||||
|
||||
projects, err := svc.List(context.Background())
|
||||
projects, err := svc.List(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
@ -190,7 +190,7 @@ func TestProjectService_List_RefreshError(t *testing.T) {
|
||||
svc := NewProjectService(repo, nil, nil)
|
||||
|
||||
// Should still return projects even if refresh fails
|
||||
projects, err := svc.List(context.Background())
|
||||
projects, err := svc.List(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user