// Package postgres provides PostgreSQL-based implementations of port interfaces. package postgres import ( "context" "database/sql" "errors" "fmt" "strings" "time" "github.com/orchard9/rdev/internal/domain" "github.com/orchard9/rdev/internal/port" ) // AuditLogger implements port.AuditLogger using PostgreSQL. type AuditLogger struct { db *sql.DB } // NewAuditLogger creates a new PostgreSQL audit logger. func NewAuditLogger(db *sql.DB) *AuditLogger { return &AuditLogger{db: db} } // Ensure AuditLogger implements port.AuditLogger at compile time. var _ port.AuditLogger = (*AuditLogger)(nil) // LogCommandStart records the start of a command execution. func (l *AuditLogger) LogCommandStart(ctx context.Context, entry *domain.AuditLogEntry) error { _, err := l.db.ExecContext(ctx, ` INSERT INTO audit_log ( id, api_key_id, command_id, project_id, command_type, args, client_ip, user_agent, started_at, status, output_size_bytes ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) `, entry.ID, entry.APIKeyID, entry.CommandID, entry.ProjectID, string(entry.CommandType), entry.Args, entry.ClientIP, entry.UserAgent, entry.StartedAt, string(domain.AuditStatusRunning), entry.OutputSizeBytes, ) if err != nil { return fmt.Errorf("insert audit log: %w", err) } return nil } // LogCommandEnd records the completion of a command execution. func (l *AuditLogger) LogCommandEnd(ctx context.Context, commandID string, result *domain.AuditResult) error { completedAt := time.Now() _, err := l.db.ExecContext(ctx, ` UPDATE audit_log SET completed_at = $1, exit_code = $2, duration_ms = $3, status = $4, error_message = $5, output_size_bytes = $6 WHERE command_id = $7 `, completedAt, result.ExitCode, result.DurationMs, string(result.Status), result.ErrorMessage, result.OutputSizeBytes, commandID, ) if err != nil { return fmt.Errorf("update audit log: %w", err) } return nil } // List returns audit log entries matching the given filters. func (l *AuditLogger) List(ctx context.Context, filters domain.AuditFilters) ([]domain.AuditLogEntry, error) { query := strings.Builder{} query.WriteString(` SELECT id, api_key_id, command_id, project_id, command_type, args, client_ip, user_agent, started_at, completed_at, exit_code, duration_ms, status, error_message, output_size_bytes, created_at FROM audit_log WHERE 1=1 `) args := make([]any, 0) argNum := 1 if filters.ProjectID != "" { query.WriteString(fmt.Sprintf(" AND project_id = $%d", argNum)) args = append(args, filters.ProjectID) argNum++ } if filters.APIKeyID != "" { query.WriteString(fmt.Sprintf(" AND api_key_id = $%d", argNum)) args = append(args, filters.APIKeyID) argNum++ } if filters.CommandType != "" { query.WriteString(fmt.Sprintf(" AND command_type = $%d", argNum)) args = append(args, string(filters.CommandType)) argNum++ } if filters.Status != "" { query.WriteString(fmt.Sprintf(" AND status = $%d", argNum)) args = append(args, string(filters.Status)) argNum++ } if filters.StartTime != nil { query.WriteString(fmt.Sprintf(" AND created_at >= $%d", argNum)) args = append(args, *filters.StartTime) argNum++ } if filters.EndTime != nil { query.WriteString(fmt.Sprintf(" AND created_at < $%d", argNum)) args = append(args, *filters.EndTime) argNum++ } query.WriteString(" ORDER BY created_at DESC") if filters.Limit > 0 { query.WriteString(fmt.Sprintf(" LIMIT $%d", argNum)) args = append(args, filters.Limit) argNum++ } if filters.Offset > 0 { query.WriteString(fmt.Sprintf(" OFFSET $%d", argNum)) args = append(args, filters.Offset) } rows, err := l.db.QueryContext(ctx, query.String(), args...) if err != nil { return nil, fmt.Errorf("query audit log: %w", err) } defer func() { _ = rows.Close() }() var entries []domain.AuditLogEntry for rows.Next() { var entry domain.AuditLogEntry var commandType string var status string var completedAt sql.NullTime var exitCode sql.NullInt32 var durationMs sql.NullInt64 var errorMessage sql.NullString if err := rows.Scan( &entry.ID, &entry.APIKeyID, &entry.CommandID, &entry.ProjectID, &commandType, &entry.Args, &entry.ClientIP, &entry.UserAgent, &entry.StartedAt, &completedAt, &exitCode, &durationMs, &status, &errorMessage, &entry.OutputSizeBytes, &entry.CreatedAt, ); err != nil { return nil, fmt.Errorf("scan audit log: %w", err) } entry.CommandType = domain.CommandType(commandType) entry.Status = domain.AuditStatus(status) if completedAt.Valid { entry.CompletedAt = &completedAt.Time } if exitCode.Valid { ec := int(exitCode.Int32) entry.ExitCode = &ec } if durationMs.Valid { dm := durationMs.Int64 entry.DurationMs = &dm } if errorMessage.Valid { entry.ErrorMessage = errorMessage.String } entries = append(entries, entry) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate audit log: %w", err) } return entries, nil } // Get returns a single audit log entry by command ID. func (l *AuditLogger) Get(ctx context.Context, commandID string) (*domain.AuditLogEntry, error) { var entry domain.AuditLogEntry var commandType string var status string var completedAt sql.NullTime var exitCode sql.NullInt32 var durationMs sql.NullInt64 var errorMessage sql.NullString err := l.db.QueryRowContext(ctx, ` SELECT id, api_key_id, command_id, project_id, command_type, args, client_ip, user_agent, started_at, completed_at, exit_code, duration_ms, status, error_message, output_size_bytes, created_at FROM audit_log WHERE command_id = $1 `, commandID).Scan( &entry.ID, &entry.APIKeyID, &entry.CommandID, &entry.ProjectID, &commandType, &entry.Args, &entry.ClientIP, &entry.UserAgent, &entry.StartedAt, &completedAt, &exitCode, &durationMs, &status, &errorMessage, &entry.OutputSizeBytes, &entry.CreatedAt, ) if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrAuditNotFound } if err != nil { return nil, fmt.Errorf("query audit log: %w", err) } entry.CommandType = domain.CommandType(commandType) entry.Status = domain.AuditStatus(status) if completedAt.Valid { entry.CompletedAt = &completedAt.Time } if exitCode.Valid { ec := int(exitCode.Int32) entry.ExitCode = &ec } if durationMs.Valid { dm := durationMs.Int64 entry.DurationMs = &dm } if errorMessage.Valid { entry.ErrorMessage = errorMessage.String } return &entry, nil }