# ABOUTME: Tests for selection validation with optimistic locking. # ABOUTME: Tests hash comparison, diff computation, and confirmed bypass. import time import pytest from animaltrack.events.payloads import AnimalCohortCreatedPayload, AnimalMovedPayload from animaltrack.events.store import EventStore from animaltrack.projections import ProjectionRegistry from animaltrack.projections.animal_registry import AnimalRegistryProjection from animaltrack.projections.event_animals import EventAnimalsProjection from animaltrack.projections.intervals import IntervalProjection from animaltrack.selection import compute_roster_hash, parse_filter, resolve_filter from animaltrack.selection.validation import ( SelectionContext, SelectionDiff, SelectionMismatchError, SelectionValidationResult, validate_selection, ) from animaltrack.services.animal import AnimalService @pytest.fixture def event_store(seeded_db): """Create an EventStore for testing.""" return EventStore(seeded_db) @pytest.fixture def projection_registry(seeded_db): """Create a ProjectionRegistry with animal projections registered.""" registry = ProjectionRegistry() registry.register(AnimalRegistryProjection(seeded_db)) registry.register(EventAnimalsProjection(seeded_db)) registry.register(IntervalProjection(seeded_db)) return registry @pytest.fixture def animal_service(seeded_db, event_store, projection_registry): """Create an AnimalService for testing.""" return AnimalService(seeded_db, event_store, projection_registry) @pytest.fixture def strip1_location_id(seeded_db): """Get Strip 1 location ID from seeds.""" row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() return row[0] @pytest.fixture def strip2_location_id(seeded_db): """Get Strip 2 location ID from seeds.""" row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 2'").fetchone() return row[0] def make_cohort_payload( location_id: str, count: int = 3, species: str = "duck", sex: str = "unknown", life_stage: str = "adult", ) -> AnimalCohortCreatedPayload: """Create a cohort payload for testing.""" return AnimalCohortCreatedPayload( species=species, count=count, life_stage=life_stage, sex=sex, location_id=location_id, origin="purchased", ) # ============================================================================ # Tests for SelectionContext and SelectionDiff models # ============================================================================ class TestSelectionContextModel: """Tests for SelectionContext dataclass.""" def test_creates_with_required_fields(self): """SelectionContext can be created with required fields.""" ctx = SelectionContext( filter="species:duck", resolved_ids=["id1", "id2"], roster_hash="abc123", ts_utc=1000000, from_location_id=None, ) assert ctx.filter == "species:duck" assert ctx.resolved_ids == ["id1", "id2"] assert ctx.roster_hash == "abc123" assert ctx.ts_utc == 1000000 assert ctx.from_location_id is None assert ctx.confirmed is False assert ctx.resolver_version == "v1" def test_creates_with_optional_fields(self): """SelectionContext can be created with optional fields.""" ctx = SelectionContext( filter="", resolved_ids=["id1"], roster_hash="abc", ts_utc=1000, from_location_id="loc123", confirmed=True, resolver_version="v1", ) assert ctx.from_location_id == "loc123" assert ctx.confirmed is True class TestSelectionDiffModel: """Tests for SelectionDiff dataclass.""" def test_creates_with_all_fields(self): """SelectionDiff can be created with all fields.""" diff = SelectionDiff( added=["id3", "id4"], removed=["id1"], server_count=3, client_count=2, ) assert diff.added == ["id3", "id4"] assert diff.removed == ["id1"] assert diff.server_count == 3 assert diff.client_count == 2 # ============================================================================ # Tests for validate_selection - hash match # ============================================================================ class TestValidateSelectionHashMatch: """Tests for validate_selection when hashes match.""" def test_returns_valid_when_hashes_match(self, seeded_db, animal_service, strip1_location_id): """validate_selection returns valid=True when hashes match.""" # Create a cohort payload = make_cohort_payload(strip1_location_id, count=3) ts_utc = int(time.time() * 1000) animal_service.create_cohort(payload, ts_utc, "test_user") # Resolve at same timestamp to get correct hash filter_ast = parse_filter("species:duck") resolution = resolve_filter(seeded_db, filter_ast, ts_utc) # Create context with matching hash ctx = SelectionContext( filter="species:duck", resolved_ids=resolution.animal_ids, roster_hash=resolution.roster_hash, ts_utc=ts_utc, from_location_id=None, ) result = validate_selection(seeded_db, ctx) assert result.valid is True assert result.resolved_ids == resolution.animal_ids assert result.roster_hash == resolution.roster_hash assert result.diff is None def test_returns_valid_with_empty_filter(self, seeded_db, animal_service, strip1_location_id): """validate_selection works with empty filter (match all).""" # Create a cohort payload = make_cohort_payload(strip1_location_id, count=2) ts_utc = int(time.time() * 1000) animal_service.create_cohort(payload, ts_utc, "test_user") # Resolve with empty filter filter_ast = parse_filter("") resolution = resolve_filter(seeded_db, filter_ast, ts_utc) ctx = SelectionContext( filter="", resolved_ids=resolution.animal_ids, roster_hash=resolution.roster_hash, ts_utc=ts_utc, from_location_id=None, ) result = validate_selection(seeded_db, ctx) assert result.valid is True assert result.diff is None # ============================================================================ # Tests for validate_selection - hash mismatch # ============================================================================ class TestValidateSelectionHashMismatch: """Tests for validate_selection when hashes don't match.""" def test_returns_invalid_when_animal_added(self, seeded_db, animal_service, strip1_location_id): """validate_selection returns valid=False when new animal was added.""" # Create initial cohort ts_before = int(time.time() * 1000) payload1 = make_cohort_payload(strip1_location_id, count=2) animal_service.create_cohort(payload1, ts_before, "test_user") # Client resolves at ts_before filter_ast = parse_filter("species:duck") client_resolution = resolve_filter(seeded_db, filter_ast, ts_before) # Add another animal after client resolution ts_after = ts_before + 1000 payload2 = make_cohort_payload(strip1_location_id, count=1) event2 = animal_service.create_cohort(payload2, ts_after, "test_user") new_animal_id = event2.entity_refs["animal_ids"][0] # Create context with old hash but at ts_after ctx = SelectionContext( filter="species:duck", resolved_ids=client_resolution.animal_ids, roster_hash=client_resolution.roster_hash, ts_utc=ts_after, from_location_id=None, ) result = validate_selection(seeded_db, ctx) assert result.valid is False assert result.diff is not None assert new_animal_id in result.diff.added assert result.diff.removed == [] assert result.diff.server_count == 3 assert result.diff.client_count == 2 def test_returns_invalid_when_animal_moved_away( self, seeded_db, animal_service, strip1_location_id, strip2_location_id ): """validate_selection returns valid=False when animal moved to different location.""" # Create cohort at Strip 1 ts_create = int(time.time() * 1000) payload = make_cohort_payload(strip1_location_id, count=3) event = animal_service.create_cohort(payload, ts_create, "test_user") animal_ids = event.entity_refs["animal_ids"] # Client resolves Strip 1 filter filter_ast = parse_filter("location:'Strip 1'") client_resolution = resolve_filter(seeded_db, filter_ast, ts_create) # Move one animal to Strip 2 ts_move = ts_create + 1000 move_payload = AnimalMovedPayload( resolved_ids=[animal_ids[0]], from_location_id=strip1_location_id, to_location_id=strip2_location_id, ) animal_service.move_animals(move_payload, ts_move, "test_user") # Create context with old hash but at ts_move ctx = SelectionContext( filter="location:'Strip 1'", resolved_ids=client_resolution.animal_ids, roster_hash=client_resolution.roster_hash, ts_utc=ts_move, from_location_id=None, ) result = validate_selection(seeded_db, ctx) assert result.valid is False assert result.diff is not None assert result.diff.added == [] assert animal_ids[0] in result.diff.removed assert result.diff.server_count == 2 assert result.diff.client_count == 3 class TestValidateSelectionDiffCorrectness: """Tests for accurate diff computation.""" def test_diff_includes_both_added_and_removed( self, seeded_db, animal_service, strip1_location_id, strip2_location_id ): """validate_selection diff correctly shows both added and removed.""" # Create cohorts at both locations ts_create = int(time.time() * 1000) # 2 ducks at Strip 1 payload1 = make_cohort_payload(strip1_location_id, count=2, species="duck") event1 = animal_service.create_cohort(payload1, ts_create, "test_user") strip1_ducks = event1.entity_refs["animal_ids"] # 1 goose at Strip 1 (won't match species:duck filter) payload2 = make_cohort_payload(strip1_location_id, count=1, species="goose") event2 = animal_service.create_cohort(payload2, ts_create, "test_user") strip1_goose = event2.entity_refs["animal_ids"][0] # Client thinks they resolved "location:'Strip 1'" but used wrong IDs # (simulating a bug where client included goose but excluded one duck) wrong_client_ids = [strip1_ducks[0], strip1_goose] # missing strip1_ducks[1] wrong_hash = compute_roster_hash(wrong_client_ids) ctx = SelectionContext( filter="species:duck", resolved_ids=wrong_client_ids, roster_hash=wrong_hash, ts_utc=ts_create, from_location_id=None, ) result = validate_selection(seeded_db, ctx) assert result.valid is False assert result.diff is not None # Server resolved both ducks, but client had duck[0] + goose # So: added = duck[1], removed = goose assert strip1_ducks[1] in result.diff.added assert strip1_goose in result.diff.removed # ============================================================================ # Tests for validate_selection - confirmed bypass # ============================================================================ class TestValidateSelectionConfirmedBypass: """Tests for confirmed=True bypassing mismatch.""" def test_confirmed_true_returns_valid_despite_mismatch( self, seeded_db, animal_service, strip1_location_id ): """validate_selection returns valid=True when confirmed=True despite mismatch.""" # Create cohort ts_before = int(time.time() * 1000) payload1 = make_cohort_payload(strip1_location_id, count=2) animal_service.create_cohort(payload1, ts_before, "test_user") # Client resolves at ts_before filter_ast = parse_filter("species:duck") client_resolution = resolve_filter(seeded_db, filter_ast, ts_before) # Add another animal ts_after = ts_before + 1000 payload2 = make_cohort_payload(strip1_location_id, count=1) animal_service.create_cohort(payload2, ts_after, "test_user") # Create context with old hash, new timestamp, but confirmed=True ctx = SelectionContext( filter="species:duck", resolved_ids=client_resolution.animal_ids, roster_hash=client_resolution.roster_hash, ts_utc=ts_after, from_location_id=None, confirmed=True, ) result = validate_selection(seeded_db, ctx) assert result.valid is True # Uses client's IDs when confirmed assert result.resolved_ids == client_resolution.animal_ids # Still includes diff for transparency assert result.diff is not None assert result.diff.server_count == 3 assert result.diff.client_count == 2 # ============================================================================ # Tests for from_location_id in hash # ============================================================================ class TestValidateSelectionFromLocationId: """Tests for from_location_id inclusion in hash comparison.""" def test_from_location_id_included_in_hash_comparison( self, seeded_db, animal_service, strip1_location_id, strip2_location_id ): """validate_selection includes from_location_id in hash when provided.""" # Create cohort at Strip 1 ts_utc = int(time.time() * 1000) payload = make_cohort_payload(strip1_location_id, count=2) event = animal_service.create_cohort(payload, ts_utc, "test_user") animal_ids = event.entity_refs["animal_ids"] # Compute hash with from_location_id hash_with_from = compute_roster_hash(sorted(animal_ids), strip1_location_id) # Context with from_location_id should match hash_with_from ctx_with_from = SelectionContext( filter="", resolved_ids=sorted(animal_ids), roster_hash=hash_with_from, ts_utc=ts_utc, from_location_id=strip1_location_id, ) result = validate_selection(seeded_db, ctx_with_from) assert result.valid is True assert result.roster_hash == hash_with_from def test_mismatched_from_location_id_causes_invalid( self, seeded_db, animal_service, strip1_location_id, strip2_location_id ): """validate_selection returns invalid if from_location_id differs.""" # Create cohort at Strip 1 ts_utc = int(time.time() * 1000) payload = make_cohort_payload(strip1_location_id, count=2) event = animal_service.create_cohort(payload, ts_utc, "test_user") animal_ids = event.entity_refs["animal_ids"] # Client computed hash without from_location_id hash_without_from = compute_roster_hash(sorted(animal_ids)) # But context says from_location_id is set (hash should include it) ctx = SelectionContext( filter="", resolved_ids=sorted(animal_ids), roster_hash=hash_without_from, # Wrong hash ts_utc=ts_utc, from_location_id=strip1_location_id, # Server will include this ) result = validate_selection(seeded_db, ctx) # Hash mismatch because server includes from_location_id assert result.valid is False # ============================================================================ # Tests for SelectionMismatchError # ============================================================================ class TestSelectionMismatchError: """Tests for SelectionMismatchError exception.""" def test_stores_validation_result(self): """SelectionMismatchError stores the validation result.""" diff = SelectionDiff(added=["id1"], removed=[], server_count=2, client_count=1) result = SelectionValidationResult( valid=False, resolved_ids=["id1", "id2"], roster_hash="abc", diff=diff, ) error = SelectionMismatchError(result) assert error.result is result assert error.result.diff is diff