From e9d3f34994b2354f608b10de06b086d938a96808 Mon Sep 17 00:00:00 2001 From: Petru Paler Date: Mon, 29 Dec 2025 15:46:19 +0000 Subject: [PATCH] feat: add selection validation with optimistic locking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Step 5.3 - selection validation for optimistic locking: - SelectionContext: holds client's filter, resolved_ids, roster_hash, ts_utc - SelectionDiff: shows added/removed animals on mismatch - SelectionValidationResult: validation result with diff if applicable - validate_selection(): re-resolves at ts_utc, compares hashes, returns diff - SelectionMismatchError: exception for unconfirmed mismatches Tests cover: hash match, mismatch detection, diff correctness, confirmed bypass, from_location_id in hash comparison. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- PLAN.md | 10 +- src/animaltrack/selection/__init__.py | 14 +- src/animaltrack/selection/validation.py | 149 ++++++++ tests/test_selection_validation.py | 458 ++++++++++++++++++++++++ 4 files changed, 625 insertions(+), 6 deletions(-) create mode 100644 src/animaltrack/selection/validation.py create mode 100644 tests/test_selection_validation.py diff --git a/PLAN.md b/PLAN.md index 1c76d27..a65891c 100644 --- a/PLAN.md +++ b/PLAN.md @@ -205,11 +205,11 @@ Check off items as completed. Each phase builds on the previous. - [x] **Commit checkpoint** ### Step 5.3: Optimistic Locking -- [ ] Create `selection/validation.py` for selection validation -- [ ] Re-resolve on submit, compute hash, return diff on mismatch -- [ ] Create SelectionContext and SelectionDiff models -- [ ] Write tests: mismatch detected, diff correct, confirmed bypasses -- [ ] **Commit checkpoint** +- [x] Create `selection/validation.py` for selection validation +- [x] Re-resolve on submit, compute hash, return diff on mismatch +- [x] Create SelectionContext and SelectionDiff models +- [x] Write tests: mismatch detected, diff correct, confirmed bypasses +- [x] **Commit checkpoint** --- diff --git a/src/animaltrack/selection/__init__.py b/src/animaltrack/selection/__init__.py index 7fe2ac8..ab37dd4 100644 --- a/src/animaltrack/selection/__init__.py +++ b/src/animaltrack/selection/__init__.py @@ -1,5 +1,5 @@ # ABOUTME: Selection system for resolving animal sets from filters. -# ABOUTME: Provides parser, AST, resolver, and hash for animal selection contexts. +# ABOUTME: Provides parser, AST, resolver, hash, and validation for animal selection contexts. from animaltrack.selection.ast import FieldFilter, FilterAST from animaltrack.selection.hash import compute_roster_hash @@ -10,15 +10,27 @@ from animaltrack.selection.resolver import ( resolve_filter, resolve_selection, ) +from animaltrack.selection.validation import ( + SelectionContext, + SelectionDiff, + SelectionMismatchError, + SelectionValidationResult, + validate_selection, +) __all__ = [ "FieldFilter", "FilterAST", "ParseError", + "SelectionContext", + "SelectionDiff", + "SelectionMismatchError", "SelectionResolverError", "SelectionResult", + "SelectionValidationResult", "compute_roster_hash", "parse_filter", "resolve_filter", "resolve_selection", + "validate_selection", ] diff --git a/src/animaltrack/selection/validation.py b/src/animaltrack/selection/validation.py new file mode 100644 index 0000000..7d1dd61 --- /dev/null +++ b/src/animaltrack/selection/validation.py @@ -0,0 +1,149 @@ +# ABOUTME: Selection validation with optimistic locking for animal operations. +# ABOUTME: Re-resolves at ts_utc, compares hashes, returns diff on mismatch. + +from dataclasses import dataclass +from typing import Any + +from animaltrack.selection.hash import compute_roster_hash +from animaltrack.selection.parser import parse_filter +from animaltrack.selection.resolver import resolve_filter + + +@dataclass +class SelectionContext: + """Context for validating an animal selection. + + Contains client's filter, resolved IDs, and hash for comparison. + """ + + filter: str # DSL filter string + resolved_ids: list[str] # Client's resolved animal IDs + roster_hash: str # Client's computed hash + ts_utc: int # Point-in-time for resolution + from_location_id: str | None # For move operations (included in hash) + confirmed: bool = False # Override on mismatch + resolver_version: str = "v1" # Fixed version string + + +@dataclass +class SelectionDiff: + """Difference between client and server resolved selections. + + Used to inform client what changed since their resolution. + """ + + added: list[str] # IDs in server resolution but not client + removed: list[str] # IDs in client but not server resolution + server_count: int # Server's resolved count + client_count: int # Client's resolved count + + +@dataclass +class SelectionValidationResult: + """Result of selection validation. + + valid=True means the selection can proceed (match or confirmed). + valid=False means mismatch detected and not confirmed. + """ + + valid: bool # True if match or confirmed + resolved_ids: list[str] # IDs to use for event + roster_hash: str # Hash to use + diff: SelectionDiff | None # None if match, populated if mismatch + + +class SelectionMismatchError(Exception): + """Raised when selection validation fails and not confirmed. + + Contains the validation result with diff for client to display. + """ + + def __init__(self, result: SelectionValidationResult) -> None: + self.result = result + super().__init__("Selection mismatch detected") + + +def _compute_diff( + client_ids: list[str], + server_ids: list[str], +) -> SelectionDiff: + """Compute difference between client and server resolved IDs. + + Args: + client_ids: IDs from client's resolution. + server_ids: IDs from server's resolution. + + Returns: + SelectionDiff with added, removed, and counts. + """ + client_set = set(client_ids) + server_set = set(server_ids) + + added = sorted(server_set - client_set) + removed = sorted(client_set - server_set) + + return SelectionDiff( + added=added, + removed=removed, + server_count=len(server_ids), + client_count=len(client_ids), + ) + + +def validate_selection( + db: Any, + context: SelectionContext, +) -> SelectionValidationResult: + """Validate client selection against server resolution at ts_utc. + + Re-resolves the filter at ts_utc, computes roster hash, and compares + with client's hash. Returns valid=True if hashes match or if + confirmed=True. Returns valid=False with diff if mismatch and not + confirmed. + + Args: + db: Database connection. + context: SelectionContext with client's filter, IDs, and hash. + + Returns: + SelectionValidationResult with validation status and diff if applicable. + """ + # Parse and resolve filter at ts_utc + filter_ast = parse_filter(context.filter) + resolution = resolve_filter(db, filter_ast, context.ts_utc) + + # Compute server's hash (including from_location_id if provided) + server_hash = compute_roster_hash( + resolution.animal_ids, + context.from_location_id, + ) + + # Compare hashes + if server_hash == context.roster_hash: + # Match - proceed with client's IDs + return SelectionValidationResult( + valid=True, + resolved_ids=context.resolved_ids, + roster_hash=context.roster_hash, + diff=None, + ) + + # Mismatch - compute diff + diff = _compute_diff(context.resolved_ids, resolution.animal_ids) + + if context.confirmed: + # Client confirmed mismatch - trust their IDs + return SelectionValidationResult( + valid=True, + resolved_ids=context.resolved_ids, + roster_hash=context.roster_hash, + diff=diff, + ) + + # Mismatch not confirmed - return invalid with server's resolution + return SelectionValidationResult( + valid=False, + resolved_ids=resolution.animal_ids, + roster_hash=server_hash, + diff=diff, + ) diff --git a/tests/test_selection_validation.py b/tests/test_selection_validation.py new file mode 100644 index 0000000..e20ee62 --- /dev/null +++ b/tests/test_selection_validation.py @@ -0,0 +1,458 @@ +# ABOUTME: Tests for selection validation with optimistic locking. +# ABOUTME: Tests hash comparison, diff computation, and confirmed bypass. + +import time + +import pytest + +from animaltrack.events.payloads import AnimalCohortCreatedPayload, AnimalMovedPayload +from animaltrack.events.store import EventStore +from animaltrack.projections import ProjectionRegistry +from animaltrack.projections.animal_registry import AnimalRegistryProjection +from animaltrack.projections.event_animals import EventAnimalsProjection +from animaltrack.projections.intervals import IntervalProjection +from animaltrack.selection import compute_roster_hash, parse_filter, resolve_filter +from animaltrack.selection.validation import ( + SelectionContext, + SelectionDiff, + SelectionMismatchError, + SelectionValidationResult, + validate_selection, +) +from animaltrack.services.animal import AnimalService + + +@pytest.fixture +def event_store(seeded_db): + """Create an EventStore for testing.""" + return EventStore(seeded_db) + + +@pytest.fixture +def projection_registry(seeded_db): + """Create a ProjectionRegistry with animal projections registered.""" + registry = ProjectionRegistry() + registry.register(AnimalRegistryProjection(seeded_db)) + registry.register(EventAnimalsProjection(seeded_db)) + registry.register(IntervalProjection(seeded_db)) + return registry + + +@pytest.fixture +def animal_service(seeded_db, event_store, projection_registry): + """Create an AnimalService for testing.""" + return AnimalService(seeded_db, event_store, projection_registry) + + +@pytest.fixture +def strip1_location_id(seeded_db): + """Get Strip 1 location ID from seeds.""" + row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() + return row[0] + + +@pytest.fixture +def strip2_location_id(seeded_db): + """Get Strip 2 location ID from seeds.""" + row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 2'").fetchone() + return row[0] + + +def make_cohort_payload( + location_id: str, + count: int = 3, + species: str = "duck", + sex: str = "unknown", + life_stage: str = "adult", +) -> AnimalCohortCreatedPayload: + """Create a cohort payload for testing.""" + return AnimalCohortCreatedPayload( + species=species, + count=count, + life_stage=life_stage, + sex=sex, + location_id=location_id, + origin="purchased", + ) + + +# ============================================================================ +# Tests for SelectionContext and SelectionDiff models +# ============================================================================ + + +class TestSelectionContextModel: + """Tests for SelectionContext dataclass.""" + + def test_creates_with_required_fields(self): + """SelectionContext can be created with required fields.""" + ctx = SelectionContext( + filter="species:duck", + resolved_ids=["id1", "id2"], + roster_hash="abc123", + ts_utc=1000000, + from_location_id=None, + ) + + assert ctx.filter == "species:duck" + assert ctx.resolved_ids == ["id1", "id2"] + assert ctx.roster_hash == "abc123" + assert ctx.ts_utc == 1000000 + assert ctx.from_location_id is None + assert ctx.confirmed is False + assert ctx.resolver_version == "v1" + + def test_creates_with_optional_fields(self): + """SelectionContext can be created with optional fields.""" + ctx = SelectionContext( + filter="", + resolved_ids=["id1"], + roster_hash="abc", + ts_utc=1000, + from_location_id="loc123", + confirmed=True, + resolver_version="v1", + ) + + assert ctx.from_location_id == "loc123" + assert ctx.confirmed is True + + +class TestSelectionDiffModel: + """Tests for SelectionDiff dataclass.""" + + def test_creates_with_all_fields(self): + """SelectionDiff can be created with all fields.""" + diff = SelectionDiff( + added=["id3", "id4"], + removed=["id1"], + server_count=3, + client_count=2, + ) + + assert diff.added == ["id3", "id4"] + assert diff.removed == ["id1"] + assert diff.server_count == 3 + assert diff.client_count == 2 + + +# ============================================================================ +# Tests for validate_selection - hash match +# ============================================================================ + + +class TestValidateSelectionHashMatch: + """Tests for validate_selection when hashes match.""" + + def test_returns_valid_when_hashes_match(self, seeded_db, animal_service, strip1_location_id): + """validate_selection returns valid=True when hashes match.""" + # Create a cohort + payload = make_cohort_payload(strip1_location_id, count=3) + ts_utc = int(time.time() * 1000) + animal_service.create_cohort(payload, ts_utc, "test_user") + + # Resolve at same timestamp to get correct hash + filter_ast = parse_filter("species:duck") + resolution = resolve_filter(seeded_db, filter_ast, ts_utc) + + # Create context with matching hash + ctx = SelectionContext( + filter="species:duck", + resolved_ids=resolution.animal_ids, + roster_hash=resolution.roster_hash, + ts_utc=ts_utc, + from_location_id=None, + ) + + result = validate_selection(seeded_db, ctx) + + assert result.valid is True + assert result.resolved_ids == resolution.animal_ids + assert result.roster_hash == resolution.roster_hash + assert result.diff is None + + def test_returns_valid_with_empty_filter(self, seeded_db, animal_service, strip1_location_id): + """validate_selection works with empty filter (match all).""" + # Create a cohort + payload = make_cohort_payload(strip1_location_id, count=2) + ts_utc = int(time.time() * 1000) + animal_service.create_cohort(payload, ts_utc, "test_user") + + # Resolve with empty filter + filter_ast = parse_filter("") + resolution = resolve_filter(seeded_db, filter_ast, ts_utc) + + ctx = SelectionContext( + filter="", + resolved_ids=resolution.animal_ids, + roster_hash=resolution.roster_hash, + ts_utc=ts_utc, + from_location_id=None, + ) + + result = validate_selection(seeded_db, ctx) + + assert result.valid is True + assert result.diff is None + + +# ============================================================================ +# Tests for validate_selection - hash mismatch +# ============================================================================ + + +class TestValidateSelectionHashMismatch: + """Tests for validate_selection when hashes don't match.""" + + def test_returns_invalid_when_animal_added(self, seeded_db, animal_service, strip1_location_id): + """validate_selection returns valid=False when new animal was added.""" + # Create initial cohort + ts_before = int(time.time() * 1000) + payload1 = make_cohort_payload(strip1_location_id, count=2) + animal_service.create_cohort(payload1, ts_before, "test_user") + + # Client resolves at ts_before + filter_ast = parse_filter("species:duck") + client_resolution = resolve_filter(seeded_db, filter_ast, ts_before) + + # Add another animal after client resolution + ts_after = ts_before + 1000 + payload2 = make_cohort_payload(strip1_location_id, count=1) + event2 = animal_service.create_cohort(payload2, ts_after, "test_user") + new_animal_id = event2.entity_refs["animal_ids"][0] + + # Create context with old hash but at ts_after + ctx = SelectionContext( + filter="species:duck", + resolved_ids=client_resolution.animal_ids, + roster_hash=client_resolution.roster_hash, + ts_utc=ts_after, + from_location_id=None, + ) + + result = validate_selection(seeded_db, ctx) + + assert result.valid is False + assert result.diff is not None + assert new_animal_id in result.diff.added + assert result.diff.removed == [] + assert result.diff.server_count == 3 + assert result.diff.client_count == 2 + + def test_returns_invalid_when_animal_moved_away( + self, seeded_db, animal_service, strip1_location_id, strip2_location_id + ): + """validate_selection returns valid=False when animal moved to different location.""" + # Create cohort at Strip 1 + ts_create = int(time.time() * 1000) + payload = make_cohort_payload(strip1_location_id, count=3) + event = animal_service.create_cohort(payload, ts_create, "test_user") + animal_ids = event.entity_refs["animal_ids"] + + # Client resolves Strip 1 filter + filter_ast = parse_filter("location:'Strip 1'") + client_resolution = resolve_filter(seeded_db, filter_ast, ts_create) + + # Move one animal to Strip 2 + ts_move = ts_create + 1000 + move_payload = AnimalMovedPayload( + resolved_ids=[animal_ids[0]], + from_location_id=strip1_location_id, + to_location_id=strip2_location_id, + ) + animal_service.move_animals(move_payload, ts_move, "test_user") + + # Create context with old hash but at ts_move + ctx = SelectionContext( + filter="location:'Strip 1'", + resolved_ids=client_resolution.animal_ids, + roster_hash=client_resolution.roster_hash, + ts_utc=ts_move, + from_location_id=None, + ) + + result = validate_selection(seeded_db, ctx) + + assert result.valid is False + assert result.diff is not None + assert result.diff.added == [] + assert animal_ids[0] in result.diff.removed + assert result.diff.server_count == 2 + assert result.diff.client_count == 3 + + +class TestValidateSelectionDiffCorrectness: + """Tests for accurate diff computation.""" + + def test_diff_includes_both_added_and_removed( + self, seeded_db, animal_service, strip1_location_id, strip2_location_id + ): + """validate_selection diff correctly shows both added and removed.""" + # Create cohorts at both locations + ts_create = int(time.time() * 1000) + + # 2 ducks at Strip 1 + payload1 = make_cohort_payload(strip1_location_id, count=2, species="duck") + event1 = animal_service.create_cohort(payload1, ts_create, "test_user") + strip1_ducks = event1.entity_refs["animal_ids"] + + # 1 goose at Strip 1 (won't match species:duck filter) + payload2 = make_cohort_payload(strip1_location_id, count=1, species="goose") + event2 = animal_service.create_cohort(payload2, ts_create, "test_user") + strip1_goose = event2.entity_refs["animal_ids"][0] + + # Client thinks they resolved "location:'Strip 1'" but used wrong IDs + # (simulating a bug where client included goose but excluded one duck) + wrong_client_ids = [strip1_ducks[0], strip1_goose] # missing strip1_ducks[1] + wrong_hash = compute_roster_hash(wrong_client_ids) + + ctx = SelectionContext( + filter="species:duck", + resolved_ids=wrong_client_ids, + roster_hash=wrong_hash, + ts_utc=ts_create, + from_location_id=None, + ) + + result = validate_selection(seeded_db, ctx) + + assert result.valid is False + assert result.diff is not None + # Server resolved both ducks, but client had duck[0] + goose + # So: added = duck[1], removed = goose + assert strip1_ducks[1] in result.diff.added + assert strip1_goose in result.diff.removed + + +# ============================================================================ +# Tests for validate_selection - confirmed bypass +# ============================================================================ + + +class TestValidateSelectionConfirmedBypass: + """Tests for confirmed=True bypassing mismatch.""" + + def test_confirmed_true_returns_valid_despite_mismatch( + self, seeded_db, animal_service, strip1_location_id + ): + """validate_selection returns valid=True when confirmed=True despite mismatch.""" + # Create cohort + ts_before = int(time.time() * 1000) + payload1 = make_cohort_payload(strip1_location_id, count=2) + animal_service.create_cohort(payload1, ts_before, "test_user") + + # Client resolves at ts_before + filter_ast = parse_filter("species:duck") + client_resolution = resolve_filter(seeded_db, filter_ast, ts_before) + + # Add another animal + ts_after = ts_before + 1000 + payload2 = make_cohort_payload(strip1_location_id, count=1) + animal_service.create_cohort(payload2, ts_after, "test_user") + + # Create context with old hash, new timestamp, but confirmed=True + ctx = SelectionContext( + filter="species:duck", + resolved_ids=client_resolution.animal_ids, + roster_hash=client_resolution.roster_hash, + ts_utc=ts_after, + from_location_id=None, + confirmed=True, + ) + + result = validate_selection(seeded_db, ctx) + + assert result.valid is True + # Uses client's IDs when confirmed + assert result.resolved_ids == client_resolution.animal_ids + # Still includes diff for transparency + assert result.diff is not None + assert result.diff.server_count == 3 + assert result.diff.client_count == 2 + + +# ============================================================================ +# Tests for from_location_id in hash +# ============================================================================ + + +class TestValidateSelectionFromLocationId: + """Tests for from_location_id inclusion in hash comparison.""" + + def test_from_location_id_included_in_hash_comparison( + self, seeded_db, animal_service, strip1_location_id, strip2_location_id + ): + """validate_selection includes from_location_id in hash when provided.""" + # Create cohort at Strip 1 + ts_utc = int(time.time() * 1000) + payload = make_cohort_payload(strip1_location_id, count=2) + event = animal_service.create_cohort(payload, ts_utc, "test_user") + animal_ids = event.entity_refs["animal_ids"] + + # Compute hash with from_location_id + hash_with_from = compute_roster_hash(sorted(animal_ids), strip1_location_id) + + # Context with from_location_id should match hash_with_from + ctx_with_from = SelectionContext( + filter="", + resolved_ids=sorted(animal_ids), + roster_hash=hash_with_from, + ts_utc=ts_utc, + from_location_id=strip1_location_id, + ) + + result = validate_selection(seeded_db, ctx_with_from) + + assert result.valid is True + assert result.roster_hash == hash_with_from + + def test_mismatched_from_location_id_causes_invalid( + self, seeded_db, animal_service, strip1_location_id, strip2_location_id + ): + """validate_selection returns invalid if from_location_id differs.""" + # Create cohort at Strip 1 + ts_utc = int(time.time() * 1000) + payload = make_cohort_payload(strip1_location_id, count=2) + event = animal_service.create_cohort(payload, ts_utc, "test_user") + animal_ids = event.entity_refs["animal_ids"] + + # Client computed hash without from_location_id + hash_without_from = compute_roster_hash(sorted(animal_ids)) + + # But context says from_location_id is set (hash should include it) + ctx = SelectionContext( + filter="", + resolved_ids=sorted(animal_ids), + roster_hash=hash_without_from, # Wrong hash + ts_utc=ts_utc, + from_location_id=strip1_location_id, # Server will include this + ) + + result = validate_selection(seeded_db, ctx) + + # Hash mismatch because server includes from_location_id + assert result.valid is False + + +# ============================================================================ +# Tests for SelectionMismatchError +# ============================================================================ + + +class TestSelectionMismatchError: + """Tests for SelectionMismatchError exception.""" + + def test_stores_validation_result(self): + """SelectionMismatchError stores the validation result.""" + diff = SelectionDiff(added=["id1"], removed=[], server_count=2, client_count=1) + result = SelectionValidationResult( + valid=False, + resolved_ids=["id1", "id2"], + roster_hash="abc", + diff=diff, + ) + + error = SelectionMismatchError(result) + + assert error.result is result + assert error.result.diff is diff