feat: add CIDR/netmask support for trusted proxy IPs
TRUSTED_PROXY_IPS now accepts CIDR notation (e.g., 192.168.1.0/24) in addition to exact IP addresses. Supports both IPv4 and IPv6. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# ABOUTME: Application configuration loaded from environment variables.
|
||||
# ABOUTME: Uses Pydantic Settings for validation and type coercion.
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
@@ -44,6 +45,43 @@ class Settings(BaseSettings):
|
||||
"""Parse trusted proxy IPs from comma-separated raw string."""
|
||||
return _parse_comma_separated(self.trusted_proxy_ips_raw)
|
||||
|
||||
@cached_property
|
||||
def trusted_proxy_networks(
|
||||
self,
|
||||
) -> list[ipaddress.IPv4Network | ipaddress.IPv6Network]:
|
||||
"""Parse trusted proxy IPs/CIDRs into network objects.
|
||||
|
||||
Plain IPs become /32 (IPv4) or /128 (IPv6) networks.
|
||||
CIDR notation is parsed directly.
|
||||
Entries that cannot be parsed as IP/CIDR are skipped (handled by
|
||||
trusted_proxy_literals for backwards compatibility).
|
||||
"""
|
||||
networks = []
|
||||
for entry in self.trusted_proxy_ips:
|
||||
try:
|
||||
# ip_network with strict=False allows "192.168.1.1/24" to work
|
||||
# (normalizes to "192.168.1.0/24")
|
||||
network = ipaddress.ip_network(entry, strict=False)
|
||||
networks.append(network)
|
||||
except ValueError:
|
||||
# Not a valid IP/network - will be handled by trusted_proxy_literals
|
||||
pass
|
||||
return networks
|
||||
|
||||
@cached_property
|
||||
def trusted_proxy_literals(self) -> frozenset[str]:
|
||||
"""Get non-IP entries for exact string matching.
|
||||
|
||||
For backwards compatibility with entries like "testclient" in tests.
|
||||
"""
|
||||
literals = set()
|
||||
for entry in self.trusted_proxy_ips:
|
||||
try:
|
||||
ipaddress.ip_network(entry, strict=False)
|
||||
except ValueError:
|
||||
literals.add(entry)
|
||||
return frozenset(literals)
|
||||
|
||||
@field_validator("log_level", mode="before")
|
||||
@classmethod
|
||||
def normalize_and_validate_log_level(cls, v: str) -> str:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# ABOUTME: Middleware functions for authentication, CSRF, and request logging.
|
||||
# ABOUTME: Implements Beforeware pattern for FastHTML request processing.
|
||||
|
||||
import ipaddress
|
||||
import json
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
@@ -112,7 +113,7 @@ def get_client_ip(req: Request) -> str:
|
||||
|
||||
|
||||
def is_trusted_proxy(req: Request, settings: Settings) -> bool:
|
||||
"""Check if request comes from a trusted proxy IP.
|
||||
"""Check if request comes from a trusted proxy IP or CIDR range.
|
||||
|
||||
Args:
|
||||
req: The Starlette request object.
|
||||
@@ -121,18 +122,35 @@ def is_trusted_proxy(req: Request, settings: Settings) -> bool:
|
||||
Returns:
|
||||
True if request is from trusted proxy, False otherwise.
|
||||
"""
|
||||
trusted_ips = settings.trusted_proxy_ips
|
||||
if not trusted_ips:
|
||||
trusted_networks = settings.trusted_proxy_networks
|
||||
trusted_literals = settings.trusted_proxy_literals
|
||||
|
||||
if not trusted_networks and not trusted_literals:
|
||||
# If no trusted IPs configured, reject all (fail-secure)
|
||||
return False
|
||||
|
||||
# Get the immediate connection IP (not X-Forwarded-For)
|
||||
if req.client:
|
||||
client_ip = req.client.host
|
||||
client_ip_str = req.client.host
|
||||
else:
|
||||
return False
|
||||
|
||||
return client_ip in trusted_ips
|
||||
# Check literal matches first (for backwards compatibility with "testclient" etc)
|
||||
if client_ip_str in trusted_literals:
|
||||
return True
|
||||
|
||||
# Try to parse as IP address for network matching
|
||||
try:
|
||||
client_ip = ipaddress.ip_address(client_ip_str)
|
||||
except ValueError:
|
||||
# Not a valid IP and not in literals
|
||||
return False
|
||||
|
||||
for network in trusted_networks:
|
||||
if client_ip in network:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_expected_host(req: Request, settings: Settings) -> str:
|
||||
|
||||
Reference in New Issue
Block a user