feat: implement FastHTML app shell with auth/CSRF middleware (Step 7.1)

Add web layer foundation:
- FastHTML app factory with Beforeware pattern
- Auth middleware validating trusted proxy IPs and X-Oidc-Username header
- CSRF dual-token validation (cookie + header + Origin/Referer)
- Request ID generation (ULID) and NDJSON request logging
- Role-based permission helpers (can_edit_event, can_delete_event)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-29 19:52:15 +00:00
parent eb9dc8eadd
commit 84225d865f
10 changed files with 1579 additions and 9 deletions

View File

@@ -0,0 +1,6 @@
# ABOUTME: Web layer package for AnimalTrack.
# ABOUTME: Exports create_app for FastHTML application creation.
from animaltrack.web.app import create_app
__all__ = ["create_app"]

View File

@@ -0,0 +1,97 @@
# ABOUTME: FastHTML application factory for AnimalTrack.
# ABOUTME: Configures middleware, routes, and HTMX extensions.
from __future__ import annotations
from fasthtml.common import Beforeware, fast_app
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from animaltrack.config import Settings
from animaltrack.db import get_db
from animaltrack.web.middleware import (
auth_before,
csrf_before,
request_id_before,
)
def create_app(
settings: Settings | None = None,
db=None,
):
"""Create and configure the FastHTML application.
Args:
settings: Application settings. Loads from env if None.
db: Database connection. Creates from settings if None.
Returns:
Tuple of (app, rt) - the FastHTML app and route decorator.
"""
# Load settings if not provided
if settings is None:
settings = Settings()
# Create db connection if not provided
if db is None:
db = get_db(settings.db_path)
# Create beforeware function that combines all middleware
def before(req: Request, sess):
"""Combined beforeware function for all middleware."""
# Generate request ID first
request_id_before(req)
# Auth middleware
auth_resp = auth_before(req, settings, db)
if auth_resp is not None:
return auth_resp
# CSRF middleware
csrf_resp = csrf_before(req, settings)
if csrf_resp is not None:
return csrf_resp
return None
# Configure beforeware with skip patterns
beforeware = Beforeware(
before,
skip=[
r"/favicon\.ico",
r"/static/.*",
r".*\.css",
r".*\.js",
r"/healthz",
],
)
# Create FastHTML app with HTMX extensions
app, rt = fast_app(
before=beforeware,
exts=["head-support", "preload"],
)
# Store settings and db on app state for access in routes
app.state.settings = settings
app.state.db = db
# Register healthz route (excluded from beforeware)
@rt("/healthz")
def healthz():
"""Health check endpoint."""
# Verify database is writable
try:
db.execute("SELECT 1")
return PlainTextResponse("OK", status_code=200)
except Exception as e:
return PlainTextResponse(f"Database error: {e}", status_code=503)
# Placeholder index route (will be replaced with real UI later)
@rt("/")
def index():
"""Placeholder index route."""
return PlainTextResponse("AnimalTrack", status_code=200)
return app, rt

140
src/animaltrack/web/auth.py Normal file
View File

@@ -0,0 +1,140 @@
# ABOUTME: Authentication and authorization helpers for route handlers.
# ABOUTME: Provides user extraction, role checks, and permission decorators.
from collections.abc import Callable
from functools import wraps
from typing import Any
from starlette.requests import Request
from starlette.responses import Response
from animaltrack.models.reference import User, UserRole
from animaltrack.web.exceptions import AuthenticationError, AuthorizationError
def get_current_user(req: Request) -> User | None:
"""Get the authenticated user from request scope.
Args:
req: The Starlette request object.
Returns:
User object if authenticated, None otherwise.
"""
return req.scope.get("auth")
def is_admin(user: User) -> bool:
"""Check if user has admin role.
Args:
user: The user to check.
Returns:
True if user has admin role, False otherwise.
"""
return user.role == UserRole.ADMIN
def is_recorder(user: User) -> bool:
"""Check if user has recorder role.
Args:
user: The user to check.
Returns:
True if user has recorder role, False otherwise.
"""
return user.role == UserRole.RECORDER
def can_edit_event(user: User, event_actor: str) -> bool:
"""Check if user can edit an event.
Admins can edit any event.
Recorders can only edit their own events.
Args:
user: The user attempting to edit.
event_actor: The username of the event's creator.
Returns:
True if user can edit the event, False otherwise.
"""
if is_admin(user):
return True
return user.username == event_actor
def can_delete_event(user: User, event_actor: str, has_dependents: bool) -> bool:
"""Check if user can delete an event.
Admins can delete any event, including cascade delete with dependents.
Recorders can only delete their own events without dependents.
Args:
user: The user attempting to delete.
event_actor: The username of the event's creator.
has_dependents: Whether the event has dependent events.
Returns:
True if user can delete the event, False otherwise.
"""
if is_admin(user):
return True
# Recorder can only delete own events without dependents
return user.username == event_actor and not has_dependents
def require_auth(handler: Callable) -> Callable:
"""Decorator that requires authentication.
Returns 401 response if no authenticated user in request scope.
Args:
handler: The route handler to wrap.
Returns:
Wrapped handler that checks for authentication.
"""
@wraps(handler)
async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response:
user = get_current_user(req)
if user is None:
raise AuthenticationError("Authentication required")
return await handler(req, *args, **kwargs)
return wrapper
def require_role(*roles: UserRole) -> Callable[[Callable], Callable]:
"""Decorator factory that requires specific role(s).
Returns 401 if not authenticated, 403 if role not in allowed roles.
Args:
roles: One or more UserRole values that are allowed.
Returns:
Decorator that checks for required role.
Example:
@require_role(UserRole.ADMIN)
async def admin_only_route(req):
...
"""
def decorator(handler: Callable) -> Callable:
@wraps(handler)
async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response:
user = get_current_user(req)
if user is None:
raise AuthenticationError("Authentication required")
if user.role not in roles:
raise AuthorizationError(f"Required role: {', '.join(r.value for r in roles)}")
return await handler(req, *args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,42 @@
# ABOUTME: HTTP exception types for the web layer.
# ABOUTME: Maps domain errors to appropriate HTTP status codes.
class AuthenticationError(Exception):
"""Raised when authentication fails.
HTTP Status: 401 Unauthorized
Causes: Missing auth header, unknown user, inactive user.
"""
pass
class AuthorizationError(Exception):
"""Raised when authorization fails.
HTTP Status: 403 Forbidden
Causes: User lacks required role, cannot edit/delete event.
"""
pass
class CSRFValidationError(Exception):
"""Raised when CSRF validation fails.
HTTP Status: 403 Forbidden
Causes: Missing/mismatched token, invalid Origin/Referer.
"""
pass
class UntrustedProxyError(Exception):
"""Raised when request comes from untrusted proxy.
HTTP Status: 403 Forbidden
Causes: Request IP not in TRUSTED_PROXY_IPS.
"""
pass

View File

@@ -0,0 +1,277 @@
# ABOUTME: Middleware functions for authentication, CSRF, and request logging.
# ABOUTME: Implements Beforeware pattern for FastHTML request processing.
import json
import time
from urllib.parse import urlparse
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from animaltrack.config import Settings
from animaltrack.id_gen import generate_id
from animaltrack.repositories.users import UserRepository
# Safe HTTP methods that don't require CSRF protection
SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
def is_safe_method(method: str) -> bool:
"""Check if HTTP method is safe (doesn't require CSRF protection).
Args:
method: The HTTP method (e.g., "GET", "POST").
Returns:
True if method is safe, False otherwise.
"""
return method.upper() in SAFE_METHODS
def validate_csrf_token(cookie_token: str | None, header_token: str | None) -> bool:
"""Validate that CSRF tokens match.
Args:
cookie_token: Token from cookie.
header_token: Token from header.
Returns:
True if tokens match and are non-empty, False otherwise.
"""
if cookie_token is None or header_token is None:
return False
if not cookie_token or not header_token:
return False
return cookie_token == header_token
def validate_origin(origin: str | None, expected_host: str) -> bool:
"""Validate Origin header matches expected host.
Args:
origin: The Origin header value (e.g., "https://example.com").
expected_host: The expected host (e.g., "example.com" or "example.com:3366").
Returns:
True if origin matches expected host, False otherwise.
"""
if origin is None or not origin:
return False
try:
parsed = urlparse(origin)
origin_host = parsed.netloc
return origin_host == expected_host
except Exception:
return False
def validate_referer(referer: str | None, expected_host: str) -> bool:
"""Validate Referer header matches expected host.
Args:
referer: The Referer header value (e.g., "https://example.com/page").
expected_host: The expected host (e.g., "example.com" or "example.com:3366").
Returns:
True if referer host matches expected host, False otherwise.
"""
if referer is None or not referer:
return False
try:
parsed = urlparse(referer)
# parsed.netloc includes port if present
referer_host = parsed.netloc
if not referer_host:
return False
return referer_host == expected_host
except Exception:
return False
def get_client_ip(req: Request) -> str:
"""Extract client IP from request, respecting X-Forwarded-For.
Args:
req: The Starlette request object.
Returns:
The client IP address.
"""
# Check X-Forwarded-For header first (set by reverse proxy)
forwarded_for = req.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2"
# The first one is the original client
return forwarded_for.split(",")[0].strip()
# Fall back to direct connection IP
if req.client:
return req.client.host
return "unknown"
def is_trusted_proxy(req: Request, settings: Settings) -> bool:
"""Check if request comes from a trusted proxy IP.
Args:
req: The Starlette request object.
settings: Application settings with trusted_proxy_ips.
Returns:
True if request is from trusted proxy, False otherwise.
"""
trusted_ips = settings.trusted_proxy_ips
if not trusted_ips:
# If no trusted IPs configured, reject all (fail-secure)
return False
# Get the immediate connection IP (not X-Forwarded-For)
if req.client:
client_ip = req.client.host
else:
return False
return client_ip in trusted_ips
def get_expected_host(req: Request, settings: Settings) -> str:
"""Get the expected host for Origin/Referer validation.
Uses X-Forwarded-Host if present, otherwise request host.
Args:
req: The Starlette request object.
settings: Application settings.
Returns:
The expected host string.
"""
forwarded_host = req.headers.get("x-forwarded-host")
if forwarded_host:
return forwarded_host
return req.headers.get("host", "")
def request_id_before(req: Request) -> None:
"""Generate unique request_id and attach to request scope.
Args:
req: The Starlette request object.
"""
req.scope["request_id"] = generate_id()
req.scope["request_start_time"] = time.time()
def auth_before(req: Request, settings: Settings, db) -> Response | None:
"""Extract and validate authentication from proxy headers.
Validates:
- Request comes from trusted proxy IP
- Auth header is present
- User exists and is active in database
Args:
req: The Starlette request object.
settings: Application settings.
db: Database connection.
Returns:
None to continue processing, or Response to short-circuit.
"""
# Check trusted proxy
if not is_trusted_proxy(req, settings):
return PlainTextResponse("Forbidden: Request not from trusted proxy", status_code=403)
# Extract username from auth header
username = req.headers.get(settings.auth_header_name.lower())
if not username:
return PlainTextResponse("Unauthorized: Missing auth header", status_code=401)
# Look up user in database
user_repo = UserRepository(db)
user = user_repo.get(username)
if user is None:
return PlainTextResponse("Unauthorized: Unknown user", status_code=401)
if not user.active:
return PlainTextResponse("Unauthorized: Inactive user", status_code=401)
# Store user in scope for access by handlers
req.scope["auth"] = user
return None
def csrf_before(req: Request, settings: Settings) -> Response | None:
"""Validate CSRF token on unsafe HTTP methods.
Validates:
1. CSRF cookie present and matches header
2. Origin or Referer matches expected host
Args:
req: The Starlette request object.
settings: Application settings.
Returns:
None to continue processing, or Response to short-circuit.
"""
# Skip CSRF check for safe methods
if is_safe_method(req.method):
return None
# Get CSRF tokens
cookie_token = req.cookies.get(settings.csrf_cookie_name)
header_token = req.headers.get("x-csrf-token")
# Validate tokens match
if not validate_csrf_token(cookie_token, header_token):
return PlainTextResponse("Forbidden: Invalid CSRF token", status_code=403)
# Validate Origin or Referer
expected_host = get_expected_host(req, settings)
origin = req.headers.get("origin")
referer = req.headers.get("referer")
if not validate_origin(origin, expected_host):
# Fall back to Referer check
if not validate_referer(referer, expected_host):
return PlainTextResponse("Forbidden: Invalid Origin/Referer", status_code=403)
return None
def logging_after(
req: Request, resp: Response, settings: Settings, event_id: str | None = None
) -> None:
"""Log request in NDJSON format after response.
Format: {ts, level, route, actor, ip, method, status, duration_ms, request_id, event_id}
Args:
req: The Starlette request object.
resp: The response object.
settings: Application settings.
event_id: Optional event ID if an event was created.
"""
start_time = req.scope.get("request_start_time", time.time())
duration_ms = int((time.time() - start_time) * 1000)
user = req.scope.get("auth")
actor = user.username if user else None
log_entry = {
"ts": int(time.time() * 1000),
"level": "info",
"route": req.url.path,
"actor": actor,
"ip": get_client_ip(req),
"method": req.method,
"status": resp.status_code,
"duration_ms": duration_ms,
"request_id": req.scope.get("request_id"),
}
if event_id:
log_entry["event_id"] = event_id
# Output as NDJSON (one JSON object per line)
print(json.dumps(log_entry), flush=True)