feat: implement FastHTML app shell with auth/CSRF middleware (Step 7.1)

Add web layer foundation:
- FastHTML app factory with Beforeware pattern
- Auth middleware validating trusted proxy IPs and X-Oidc-Username header
- CSRF dual-token validation (cookie + header + Origin/Referer)
- Request ID generation (ULID) and NDJSON request logging
- Role-based permission helpers (can_edit_event, can_delete_event)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-29 19:52:15 +00:00
parent eb9dc8eadd
commit 84225d865f
10 changed files with 1579 additions and 9 deletions

18
PLAN.md
View File

@@ -249,15 +249,15 @@ Check off items as completed. Each phase builds on the previous.
## Phase 7: HTTP API
### Step 7.1: FastHTML App Shell
- [ ] Create `web/app.py` with FastHTML setup
- [ ] Configure HTMX extensions (head-support, preload, etc.)
- [ ] Create `web/middleware.py`:
- [ ] Auth middleware (X-Oidc-Username, TRUSTED_PROXY_IPS)
- [ ] CSRF middleware (cookie + header + Origin/Referer)
- [ ] Request logging (NDJSON format)
- [ ] Request ID generation
- [ ] Create `web/auth.py` with get_current_user, require_role
- [ ] Write tests: auth extraction, CSRF validation, untrusted IP rejection
- [x] Create `web/app.py` with FastHTML setup
- [x] Configure HTMX extensions (head-support, preload, etc.)
- [x] Create `web/middleware.py`:
- [x] Auth middleware (X-Oidc-Username, TRUSTED_PROXY_IPS)
- [x] CSRF middleware (cookie + header + Origin/Referer)
- [x] Request logging (NDJSON format)
- [x] Request ID generation
- [x] Create `web/auth.py` with get_current_user, require_role
- [x] Write tests: auth extraction, CSRF validation, untrusted IP rejection
- [ ] **Commit checkpoint**
### Step 7.2: Health & Static Assets

View File

@@ -0,0 +1,6 @@
# ABOUTME: Web layer package for AnimalTrack.
# ABOUTME: Exports create_app for FastHTML application creation.
from animaltrack.web.app import create_app
__all__ = ["create_app"]

View File

@@ -0,0 +1,97 @@
# ABOUTME: FastHTML application factory for AnimalTrack.
# ABOUTME: Configures middleware, routes, and HTMX extensions.
from __future__ import annotations
from fasthtml.common import Beforeware, fast_app
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from animaltrack.config import Settings
from animaltrack.db import get_db
from animaltrack.web.middleware import (
auth_before,
csrf_before,
request_id_before,
)
def create_app(
settings: Settings | None = None,
db=None,
):
"""Create and configure the FastHTML application.
Args:
settings: Application settings. Loads from env if None.
db: Database connection. Creates from settings if None.
Returns:
Tuple of (app, rt) - the FastHTML app and route decorator.
"""
# Load settings if not provided
if settings is None:
settings = Settings()
# Create db connection if not provided
if db is None:
db = get_db(settings.db_path)
# Create beforeware function that combines all middleware
def before(req: Request, sess):
"""Combined beforeware function for all middleware."""
# Generate request ID first
request_id_before(req)
# Auth middleware
auth_resp = auth_before(req, settings, db)
if auth_resp is not None:
return auth_resp
# CSRF middleware
csrf_resp = csrf_before(req, settings)
if csrf_resp is not None:
return csrf_resp
return None
# Configure beforeware with skip patterns
beforeware = Beforeware(
before,
skip=[
r"/favicon\.ico",
r"/static/.*",
r".*\.css",
r".*\.js",
r"/healthz",
],
)
# Create FastHTML app with HTMX extensions
app, rt = fast_app(
before=beforeware,
exts=["head-support", "preload"],
)
# Store settings and db on app state for access in routes
app.state.settings = settings
app.state.db = db
# Register healthz route (excluded from beforeware)
@rt("/healthz")
def healthz():
"""Health check endpoint."""
# Verify database is writable
try:
db.execute("SELECT 1")
return PlainTextResponse("OK", status_code=200)
except Exception as e:
return PlainTextResponse(f"Database error: {e}", status_code=503)
# Placeholder index route (will be replaced with real UI later)
@rt("/")
def index():
"""Placeholder index route."""
return PlainTextResponse("AnimalTrack", status_code=200)
return app, rt

140
src/animaltrack/web/auth.py Normal file
View File

@@ -0,0 +1,140 @@
# ABOUTME: Authentication and authorization helpers for route handlers.
# ABOUTME: Provides user extraction, role checks, and permission decorators.
from collections.abc import Callable
from functools import wraps
from typing import Any
from starlette.requests import Request
from starlette.responses import Response
from animaltrack.models.reference import User, UserRole
from animaltrack.web.exceptions import AuthenticationError, AuthorizationError
def get_current_user(req: Request) -> User | None:
"""Get the authenticated user from request scope.
Args:
req: The Starlette request object.
Returns:
User object if authenticated, None otherwise.
"""
return req.scope.get("auth")
def is_admin(user: User) -> bool:
"""Check if user has admin role.
Args:
user: The user to check.
Returns:
True if user has admin role, False otherwise.
"""
return user.role == UserRole.ADMIN
def is_recorder(user: User) -> bool:
"""Check if user has recorder role.
Args:
user: The user to check.
Returns:
True if user has recorder role, False otherwise.
"""
return user.role == UserRole.RECORDER
def can_edit_event(user: User, event_actor: str) -> bool:
"""Check if user can edit an event.
Admins can edit any event.
Recorders can only edit their own events.
Args:
user: The user attempting to edit.
event_actor: The username of the event's creator.
Returns:
True if user can edit the event, False otherwise.
"""
if is_admin(user):
return True
return user.username == event_actor
def can_delete_event(user: User, event_actor: str, has_dependents: bool) -> bool:
"""Check if user can delete an event.
Admins can delete any event, including cascade delete with dependents.
Recorders can only delete their own events without dependents.
Args:
user: The user attempting to delete.
event_actor: The username of the event's creator.
has_dependents: Whether the event has dependent events.
Returns:
True if user can delete the event, False otherwise.
"""
if is_admin(user):
return True
# Recorder can only delete own events without dependents
return user.username == event_actor and not has_dependents
def require_auth(handler: Callable) -> Callable:
"""Decorator that requires authentication.
Returns 401 response if no authenticated user in request scope.
Args:
handler: The route handler to wrap.
Returns:
Wrapped handler that checks for authentication.
"""
@wraps(handler)
async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response:
user = get_current_user(req)
if user is None:
raise AuthenticationError("Authentication required")
return await handler(req, *args, **kwargs)
return wrapper
def require_role(*roles: UserRole) -> Callable[[Callable], Callable]:
"""Decorator factory that requires specific role(s).
Returns 401 if not authenticated, 403 if role not in allowed roles.
Args:
roles: One or more UserRole values that are allowed.
Returns:
Decorator that checks for required role.
Example:
@require_role(UserRole.ADMIN)
async def admin_only_route(req):
...
"""
def decorator(handler: Callable) -> Callable:
@wraps(handler)
async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response:
user = get_current_user(req)
if user is None:
raise AuthenticationError("Authentication required")
if user.role not in roles:
raise AuthorizationError(f"Required role: {', '.join(r.value for r in roles)}")
return await handler(req, *args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,42 @@
# ABOUTME: HTTP exception types for the web layer.
# ABOUTME: Maps domain errors to appropriate HTTP status codes.
class AuthenticationError(Exception):
"""Raised when authentication fails.
HTTP Status: 401 Unauthorized
Causes: Missing auth header, unknown user, inactive user.
"""
pass
class AuthorizationError(Exception):
"""Raised when authorization fails.
HTTP Status: 403 Forbidden
Causes: User lacks required role, cannot edit/delete event.
"""
pass
class CSRFValidationError(Exception):
"""Raised when CSRF validation fails.
HTTP Status: 403 Forbidden
Causes: Missing/mismatched token, invalid Origin/Referer.
"""
pass
class UntrustedProxyError(Exception):
"""Raised when request comes from untrusted proxy.
HTTP Status: 403 Forbidden
Causes: Request IP not in TRUSTED_PROXY_IPS.
"""
pass

View File

@@ -0,0 +1,277 @@
# ABOUTME: Middleware functions for authentication, CSRF, and request logging.
# ABOUTME: Implements Beforeware pattern for FastHTML request processing.
import json
import time
from urllib.parse import urlparse
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from animaltrack.config import Settings
from animaltrack.id_gen import generate_id
from animaltrack.repositories.users import UserRepository
# Safe HTTP methods that don't require CSRF protection
SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
def is_safe_method(method: str) -> bool:
"""Check if HTTP method is safe (doesn't require CSRF protection).
Args:
method: The HTTP method (e.g., "GET", "POST").
Returns:
True if method is safe, False otherwise.
"""
return method.upper() in SAFE_METHODS
def validate_csrf_token(cookie_token: str | None, header_token: str | None) -> bool:
"""Validate that CSRF tokens match.
Args:
cookie_token: Token from cookie.
header_token: Token from header.
Returns:
True if tokens match and are non-empty, False otherwise.
"""
if cookie_token is None or header_token is None:
return False
if not cookie_token or not header_token:
return False
return cookie_token == header_token
def validate_origin(origin: str | None, expected_host: str) -> bool:
"""Validate Origin header matches expected host.
Args:
origin: The Origin header value (e.g., "https://example.com").
expected_host: The expected host (e.g., "example.com" or "example.com:3366").
Returns:
True if origin matches expected host, False otherwise.
"""
if origin is None or not origin:
return False
try:
parsed = urlparse(origin)
origin_host = parsed.netloc
return origin_host == expected_host
except Exception:
return False
def validate_referer(referer: str | None, expected_host: str) -> bool:
"""Validate Referer header matches expected host.
Args:
referer: The Referer header value (e.g., "https://example.com/page").
expected_host: The expected host (e.g., "example.com" or "example.com:3366").
Returns:
True if referer host matches expected host, False otherwise.
"""
if referer is None or not referer:
return False
try:
parsed = urlparse(referer)
# parsed.netloc includes port if present
referer_host = parsed.netloc
if not referer_host:
return False
return referer_host == expected_host
except Exception:
return False
def get_client_ip(req: Request) -> str:
"""Extract client IP from request, respecting X-Forwarded-For.
Args:
req: The Starlette request object.
Returns:
The client IP address.
"""
# Check X-Forwarded-For header first (set by reverse proxy)
forwarded_for = req.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2"
# The first one is the original client
return forwarded_for.split(",")[0].strip()
# Fall back to direct connection IP
if req.client:
return req.client.host
return "unknown"
def is_trusted_proxy(req: Request, settings: Settings) -> bool:
"""Check if request comes from a trusted proxy IP.
Args:
req: The Starlette request object.
settings: Application settings with trusted_proxy_ips.
Returns:
True if request is from trusted proxy, False otherwise.
"""
trusted_ips = settings.trusted_proxy_ips
if not trusted_ips:
# If no trusted IPs configured, reject all (fail-secure)
return False
# Get the immediate connection IP (not X-Forwarded-For)
if req.client:
client_ip = req.client.host
else:
return False
return client_ip in trusted_ips
def get_expected_host(req: Request, settings: Settings) -> str:
"""Get the expected host for Origin/Referer validation.
Uses X-Forwarded-Host if present, otherwise request host.
Args:
req: The Starlette request object.
settings: Application settings.
Returns:
The expected host string.
"""
forwarded_host = req.headers.get("x-forwarded-host")
if forwarded_host:
return forwarded_host
return req.headers.get("host", "")
def request_id_before(req: Request) -> None:
"""Generate unique request_id and attach to request scope.
Args:
req: The Starlette request object.
"""
req.scope["request_id"] = generate_id()
req.scope["request_start_time"] = time.time()
def auth_before(req: Request, settings: Settings, db) -> Response | None:
"""Extract and validate authentication from proxy headers.
Validates:
- Request comes from trusted proxy IP
- Auth header is present
- User exists and is active in database
Args:
req: The Starlette request object.
settings: Application settings.
db: Database connection.
Returns:
None to continue processing, or Response to short-circuit.
"""
# Check trusted proxy
if not is_trusted_proxy(req, settings):
return PlainTextResponse("Forbidden: Request not from trusted proxy", status_code=403)
# Extract username from auth header
username = req.headers.get(settings.auth_header_name.lower())
if not username:
return PlainTextResponse("Unauthorized: Missing auth header", status_code=401)
# Look up user in database
user_repo = UserRepository(db)
user = user_repo.get(username)
if user is None:
return PlainTextResponse("Unauthorized: Unknown user", status_code=401)
if not user.active:
return PlainTextResponse("Unauthorized: Inactive user", status_code=401)
# Store user in scope for access by handlers
req.scope["auth"] = user
return None
def csrf_before(req: Request, settings: Settings) -> Response | None:
"""Validate CSRF token on unsafe HTTP methods.
Validates:
1. CSRF cookie present and matches header
2. Origin or Referer matches expected host
Args:
req: The Starlette request object.
settings: Application settings.
Returns:
None to continue processing, or Response to short-circuit.
"""
# Skip CSRF check for safe methods
if is_safe_method(req.method):
return None
# Get CSRF tokens
cookie_token = req.cookies.get(settings.csrf_cookie_name)
header_token = req.headers.get("x-csrf-token")
# Validate tokens match
if not validate_csrf_token(cookie_token, header_token):
return PlainTextResponse("Forbidden: Invalid CSRF token", status_code=403)
# Validate Origin or Referer
expected_host = get_expected_host(req, settings)
origin = req.headers.get("origin")
referer = req.headers.get("referer")
if not validate_origin(origin, expected_host):
# Fall back to Referer check
if not validate_referer(referer, expected_host):
return PlainTextResponse("Forbidden: Invalid Origin/Referer", status_code=403)
return None
def logging_after(
req: Request, resp: Response, settings: Settings, event_id: str | None = None
) -> None:
"""Log request in NDJSON format after response.
Format: {ts, level, route, actor, ip, method, status, duration_ms, request_id, event_id}
Args:
req: The Starlette request object.
resp: The response object.
settings: Application settings.
event_id: Optional event ID if an event was created.
"""
start_time = req.scope.get("request_start_time", time.time())
duration_ms = int((time.time() - start_time) * 1000)
user = req.scope.get("auth")
actor = user.username if user else None
log_entry = {
"ts": int(time.time() * 1000),
"level": "info",
"route": req.url.path,
"actor": actor,
"ip": get_client_ip(req),
"method": req.method,
"status": resp.status_code,
"duration_ms": duration_ms,
"request_id": req.scope.get("request_id"),
}
if event_id:
log_entry["event_id"] = event_id
# Output as NDJSON (one JSON object per line)
print(json.dumps(log_entry), flush=True)

141
tests/test_web_app.py Normal file
View File

@@ -0,0 +1,141 @@
# ABOUTME: Integration tests for FastHTML application creation.
# ABOUTME: Tests app factory, middleware wiring, and route configuration.
import os
import pytest
def make_test_settings(
csrf_secret: str = "test-secret",
trusted_proxy_ips: str = "127.0.0.1",
auth_header_name: str = "X-Oidc-Username",
):
"""Create Settings for testing by setting env vars temporarily."""
from animaltrack.config import Settings
old_env = os.environ.copy()
try:
os.environ["CSRF_SECRET"] = csrf_secret
os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips
os.environ["AUTH_HEADER_NAME"] = auth_header_name
return Settings()
finally:
os.environ.clear()
os.environ.update(old_env)
class TestCreateApp:
"""Tests for the create_app factory function."""
def test_creates_app_with_provided_settings(self, seeded_db):
"""create_app(settings=...) uses provided settings."""
from animaltrack.web.app import create_app
settings = make_test_settings()
app, rt = create_app(settings=settings, db=seeded_db)
assert app is not None
assert rt is not None
assert app.state.settings is settings
assert app.state.db is seeded_db
def test_app_has_db_on_state(self, seeded_db):
"""Database accessible via app.state.db."""
from animaltrack.web.app import create_app
settings = make_test_settings()
app, rt = create_app(settings=settings, db=seeded_db)
assert hasattr(app.state, "db")
assert app.state.db is seeded_db
def test_app_has_settings_on_state(self, seeded_db):
"""Settings accessible via app.state.settings."""
from animaltrack.web.app import create_app
settings = make_test_settings()
app, rt = create_app(settings=settings, db=seeded_db)
assert hasattr(app.state, "settings")
assert app.state.settings.csrf_secret == "test-secret"
class TestAppWithTestClient:
"""Integration tests using Starlette TestClient."""
@pytest.fixture
def client(self, seeded_db):
"""Create a test client for the app."""
from starlette.testclient import TestClient
from animaltrack.web.app import create_app
# TestClient uses 'testclient' as the host, so we need to trust it
settings = make_test_settings(trusted_proxy_ips="testclient")
app, rt = create_app(settings=settings, db=seeded_db)
return TestClient(app, raise_server_exceptions=False)
def test_healthz_returns_200(self, client):
"""GET /healthz returns 200 OK."""
resp = client.get("/healthz")
assert resp.status_code == 200
def test_unauthenticated_route_returns_401(self, client):
"""Protected route without auth returns 401."""
# Any route that requires auth
resp = client.get("/")
assert resp.status_code == 401
def test_authenticated_request_succeeds(self, client):
"""Request with valid auth header succeeds."""
resp = client.get(
"/",
headers={"X-Oidc-Username": "ppetru"},
)
# Should get a valid response (200 or 404 if route not implemented yet)
# The key is it shouldn't be 401
assert resp.status_code != 401
def test_untrusted_proxy_returns_403(self, seeded_db):
"""Request from untrusted IP returns 403."""
from starlette.testclient import TestClient
from animaltrack.web.app import create_app
# Configure with a different trusted IP (not 'testclient')
settings = make_test_settings(trusted_proxy_ips="10.0.0.1")
app, rt = create_app(settings=settings, db=seeded_db)
client = TestClient(app, raise_server_exceptions=False)
resp = client.get("/", headers={"X-Oidc-Username": "ppetru"})
# Should fail because TestClient uses host 'testclient', not in trusted list
assert resp.status_code == 403
assert b"not from trusted proxy" in resp.content
def test_csrf_required_on_post(self, client):
"""POST without CSRF token returns 403."""
# POST to / route - should fail CSRF check before reaching handler
resp = client.post(
"/",
headers={"X-Oidc-Username": "ppetru"},
)
assert resp.status_code == 403
assert b"CSRF" in resp.content
def test_csrf_with_valid_tokens_succeeds(self, client):
"""POST with matching CSRF tokens proceeds."""
csrf_token = "test-csrf-token-123"
resp = client.post(
"/",
headers={
"X-Oidc-Username": "ppetru",
"X-CSRF-Token": csrf_token,
"Origin": "http://testserver",
},
cookies={"csrf_token": csrf_token},
)
# Should get through CSRF check (200 or 405 if method not allowed)
# The key is it shouldn't be 403 CSRF error
assert resp.status_code != 403 or b"CSRF" not in resp.content

393
tests/test_web_auth.py Normal file
View File

@@ -0,0 +1,393 @@
# ABOUTME: Tests for web authentication and authorization helpers.
# ABOUTME: Covers role checks, permission logic, and user extraction.
import os
import pytest
from animaltrack.models.reference import User, UserRole
def make_test_settings(
csrf_secret: str = "test-secret",
trusted_proxy_ips: str = "",
auth_header_name: str = "X-Oidc-Username",
):
"""Create Settings for testing by setting env vars temporarily."""
from animaltrack.config import Settings
# Settings loads from env, so we set env vars temporarily
old_env = os.environ.copy()
try:
os.environ["CSRF_SECRET"] = csrf_secret
os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips
os.environ["AUTH_HEADER_NAME"] = auth_header_name
return Settings()
finally:
os.environ.clear()
os.environ.update(old_env)
# Fixtures for test users
@pytest.fixture
def admin_user():
"""Create an admin user for testing."""
return User(
username="admin",
role=UserRole.ADMIN,
active=True,
created_at_utc=1000000,
updated_at_utc=1000000,
)
@pytest.fixture
def recorder_user():
"""Create a recorder user for testing."""
return User(
username="recorder",
role=UserRole.RECORDER,
active=True,
created_at_utc=1000000,
updated_at_utc=1000000,
)
@pytest.fixture
def inactive_user():
"""Create an inactive user for testing."""
return User(
username="inactive",
role=UserRole.RECORDER,
active=False,
created_at_utc=1000000,
updated_at_utc=1000000,
)
class TestIsAdmin:
"""Tests for is_admin helper function."""
def test_returns_true_for_admin(self, admin_user):
"""is_admin returns True for admin role."""
from animaltrack.web.auth import is_admin
assert is_admin(admin_user) is True
def test_returns_false_for_recorder(self, recorder_user):
"""is_admin returns False for recorder role."""
from animaltrack.web.auth import is_admin
assert is_admin(recorder_user) is False
class TestIsRecorder:
"""Tests for is_recorder helper function."""
def test_returns_true_for_recorder(self, recorder_user):
"""is_recorder returns True for recorder role."""
from animaltrack.web.auth import is_recorder
assert is_recorder(recorder_user) is True
def test_returns_false_for_admin(self, admin_user):
"""is_recorder returns False for admin role."""
from animaltrack.web.auth import is_recorder
assert is_recorder(admin_user) is False
class TestCanEditEvent:
"""Tests for can_edit_event permission check."""
def test_admin_can_edit_any_event(self, admin_user):
"""Admin can edit events created by any user."""
from animaltrack.web.auth import can_edit_event
assert can_edit_event(admin_user, "other_user") is True
assert can_edit_event(admin_user, "admin") is True
def test_recorder_can_edit_own_events(self, recorder_user):
"""Recorder can edit their own events."""
from animaltrack.web.auth import can_edit_event
assert can_edit_event(recorder_user, "recorder") is True
def test_recorder_cannot_edit_other_events(self, recorder_user):
"""Recorder cannot edit events created by others."""
from animaltrack.web.auth import can_edit_event
assert can_edit_event(recorder_user, "other_user") is False
class TestCanDeleteEvent:
"""Tests for can_delete_event permission check."""
def test_admin_can_delete_any_event_without_dependents(self, admin_user):
"""Admin can delete any event without dependents."""
from animaltrack.web.auth import can_delete_event
assert can_delete_event(admin_user, "other_user", has_dependents=False) is True
assert can_delete_event(admin_user, "admin", has_dependents=False) is True
def test_admin_can_delete_any_event_with_dependents(self, admin_user):
"""Admin can cascade delete events with dependents."""
from animaltrack.web.auth import can_delete_event
assert can_delete_event(admin_user, "other_user", has_dependents=True) is True
assert can_delete_event(admin_user, "admin", has_dependents=True) is True
def test_recorder_can_delete_own_event_without_dependents(self, recorder_user):
"""Recorder can delete their own events without dependents."""
from animaltrack.web.auth import can_delete_event
assert can_delete_event(recorder_user, "recorder", has_dependents=False) is True
def test_recorder_cannot_delete_own_event_with_dependents(self, recorder_user):
"""Recorder cannot delete their own events if they have dependents."""
from animaltrack.web.auth import can_delete_event
assert can_delete_event(recorder_user, "recorder", has_dependents=True) is False
def test_recorder_cannot_delete_other_events(self, recorder_user):
"""Recorder cannot delete events created by others."""
from animaltrack.web.auth import can_delete_event
assert can_delete_event(recorder_user, "other_user", has_dependents=False) is False
assert can_delete_event(recorder_user, "other_user", has_dependents=True) is False
# HTTP-level auth middleware tests
class TestGetClientIp:
"""Tests for client IP extraction."""
def test_extracts_from_x_forwarded_for(self):
"""Extracts client IP from X-Forwarded-For header."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import get_client_ip
req = MagicMock()
req.headers = {"x-forwarded-for": "203.0.113.50"}
req.client = MagicMock(host="10.0.0.1")
assert get_client_ip(req) == "203.0.113.50"
def test_extracts_first_ip_from_chain(self):
"""Extracts first IP when multiple proxies in chain."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import get_client_ip
req = MagicMock()
req.headers = {"x-forwarded-for": "203.0.113.50, 10.0.0.2, 10.0.0.1"}
req.client = MagicMock(host="10.0.0.1")
assert get_client_ip(req) == "203.0.113.50"
def test_falls_back_to_client_host(self):
"""Falls back to direct client IP when no X-Forwarded-For."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import get_client_ip
req = MagicMock()
req.headers = {}
req.client = MagicMock(host="192.168.1.100")
assert get_client_ip(req) == "192.168.1.100"
def test_returns_unknown_when_no_client(self):
"""Returns 'unknown' when no client info available."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import get_client_ip
req = MagicMock()
req.headers = {}
req.client = None
assert get_client_ip(req) == "unknown"
class TestIsTrustedProxy:
"""Tests for trusted proxy validation."""
def test_accepts_trusted_ip(self):
"""Request from trusted IP proceeds."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import is_trusted_proxy
req = MagicMock()
req.client = MagicMock(host="127.0.0.1")
settings = make_test_settings(trusted_proxy_ips="127.0.0.1,10.0.0.1")
assert is_trusted_proxy(req, settings) is True
def test_rejects_untrusted_ip(self):
"""Returns False for untrusted IP."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import is_trusted_proxy
req = MagicMock()
req.client = MagicMock(host="192.168.1.1")
settings = make_test_settings(trusted_proxy_ips="127.0.0.1,10.0.0.1")
assert is_trusted_proxy(req, settings) is False
def test_empty_trusted_list_rejects_all(self):
"""When no trusted IPs configured, all requests rejected."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import is_trusted_proxy
req = MagicMock()
req.client = MagicMock(host="127.0.0.1")
settings = make_test_settings(trusted_proxy_ips="")
assert is_trusted_proxy(req, settings) is False
def test_rejects_when_no_client(self):
"""Returns False when no client info available."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import is_trusted_proxy
req = MagicMock()
req.client = None
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
assert is_trusted_proxy(req, settings) is False
class TestAuthBefore:
"""Tests for auth_before middleware function."""
def test_rejects_untrusted_proxy(self, seeded_db):
"""Returns 403 when request from untrusted IP."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import auth_before
req = MagicMock()
req.client = MagicMock(host="192.168.1.1")
req.headers = {"x-oidc-username": "ppetru"}
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
resp = auth_before(req, settings, seeded_db)
assert resp is not None
assert resp.status_code == 403
assert b"not from trusted proxy" in resp.body
def test_rejects_missing_auth_header(self, seeded_db):
"""Returns 401 when auth header is missing."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import auth_before
req = MagicMock()
req.client = MagicMock(host="127.0.0.1")
req.headers = {} # No auth header
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
resp = auth_before(req, settings, seeded_db)
assert resp is not None
assert resp.status_code == 401
assert b"Missing auth header" in resp.body
def test_rejects_unknown_user(self, seeded_db):
"""Returns 401 when user not in database."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import auth_before
req = MagicMock()
req.client = MagicMock(host="127.0.0.1")
req.headers = {"x-oidc-username": "nonexistent_user"}
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
resp = auth_before(req, settings, seeded_db)
assert resp is not None
assert resp.status_code == 401
assert b"Unknown user" in resp.body
def test_rejects_inactive_user(self, seeded_db):
"""Returns 401 when user is inactive."""
from unittest.mock import MagicMock
from animaltrack.models.reference import User, UserRole
from animaltrack.repositories.users import UserRepository
from animaltrack.web.middleware import auth_before
# Create an inactive user
user_repo = UserRepository(seeded_db)
inactive = User(
username="inactive_test",
role=UserRole.RECORDER,
active=False,
created_at_utc=1000000,
updated_at_utc=1000000,
)
user_repo.upsert(inactive)
req = MagicMock()
req.client = MagicMock(host="127.0.0.1")
req.headers = {"x-oidc-username": "inactive_test"}
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
resp = auth_before(req, settings, seeded_db)
assert resp is not None
assert resp.status_code == 401
assert b"Inactive user" in resp.body
def test_sets_user_in_scope(self, seeded_db):
"""User object is set in req.scope['auth'] on success."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import auth_before
req = MagicMock()
req.client = MagicMock(host="127.0.0.1")
req.headers = {"x-oidc-username": "ppetru"} # From seeds
req.scope = {}
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
resp = auth_before(req, settings, seeded_db)
assert resp is None # Continue processing
assert "auth" in req.scope
assert req.scope["auth"].username == "ppetru"
assert req.scope["auth"].role == UserRole.ADMIN
def test_respects_custom_auth_header_name(self, seeded_db):
"""Uses AUTH_HEADER_NAME setting for header lookup."""
from unittest.mock import MagicMock
from animaltrack.web.middleware import auth_before
req = MagicMock()
req.client = MagicMock(host="127.0.0.1")
req.headers = {"x-custom-user": "ppetru"}
req.scope = {}
settings = make_test_settings(
trusted_proxy_ips="127.0.0.1",
auth_header_name="X-Custom-User",
)
resp = auth_before(req, settings, seeded_db)
assert resp is None # Continue processing
assert "auth" in req.scope
assert req.scope["auth"].username == "ppetru"

172
tests/test_web_csrf.py Normal file
View File

@@ -0,0 +1,172 @@
# ABOUTME: Tests for CSRF validation logic.
# ABOUTME: Covers token matching, origin/referer validation, and safe methods.
class TestValidateCSRFToken:
"""Tests for CSRF token validation logic."""
def test_accepts_matching_tokens(self):
"""Valid when cookie and header tokens match."""
from animaltrack.web.middleware import validate_csrf_token
assert validate_csrf_token("abc123", "abc123") is True
def test_rejects_mismatched_tokens(self):
"""Invalid when cookie and header tokens differ."""
from animaltrack.web.middleware import validate_csrf_token
assert validate_csrf_token("abc123", "xyz789") is False
def test_rejects_empty_cookie_token(self):
"""Invalid when cookie token is empty."""
from animaltrack.web.middleware import validate_csrf_token
assert validate_csrf_token("", "abc123") is False
def test_rejects_empty_header_token(self):
"""Invalid when header token is empty."""
from animaltrack.web.middleware import validate_csrf_token
assert validate_csrf_token("abc123", "") is False
def test_rejects_none_tokens(self):
"""Invalid when either token is None."""
from animaltrack.web.middleware import validate_csrf_token
assert validate_csrf_token(None, "abc123") is False
assert validate_csrf_token("abc123", None) is False
assert validate_csrf_token(None, None) is False
class TestIsSafeMethod:
"""Tests for HTTP safe method detection."""
def test_get_is_safe(self):
"""GET is a safe method."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("GET") is True
def test_head_is_safe(self):
"""HEAD is a safe method."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("HEAD") is True
def test_options_is_safe(self):
"""OPTIONS is a safe method."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("OPTIONS") is True
def test_post_is_not_safe(self):
"""POST is not a safe method."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("POST") is False
def test_put_is_not_safe(self):
"""PUT is not a safe method."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("PUT") is False
def test_delete_is_not_safe(self):
"""DELETE is not a safe method."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("DELETE") is False
def test_patch_is_not_safe(self):
"""PATCH is not a safe method."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("PATCH") is False
def test_case_insensitive(self):
"""Method check is case-insensitive."""
from animaltrack.web.middleware import is_safe_method
assert is_safe_method("get") is True
assert is_safe_method("Get") is True
assert is_safe_method("post") is False
class TestValidateOrigin:
"""Tests for Origin/Referer header validation."""
def test_accepts_matching_origin(self):
"""Valid when Origin matches expected host."""
from animaltrack.web.middleware import validate_origin
assert validate_origin("https://example.com", "example.com") is True
def test_accepts_matching_origin_with_port(self):
"""Valid when Origin matches expected host with port."""
from animaltrack.web.middleware import validate_origin
assert validate_origin("https://example.com:3366", "example.com:3366") is True
def test_rejects_different_origin(self):
"""Invalid when Origin doesn't match expected host."""
from animaltrack.web.middleware import validate_origin
assert validate_origin("https://evil.com", "example.com") is False
def test_rejects_subdomain_mismatch(self):
"""Invalid when Origin is a subdomain of expected host."""
from animaltrack.web.middleware import validate_origin
assert validate_origin("https://sub.example.com", "example.com") is False
def test_accepts_none_origin(self):
"""None origin returns False (will check Referer instead)."""
from animaltrack.web.middleware import validate_origin
assert validate_origin(None, "example.com") is False
def test_accepts_empty_origin(self):
"""Empty origin returns False."""
from animaltrack.web.middleware import validate_origin
assert validate_origin("", "example.com") is False
class TestValidateReferer:
"""Tests for Referer header validation."""
def test_accepts_matching_referer(self):
"""Valid when Referer host matches expected host."""
from animaltrack.web.middleware import validate_referer
assert validate_referer("https://example.com/page", "example.com") is True
def test_accepts_matching_referer_with_port(self):
"""Valid when Referer matches expected host with port."""
from animaltrack.web.middleware import validate_referer
assert validate_referer("https://example.com:3366/page", "example.com:3366") is True
def test_rejects_different_referer(self):
"""Invalid when Referer host doesn't match expected host."""
from animaltrack.web.middleware import validate_referer
assert validate_referer("https://evil.com/page", "example.com") is False
def test_rejects_none_referer(self):
"""Invalid when Referer is None."""
from animaltrack.web.middleware import validate_referer
assert validate_referer(None, "example.com") is False
def test_rejects_empty_referer(self):
"""Invalid when Referer is empty."""
from animaltrack.web.middleware import validate_referer
assert validate_referer("", "example.com") is False
def test_rejects_malformed_referer(self):
"""Invalid when Referer is malformed."""
from animaltrack.web.middleware import validate_referer
assert validate_referer("not-a-url", "example.com") is False

View File

@@ -0,0 +1,302 @@
# ABOUTME: Tests for request ID generation and logging middleware.
# ABOUTME: Covers ULID generation, scope storage, and NDJSON log format.
import json
import os
import time
from unittest.mock import MagicMock
def make_test_settings(
csrf_secret: str = "test-secret",
trusted_proxy_ips: str = "",
auth_header_name: str = "X-Oidc-Username",
):
"""Create Settings for testing by setting env vars temporarily."""
from animaltrack.config import Settings
old_env = os.environ.copy()
try:
os.environ["CSRF_SECRET"] = csrf_secret
os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips
os.environ["AUTH_HEADER_NAME"] = auth_header_name
return Settings()
finally:
os.environ.clear()
os.environ.update(old_env)
class TestRequestIdBefore:
"""Tests for request ID generation middleware."""
def test_generates_request_id(self):
"""Generates a request_id in the scope."""
from animaltrack.web.middleware import request_id_before
req = MagicMock()
req.scope = {}
request_id_before(req)
assert "request_id" in req.scope
assert req.scope["request_id"] is not None
def test_request_id_is_ulid(self):
"""Request ID is a valid 26-char ULID."""
from animaltrack.web.middleware import request_id_before
req = MagicMock()
req.scope = {}
request_id_before(req)
request_id = req.scope["request_id"]
assert len(request_id) == 26
# ULIDs are base32, should only contain valid chars
valid_chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
assert all(c in valid_chars for c in request_id.upper())
def test_generates_unique_ids(self):
"""Each request gets a unique request_id."""
from animaltrack.web.middleware import request_id_before
ids = set()
for _ in range(100):
req = MagicMock()
req.scope = {}
request_id_before(req)
ids.add(req.scope["request_id"])
assert len(ids) == 100 # All unique
def test_sets_start_time(self):
"""Sets request_start_time in scope for duration calculation."""
from animaltrack.web.middleware import request_id_before
req = MagicMock()
req.scope = {}
before = time.time()
request_id_before(req)
after = time.time()
assert "request_start_time" in req.scope
assert before <= req.scope["request_start_time"] <= after
class TestLoggingAfter:
"""Tests for request logging middleware."""
def test_logs_in_ndjson_format(self, capsys):
"""Log line is valid JSON."""
from animaltrack.web.middleware import logging_after
req = MagicMock()
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
req.url = MagicMock(path="/test")
req.method = "GET"
req.headers = {}
req.client = MagicMock(host="127.0.0.1")
resp = MagicMock()
resp.status_code = 200
settings = make_test_settings()
logging_after(req, resp, settings)
captured = capsys.readouterr()
log_line = captured.out.strip()
# Should be valid JSON
parsed = json.loads(log_line)
assert isinstance(parsed, dict)
def test_log_includes_required_fields(self, capsys):
"""Log includes: ts, level, route, actor, ip, method, status, duration_ms."""
from animaltrack.models.reference import User, UserRole
from animaltrack.web.middleware import logging_after
user = User(
username="testuser",
role=UserRole.RECORDER,
active=True,
created_at_utc=1000000,
updated_at_utc=1000000,
)
req = MagicMock()
req.scope = {
"request_id": "TEST123",
"request_start_time": time.time() - 0.1, # 100ms ago
"auth": user,
}
req.url = MagicMock(path="/test/route")
req.method = "POST"
req.headers = {}
req.client = MagicMock(host="127.0.0.1")
resp = MagicMock()
resp.status_code = 201
settings = make_test_settings()
logging_after(req, resp, settings)
captured = capsys.readouterr()
parsed = json.loads(captured.out.strip())
assert "ts" in parsed
assert parsed["level"] == "info"
assert parsed["route"] == "/test/route"
assert parsed["actor"] == "testuser"
assert parsed["method"] == "POST"
assert parsed["status"] == 201
assert "duration_ms" in parsed
assert parsed["duration_ms"] >= 0
def test_log_includes_request_id(self, capsys):
"""Log includes request_id field."""
from animaltrack.web.middleware import logging_after
req = MagicMock()
req.scope = {"request_id": "01ARZ3NDEKTSV4RRFFQ69G5FAV", "request_start_time": time.time()}
req.url = MagicMock(path="/test")
req.method = "GET"
req.headers = {}
req.client = MagicMock(host="127.0.0.1")
resp = MagicMock()
resp.status_code = 200
settings = make_test_settings()
logging_after(req, resp, settings)
captured = capsys.readouterr()
parsed = json.loads(captured.out.strip())
assert parsed["request_id"] == "01ARZ3NDEKTSV4RRFFQ69G5FAV"
def test_log_includes_event_id_when_provided(self, capsys):
"""Log includes event_id field when provided."""
from animaltrack.web.middleware import logging_after
req = MagicMock()
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
req.url = MagicMock(path="/actions/product-collected")
req.method = "POST"
req.headers = {}
req.client = MagicMock(host="127.0.0.1")
resp = MagicMock()
resp.status_code = 200
settings = make_test_settings()
logging_after(req, resp, settings, event_id="01ARZ3NDEKTSV4RRFFQ69G5EVT")
captured = capsys.readouterr()
parsed = json.loads(captured.out.strip())
assert parsed["event_id"] == "01ARZ3NDEKTSV4RRFFQ69G5EVT"
def test_duration_ms_accurate(self, capsys):
"""duration_ms reflects actual request time."""
from animaltrack.web.middleware import logging_after
req = MagicMock()
# Started 150ms ago
req.scope = {"request_id": "TEST123", "request_start_time": time.time() - 0.150}
req.url = MagicMock(path="/test")
req.method = "GET"
req.headers = {}
req.client = MagicMock(host="127.0.0.1")
resp = MagicMock()
resp.status_code = 200
settings = make_test_settings()
logging_after(req, resp, settings)
captured = capsys.readouterr()
parsed = json.loads(captured.out.strip())
# Should be approximately 150ms (with some tolerance)
assert 100 <= parsed["duration_ms"] <= 250
def test_logs_x_forwarded_for_ip(self, capsys):
"""Log ip field uses X-Forwarded-For when present."""
from animaltrack.web.middleware import logging_after
req = MagicMock()
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
req.url = MagicMock(path="/test")
req.method = "GET"
req.headers = {"x-forwarded-for": "203.0.113.50, 10.0.0.1"}
req.client = MagicMock(host="10.0.0.1") # Proxy IP
resp = MagicMock()
resp.status_code = 200
settings = make_test_settings()
logging_after(req, resp, settings)
captured = capsys.readouterr()
parsed = json.loads(captured.out.strip())
# Should use the client IP from X-Forwarded-For, not the proxy IP
assert parsed["ip"] == "203.0.113.50"
def test_actor_is_none_when_unauthenticated(self, capsys):
"""Log actor is None when no user in scope."""
from animaltrack.web.middleware import logging_after
req = MagicMock()
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
req.url = MagicMock(path="/healthz")
req.method = "GET"
req.headers = {}
req.client = MagicMock(host="127.0.0.1")
resp = MagicMock()
resp.status_code = 200
settings = make_test_settings()
logging_after(req, resp, settings)
captured = capsys.readouterr()
parsed = json.loads(captured.out.strip())
assert parsed["actor"] is None
def test_timestamp_is_milliseconds(self, capsys):
"""Log ts field is milliseconds since epoch."""
from animaltrack.web.middleware import logging_after
req = MagicMock()
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
req.url = MagicMock(path="/test")
req.method = "GET"
req.headers = {}
req.client = MagicMock(host="127.0.0.1")
resp = MagicMock()
resp.status_code = 200
settings = make_test_settings()
before_ms = int(time.time() * 1000)
logging_after(req, resp, settings)
after_ms = int(time.time() * 1000)
captured = capsys.readouterr()
parsed = json.loads(captured.out.strip())
# ts should be between before and after (in milliseconds)
assert before_ms <= parsed["ts"] <= after_ms
# Should be a reasonable timestamp (year 2020+)
assert parsed["ts"] > 1577836800000 # 2020-01-01