diff --git a/PLAN.md b/PLAN.md index 76d4133..dd12288 100644 --- a/PLAN.md +++ b/PLAN.md @@ -249,15 +249,15 @@ Check off items as completed. Each phase builds on the previous. ## Phase 7: HTTP API ### Step 7.1: FastHTML App Shell -- [ ] Create `web/app.py` with FastHTML setup -- [ ] Configure HTMX extensions (head-support, preload, etc.) -- [ ] Create `web/middleware.py`: - - [ ] Auth middleware (X-Oidc-Username, TRUSTED_PROXY_IPS) - - [ ] CSRF middleware (cookie + header + Origin/Referer) - - [ ] Request logging (NDJSON format) - - [ ] Request ID generation -- [ ] Create `web/auth.py` with get_current_user, require_role -- [ ] Write tests: auth extraction, CSRF validation, untrusted IP rejection +- [x] Create `web/app.py` with FastHTML setup +- [x] Configure HTMX extensions (head-support, preload, etc.) +- [x] Create `web/middleware.py`: + - [x] Auth middleware (X-Oidc-Username, TRUSTED_PROXY_IPS) + - [x] CSRF middleware (cookie + header + Origin/Referer) + - [x] Request logging (NDJSON format) + - [x] Request ID generation +- [x] Create `web/auth.py` with get_current_user, require_role +- [x] Write tests: auth extraction, CSRF validation, untrusted IP rejection - [ ] **Commit checkpoint** ### Step 7.2: Health & Static Assets diff --git a/src/animaltrack/web/__init__.py b/src/animaltrack/web/__init__.py new file mode 100644 index 0000000..f64fee3 --- /dev/null +++ b/src/animaltrack/web/__init__.py @@ -0,0 +1,6 @@ +# ABOUTME: Web layer package for AnimalTrack. +# ABOUTME: Exports create_app for FastHTML application creation. + +from animaltrack.web.app import create_app + +__all__ = ["create_app"] diff --git a/src/animaltrack/web/app.py b/src/animaltrack/web/app.py new file mode 100644 index 0000000..3e69018 --- /dev/null +++ b/src/animaltrack/web/app.py @@ -0,0 +1,97 @@ +# ABOUTME: FastHTML application factory for AnimalTrack. +# ABOUTME: Configures middleware, routes, and HTMX extensions. + +from __future__ import annotations + +from fasthtml.common import Beforeware, fast_app +from starlette.requests import Request +from starlette.responses import PlainTextResponse + +from animaltrack.config import Settings +from animaltrack.db import get_db +from animaltrack.web.middleware import ( + auth_before, + csrf_before, + request_id_before, +) + + +def create_app( + settings: Settings | None = None, + db=None, +): + """Create and configure the FastHTML application. + + Args: + settings: Application settings. Loads from env if None. + db: Database connection. Creates from settings if None. + + Returns: + Tuple of (app, rt) - the FastHTML app and route decorator. + """ + # Load settings if not provided + if settings is None: + settings = Settings() + + # Create db connection if not provided + if db is None: + db = get_db(settings.db_path) + + # Create beforeware function that combines all middleware + def before(req: Request, sess): + """Combined beforeware function for all middleware.""" + # Generate request ID first + request_id_before(req) + + # Auth middleware + auth_resp = auth_before(req, settings, db) + if auth_resp is not None: + return auth_resp + + # CSRF middleware + csrf_resp = csrf_before(req, settings) + if csrf_resp is not None: + return csrf_resp + + return None + + # Configure beforeware with skip patterns + beforeware = Beforeware( + before, + skip=[ + r"/favicon\.ico", + r"/static/.*", + r".*\.css", + r".*\.js", + r"/healthz", + ], + ) + + # Create FastHTML app with HTMX extensions + app, rt = fast_app( + before=beforeware, + exts=["head-support", "preload"], + ) + + # Store settings and db on app state for access in routes + app.state.settings = settings + app.state.db = db + + # Register healthz route (excluded from beforeware) + @rt("/healthz") + def healthz(): + """Health check endpoint.""" + # Verify database is writable + try: + db.execute("SELECT 1") + return PlainTextResponse("OK", status_code=200) + except Exception as e: + return PlainTextResponse(f"Database error: {e}", status_code=503) + + # Placeholder index route (will be replaced with real UI later) + @rt("/") + def index(): + """Placeholder index route.""" + return PlainTextResponse("AnimalTrack", status_code=200) + + return app, rt diff --git a/src/animaltrack/web/auth.py b/src/animaltrack/web/auth.py new file mode 100644 index 0000000..6f1cedd --- /dev/null +++ b/src/animaltrack/web/auth.py @@ -0,0 +1,140 @@ +# ABOUTME: Authentication and authorization helpers for route handlers. +# ABOUTME: Provides user extraction, role checks, and permission decorators. + +from collections.abc import Callable +from functools import wraps +from typing import Any + +from starlette.requests import Request +from starlette.responses import Response + +from animaltrack.models.reference import User, UserRole +from animaltrack.web.exceptions import AuthenticationError, AuthorizationError + + +def get_current_user(req: Request) -> User | None: + """Get the authenticated user from request scope. + + Args: + req: The Starlette request object. + + Returns: + User object if authenticated, None otherwise. + """ + return req.scope.get("auth") + + +def is_admin(user: User) -> bool: + """Check if user has admin role. + + Args: + user: The user to check. + + Returns: + True if user has admin role, False otherwise. + """ + return user.role == UserRole.ADMIN + + +def is_recorder(user: User) -> bool: + """Check if user has recorder role. + + Args: + user: The user to check. + + Returns: + True if user has recorder role, False otherwise. + """ + return user.role == UserRole.RECORDER + + +def can_edit_event(user: User, event_actor: str) -> bool: + """Check if user can edit an event. + + Admins can edit any event. + Recorders can only edit their own events. + + Args: + user: The user attempting to edit. + event_actor: The username of the event's creator. + + Returns: + True if user can edit the event, False otherwise. + """ + if is_admin(user): + return True + return user.username == event_actor + + +def can_delete_event(user: User, event_actor: str, has_dependents: bool) -> bool: + """Check if user can delete an event. + + Admins can delete any event, including cascade delete with dependents. + Recorders can only delete their own events without dependents. + + Args: + user: The user attempting to delete. + event_actor: The username of the event's creator. + has_dependents: Whether the event has dependent events. + + Returns: + True if user can delete the event, False otherwise. + """ + if is_admin(user): + return True + # Recorder can only delete own events without dependents + return user.username == event_actor and not has_dependents + + +def require_auth(handler: Callable) -> Callable: + """Decorator that requires authentication. + + Returns 401 response if no authenticated user in request scope. + + Args: + handler: The route handler to wrap. + + Returns: + Wrapped handler that checks for authentication. + """ + + @wraps(handler) + async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response: + user = get_current_user(req) + if user is None: + raise AuthenticationError("Authentication required") + return await handler(req, *args, **kwargs) + + return wrapper + + +def require_role(*roles: UserRole) -> Callable[[Callable], Callable]: + """Decorator factory that requires specific role(s). + + Returns 401 if not authenticated, 403 if role not in allowed roles. + + Args: + roles: One or more UserRole values that are allowed. + + Returns: + Decorator that checks for required role. + + Example: + @require_role(UserRole.ADMIN) + async def admin_only_route(req): + ... + """ + + def decorator(handler: Callable) -> Callable: + @wraps(handler) + async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response: + user = get_current_user(req) + if user is None: + raise AuthenticationError("Authentication required") + if user.role not in roles: + raise AuthorizationError(f"Required role: {', '.join(r.value for r in roles)}") + return await handler(req, *args, **kwargs) + + return wrapper + + return decorator diff --git a/src/animaltrack/web/exceptions.py b/src/animaltrack/web/exceptions.py new file mode 100644 index 0000000..4a1b941 --- /dev/null +++ b/src/animaltrack/web/exceptions.py @@ -0,0 +1,42 @@ +# ABOUTME: HTTP exception types for the web layer. +# ABOUTME: Maps domain errors to appropriate HTTP status codes. + + +class AuthenticationError(Exception): + """Raised when authentication fails. + + HTTP Status: 401 Unauthorized + Causes: Missing auth header, unknown user, inactive user. + """ + + pass + + +class AuthorizationError(Exception): + """Raised when authorization fails. + + HTTP Status: 403 Forbidden + Causes: User lacks required role, cannot edit/delete event. + """ + + pass + + +class CSRFValidationError(Exception): + """Raised when CSRF validation fails. + + HTTP Status: 403 Forbidden + Causes: Missing/mismatched token, invalid Origin/Referer. + """ + + pass + + +class UntrustedProxyError(Exception): + """Raised when request comes from untrusted proxy. + + HTTP Status: 403 Forbidden + Causes: Request IP not in TRUSTED_PROXY_IPS. + """ + + pass diff --git a/src/animaltrack/web/middleware.py b/src/animaltrack/web/middleware.py new file mode 100644 index 0000000..0144866 --- /dev/null +++ b/src/animaltrack/web/middleware.py @@ -0,0 +1,277 @@ +# ABOUTME: Middleware functions for authentication, CSRF, and request logging. +# ABOUTME: Implements Beforeware pattern for FastHTML request processing. + +import json +import time +from urllib.parse import urlparse + +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response + +from animaltrack.config import Settings +from animaltrack.id_gen import generate_id +from animaltrack.repositories.users import UserRepository + +# Safe HTTP methods that don't require CSRF protection +SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) + + +def is_safe_method(method: str) -> bool: + """Check if HTTP method is safe (doesn't require CSRF protection). + + Args: + method: The HTTP method (e.g., "GET", "POST"). + + Returns: + True if method is safe, False otherwise. + """ + return method.upper() in SAFE_METHODS + + +def validate_csrf_token(cookie_token: str | None, header_token: str | None) -> bool: + """Validate that CSRF tokens match. + + Args: + cookie_token: Token from cookie. + header_token: Token from header. + + Returns: + True if tokens match and are non-empty, False otherwise. + """ + if cookie_token is None or header_token is None: + return False + if not cookie_token or not header_token: + return False + return cookie_token == header_token + + +def validate_origin(origin: str | None, expected_host: str) -> bool: + """Validate Origin header matches expected host. + + Args: + origin: The Origin header value (e.g., "https://example.com"). + expected_host: The expected host (e.g., "example.com" or "example.com:3366"). + + Returns: + True if origin matches expected host, False otherwise. + """ + if origin is None or not origin: + return False + try: + parsed = urlparse(origin) + origin_host = parsed.netloc + return origin_host == expected_host + except Exception: + return False + + +def validate_referer(referer: str | None, expected_host: str) -> bool: + """Validate Referer header matches expected host. + + Args: + referer: The Referer header value (e.g., "https://example.com/page"). + expected_host: The expected host (e.g., "example.com" or "example.com:3366"). + + Returns: + True if referer host matches expected host, False otherwise. + """ + if referer is None or not referer: + return False + try: + parsed = urlparse(referer) + # parsed.netloc includes port if present + referer_host = parsed.netloc + if not referer_host: + return False + return referer_host == expected_host + except Exception: + return False + + +def get_client_ip(req: Request) -> str: + """Extract client IP from request, respecting X-Forwarded-For. + + Args: + req: The Starlette request object. + + Returns: + The client IP address. + """ + # Check X-Forwarded-For header first (set by reverse proxy) + forwarded_for = req.headers.get("x-forwarded-for") + if forwarded_for: + # X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2" + # The first one is the original client + return forwarded_for.split(",")[0].strip() + + # Fall back to direct connection IP + if req.client: + return req.client.host + return "unknown" + + +def is_trusted_proxy(req: Request, settings: Settings) -> bool: + """Check if request comes from a trusted proxy IP. + + Args: + req: The Starlette request object. + settings: Application settings with trusted_proxy_ips. + + Returns: + True if request is from trusted proxy, False otherwise. + """ + trusted_ips = settings.trusted_proxy_ips + if not trusted_ips: + # If no trusted IPs configured, reject all (fail-secure) + return False + + # Get the immediate connection IP (not X-Forwarded-For) + if req.client: + client_ip = req.client.host + else: + return False + + return client_ip in trusted_ips + + +def get_expected_host(req: Request, settings: Settings) -> str: + """Get the expected host for Origin/Referer validation. + + Uses X-Forwarded-Host if present, otherwise request host. + + Args: + req: The Starlette request object. + settings: Application settings. + + Returns: + The expected host string. + """ + forwarded_host = req.headers.get("x-forwarded-host") + if forwarded_host: + return forwarded_host + return req.headers.get("host", "") + + +def request_id_before(req: Request) -> None: + """Generate unique request_id and attach to request scope. + + Args: + req: The Starlette request object. + """ + req.scope["request_id"] = generate_id() + req.scope["request_start_time"] = time.time() + + +def auth_before(req: Request, settings: Settings, db) -> Response | None: + """Extract and validate authentication from proxy headers. + + Validates: + - Request comes from trusted proxy IP + - Auth header is present + - User exists and is active in database + + Args: + req: The Starlette request object. + settings: Application settings. + db: Database connection. + + Returns: + None to continue processing, or Response to short-circuit. + """ + # Check trusted proxy + if not is_trusted_proxy(req, settings): + return PlainTextResponse("Forbidden: Request not from trusted proxy", status_code=403) + + # Extract username from auth header + username = req.headers.get(settings.auth_header_name.lower()) + if not username: + return PlainTextResponse("Unauthorized: Missing auth header", status_code=401) + + # Look up user in database + user_repo = UserRepository(db) + user = user_repo.get(username) + if user is None: + return PlainTextResponse("Unauthorized: Unknown user", status_code=401) + if not user.active: + return PlainTextResponse("Unauthorized: Inactive user", status_code=401) + + # Store user in scope for access by handlers + req.scope["auth"] = user + return None + + +def csrf_before(req: Request, settings: Settings) -> Response | None: + """Validate CSRF token on unsafe HTTP methods. + + Validates: + 1. CSRF cookie present and matches header + 2. Origin or Referer matches expected host + + Args: + req: The Starlette request object. + settings: Application settings. + + Returns: + None to continue processing, or Response to short-circuit. + """ + # Skip CSRF check for safe methods + if is_safe_method(req.method): + return None + + # Get CSRF tokens + cookie_token = req.cookies.get(settings.csrf_cookie_name) + header_token = req.headers.get("x-csrf-token") + + # Validate tokens match + if not validate_csrf_token(cookie_token, header_token): + return PlainTextResponse("Forbidden: Invalid CSRF token", status_code=403) + + # Validate Origin or Referer + expected_host = get_expected_host(req, settings) + origin = req.headers.get("origin") + referer = req.headers.get("referer") + + if not validate_origin(origin, expected_host): + # Fall back to Referer check + if not validate_referer(referer, expected_host): + return PlainTextResponse("Forbidden: Invalid Origin/Referer", status_code=403) + + return None + + +def logging_after( + req: Request, resp: Response, settings: Settings, event_id: str | None = None +) -> None: + """Log request in NDJSON format after response. + + Format: {ts, level, route, actor, ip, method, status, duration_ms, request_id, event_id} + + Args: + req: The Starlette request object. + resp: The response object. + settings: Application settings. + event_id: Optional event ID if an event was created. + """ + start_time = req.scope.get("request_start_time", time.time()) + duration_ms = int((time.time() - start_time) * 1000) + + user = req.scope.get("auth") + actor = user.username if user else None + + log_entry = { + "ts": int(time.time() * 1000), + "level": "info", + "route": req.url.path, + "actor": actor, + "ip": get_client_ip(req), + "method": req.method, + "status": resp.status_code, + "duration_ms": duration_ms, + "request_id": req.scope.get("request_id"), + } + + if event_id: + log_entry["event_id"] = event_id + + # Output as NDJSON (one JSON object per line) + print(json.dumps(log_entry), flush=True) diff --git a/tests/test_web_app.py b/tests/test_web_app.py new file mode 100644 index 0000000..5d37d07 --- /dev/null +++ b/tests/test_web_app.py @@ -0,0 +1,141 @@ +# ABOUTME: Integration tests for FastHTML application creation. +# ABOUTME: Tests app factory, middleware wiring, and route configuration. + +import os + +import pytest + + +def make_test_settings( + csrf_secret: str = "test-secret", + trusted_proxy_ips: str = "127.0.0.1", + auth_header_name: str = "X-Oidc-Username", +): + """Create Settings for testing by setting env vars temporarily.""" + from animaltrack.config import Settings + + old_env = os.environ.copy() + try: + os.environ["CSRF_SECRET"] = csrf_secret + os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips + os.environ["AUTH_HEADER_NAME"] = auth_header_name + return Settings() + finally: + os.environ.clear() + os.environ.update(old_env) + + +class TestCreateApp: + """Tests for the create_app factory function.""" + + def test_creates_app_with_provided_settings(self, seeded_db): + """create_app(settings=...) uses provided settings.""" + from animaltrack.web.app import create_app + + settings = make_test_settings() + + app, rt = create_app(settings=settings, db=seeded_db) + + assert app is not None + assert rt is not None + assert app.state.settings is settings + assert app.state.db is seeded_db + + def test_app_has_db_on_state(self, seeded_db): + """Database accessible via app.state.db.""" + from animaltrack.web.app import create_app + + settings = make_test_settings() + app, rt = create_app(settings=settings, db=seeded_db) + + assert hasattr(app.state, "db") + assert app.state.db is seeded_db + + def test_app_has_settings_on_state(self, seeded_db): + """Settings accessible via app.state.settings.""" + from animaltrack.web.app import create_app + + settings = make_test_settings() + app, rt = create_app(settings=settings, db=seeded_db) + + assert hasattr(app.state, "settings") + assert app.state.settings.csrf_secret == "test-secret" + + +class TestAppWithTestClient: + """Integration tests using Starlette TestClient.""" + + @pytest.fixture + def client(self, seeded_db): + """Create a test client for the app.""" + from starlette.testclient import TestClient + + from animaltrack.web.app import create_app + + # TestClient uses 'testclient' as the host, so we need to trust it + settings = make_test_settings(trusted_proxy_ips="testclient") + app, rt = create_app(settings=settings, db=seeded_db) + return TestClient(app, raise_server_exceptions=False) + + def test_healthz_returns_200(self, client): + """GET /healthz returns 200 OK.""" + resp = client.get("/healthz") + assert resp.status_code == 200 + + def test_unauthenticated_route_returns_401(self, client): + """Protected route without auth returns 401.""" + # Any route that requires auth + resp = client.get("/") + assert resp.status_code == 401 + + def test_authenticated_request_succeeds(self, client): + """Request with valid auth header succeeds.""" + resp = client.get( + "/", + headers={"X-Oidc-Username": "ppetru"}, + ) + # Should get a valid response (200 or 404 if route not implemented yet) + # The key is it shouldn't be 401 + assert resp.status_code != 401 + + def test_untrusted_proxy_returns_403(self, seeded_db): + """Request from untrusted IP returns 403.""" + from starlette.testclient import TestClient + + from animaltrack.web.app import create_app + + # Configure with a different trusted IP (not 'testclient') + settings = make_test_settings(trusted_proxy_ips="10.0.0.1") + app, rt = create_app(settings=settings, db=seeded_db) + client = TestClient(app, raise_server_exceptions=False) + + resp = client.get("/", headers={"X-Oidc-Username": "ppetru"}) + # Should fail because TestClient uses host 'testclient', not in trusted list + assert resp.status_code == 403 + assert b"not from trusted proxy" in resp.content + + def test_csrf_required_on_post(self, client): + """POST without CSRF token returns 403.""" + # POST to / route - should fail CSRF check before reaching handler + resp = client.post( + "/", + headers={"X-Oidc-Username": "ppetru"}, + ) + assert resp.status_code == 403 + assert b"CSRF" in resp.content + + def test_csrf_with_valid_tokens_succeeds(self, client): + """POST with matching CSRF tokens proceeds.""" + csrf_token = "test-csrf-token-123" + resp = client.post( + "/", + headers={ + "X-Oidc-Username": "ppetru", + "X-CSRF-Token": csrf_token, + "Origin": "http://testserver", + }, + cookies={"csrf_token": csrf_token}, + ) + # Should get through CSRF check (200 or 405 if method not allowed) + # The key is it shouldn't be 403 CSRF error + assert resp.status_code != 403 or b"CSRF" not in resp.content diff --git a/tests/test_web_auth.py b/tests/test_web_auth.py new file mode 100644 index 0000000..7fdd485 --- /dev/null +++ b/tests/test_web_auth.py @@ -0,0 +1,393 @@ +# ABOUTME: Tests for web authentication and authorization helpers. +# ABOUTME: Covers role checks, permission logic, and user extraction. + +import os + +import pytest + +from animaltrack.models.reference import User, UserRole + + +def make_test_settings( + csrf_secret: str = "test-secret", + trusted_proxy_ips: str = "", + auth_header_name: str = "X-Oidc-Username", +): + """Create Settings for testing by setting env vars temporarily.""" + from animaltrack.config import Settings + + # Settings loads from env, so we set env vars temporarily + old_env = os.environ.copy() + try: + os.environ["CSRF_SECRET"] = csrf_secret + os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips + os.environ["AUTH_HEADER_NAME"] = auth_header_name + return Settings() + finally: + os.environ.clear() + os.environ.update(old_env) + + +# Fixtures for test users +@pytest.fixture +def admin_user(): + """Create an admin user for testing.""" + return User( + username="admin", + role=UserRole.ADMIN, + active=True, + created_at_utc=1000000, + updated_at_utc=1000000, + ) + + +@pytest.fixture +def recorder_user(): + """Create a recorder user for testing.""" + return User( + username="recorder", + role=UserRole.RECORDER, + active=True, + created_at_utc=1000000, + updated_at_utc=1000000, + ) + + +@pytest.fixture +def inactive_user(): + """Create an inactive user for testing.""" + return User( + username="inactive", + role=UserRole.RECORDER, + active=False, + created_at_utc=1000000, + updated_at_utc=1000000, + ) + + +class TestIsAdmin: + """Tests for is_admin helper function.""" + + def test_returns_true_for_admin(self, admin_user): + """is_admin returns True for admin role.""" + from animaltrack.web.auth import is_admin + + assert is_admin(admin_user) is True + + def test_returns_false_for_recorder(self, recorder_user): + """is_admin returns False for recorder role.""" + from animaltrack.web.auth import is_admin + + assert is_admin(recorder_user) is False + + +class TestIsRecorder: + """Tests for is_recorder helper function.""" + + def test_returns_true_for_recorder(self, recorder_user): + """is_recorder returns True for recorder role.""" + from animaltrack.web.auth import is_recorder + + assert is_recorder(recorder_user) is True + + def test_returns_false_for_admin(self, admin_user): + """is_recorder returns False for admin role.""" + from animaltrack.web.auth import is_recorder + + assert is_recorder(admin_user) is False + + +class TestCanEditEvent: + """Tests for can_edit_event permission check.""" + + def test_admin_can_edit_any_event(self, admin_user): + """Admin can edit events created by any user.""" + from animaltrack.web.auth import can_edit_event + + assert can_edit_event(admin_user, "other_user") is True + assert can_edit_event(admin_user, "admin") is True + + def test_recorder_can_edit_own_events(self, recorder_user): + """Recorder can edit their own events.""" + from animaltrack.web.auth import can_edit_event + + assert can_edit_event(recorder_user, "recorder") is True + + def test_recorder_cannot_edit_other_events(self, recorder_user): + """Recorder cannot edit events created by others.""" + from animaltrack.web.auth import can_edit_event + + assert can_edit_event(recorder_user, "other_user") is False + + +class TestCanDeleteEvent: + """Tests for can_delete_event permission check.""" + + def test_admin_can_delete_any_event_without_dependents(self, admin_user): + """Admin can delete any event without dependents.""" + from animaltrack.web.auth import can_delete_event + + assert can_delete_event(admin_user, "other_user", has_dependents=False) is True + assert can_delete_event(admin_user, "admin", has_dependents=False) is True + + def test_admin_can_delete_any_event_with_dependents(self, admin_user): + """Admin can cascade delete events with dependents.""" + from animaltrack.web.auth import can_delete_event + + assert can_delete_event(admin_user, "other_user", has_dependents=True) is True + assert can_delete_event(admin_user, "admin", has_dependents=True) is True + + def test_recorder_can_delete_own_event_without_dependents(self, recorder_user): + """Recorder can delete their own events without dependents.""" + from animaltrack.web.auth import can_delete_event + + assert can_delete_event(recorder_user, "recorder", has_dependents=False) is True + + def test_recorder_cannot_delete_own_event_with_dependents(self, recorder_user): + """Recorder cannot delete their own events if they have dependents.""" + from animaltrack.web.auth import can_delete_event + + assert can_delete_event(recorder_user, "recorder", has_dependents=True) is False + + def test_recorder_cannot_delete_other_events(self, recorder_user): + """Recorder cannot delete events created by others.""" + from animaltrack.web.auth import can_delete_event + + assert can_delete_event(recorder_user, "other_user", has_dependents=False) is False + assert can_delete_event(recorder_user, "other_user", has_dependents=True) is False + + +# HTTP-level auth middleware tests + + +class TestGetClientIp: + """Tests for client IP extraction.""" + + def test_extracts_from_x_forwarded_for(self): + """Extracts client IP from X-Forwarded-For header.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import get_client_ip + + req = MagicMock() + req.headers = {"x-forwarded-for": "203.0.113.50"} + req.client = MagicMock(host="10.0.0.1") + + assert get_client_ip(req) == "203.0.113.50" + + def test_extracts_first_ip_from_chain(self): + """Extracts first IP when multiple proxies in chain.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import get_client_ip + + req = MagicMock() + req.headers = {"x-forwarded-for": "203.0.113.50, 10.0.0.2, 10.0.0.1"} + req.client = MagicMock(host="10.0.0.1") + + assert get_client_ip(req) == "203.0.113.50" + + def test_falls_back_to_client_host(self): + """Falls back to direct client IP when no X-Forwarded-For.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import get_client_ip + + req = MagicMock() + req.headers = {} + req.client = MagicMock(host="192.168.1.100") + + assert get_client_ip(req) == "192.168.1.100" + + def test_returns_unknown_when_no_client(self): + """Returns 'unknown' when no client info available.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import get_client_ip + + req = MagicMock() + req.headers = {} + req.client = None + + assert get_client_ip(req) == "unknown" + + +class TestIsTrustedProxy: + """Tests for trusted proxy validation.""" + + def test_accepts_trusted_ip(self): + """Request from trusted IP proceeds.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import is_trusted_proxy + + req = MagicMock() + req.client = MagicMock(host="127.0.0.1") + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1,10.0.0.1") + + assert is_trusted_proxy(req, settings) is True + + def test_rejects_untrusted_ip(self): + """Returns False for untrusted IP.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import is_trusted_proxy + + req = MagicMock() + req.client = MagicMock(host="192.168.1.1") + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1,10.0.0.1") + + assert is_trusted_proxy(req, settings) is False + + def test_empty_trusted_list_rejects_all(self): + """When no trusted IPs configured, all requests rejected.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import is_trusted_proxy + + req = MagicMock() + req.client = MagicMock(host="127.0.0.1") + + settings = make_test_settings(trusted_proxy_ips="") + + assert is_trusted_proxy(req, settings) is False + + def test_rejects_when_no_client(self): + """Returns False when no client info available.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import is_trusted_proxy + + req = MagicMock() + req.client = None + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1") + + assert is_trusted_proxy(req, settings) is False + + +class TestAuthBefore: + """Tests for auth_before middleware function.""" + + def test_rejects_untrusted_proxy(self, seeded_db): + """Returns 403 when request from untrusted IP.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import auth_before + + req = MagicMock() + req.client = MagicMock(host="192.168.1.1") + req.headers = {"x-oidc-username": "ppetru"} + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1") + + resp = auth_before(req, settings, seeded_db) + assert resp is not None + assert resp.status_code == 403 + assert b"not from trusted proxy" in resp.body + + def test_rejects_missing_auth_header(self, seeded_db): + """Returns 401 when auth header is missing.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import auth_before + + req = MagicMock() + req.client = MagicMock(host="127.0.0.1") + req.headers = {} # No auth header + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1") + + resp = auth_before(req, settings, seeded_db) + assert resp is not None + assert resp.status_code == 401 + assert b"Missing auth header" in resp.body + + def test_rejects_unknown_user(self, seeded_db): + """Returns 401 when user not in database.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import auth_before + + req = MagicMock() + req.client = MagicMock(host="127.0.0.1") + req.headers = {"x-oidc-username": "nonexistent_user"} + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1") + + resp = auth_before(req, settings, seeded_db) + assert resp is not None + assert resp.status_code == 401 + assert b"Unknown user" in resp.body + + def test_rejects_inactive_user(self, seeded_db): + """Returns 401 when user is inactive.""" + from unittest.mock import MagicMock + + from animaltrack.models.reference import User, UserRole + from animaltrack.repositories.users import UserRepository + from animaltrack.web.middleware import auth_before + + # Create an inactive user + user_repo = UserRepository(seeded_db) + inactive = User( + username="inactive_test", + role=UserRole.RECORDER, + active=False, + created_at_utc=1000000, + updated_at_utc=1000000, + ) + user_repo.upsert(inactive) + + req = MagicMock() + req.client = MagicMock(host="127.0.0.1") + req.headers = {"x-oidc-username": "inactive_test"} + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1") + + resp = auth_before(req, settings, seeded_db) + assert resp is not None + assert resp.status_code == 401 + assert b"Inactive user" in resp.body + + def test_sets_user_in_scope(self, seeded_db): + """User object is set in req.scope['auth'] on success.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import auth_before + + req = MagicMock() + req.client = MagicMock(host="127.0.0.1") + req.headers = {"x-oidc-username": "ppetru"} # From seeds + req.scope = {} + + settings = make_test_settings(trusted_proxy_ips="127.0.0.1") + + resp = auth_before(req, settings, seeded_db) + assert resp is None # Continue processing + assert "auth" in req.scope + assert req.scope["auth"].username == "ppetru" + assert req.scope["auth"].role == UserRole.ADMIN + + def test_respects_custom_auth_header_name(self, seeded_db): + """Uses AUTH_HEADER_NAME setting for header lookup.""" + from unittest.mock import MagicMock + + from animaltrack.web.middleware import auth_before + + req = MagicMock() + req.client = MagicMock(host="127.0.0.1") + req.headers = {"x-custom-user": "ppetru"} + req.scope = {} + + settings = make_test_settings( + trusted_proxy_ips="127.0.0.1", + auth_header_name="X-Custom-User", + ) + + resp = auth_before(req, settings, seeded_db) + assert resp is None # Continue processing + assert "auth" in req.scope + assert req.scope["auth"].username == "ppetru" diff --git a/tests/test_web_csrf.py b/tests/test_web_csrf.py new file mode 100644 index 0000000..0fdeef7 --- /dev/null +++ b/tests/test_web_csrf.py @@ -0,0 +1,172 @@ +# ABOUTME: Tests for CSRF validation logic. +# ABOUTME: Covers token matching, origin/referer validation, and safe methods. + + +class TestValidateCSRFToken: + """Tests for CSRF token validation logic.""" + + def test_accepts_matching_tokens(self): + """Valid when cookie and header tokens match.""" + from animaltrack.web.middleware import validate_csrf_token + + assert validate_csrf_token("abc123", "abc123") is True + + def test_rejects_mismatched_tokens(self): + """Invalid when cookie and header tokens differ.""" + from animaltrack.web.middleware import validate_csrf_token + + assert validate_csrf_token("abc123", "xyz789") is False + + def test_rejects_empty_cookie_token(self): + """Invalid when cookie token is empty.""" + from animaltrack.web.middleware import validate_csrf_token + + assert validate_csrf_token("", "abc123") is False + + def test_rejects_empty_header_token(self): + """Invalid when header token is empty.""" + from animaltrack.web.middleware import validate_csrf_token + + assert validate_csrf_token("abc123", "") is False + + def test_rejects_none_tokens(self): + """Invalid when either token is None.""" + from animaltrack.web.middleware import validate_csrf_token + + assert validate_csrf_token(None, "abc123") is False + assert validate_csrf_token("abc123", None) is False + assert validate_csrf_token(None, None) is False + + +class TestIsSafeMethod: + """Tests for HTTP safe method detection.""" + + def test_get_is_safe(self): + """GET is a safe method.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("GET") is True + + def test_head_is_safe(self): + """HEAD is a safe method.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("HEAD") is True + + def test_options_is_safe(self): + """OPTIONS is a safe method.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("OPTIONS") is True + + def test_post_is_not_safe(self): + """POST is not a safe method.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("POST") is False + + def test_put_is_not_safe(self): + """PUT is not a safe method.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("PUT") is False + + def test_delete_is_not_safe(self): + """DELETE is not a safe method.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("DELETE") is False + + def test_patch_is_not_safe(self): + """PATCH is not a safe method.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("PATCH") is False + + def test_case_insensitive(self): + """Method check is case-insensitive.""" + from animaltrack.web.middleware import is_safe_method + + assert is_safe_method("get") is True + assert is_safe_method("Get") is True + assert is_safe_method("post") is False + + +class TestValidateOrigin: + """Tests for Origin/Referer header validation.""" + + def test_accepts_matching_origin(self): + """Valid when Origin matches expected host.""" + from animaltrack.web.middleware import validate_origin + + assert validate_origin("https://example.com", "example.com") is True + + def test_accepts_matching_origin_with_port(self): + """Valid when Origin matches expected host with port.""" + from animaltrack.web.middleware import validate_origin + + assert validate_origin("https://example.com:3366", "example.com:3366") is True + + def test_rejects_different_origin(self): + """Invalid when Origin doesn't match expected host.""" + from animaltrack.web.middleware import validate_origin + + assert validate_origin("https://evil.com", "example.com") is False + + def test_rejects_subdomain_mismatch(self): + """Invalid when Origin is a subdomain of expected host.""" + from animaltrack.web.middleware import validate_origin + + assert validate_origin("https://sub.example.com", "example.com") is False + + def test_accepts_none_origin(self): + """None origin returns False (will check Referer instead).""" + from animaltrack.web.middleware import validate_origin + + assert validate_origin(None, "example.com") is False + + def test_accepts_empty_origin(self): + """Empty origin returns False.""" + from animaltrack.web.middleware import validate_origin + + assert validate_origin("", "example.com") is False + + +class TestValidateReferer: + """Tests for Referer header validation.""" + + def test_accepts_matching_referer(self): + """Valid when Referer host matches expected host.""" + from animaltrack.web.middleware import validate_referer + + assert validate_referer("https://example.com/page", "example.com") is True + + def test_accepts_matching_referer_with_port(self): + """Valid when Referer matches expected host with port.""" + from animaltrack.web.middleware import validate_referer + + assert validate_referer("https://example.com:3366/page", "example.com:3366") is True + + def test_rejects_different_referer(self): + """Invalid when Referer host doesn't match expected host.""" + from animaltrack.web.middleware import validate_referer + + assert validate_referer("https://evil.com/page", "example.com") is False + + def test_rejects_none_referer(self): + """Invalid when Referer is None.""" + from animaltrack.web.middleware import validate_referer + + assert validate_referer(None, "example.com") is False + + def test_rejects_empty_referer(self): + """Invalid when Referer is empty.""" + from animaltrack.web.middleware import validate_referer + + assert validate_referer("", "example.com") is False + + def test_rejects_malformed_referer(self): + """Invalid when Referer is malformed.""" + from animaltrack.web.middleware import validate_referer + + assert validate_referer("not-a-url", "example.com") is False diff --git a/tests/test_web_middleware.py b/tests/test_web_middleware.py new file mode 100644 index 0000000..ae7216a --- /dev/null +++ b/tests/test_web_middleware.py @@ -0,0 +1,302 @@ +# ABOUTME: Tests for request ID generation and logging middleware. +# ABOUTME: Covers ULID generation, scope storage, and NDJSON log format. + +import json +import os +import time +from unittest.mock import MagicMock + + +def make_test_settings( + csrf_secret: str = "test-secret", + trusted_proxy_ips: str = "", + auth_header_name: str = "X-Oidc-Username", +): + """Create Settings for testing by setting env vars temporarily.""" + from animaltrack.config import Settings + + old_env = os.environ.copy() + try: + os.environ["CSRF_SECRET"] = csrf_secret + os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips + os.environ["AUTH_HEADER_NAME"] = auth_header_name + return Settings() + finally: + os.environ.clear() + os.environ.update(old_env) + + +class TestRequestIdBefore: + """Tests for request ID generation middleware.""" + + def test_generates_request_id(self): + """Generates a request_id in the scope.""" + from animaltrack.web.middleware import request_id_before + + req = MagicMock() + req.scope = {} + + request_id_before(req) + + assert "request_id" in req.scope + assert req.scope["request_id"] is not None + + def test_request_id_is_ulid(self): + """Request ID is a valid 26-char ULID.""" + from animaltrack.web.middleware import request_id_before + + req = MagicMock() + req.scope = {} + + request_id_before(req) + + request_id = req.scope["request_id"] + assert len(request_id) == 26 + # ULIDs are base32, should only contain valid chars + valid_chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ" + assert all(c in valid_chars for c in request_id.upper()) + + def test_generates_unique_ids(self): + """Each request gets a unique request_id.""" + from animaltrack.web.middleware import request_id_before + + ids = set() + for _ in range(100): + req = MagicMock() + req.scope = {} + request_id_before(req) + ids.add(req.scope["request_id"]) + + assert len(ids) == 100 # All unique + + def test_sets_start_time(self): + """Sets request_start_time in scope for duration calculation.""" + from animaltrack.web.middleware import request_id_before + + req = MagicMock() + req.scope = {} + + before = time.time() + request_id_before(req) + after = time.time() + + assert "request_start_time" in req.scope + assert before <= req.scope["request_start_time"] <= after + + +class TestLoggingAfter: + """Tests for request logging middleware.""" + + def test_logs_in_ndjson_format(self, capsys): + """Log line is valid JSON.""" + from animaltrack.web.middleware import logging_after + + req = MagicMock() + req.scope = {"request_id": "TEST123", "request_start_time": time.time()} + req.url = MagicMock(path="/test") + req.method = "GET" + req.headers = {} + req.client = MagicMock(host="127.0.0.1") + + resp = MagicMock() + resp.status_code = 200 + + settings = make_test_settings() + + logging_after(req, resp, settings) + + captured = capsys.readouterr() + log_line = captured.out.strip() + + # Should be valid JSON + parsed = json.loads(log_line) + assert isinstance(parsed, dict) + + def test_log_includes_required_fields(self, capsys): + """Log includes: ts, level, route, actor, ip, method, status, duration_ms.""" + from animaltrack.models.reference import User, UserRole + from animaltrack.web.middleware import logging_after + + user = User( + username="testuser", + role=UserRole.RECORDER, + active=True, + created_at_utc=1000000, + updated_at_utc=1000000, + ) + + req = MagicMock() + req.scope = { + "request_id": "TEST123", + "request_start_time": time.time() - 0.1, # 100ms ago + "auth": user, + } + req.url = MagicMock(path="/test/route") + req.method = "POST" + req.headers = {} + req.client = MagicMock(host="127.0.0.1") + + resp = MagicMock() + resp.status_code = 201 + + settings = make_test_settings() + + logging_after(req, resp, settings) + + captured = capsys.readouterr() + parsed = json.loads(captured.out.strip()) + + assert "ts" in parsed + assert parsed["level"] == "info" + assert parsed["route"] == "/test/route" + assert parsed["actor"] == "testuser" + assert parsed["method"] == "POST" + assert parsed["status"] == 201 + assert "duration_ms" in parsed + assert parsed["duration_ms"] >= 0 + + def test_log_includes_request_id(self, capsys): + """Log includes request_id field.""" + from animaltrack.web.middleware import logging_after + + req = MagicMock() + req.scope = {"request_id": "01ARZ3NDEKTSV4RRFFQ69G5FAV", "request_start_time": time.time()} + req.url = MagicMock(path="/test") + req.method = "GET" + req.headers = {} + req.client = MagicMock(host="127.0.0.1") + + resp = MagicMock() + resp.status_code = 200 + + settings = make_test_settings() + + logging_after(req, resp, settings) + + captured = capsys.readouterr() + parsed = json.loads(captured.out.strip()) + + assert parsed["request_id"] == "01ARZ3NDEKTSV4RRFFQ69G5FAV" + + def test_log_includes_event_id_when_provided(self, capsys): + """Log includes event_id field when provided.""" + from animaltrack.web.middleware import logging_after + + req = MagicMock() + req.scope = {"request_id": "TEST123", "request_start_time": time.time()} + req.url = MagicMock(path="/actions/product-collected") + req.method = "POST" + req.headers = {} + req.client = MagicMock(host="127.0.0.1") + + resp = MagicMock() + resp.status_code = 200 + + settings = make_test_settings() + + logging_after(req, resp, settings, event_id="01ARZ3NDEKTSV4RRFFQ69G5EVT") + + captured = capsys.readouterr() + parsed = json.loads(captured.out.strip()) + + assert parsed["event_id"] == "01ARZ3NDEKTSV4RRFFQ69G5EVT" + + def test_duration_ms_accurate(self, capsys): + """duration_ms reflects actual request time.""" + from animaltrack.web.middleware import logging_after + + req = MagicMock() + # Started 150ms ago + req.scope = {"request_id": "TEST123", "request_start_time": time.time() - 0.150} + req.url = MagicMock(path="/test") + req.method = "GET" + req.headers = {} + req.client = MagicMock(host="127.0.0.1") + + resp = MagicMock() + resp.status_code = 200 + + settings = make_test_settings() + + logging_after(req, resp, settings) + + captured = capsys.readouterr() + parsed = json.loads(captured.out.strip()) + + # Should be approximately 150ms (with some tolerance) + assert 100 <= parsed["duration_ms"] <= 250 + + def test_logs_x_forwarded_for_ip(self, capsys): + """Log ip field uses X-Forwarded-For when present.""" + from animaltrack.web.middleware import logging_after + + req = MagicMock() + req.scope = {"request_id": "TEST123", "request_start_time": time.time()} + req.url = MagicMock(path="/test") + req.method = "GET" + req.headers = {"x-forwarded-for": "203.0.113.50, 10.0.0.1"} + req.client = MagicMock(host="10.0.0.1") # Proxy IP + + resp = MagicMock() + resp.status_code = 200 + + settings = make_test_settings() + + logging_after(req, resp, settings) + + captured = capsys.readouterr() + parsed = json.loads(captured.out.strip()) + + # Should use the client IP from X-Forwarded-For, not the proxy IP + assert parsed["ip"] == "203.0.113.50" + + def test_actor_is_none_when_unauthenticated(self, capsys): + """Log actor is None when no user in scope.""" + from animaltrack.web.middleware import logging_after + + req = MagicMock() + req.scope = {"request_id": "TEST123", "request_start_time": time.time()} + req.url = MagicMock(path="/healthz") + req.method = "GET" + req.headers = {} + req.client = MagicMock(host="127.0.0.1") + + resp = MagicMock() + resp.status_code = 200 + + settings = make_test_settings() + + logging_after(req, resp, settings) + + captured = capsys.readouterr() + parsed = json.loads(captured.out.strip()) + + assert parsed["actor"] is None + + def test_timestamp_is_milliseconds(self, capsys): + """Log ts field is milliseconds since epoch.""" + from animaltrack.web.middleware import logging_after + + req = MagicMock() + req.scope = {"request_id": "TEST123", "request_start_time": time.time()} + req.url = MagicMock(path="/test") + req.method = "GET" + req.headers = {} + req.client = MagicMock(host="127.0.0.1") + + resp = MagicMock() + resp.status_code = 200 + + settings = make_test_settings() + + before_ms = int(time.time() * 1000) + logging_after(req, resp, settings) + after_ms = int(time.time() * 1000) + + captured = capsys.readouterr() + parsed = json.loads(captured.out.strip()) + + # ts should be between before and after (in milliseconds) + assert before_ms <= parsed["ts"] <= after_ms + # Should be a reasonable timestamp (year 2020+) + assert parsed["ts"] > 1577836800000 # 2020-01-01