# ABOUTME: Tests for EventAnimalsProjection. # ABOUTME: Validates event_animals link table updates on animal events. from animaltrack.events.types import ANIMAL_COHORT_CREATED from animaltrack.models.events import Event from animaltrack.projections.event_animals import EventAnimalsProjection def make_cohort_event( event_id: str, animal_ids: list[str], location_id: str = "01ARZ3NDEKTSV4RRFFQ69G5FAV", ts_utc: int = 1704067200000, ) -> Event: """Create a test AnimalCohortCreated event.""" return Event( id=event_id, type=ANIMAL_COHORT_CREATED, ts_utc=ts_utc, actor="test_user", entity_refs={ "location_id": location_id, "animal_ids": animal_ids, }, payload={ "species": "duck", "count": len(animal_ids), "life_stage": "adult", "sex": "unknown", "location_id": location_id, "origin": "purchased", "notes": None, }, version=1, ) class TestEventAnimalsProjectionEventTypes: """Tests for get_event_types method.""" def test_handles_animal_cohort_created(self, seeded_db): """Projection handles AnimalCohortCreated event type.""" projection = EventAnimalsProjection(seeded_db) assert ANIMAL_COHORT_CREATED in projection.get_event_types() class TestEventAnimalsProjectionApply: """Tests for apply().""" def test_creates_event_animal_link_for_each_animal(self, seeded_db): """Apply creates one row in event_animals per animal_id.""" row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() location_id = row[0] animal_ids = [ "01ARZ3NDEKTSV4RRFFQ69G5A01", "01ARZ3NDEKTSV4RRFFQ69G5A02", "01ARZ3NDEKTSV4RRFFQ69G5A03", ] event_id = "01ARZ3NDEKTSV4RRFFQ69G5001" projection = EventAnimalsProjection(seeded_db) event = make_cohort_event(event_id, animal_ids, location_id=location_id) projection.apply(event) # Check event_animals has 3 rows count = seeded_db.execute("SELECT COUNT(*) FROM event_animals").fetchone()[0] assert count == 3 # Check each animal_id is linked for animal_id in animal_ids: row = seeded_db.execute( "SELECT event_id FROM event_animals WHERE animal_id = ?", (animal_id,), ).fetchone() assert row is not None assert row[0] == event_id def test_event_animal_link_has_correct_event_id(self, seeded_db): """Event animal link has correct event_id.""" row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() location_id = row[0] animal_ids = ["01ARZ3NDEKTSV4RRFFQ69G5A01"] event_id = "01ARZ3NDEKTSV4RRFFQ69G5001" projection = EventAnimalsProjection(seeded_db) event = make_cohort_event(event_id, animal_ids, location_id=location_id) projection.apply(event) row = seeded_db.execute( "SELECT event_id FROM event_animals WHERE animal_id = ?", (animal_ids[0],), ).fetchone() assert row[0] == event_id def test_event_animal_link_has_correct_ts_utc(self, seeded_db): """Event animal link has correct ts_utc.""" row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() location_id = row[0] animal_ids = ["01ARZ3NDEKTSV4RRFFQ69G5A01"] event_id = "01ARZ3NDEKTSV4RRFFQ69G5001" ts_utc = 1704067200000 projection = EventAnimalsProjection(seeded_db) event = make_cohort_event(event_id, animal_ids, location_id=location_id, ts_utc=ts_utc) projection.apply(event) row = seeded_db.execute( "SELECT ts_utc FROM event_animals WHERE animal_id = ?", (animal_ids[0],), ).fetchone() assert row[0] == ts_utc class TestEventAnimalsProjectionRevert: """Tests for revert().""" def test_removes_event_animal_links(self, seeded_db): """Revert deletes rows from event_animals.""" row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() location_id = row[0] animal_ids = [ "01ARZ3NDEKTSV4RRFFQ69G5A01", "01ARZ3NDEKTSV4RRFFQ69G5A02", ] event_id = "01ARZ3NDEKTSV4RRFFQ69G5001" projection = EventAnimalsProjection(seeded_db) event = make_cohort_event(event_id, animal_ids, location_id=location_id) projection.apply(event) # Verify rows exist count = seeded_db.execute("SELECT COUNT(*) FROM event_animals").fetchone()[0] assert count == 2 # Revert projection.revert(event) # Verify rows removed count = seeded_db.execute("SELECT COUNT(*) FROM event_animals").fetchone()[0] assert count == 0 def test_revert_only_affects_specific_event(self, seeded_db): """Revert only removes links for the specific event.""" row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() location_id = row[0] # Create first event animal_ids_1 = ["01ARZ3NDEKTSV4RRFFQ69G5A01"] event_id_1 = "01ARZ3NDEKTSV4RRFFQ69G5001" projection = EventAnimalsProjection(seeded_db) event1 = make_cohort_event(event_id_1, animal_ids_1, location_id=location_id) projection.apply(event1) # Create second event animal_ids_2 = ["01ARZ3NDEKTSV4RRFFQ69G5A02"] event_id_2 = "01ARZ3NDEKTSV4RRFFQ69G5002" event2 = make_cohort_event( event_id_2, animal_ids_2, location_id=location_id, ts_utc=1704067300000 ) projection.apply(event2) # Verify both exist count = seeded_db.execute("SELECT COUNT(*) FROM event_animals").fetchone()[0] assert count == 2 # Revert only event1 projection.revert(event1) # Event2's link should still exist count = seeded_db.execute("SELECT COUNT(*) FROM event_animals").fetchone()[0] assert count == 1 row = seeded_db.execute("SELECT event_id FROM event_animals").fetchone() assert row[0] == event_id_2