feat: implement FastHTML app shell with auth/CSRF middleware (Step 7.1)
Add web layer foundation: - FastHTML app factory with Beforeware pattern - Auth middleware validating trusted proxy IPs and X-Oidc-Username header - CSRF dual-token validation (cookie + header + Origin/Referer) - Request ID generation (ULID) and NDJSON request logging - Role-based permission helpers (can_edit_event, can_delete_event) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
18
PLAN.md
18
PLAN.md
@@ -249,15 +249,15 @@ Check off items as completed. Each phase builds on the previous.
|
||||
## Phase 7: HTTP API
|
||||
|
||||
### Step 7.1: FastHTML App Shell
|
||||
- [ ] Create `web/app.py` with FastHTML setup
|
||||
- [ ] Configure HTMX extensions (head-support, preload, etc.)
|
||||
- [ ] Create `web/middleware.py`:
|
||||
- [ ] Auth middleware (X-Oidc-Username, TRUSTED_PROXY_IPS)
|
||||
- [ ] CSRF middleware (cookie + header + Origin/Referer)
|
||||
- [ ] Request logging (NDJSON format)
|
||||
- [ ] Request ID generation
|
||||
- [ ] Create `web/auth.py` with get_current_user, require_role
|
||||
- [ ] Write tests: auth extraction, CSRF validation, untrusted IP rejection
|
||||
- [x] Create `web/app.py` with FastHTML setup
|
||||
- [x] Configure HTMX extensions (head-support, preload, etc.)
|
||||
- [x] Create `web/middleware.py`:
|
||||
- [x] Auth middleware (X-Oidc-Username, TRUSTED_PROXY_IPS)
|
||||
- [x] CSRF middleware (cookie + header + Origin/Referer)
|
||||
- [x] Request logging (NDJSON format)
|
||||
- [x] Request ID generation
|
||||
- [x] Create `web/auth.py` with get_current_user, require_role
|
||||
- [x] Write tests: auth extraction, CSRF validation, untrusted IP rejection
|
||||
- [ ] **Commit checkpoint**
|
||||
|
||||
### Step 7.2: Health & Static Assets
|
||||
|
||||
6
src/animaltrack/web/__init__.py
Normal file
6
src/animaltrack/web/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# ABOUTME: Web layer package for AnimalTrack.
|
||||
# ABOUTME: Exports create_app for FastHTML application creation.
|
||||
|
||||
from animaltrack.web.app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
97
src/animaltrack/web/app.py
Normal file
97
src/animaltrack/web/app.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# ABOUTME: FastHTML application factory for AnimalTrack.
|
||||
# ABOUTME: Configures middleware, routes, and HTMX extensions.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fasthtml.common import Beforeware, fast_app
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from animaltrack.config import Settings
|
||||
from animaltrack.db import get_db
|
||||
from animaltrack.web.middleware import (
|
||||
auth_before,
|
||||
csrf_before,
|
||||
request_id_before,
|
||||
)
|
||||
|
||||
|
||||
def create_app(
|
||||
settings: Settings | None = None,
|
||||
db=None,
|
||||
):
|
||||
"""Create and configure the FastHTML application.
|
||||
|
||||
Args:
|
||||
settings: Application settings. Loads from env if None.
|
||||
db: Database connection. Creates from settings if None.
|
||||
|
||||
Returns:
|
||||
Tuple of (app, rt) - the FastHTML app and route decorator.
|
||||
"""
|
||||
# Load settings if not provided
|
||||
if settings is None:
|
||||
settings = Settings()
|
||||
|
||||
# Create db connection if not provided
|
||||
if db is None:
|
||||
db = get_db(settings.db_path)
|
||||
|
||||
# Create beforeware function that combines all middleware
|
||||
def before(req: Request, sess):
|
||||
"""Combined beforeware function for all middleware."""
|
||||
# Generate request ID first
|
||||
request_id_before(req)
|
||||
|
||||
# Auth middleware
|
||||
auth_resp = auth_before(req, settings, db)
|
||||
if auth_resp is not None:
|
||||
return auth_resp
|
||||
|
||||
# CSRF middleware
|
||||
csrf_resp = csrf_before(req, settings)
|
||||
if csrf_resp is not None:
|
||||
return csrf_resp
|
||||
|
||||
return None
|
||||
|
||||
# Configure beforeware with skip patterns
|
||||
beforeware = Beforeware(
|
||||
before,
|
||||
skip=[
|
||||
r"/favicon\.ico",
|
||||
r"/static/.*",
|
||||
r".*\.css",
|
||||
r".*\.js",
|
||||
r"/healthz",
|
||||
],
|
||||
)
|
||||
|
||||
# Create FastHTML app with HTMX extensions
|
||||
app, rt = fast_app(
|
||||
before=beforeware,
|
||||
exts=["head-support", "preload"],
|
||||
)
|
||||
|
||||
# Store settings and db on app state for access in routes
|
||||
app.state.settings = settings
|
||||
app.state.db = db
|
||||
|
||||
# Register healthz route (excluded from beforeware)
|
||||
@rt("/healthz")
|
||||
def healthz():
|
||||
"""Health check endpoint."""
|
||||
# Verify database is writable
|
||||
try:
|
||||
db.execute("SELECT 1")
|
||||
return PlainTextResponse("OK", status_code=200)
|
||||
except Exception as e:
|
||||
return PlainTextResponse(f"Database error: {e}", status_code=503)
|
||||
|
||||
# Placeholder index route (will be replaced with real UI later)
|
||||
@rt("/")
|
||||
def index():
|
||||
"""Placeholder index route."""
|
||||
return PlainTextResponse("AnimalTrack", status_code=200)
|
||||
|
||||
return app, rt
|
||||
140
src/animaltrack/web/auth.py
Normal file
140
src/animaltrack/web/auth.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# ABOUTME: Authentication and authorization helpers for route handlers.
|
||||
# ABOUTME: Provides user extraction, role checks, and permission decorators.
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from animaltrack.models.reference import User, UserRole
|
||||
from animaltrack.web.exceptions import AuthenticationError, AuthorizationError
|
||||
|
||||
|
||||
def get_current_user(req: Request) -> User | None:
|
||||
"""Get the authenticated user from request scope.
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
|
||||
Returns:
|
||||
User object if authenticated, None otherwise.
|
||||
"""
|
||||
return req.scope.get("auth")
|
||||
|
||||
|
||||
def is_admin(user: User) -> bool:
|
||||
"""Check if user has admin role.
|
||||
|
||||
Args:
|
||||
user: The user to check.
|
||||
|
||||
Returns:
|
||||
True if user has admin role, False otherwise.
|
||||
"""
|
||||
return user.role == UserRole.ADMIN
|
||||
|
||||
|
||||
def is_recorder(user: User) -> bool:
|
||||
"""Check if user has recorder role.
|
||||
|
||||
Args:
|
||||
user: The user to check.
|
||||
|
||||
Returns:
|
||||
True if user has recorder role, False otherwise.
|
||||
"""
|
||||
return user.role == UserRole.RECORDER
|
||||
|
||||
|
||||
def can_edit_event(user: User, event_actor: str) -> bool:
|
||||
"""Check if user can edit an event.
|
||||
|
||||
Admins can edit any event.
|
||||
Recorders can only edit their own events.
|
||||
|
||||
Args:
|
||||
user: The user attempting to edit.
|
||||
event_actor: The username of the event's creator.
|
||||
|
||||
Returns:
|
||||
True if user can edit the event, False otherwise.
|
||||
"""
|
||||
if is_admin(user):
|
||||
return True
|
||||
return user.username == event_actor
|
||||
|
||||
|
||||
def can_delete_event(user: User, event_actor: str, has_dependents: bool) -> bool:
|
||||
"""Check if user can delete an event.
|
||||
|
||||
Admins can delete any event, including cascade delete with dependents.
|
||||
Recorders can only delete their own events without dependents.
|
||||
|
||||
Args:
|
||||
user: The user attempting to delete.
|
||||
event_actor: The username of the event's creator.
|
||||
has_dependents: Whether the event has dependent events.
|
||||
|
||||
Returns:
|
||||
True if user can delete the event, False otherwise.
|
||||
"""
|
||||
if is_admin(user):
|
||||
return True
|
||||
# Recorder can only delete own events without dependents
|
||||
return user.username == event_actor and not has_dependents
|
||||
|
||||
|
||||
def require_auth(handler: Callable) -> Callable:
|
||||
"""Decorator that requires authentication.
|
||||
|
||||
Returns 401 response if no authenticated user in request scope.
|
||||
|
||||
Args:
|
||||
handler: The route handler to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped handler that checks for authentication.
|
||||
"""
|
||||
|
||||
@wraps(handler)
|
||||
async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response:
|
||||
user = get_current_user(req)
|
||||
if user is None:
|
||||
raise AuthenticationError("Authentication required")
|
||||
return await handler(req, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_role(*roles: UserRole) -> Callable[[Callable], Callable]:
|
||||
"""Decorator factory that requires specific role(s).
|
||||
|
||||
Returns 401 if not authenticated, 403 if role not in allowed roles.
|
||||
|
||||
Args:
|
||||
roles: One or more UserRole values that are allowed.
|
||||
|
||||
Returns:
|
||||
Decorator that checks for required role.
|
||||
|
||||
Example:
|
||||
@require_role(UserRole.ADMIN)
|
||||
async def admin_only_route(req):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(handler: Callable) -> Callable:
|
||||
@wraps(handler)
|
||||
async def wrapper(req: Request, *args: Any, **kwargs: Any) -> Response:
|
||||
user = get_current_user(req)
|
||||
if user is None:
|
||||
raise AuthenticationError("Authentication required")
|
||||
if user.role not in roles:
|
||||
raise AuthorizationError(f"Required role: {', '.join(r.value for r in roles)}")
|
||||
return await handler(req, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
42
src/animaltrack/web/exceptions.py
Normal file
42
src/animaltrack/web/exceptions.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# ABOUTME: HTTP exception types for the web layer.
|
||||
# ABOUTME: Maps domain errors to appropriate HTTP status codes.
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails.
|
||||
|
||||
HTTP Status: 401 Unauthorized
|
||||
Causes: Missing auth header, unknown user, inactive user.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AuthorizationError(Exception):
|
||||
"""Raised when authorization fails.
|
||||
|
||||
HTTP Status: 403 Forbidden
|
||||
Causes: User lacks required role, cannot edit/delete event.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CSRFValidationError(Exception):
|
||||
"""Raised when CSRF validation fails.
|
||||
|
||||
HTTP Status: 403 Forbidden
|
||||
Causes: Missing/mismatched token, invalid Origin/Referer.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UntrustedProxyError(Exception):
|
||||
"""Raised when request comes from untrusted proxy.
|
||||
|
||||
HTTP Status: 403 Forbidden
|
||||
Causes: Request IP not in TRUSTED_PROXY_IPS.
|
||||
"""
|
||||
|
||||
pass
|
||||
277
src/animaltrack/web/middleware.py
Normal file
277
src/animaltrack/web/middleware.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# ABOUTME: Middleware functions for authentication, CSRF, and request logging.
|
||||
# ABOUTME: Implements Beforeware pattern for FastHTML request processing.
|
||||
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse, Response
|
||||
|
||||
from animaltrack.config import Settings
|
||||
from animaltrack.id_gen import generate_id
|
||||
from animaltrack.repositories.users import UserRepository
|
||||
|
||||
# Safe HTTP methods that don't require CSRF protection
|
||||
SAFE_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})
|
||||
|
||||
|
||||
def is_safe_method(method: str) -> bool:
|
||||
"""Check if HTTP method is safe (doesn't require CSRF protection).
|
||||
|
||||
Args:
|
||||
method: The HTTP method (e.g., "GET", "POST").
|
||||
|
||||
Returns:
|
||||
True if method is safe, False otherwise.
|
||||
"""
|
||||
return method.upper() in SAFE_METHODS
|
||||
|
||||
|
||||
def validate_csrf_token(cookie_token: str | None, header_token: str | None) -> bool:
|
||||
"""Validate that CSRF tokens match.
|
||||
|
||||
Args:
|
||||
cookie_token: Token from cookie.
|
||||
header_token: Token from header.
|
||||
|
||||
Returns:
|
||||
True if tokens match and are non-empty, False otherwise.
|
||||
"""
|
||||
if cookie_token is None or header_token is None:
|
||||
return False
|
||||
if not cookie_token or not header_token:
|
||||
return False
|
||||
return cookie_token == header_token
|
||||
|
||||
|
||||
def validate_origin(origin: str | None, expected_host: str) -> bool:
|
||||
"""Validate Origin header matches expected host.
|
||||
|
||||
Args:
|
||||
origin: The Origin header value (e.g., "https://example.com").
|
||||
expected_host: The expected host (e.g., "example.com" or "example.com:3366").
|
||||
|
||||
Returns:
|
||||
True if origin matches expected host, False otherwise.
|
||||
"""
|
||||
if origin is None or not origin:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(origin)
|
||||
origin_host = parsed.netloc
|
||||
return origin_host == expected_host
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def validate_referer(referer: str | None, expected_host: str) -> bool:
|
||||
"""Validate Referer header matches expected host.
|
||||
|
||||
Args:
|
||||
referer: The Referer header value (e.g., "https://example.com/page").
|
||||
expected_host: The expected host (e.g., "example.com" or "example.com:3366").
|
||||
|
||||
Returns:
|
||||
True if referer host matches expected host, False otherwise.
|
||||
"""
|
||||
if referer is None or not referer:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(referer)
|
||||
# parsed.netloc includes port if present
|
||||
referer_host = parsed.netloc
|
||||
if not referer_host:
|
||||
return False
|
||||
return referer_host == expected_host
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_client_ip(req: Request) -> str:
|
||||
"""Extract client IP from request, respecting X-Forwarded-For.
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
|
||||
Returns:
|
||||
The client IP address.
|
||||
"""
|
||||
# Check X-Forwarded-For header first (set by reverse proxy)
|
||||
forwarded_for = req.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For can contain multiple IPs: "client, proxy1, proxy2"
|
||||
# The first one is the original client
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# Fall back to direct connection IP
|
||||
if req.client:
|
||||
return req.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
def is_trusted_proxy(req: Request, settings: Settings) -> bool:
|
||||
"""Check if request comes from a trusted proxy IP.
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
settings: Application settings with trusted_proxy_ips.
|
||||
|
||||
Returns:
|
||||
True if request is from trusted proxy, False otherwise.
|
||||
"""
|
||||
trusted_ips = settings.trusted_proxy_ips
|
||||
if not trusted_ips:
|
||||
# If no trusted IPs configured, reject all (fail-secure)
|
||||
return False
|
||||
|
||||
# Get the immediate connection IP (not X-Forwarded-For)
|
||||
if req.client:
|
||||
client_ip = req.client.host
|
||||
else:
|
||||
return False
|
||||
|
||||
return client_ip in trusted_ips
|
||||
|
||||
|
||||
def get_expected_host(req: Request, settings: Settings) -> str:
|
||||
"""Get the expected host for Origin/Referer validation.
|
||||
|
||||
Uses X-Forwarded-Host if present, otherwise request host.
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
settings: Application settings.
|
||||
|
||||
Returns:
|
||||
The expected host string.
|
||||
"""
|
||||
forwarded_host = req.headers.get("x-forwarded-host")
|
||||
if forwarded_host:
|
||||
return forwarded_host
|
||||
return req.headers.get("host", "")
|
||||
|
||||
|
||||
def request_id_before(req: Request) -> None:
|
||||
"""Generate unique request_id and attach to request scope.
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
"""
|
||||
req.scope["request_id"] = generate_id()
|
||||
req.scope["request_start_time"] = time.time()
|
||||
|
||||
|
||||
def auth_before(req: Request, settings: Settings, db) -> Response | None:
|
||||
"""Extract and validate authentication from proxy headers.
|
||||
|
||||
Validates:
|
||||
- Request comes from trusted proxy IP
|
||||
- Auth header is present
|
||||
- User exists and is active in database
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
settings: Application settings.
|
||||
db: Database connection.
|
||||
|
||||
Returns:
|
||||
None to continue processing, or Response to short-circuit.
|
||||
"""
|
||||
# Check trusted proxy
|
||||
if not is_trusted_proxy(req, settings):
|
||||
return PlainTextResponse("Forbidden: Request not from trusted proxy", status_code=403)
|
||||
|
||||
# Extract username from auth header
|
||||
username = req.headers.get(settings.auth_header_name.lower())
|
||||
if not username:
|
||||
return PlainTextResponse("Unauthorized: Missing auth header", status_code=401)
|
||||
|
||||
# Look up user in database
|
||||
user_repo = UserRepository(db)
|
||||
user = user_repo.get(username)
|
||||
if user is None:
|
||||
return PlainTextResponse("Unauthorized: Unknown user", status_code=401)
|
||||
if not user.active:
|
||||
return PlainTextResponse("Unauthorized: Inactive user", status_code=401)
|
||||
|
||||
# Store user in scope for access by handlers
|
||||
req.scope["auth"] = user
|
||||
return None
|
||||
|
||||
|
||||
def csrf_before(req: Request, settings: Settings) -> Response | None:
|
||||
"""Validate CSRF token on unsafe HTTP methods.
|
||||
|
||||
Validates:
|
||||
1. CSRF cookie present and matches header
|
||||
2. Origin or Referer matches expected host
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
settings: Application settings.
|
||||
|
||||
Returns:
|
||||
None to continue processing, or Response to short-circuit.
|
||||
"""
|
||||
# Skip CSRF check for safe methods
|
||||
if is_safe_method(req.method):
|
||||
return None
|
||||
|
||||
# Get CSRF tokens
|
||||
cookie_token = req.cookies.get(settings.csrf_cookie_name)
|
||||
header_token = req.headers.get("x-csrf-token")
|
||||
|
||||
# Validate tokens match
|
||||
if not validate_csrf_token(cookie_token, header_token):
|
||||
return PlainTextResponse("Forbidden: Invalid CSRF token", status_code=403)
|
||||
|
||||
# Validate Origin or Referer
|
||||
expected_host = get_expected_host(req, settings)
|
||||
origin = req.headers.get("origin")
|
||||
referer = req.headers.get("referer")
|
||||
|
||||
if not validate_origin(origin, expected_host):
|
||||
# Fall back to Referer check
|
||||
if not validate_referer(referer, expected_host):
|
||||
return PlainTextResponse("Forbidden: Invalid Origin/Referer", status_code=403)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def logging_after(
|
||||
req: Request, resp: Response, settings: Settings, event_id: str | None = None
|
||||
) -> None:
|
||||
"""Log request in NDJSON format after response.
|
||||
|
||||
Format: {ts, level, route, actor, ip, method, status, duration_ms, request_id, event_id}
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
resp: The response object.
|
||||
settings: Application settings.
|
||||
event_id: Optional event ID if an event was created.
|
||||
"""
|
||||
start_time = req.scope.get("request_start_time", time.time())
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
user = req.scope.get("auth")
|
||||
actor = user.username if user else None
|
||||
|
||||
log_entry = {
|
||||
"ts": int(time.time() * 1000),
|
||||
"level": "info",
|
||||
"route": req.url.path,
|
||||
"actor": actor,
|
||||
"ip": get_client_ip(req),
|
||||
"method": req.method,
|
||||
"status": resp.status_code,
|
||||
"duration_ms": duration_ms,
|
||||
"request_id": req.scope.get("request_id"),
|
||||
}
|
||||
|
||||
if event_id:
|
||||
log_entry["event_id"] = event_id
|
||||
|
||||
# Output as NDJSON (one JSON object per line)
|
||||
print(json.dumps(log_entry), flush=True)
|
||||
141
tests/test_web_app.py
Normal file
141
tests/test_web_app.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# ABOUTME: Integration tests for FastHTML application creation.
|
||||
# ABOUTME: Tests app factory, middleware wiring, and route configuration.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def make_test_settings(
|
||||
csrf_secret: str = "test-secret",
|
||||
trusted_proxy_ips: str = "127.0.0.1",
|
||||
auth_header_name: str = "X-Oidc-Username",
|
||||
):
|
||||
"""Create Settings for testing by setting env vars temporarily."""
|
||||
from animaltrack.config import Settings
|
||||
|
||||
old_env = os.environ.copy()
|
||||
try:
|
||||
os.environ["CSRF_SECRET"] = csrf_secret
|
||||
os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips
|
||||
os.environ["AUTH_HEADER_NAME"] = auth_header_name
|
||||
return Settings()
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(old_env)
|
||||
|
||||
|
||||
class TestCreateApp:
|
||||
"""Tests for the create_app factory function."""
|
||||
|
||||
def test_creates_app_with_provided_settings(self, seeded_db):
|
||||
"""create_app(settings=...) uses provided settings."""
|
||||
from animaltrack.web.app import create_app
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
app, rt = create_app(settings=settings, db=seeded_db)
|
||||
|
||||
assert app is not None
|
||||
assert rt is not None
|
||||
assert app.state.settings is settings
|
||||
assert app.state.db is seeded_db
|
||||
|
||||
def test_app_has_db_on_state(self, seeded_db):
|
||||
"""Database accessible via app.state.db."""
|
||||
from animaltrack.web.app import create_app
|
||||
|
||||
settings = make_test_settings()
|
||||
app, rt = create_app(settings=settings, db=seeded_db)
|
||||
|
||||
assert hasattr(app.state, "db")
|
||||
assert app.state.db is seeded_db
|
||||
|
||||
def test_app_has_settings_on_state(self, seeded_db):
|
||||
"""Settings accessible via app.state.settings."""
|
||||
from animaltrack.web.app import create_app
|
||||
|
||||
settings = make_test_settings()
|
||||
app, rt = create_app(settings=settings, db=seeded_db)
|
||||
|
||||
assert hasattr(app.state, "settings")
|
||||
assert app.state.settings.csrf_secret == "test-secret"
|
||||
|
||||
|
||||
class TestAppWithTestClient:
|
||||
"""Integration tests using Starlette TestClient."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, seeded_db):
|
||||
"""Create a test client for the app."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from animaltrack.web.app import create_app
|
||||
|
||||
# TestClient uses 'testclient' as the host, so we need to trust it
|
||||
settings = make_test_settings(trusted_proxy_ips="testclient")
|
||||
app, rt = create_app(settings=settings, db=seeded_db)
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def test_healthz_returns_200(self, client):
|
||||
"""GET /healthz returns 200 OK."""
|
||||
resp = client.get("/healthz")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_unauthenticated_route_returns_401(self, client):
|
||||
"""Protected route without auth returns 401."""
|
||||
# Any route that requires auth
|
||||
resp = client.get("/")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_authenticated_request_succeeds(self, client):
|
||||
"""Request with valid auth header succeeds."""
|
||||
resp = client.get(
|
||||
"/",
|
||||
headers={"X-Oidc-Username": "ppetru"},
|
||||
)
|
||||
# Should get a valid response (200 or 404 if route not implemented yet)
|
||||
# The key is it shouldn't be 401
|
||||
assert resp.status_code != 401
|
||||
|
||||
def test_untrusted_proxy_returns_403(self, seeded_db):
|
||||
"""Request from untrusted IP returns 403."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from animaltrack.web.app import create_app
|
||||
|
||||
# Configure with a different trusted IP (not 'testclient')
|
||||
settings = make_test_settings(trusted_proxy_ips="10.0.0.1")
|
||||
app, rt = create_app(settings=settings, db=seeded_db)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
resp = client.get("/", headers={"X-Oidc-Username": "ppetru"})
|
||||
# Should fail because TestClient uses host 'testclient', not in trusted list
|
||||
assert resp.status_code == 403
|
||||
assert b"not from trusted proxy" in resp.content
|
||||
|
||||
def test_csrf_required_on_post(self, client):
|
||||
"""POST without CSRF token returns 403."""
|
||||
# POST to / route - should fail CSRF check before reaching handler
|
||||
resp = client.post(
|
||||
"/",
|
||||
headers={"X-Oidc-Username": "ppetru"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert b"CSRF" in resp.content
|
||||
|
||||
def test_csrf_with_valid_tokens_succeeds(self, client):
|
||||
"""POST with matching CSRF tokens proceeds."""
|
||||
csrf_token = "test-csrf-token-123"
|
||||
resp = client.post(
|
||||
"/",
|
||||
headers={
|
||||
"X-Oidc-Username": "ppetru",
|
||||
"X-CSRF-Token": csrf_token,
|
||||
"Origin": "http://testserver",
|
||||
},
|
||||
cookies={"csrf_token": csrf_token},
|
||||
)
|
||||
# Should get through CSRF check (200 or 405 if method not allowed)
|
||||
# The key is it shouldn't be 403 CSRF error
|
||||
assert resp.status_code != 403 or b"CSRF" not in resp.content
|
||||
393
tests/test_web_auth.py
Normal file
393
tests/test_web_auth.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# ABOUTME: Tests for web authentication and authorization helpers.
|
||||
# ABOUTME: Covers role checks, permission logic, and user extraction.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from animaltrack.models.reference import User, UserRole
|
||||
|
||||
|
||||
def make_test_settings(
|
||||
csrf_secret: str = "test-secret",
|
||||
trusted_proxy_ips: str = "",
|
||||
auth_header_name: str = "X-Oidc-Username",
|
||||
):
|
||||
"""Create Settings for testing by setting env vars temporarily."""
|
||||
from animaltrack.config import Settings
|
||||
|
||||
# Settings loads from env, so we set env vars temporarily
|
||||
old_env = os.environ.copy()
|
||||
try:
|
||||
os.environ["CSRF_SECRET"] = csrf_secret
|
||||
os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips
|
||||
os.environ["AUTH_HEADER_NAME"] = auth_header_name
|
||||
return Settings()
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(old_env)
|
||||
|
||||
|
||||
# Fixtures for test users
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
"""Create an admin user for testing."""
|
||||
return User(
|
||||
username="admin",
|
||||
role=UserRole.ADMIN,
|
||||
active=True,
|
||||
created_at_utc=1000000,
|
||||
updated_at_utc=1000000,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recorder_user():
|
||||
"""Create a recorder user for testing."""
|
||||
return User(
|
||||
username="recorder",
|
||||
role=UserRole.RECORDER,
|
||||
active=True,
|
||||
created_at_utc=1000000,
|
||||
updated_at_utc=1000000,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inactive_user():
|
||||
"""Create an inactive user for testing."""
|
||||
return User(
|
||||
username="inactive",
|
||||
role=UserRole.RECORDER,
|
||||
active=False,
|
||||
created_at_utc=1000000,
|
||||
updated_at_utc=1000000,
|
||||
)
|
||||
|
||||
|
||||
class TestIsAdmin:
|
||||
"""Tests for is_admin helper function."""
|
||||
|
||||
def test_returns_true_for_admin(self, admin_user):
|
||||
"""is_admin returns True for admin role."""
|
||||
from animaltrack.web.auth import is_admin
|
||||
|
||||
assert is_admin(admin_user) is True
|
||||
|
||||
def test_returns_false_for_recorder(self, recorder_user):
|
||||
"""is_admin returns False for recorder role."""
|
||||
from animaltrack.web.auth import is_admin
|
||||
|
||||
assert is_admin(recorder_user) is False
|
||||
|
||||
|
||||
class TestIsRecorder:
|
||||
"""Tests for is_recorder helper function."""
|
||||
|
||||
def test_returns_true_for_recorder(self, recorder_user):
|
||||
"""is_recorder returns True for recorder role."""
|
||||
from animaltrack.web.auth import is_recorder
|
||||
|
||||
assert is_recorder(recorder_user) is True
|
||||
|
||||
def test_returns_false_for_admin(self, admin_user):
|
||||
"""is_recorder returns False for admin role."""
|
||||
from animaltrack.web.auth import is_recorder
|
||||
|
||||
assert is_recorder(admin_user) is False
|
||||
|
||||
|
||||
class TestCanEditEvent:
|
||||
"""Tests for can_edit_event permission check."""
|
||||
|
||||
def test_admin_can_edit_any_event(self, admin_user):
|
||||
"""Admin can edit events created by any user."""
|
||||
from animaltrack.web.auth import can_edit_event
|
||||
|
||||
assert can_edit_event(admin_user, "other_user") is True
|
||||
assert can_edit_event(admin_user, "admin") is True
|
||||
|
||||
def test_recorder_can_edit_own_events(self, recorder_user):
|
||||
"""Recorder can edit their own events."""
|
||||
from animaltrack.web.auth import can_edit_event
|
||||
|
||||
assert can_edit_event(recorder_user, "recorder") is True
|
||||
|
||||
def test_recorder_cannot_edit_other_events(self, recorder_user):
|
||||
"""Recorder cannot edit events created by others."""
|
||||
from animaltrack.web.auth import can_edit_event
|
||||
|
||||
assert can_edit_event(recorder_user, "other_user") is False
|
||||
|
||||
|
||||
class TestCanDeleteEvent:
|
||||
"""Tests for can_delete_event permission check."""
|
||||
|
||||
def test_admin_can_delete_any_event_without_dependents(self, admin_user):
|
||||
"""Admin can delete any event without dependents."""
|
||||
from animaltrack.web.auth import can_delete_event
|
||||
|
||||
assert can_delete_event(admin_user, "other_user", has_dependents=False) is True
|
||||
assert can_delete_event(admin_user, "admin", has_dependents=False) is True
|
||||
|
||||
def test_admin_can_delete_any_event_with_dependents(self, admin_user):
|
||||
"""Admin can cascade delete events with dependents."""
|
||||
from animaltrack.web.auth import can_delete_event
|
||||
|
||||
assert can_delete_event(admin_user, "other_user", has_dependents=True) is True
|
||||
assert can_delete_event(admin_user, "admin", has_dependents=True) is True
|
||||
|
||||
def test_recorder_can_delete_own_event_without_dependents(self, recorder_user):
|
||||
"""Recorder can delete their own events without dependents."""
|
||||
from animaltrack.web.auth import can_delete_event
|
||||
|
||||
assert can_delete_event(recorder_user, "recorder", has_dependents=False) is True
|
||||
|
||||
def test_recorder_cannot_delete_own_event_with_dependents(self, recorder_user):
|
||||
"""Recorder cannot delete their own events if they have dependents."""
|
||||
from animaltrack.web.auth import can_delete_event
|
||||
|
||||
assert can_delete_event(recorder_user, "recorder", has_dependents=True) is False
|
||||
|
||||
def test_recorder_cannot_delete_other_events(self, recorder_user):
|
||||
"""Recorder cannot delete events created by others."""
|
||||
from animaltrack.web.auth import can_delete_event
|
||||
|
||||
assert can_delete_event(recorder_user, "other_user", has_dependents=False) is False
|
||||
assert can_delete_event(recorder_user, "other_user", has_dependents=True) is False
|
||||
|
||||
|
||||
# HTTP-level auth middleware tests
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
"""Tests for client IP extraction."""
|
||||
|
||||
def test_extracts_from_x_forwarded_for(self):
|
||||
"""Extracts client IP from X-Forwarded-For header."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "203.0.113.50"}
|
||||
req.client = MagicMock(host="10.0.0.1")
|
||||
|
||||
assert get_client_ip(req) == "203.0.113.50"
|
||||
|
||||
def test_extracts_first_ip_from_chain(self):
|
||||
"""Extracts first IP when multiple proxies in chain."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.headers = {"x-forwarded-for": "203.0.113.50, 10.0.0.2, 10.0.0.1"}
|
||||
req.client = MagicMock(host="10.0.0.1")
|
||||
|
||||
assert get_client_ip(req) == "203.0.113.50"
|
||||
|
||||
def test_falls_back_to_client_host(self):
|
||||
"""Falls back to direct client IP when no X-Forwarded-For."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="192.168.1.100")
|
||||
|
||||
assert get_client_ip(req) == "192.168.1.100"
|
||||
|
||||
def test_returns_unknown_when_no_client(self):
|
||||
"""Returns 'unknown' when no client info available."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
req.client = None
|
||||
|
||||
assert get_client_ip(req) == "unknown"
|
||||
|
||||
|
||||
class TestIsTrustedProxy:
|
||||
"""Tests for trusted proxy validation."""
|
||||
|
||||
def test_accepts_trusted_ip(self):
|
||||
"""Request from trusted IP proceeds."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import is_trusted_proxy
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1,10.0.0.1")
|
||||
|
||||
assert is_trusted_proxy(req, settings) is True
|
||||
|
||||
def test_rejects_untrusted_ip(self):
|
||||
"""Returns False for untrusted IP."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import is_trusted_proxy
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="192.168.1.1")
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1,10.0.0.1")
|
||||
|
||||
assert is_trusted_proxy(req, settings) is False
|
||||
|
||||
def test_empty_trusted_list_rejects_all(self):
|
||||
"""When no trusted IPs configured, all requests rejected."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import is_trusted_proxy
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="")
|
||||
|
||||
assert is_trusted_proxy(req, settings) is False
|
||||
|
||||
def test_rejects_when_no_client(self):
|
||||
"""Returns False when no client info available."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import is_trusted_proxy
|
||||
|
||||
req = MagicMock()
|
||||
req.client = None
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
|
||||
|
||||
assert is_trusted_proxy(req, settings) is False
|
||||
|
||||
|
||||
class TestAuthBefore:
|
||||
"""Tests for auth_before middleware function."""
|
||||
|
||||
def test_rejects_untrusted_proxy(self, seeded_db):
|
||||
"""Returns 403 when request from untrusted IP."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import auth_before
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="192.168.1.1")
|
||||
req.headers = {"x-oidc-username": "ppetru"}
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
|
||||
|
||||
resp = auth_before(req, settings, seeded_db)
|
||||
assert resp is not None
|
||||
assert resp.status_code == 403
|
||||
assert b"not from trusted proxy" in resp.body
|
||||
|
||||
def test_rejects_missing_auth_header(self, seeded_db):
|
||||
"""Returns 401 when auth header is missing."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import auth_before
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
req.headers = {} # No auth header
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
|
||||
|
||||
resp = auth_before(req, settings, seeded_db)
|
||||
assert resp is not None
|
||||
assert resp.status_code == 401
|
||||
assert b"Missing auth header" in resp.body
|
||||
|
||||
def test_rejects_unknown_user(self, seeded_db):
|
||||
"""Returns 401 when user not in database."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import auth_before
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
req.headers = {"x-oidc-username": "nonexistent_user"}
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
|
||||
|
||||
resp = auth_before(req, settings, seeded_db)
|
||||
assert resp is not None
|
||||
assert resp.status_code == 401
|
||||
assert b"Unknown user" in resp.body
|
||||
|
||||
def test_rejects_inactive_user(self, seeded_db):
|
||||
"""Returns 401 when user is inactive."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.models.reference import User, UserRole
|
||||
from animaltrack.repositories.users import UserRepository
|
||||
from animaltrack.web.middleware import auth_before
|
||||
|
||||
# Create an inactive user
|
||||
user_repo = UserRepository(seeded_db)
|
||||
inactive = User(
|
||||
username="inactive_test",
|
||||
role=UserRole.RECORDER,
|
||||
active=False,
|
||||
created_at_utc=1000000,
|
||||
updated_at_utc=1000000,
|
||||
)
|
||||
user_repo.upsert(inactive)
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
req.headers = {"x-oidc-username": "inactive_test"}
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
|
||||
|
||||
resp = auth_before(req, settings, seeded_db)
|
||||
assert resp is not None
|
||||
assert resp.status_code == 401
|
||||
assert b"Inactive user" in resp.body
|
||||
|
||||
def test_sets_user_in_scope(self, seeded_db):
|
||||
"""User object is set in req.scope['auth'] on success."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import auth_before
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
req.headers = {"x-oidc-username": "ppetru"} # From seeds
|
||||
req.scope = {}
|
||||
|
||||
settings = make_test_settings(trusted_proxy_ips="127.0.0.1")
|
||||
|
||||
resp = auth_before(req, settings, seeded_db)
|
||||
assert resp is None # Continue processing
|
||||
assert "auth" in req.scope
|
||||
assert req.scope["auth"].username == "ppetru"
|
||||
assert req.scope["auth"].role == UserRole.ADMIN
|
||||
|
||||
def test_respects_custom_auth_header_name(self, seeded_db):
|
||||
"""Uses AUTH_HEADER_NAME setting for header lookup."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from animaltrack.web.middleware import auth_before
|
||||
|
||||
req = MagicMock()
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
req.headers = {"x-custom-user": "ppetru"}
|
||||
req.scope = {}
|
||||
|
||||
settings = make_test_settings(
|
||||
trusted_proxy_ips="127.0.0.1",
|
||||
auth_header_name="X-Custom-User",
|
||||
)
|
||||
|
||||
resp = auth_before(req, settings, seeded_db)
|
||||
assert resp is None # Continue processing
|
||||
assert "auth" in req.scope
|
||||
assert req.scope["auth"].username == "ppetru"
|
||||
172
tests/test_web_csrf.py
Normal file
172
tests/test_web_csrf.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# ABOUTME: Tests for CSRF validation logic.
|
||||
# ABOUTME: Covers token matching, origin/referer validation, and safe methods.
|
||||
|
||||
|
||||
class TestValidateCSRFToken:
|
||||
"""Tests for CSRF token validation logic."""
|
||||
|
||||
def test_accepts_matching_tokens(self):
|
||||
"""Valid when cookie and header tokens match."""
|
||||
from animaltrack.web.middleware import validate_csrf_token
|
||||
|
||||
assert validate_csrf_token("abc123", "abc123") is True
|
||||
|
||||
def test_rejects_mismatched_tokens(self):
|
||||
"""Invalid when cookie and header tokens differ."""
|
||||
from animaltrack.web.middleware import validate_csrf_token
|
||||
|
||||
assert validate_csrf_token("abc123", "xyz789") is False
|
||||
|
||||
def test_rejects_empty_cookie_token(self):
|
||||
"""Invalid when cookie token is empty."""
|
||||
from animaltrack.web.middleware import validate_csrf_token
|
||||
|
||||
assert validate_csrf_token("", "abc123") is False
|
||||
|
||||
def test_rejects_empty_header_token(self):
|
||||
"""Invalid when header token is empty."""
|
||||
from animaltrack.web.middleware import validate_csrf_token
|
||||
|
||||
assert validate_csrf_token("abc123", "") is False
|
||||
|
||||
def test_rejects_none_tokens(self):
|
||||
"""Invalid when either token is None."""
|
||||
from animaltrack.web.middleware import validate_csrf_token
|
||||
|
||||
assert validate_csrf_token(None, "abc123") is False
|
||||
assert validate_csrf_token("abc123", None) is False
|
||||
assert validate_csrf_token(None, None) is False
|
||||
|
||||
|
||||
class TestIsSafeMethod:
|
||||
"""Tests for HTTP safe method detection."""
|
||||
|
||||
def test_get_is_safe(self):
|
||||
"""GET is a safe method."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("GET") is True
|
||||
|
||||
def test_head_is_safe(self):
|
||||
"""HEAD is a safe method."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("HEAD") is True
|
||||
|
||||
def test_options_is_safe(self):
|
||||
"""OPTIONS is a safe method."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("OPTIONS") is True
|
||||
|
||||
def test_post_is_not_safe(self):
|
||||
"""POST is not a safe method."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("POST") is False
|
||||
|
||||
def test_put_is_not_safe(self):
|
||||
"""PUT is not a safe method."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("PUT") is False
|
||||
|
||||
def test_delete_is_not_safe(self):
|
||||
"""DELETE is not a safe method."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("DELETE") is False
|
||||
|
||||
def test_patch_is_not_safe(self):
|
||||
"""PATCH is not a safe method."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("PATCH") is False
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Method check is case-insensitive."""
|
||||
from animaltrack.web.middleware import is_safe_method
|
||||
|
||||
assert is_safe_method("get") is True
|
||||
assert is_safe_method("Get") is True
|
||||
assert is_safe_method("post") is False
|
||||
|
||||
|
||||
class TestValidateOrigin:
|
||||
"""Tests for Origin/Referer header validation."""
|
||||
|
||||
def test_accepts_matching_origin(self):
|
||||
"""Valid when Origin matches expected host."""
|
||||
from animaltrack.web.middleware import validate_origin
|
||||
|
||||
assert validate_origin("https://example.com", "example.com") is True
|
||||
|
||||
def test_accepts_matching_origin_with_port(self):
|
||||
"""Valid when Origin matches expected host with port."""
|
||||
from animaltrack.web.middleware import validate_origin
|
||||
|
||||
assert validate_origin("https://example.com:3366", "example.com:3366") is True
|
||||
|
||||
def test_rejects_different_origin(self):
|
||||
"""Invalid when Origin doesn't match expected host."""
|
||||
from animaltrack.web.middleware import validate_origin
|
||||
|
||||
assert validate_origin("https://evil.com", "example.com") is False
|
||||
|
||||
def test_rejects_subdomain_mismatch(self):
|
||||
"""Invalid when Origin is a subdomain of expected host."""
|
||||
from animaltrack.web.middleware import validate_origin
|
||||
|
||||
assert validate_origin("https://sub.example.com", "example.com") is False
|
||||
|
||||
def test_accepts_none_origin(self):
|
||||
"""None origin returns False (will check Referer instead)."""
|
||||
from animaltrack.web.middleware import validate_origin
|
||||
|
||||
assert validate_origin(None, "example.com") is False
|
||||
|
||||
def test_accepts_empty_origin(self):
|
||||
"""Empty origin returns False."""
|
||||
from animaltrack.web.middleware import validate_origin
|
||||
|
||||
assert validate_origin("", "example.com") is False
|
||||
|
||||
|
||||
class TestValidateReferer:
|
||||
"""Tests for Referer header validation."""
|
||||
|
||||
def test_accepts_matching_referer(self):
|
||||
"""Valid when Referer host matches expected host."""
|
||||
from animaltrack.web.middleware import validate_referer
|
||||
|
||||
assert validate_referer("https://example.com/page", "example.com") is True
|
||||
|
||||
def test_accepts_matching_referer_with_port(self):
|
||||
"""Valid when Referer matches expected host with port."""
|
||||
from animaltrack.web.middleware import validate_referer
|
||||
|
||||
assert validate_referer("https://example.com:3366/page", "example.com:3366") is True
|
||||
|
||||
def test_rejects_different_referer(self):
|
||||
"""Invalid when Referer host doesn't match expected host."""
|
||||
from animaltrack.web.middleware import validate_referer
|
||||
|
||||
assert validate_referer("https://evil.com/page", "example.com") is False
|
||||
|
||||
def test_rejects_none_referer(self):
|
||||
"""Invalid when Referer is None."""
|
||||
from animaltrack.web.middleware import validate_referer
|
||||
|
||||
assert validate_referer(None, "example.com") is False
|
||||
|
||||
def test_rejects_empty_referer(self):
|
||||
"""Invalid when Referer is empty."""
|
||||
from animaltrack.web.middleware import validate_referer
|
||||
|
||||
assert validate_referer("", "example.com") is False
|
||||
|
||||
def test_rejects_malformed_referer(self):
|
||||
"""Invalid when Referer is malformed."""
|
||||
from animaltrack.web.middleware import validate_referer
|
||||
|
||||
assert validate_referer("not-a-url", "example.com") is False
|
||||
302
tests/test_web_middleware.py
Normal file
302
tests/test_web_middleware.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# ABOUTME: Tests for request ID generation and logging middleware.
|
||||
# ABOUTME: Covers ULID generation, scope storage, and NDJSON log format.
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def make_test_settings(
|
||||
csrf_secret: str = "test-secret",
|
||||
trusted_proxy_ips: str = "",
|
||||
auth_header_name: str = "X-Oidc-Username",
|
||||
):
|
||||
"""Create Settings for testing by setting env vars temporarily."""
|
||||
from animaltrack.config import Settings
|
||||
|
||||
old_env = os.environ.copy()
|
||||
try:
|
||||
os.environ["CSRF_SECRET"] = csrf_secret
|
||||
os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips
|
||||
os.environ["AUTH_HEADER_NAME"] = auth_header_name
|
||||
return Settings()
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(old_env)
|
||||
|
||||
|
||||
class TestRequestIdBefore:
|
||||
"""Tests for request ID generation middleware."""
|
||||
|
||||
def test_generates_request_id(self):
|
||||
"""Generates a request_id in the scope."""
|
||||
from animaltrack.web.middleware import request_id_before
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {}
|
||||
|
||||
request_id_before(req)
|
||||
|
||||
assert "request_id" in req.scope
|
||||
assert req.scope["request_id"] is not None
|
||||
|
||||
def test_request_id_is_ulid(self):
|
||||
"""Request ID is a valid 26-char ULID."""
|
||||
from animaltrack.web.middleware import request_id_before
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {}
|
||||
|
||||
request_id_before(req)
|
||||
|
||||
request_id = req.scope["request_id"]
|
||||
assert len(request_id) == 26
|
||||
# ULIDs are base32, should only contain valid chars
|
||||
valid_chars = "0123456789ABCDEFGHJKMNPQRSTVWXYZ"
|
||||
assert all(c in valid_chars for c in request_id.upper())
|
||||
|
||||
def test_generates_unique_ids(self):
|
||||
"""Each request gets a unique request_id."""
|
||||
from animaltrack.web.middleware import request_id_before
|
||||
|
||||
ids = set()
|
||||
for _ in range(100):
|
||||
req = MagicMock()
|
||||
req.scope = {}
|
||||
request_id_before(req)
|
||||
ids.add(req.scope["request_id"])
|
||||
|
||||
assert len(ids) == 100 # All unique
|
||||
|
||||
def test_sets_start_time(self):
|
||||
"""Sets request_start_time in scope for duration calculation."""
|
||||
from animaltrack.web.middleware import request_id_before
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {}
|
||||
|
||||
before = time.time()
|
||||
request_id_before(req)
|
||||
after = time.time()
|
||||
|
||||
assert "request_start_time" in req.scope
|
||||
assert before <= req.scope["request_start_time"] <= after
|
||||
|
||||
|
||||
class TestLoggingAfter:
|
||||
"""Tests for request logging middleware."""
|
||||
|
||||
def test_logs_in_ndjson_format(self, capsys):
|
||||
"""Log line is valid JSON."""
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
|
||||
req.url = MagicMock(path="/test")
|
||||
req.method = "GET"
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
logging_after(req, resp, settings)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
log_line = captured.out.strip()
|
||||
|
||||
# Should be valid JSON
|
||||
parsed = json.loads(log_line)
|
||||
assert isinstance(parsed, dict)
|
||||
|
||||
def test_log_includes_required_fields(self, capsys):
|
||||
"""Log includes: ts, level, route, actor, ip, method, status, duration_ms."""
|
||||
from animaltrack.models.reference import User, UserRole
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
user = User(
|
||||
username="testuser",
|
||||
role=UserRole.RECORDER,
|
||||
active=True,
|
||||
created_at_utc=1000000,
|
||||
updated_at_utc=1000000,
|
||||
)
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {
|
||||
"request_id": "TEST123",
|
||||
"request_start_time": time.time() - 0.1, # 100ms ago
|
||||
"auth": user,
|
||||
}
|
||||
req.url = MagicMock(path="/test/route")
|
||||
req.method = "POST"
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 201
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
logging_after(req, resp, settings)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out.strip())
|
||||
|
||||
assert "ts" in parsed
|
||||
assert parsed["level"] == "info"
|
||||
assert parsed["route"] == "/test/route"
|
||||
assert parsed["actor"] == "testuser"
|
||||
assert parsed["method"] == "POST"
|
||||
assert parsed["status"] == 201
|
||||
assert "duration_ms" in parsed
|
||||
assert parsed["duration_ms"] >= 0
|
||||
|
||||
def test_log_includes_request_id(self, capsys):
|
||||
"""Log includes request_id field."""
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {"request_id": "01ARZ3NDEKTSV4RRFFQ69G5FAV", "request_start_time": time.time()}
|
||||
req.url = MagicMock(path="/test")
|
||||
req.method = "GET"
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
logging_after(req, resp, settings)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out.strip())
|
||||
|
||||
assert parsed["request_id"] == "01ARZ3NDEKTSV4RRFFQ69G5FAV"
|
||||
|
||||
def test_log_includes_event_id_when_provided(self, capsys):
|
||||
"""Log includes event_id field when provided."""
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
|
||||
req.url = MagicMock(path="/actions/product-collected")
|
||||
req.method = "POST"
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
logging_after(req, resp, settings, event_id="01ARZ3NDEKTSV4RRFFQ69G5EVT")
|
||||
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out.strip())
|
||||
|
||||
assert parsed["event_id"] == "01ARZ3NDEKTSV4RRFFQ69G5EVT"
|
||||
|
||||
def test_duration_ms_accurate(self, capsys):
|
||||
"""duration_ms reflects actual request time."""
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
req = MagicMock()
|
||||
# Started 150ms ago
|
||||
req.scope = {"request_id": "TEST123", "request_start_time": time.time() - 0.150}
|
||||
req.url = MagicMock(path="/test")
|
||||
req.method = "GET"
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
logging_after(req, resp, settings)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out.strip())
|
||||
|
||||
# Should be approximately 150ms (with some tolerance)
|
||||
assert 100 <= parsed["duration_ms"] <= 250
|
||||
|
||||
def test_logs_x_forwarded_for_ip(self, capsys):
|
||||
"""Log ip field uses X-Forwarded-For when present."""
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
|
||||
req.url = MagicMock(path="/test")
|
||||
req.method = "GET"
|
||||
req.headers = {"x-forwarded-for": "203.0.113.50, 10.0.0.1"}
|
||||
req.client = MagicMock(host="10.0.0.1") # Proxy IP
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
logging_after(req, resp, settings)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out.strip())
|
||||
|
||||
# Should use the client IP from X-Forwarded-For, not the proxy IP
|
||||
assert parsed["ip"] == "203.0.113.50"
|
||||
|
||||
def test_actor_is_none_when_unauthenticated(self, capsys):
|
||||
"""Log actor is None when no user in scope."""
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
|
||||
req.url = MagicMock(path="/healthz")
|
||||
req.method = "GET"
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
logging_after(req, resp, settings)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out.strip())
|
||||
|
||||
assert parsed["actor"] is None
|
||||
|
||||
def test_timestamp_is_milliseconds(self, capsys):
|
||||
"""Log ts field is milliseconds since epoch."""
|
||||
from animaltrack.web.middleware import logging_after
|
||||
|
||||
req = MagicMock()
|
||||
req.scope = {"request_id": "TEST123", "request_start_time": time.time()}
|
||||
req.url = MagicMock(path="/test")
|
||||
req.method = "GET"
|
||||
req.headers = {}
|
||||
req.client = MagicMock(host="127.0.0.1")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
|
||||
settings = make_test_settings()
|
||||
|
||||
before_ms = int(time.time() * 1000)
|
||||
logging_after(req, resp, settings)
|
||||
after_ms = int(time.time() * 1000)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
parsed = json.loads(captured.out.strip())
|
||||
|
||||
# ts should be between before and after (in milliseconds)
|
||||
assert before_ms <= parsed["ts"] <= after_ms
|
||||
# Should be a reasonable timestamp (year 2020+)
|
||||
assert parsed["ts"] > 1577836800000 # 2020-01-01
|
||||
Reference in New Issue
Block a user