diff --git a/src/animaltrack/projections/animal_registry.py b/src/animaltrack/projections/animal_registry.py index 175bf9e..ca7bb0e 100644 --- a/src/animaltrack/projections/animal_registry.py +++ b/src/animaltrack/projections/animal_registry.py @@ -3,7 +3,11 @@ from typing import Any -from animaltrack.events.types import ANIMAL_COHORT_CREATED, ANIMAL_MOVED +from animaltrack.events.types import ( + ANIMAL_ATTRIBUTES_UPDATED, + ANIMAL_COHORT_CREATED, + ANIMAL_MOVED, +) from animaltrack.models.events import Event from animaltrack.projections.base import Projection @@ -26,7 +30,7 @@ class AnimalRegistryProjection(Projection): def get_event_types(self) -> list[str]: """Return the event types this projection handles.""" - return [ANIMAL_COHORT_CREATED, ANIMAL_MOVED] + return [ANIMAL_COHORT_CREATED, ANIMAL_MOVED, ANIMAL_ATTRIBUTES_UPDATED] def apply(self, event: Event) -> None: """Apply an event to update registry tables.""" @@ -34,6 +38,8 @@ class AnimalRegistryProjection(Projection): self._apply_cohort_created(event) elif event.type == ANIMAL_MOVED: self._apply_animal_moved(event) + elif event.type == ANIMAL_ATTRIBUTES_UPDATED: + self._apply_attributes_updated(event) def revert(self, event: Event) -> None: """Revert an event from registry tables.""" @@ -41,6 +47,8 @@ class AnimalRegistryProjection(Projection): self._revert_cohort_created(event) elif event.type == ANIMAL_MOVED: self._revert_animal_moved(event) + elif event.type == ANIMAL_ATTRIBUTES_UPDATED: + self._revert_attributes_updated(event) def _apply_cohort_created(self, event: Event) -> None: """Create animals in registry from cohort event. @@ -187,3 +195,110 @@ class AnimalRegistryProjection(Projection): """, (from_location_id, animal_id), ) + + def _apply_attributes_updated(self, event: Event) -> None: + """Update animal attributes in registry tables. + + For each animal, updates the changed attributes in both + animal_registry and live_animals_by_location tables. + """ + animal_ids = event.entity_refs.get("animal_ids", []) + changed_attrs = event.entity_refs.get("changed_attrs", {}) + ts_utc = event.ts_utc + + for animal_id in animal_ids: + animal_changes = changed_attrs.get(animal_id, {}) + if not animal_changes: + continue + + # Build dynamic SQL for animal_registry + set_clauses = ["last_event_utc = ?"] + values = [ts_utc] + + for attr, values_dict in animal_changes.items(): + set_clauses.append(f"{attr} = ?") + values.append(values_dict["new"]) + + values.append(animal_id) + + self.db.execute( + f""" + UPDATE animal_registry + SET {", ".join(set_clauses)} + WHERE animal_id = ? + """, + values, + ) + + # Build dynamic SQL for live_animals_by_location + set_clauses_live = [] + values_live = [] + + for attr, values_dict in animal_changes.items(): + set_clauses_live.append(f"{attr} = ?") + values_live.append(values_dict["new"]) + + values_live.append(animal_id) + + if set_clauses_live: + self.db.execute( + f""" + UPDATE live_animals_by_location + SET {", ".join(set_clauses_live)} + WHERE animal_id = ? + """, + values_live, + ) + + def _revert_attributes_updated(self, event: Event) -> None: + """Revert attribute updates, restoring previous values. + + Uses changed_attrs from entity_refs to restore + the previous attribute values. + """ + animal_ids = event.entity_refs.get("animal_ids", []) + changed_attrs = event.entity_refs.get("changed_attrs", {}) + + for animal_id in animal_ids: + animal_changes = changed_attrs.get(animal_id, {}) + if not animal_changes: + continue + + # Build dynamic SQL for animal_registry (restore old values) + set_clauses = [] + values = [] + + for attr, values_dict in animal_changes.items(): + set_clauses.append(f"{attr} = ?") + values.append(values_dict["old"]) + + values.append(animal_id) + + self.db.execute( + f""" + UPDATE animal_registry + SET {", ".join(set_clauses)} + WHERE animal_id = ? + """, + values, + ) + + # Build dynamic SQL for live_animals_by_location + set_clauses_live = [] + values_live = [] + + for attr, values_dict in animal_changes.items(): + set_clauses_live.append(f"{attr} = ?") + values_live.append(values_dict["old"]) + + values_live.append(animal_id) + + if set_clauses_live: + self.db.execute( + f""" + UPDATE live_animals_by_location + SET {", ".join(set_clauses_live)} + WHERE animal_id = ? + """, + values_live, + ) diff --git a/src/animaltrack/projections/event_animals.py b/src/animaltrack/projections/event_animals.py index 626ace9..0fc46f9 100644 --- a/src/animaltrack/projections/event_animals.py +++ b/src/animaltrack/projections/event_animals.py @@ -3,7 +3,11 @@ from typing import Any -from animaltrack.events.types import ANIMAL_COHORT_CREATED, ANIMAL_MOVED +from animaltrack.events.types import ( + ANIMAL_ATTRIBUTES_UPDATED, + ANIMAL_COHORT_CREATED, + ANIMAL_MOVED, +) from animaltrack.models.events import Event from animaltrack.projections.base import Projection @@ -26,7 +30,7 @@ class EventAnimalsProjection(Projection): def get_event_types(self) -> list[str]: """Return the event types this projection handles.""" - return [ANIMAL_COHORT_CREATED, ANIMAL_MOVED] + return [ANIMAL_COHORT_CREATED, ANIMAL_MOVED, ANIMAL_ATTRIBUTES_UPDATED] def apply(self, event: Event) -> None: """Link event to affected animals.""" diff --git a/src/animaltrack/projections/intervals.py b/src/animaltrack/projections/intervals.py index 170dc68..db8a698 100644 --- a/src/animaltrack/projections/intervals.py +++ b/src/animaltrack/projections/intervals.py @@ -3,7 +3,11 @@ from typing import Any -from animaltrack.events.types import ANIMAL_COHORT_CREATED, ANIMAL_MOVED +from animaltrack.events.types import ( + ANIMAL_ATTRIBUTES_UPDATED, + ANIMAL_COHORT_CREATED, + ANIMAL_MOVED, +) from animaltrack.models.events import Event from animaltrack.projections.base import Projection @@ -29,7 +33,7 @@ class IntervalProjection(Projection): def get_event_types(self) -> list[str]: """Return the event types this projection handles.""" - return [ANIMAL_COHORT_CREATED, ANIMAL_MOVED] + return [ANIMAL_COHORT_CREATED, ANIMAL_MOVED, ANIMAL_ATTRIBUTES_UPDATED] def apply(self, event: Event) -> None: """Create intervals for event.""" @@ -37,6 +41,8 @@ class IntervalProjection(Projection): self._apply_cohort_created(event) elif event.type == ANIMAL_MOVED: self._apply_animal_moved(event) + elif event.type == ANIMAL_ATTRIBUTES_UPDATED: + self._apply_attributes_updated(event) def revert(self, event: Event) -> None: """Remove intervals created by event.""" @@ -44,6 +50,8 @@ class IntervalProjection(Projection): self._revert_cohort_created(event) elif event.type == ANIMAL_MOVED: self._revert_animal_moved(event) + elif event.type == ANIMAL_ATTRIBUTES_UPDATED: + self._revert_attributes_updated(event) def _apply_cohort_created(self, event: Event) -> None: """Create initial intervals for new animals. @@ -170,3 +178,79 @@ class IntervalProjection(Projection): """, (animal_id, from_location_id, ts_utc), ) + + def _apply_attributes_updated(self, event: Event) -> None: + """Close old attribute intervals and open new ones for changed attrs. + + For each animal: + - For each changed attribute in the payload set: + - Close the current open interval with end_utc=ts_utc + - Create a new open interval with the new value + - Only create intervals for actually changed values + """ + animal_ids = event.entity_refs.get("animal_ids", []) + changed_attrs = event.entity_refs.get("changed_attrs", {}) + ts_utc = event.ts_utc + + for animal_id in animal_ids: + animal_changes = changed_attrs.get(animal_id, {}) + for attr, values in animal_changes.items(): + old_value = values["old"] + new_value = values["new"] + + # Close the old interval + self.db.execute( + """ + UPDATE animal_attr_intervals + SET end_utc = ? + WHERE animal_id = ? AND attr = ? AND value = ? AND end_utc IS NULL + """, + (ts_utc, animal_id, attr, old_value), + ) + + # Create new interval + self.db.execute( + """ + INSERT INTO animal_attr_intervals + (animal_id, attr, value, start_utc, end_utc) + VALUES (?, ?, ?, ?, NULL) + """, + (animal_id, attr, new_value, ts_utc), + ) + + def _revert_attributes_updated(self, event: Event) -> None: + """Revert attributes by removing new intervals and reopening old ones. + + For each animal: + - For each changed attribute: + - Delete the new interval + - Reopen the old interval by setting end_utc=NULL + """ + animal_ids = event.entity_refs.get("animal_ids", []) + changed_attrs = event.entity_refs.get("changed_attrs", {}) + ts_utc = event.ts_utc + + for animal_id in animal_ids: + animal_changes = changed_attrs.get(animal_id, {}) + for attr, values in animal_changes.items(): + old_value = values["old"] + new_value = values["new"] + + # Delete the new interval + self.db.execute( + """ + DELETE FROM animal_attr_intervals + WHERE animal_id = ? AND attr = ? AND value = ? AND start_utc = ? + """, + (animal_id, attr, new_value, ts_utc), + ) + + # Reopen the old interval + self.db.execute( + """ + UPDATE animal_attr_intervals + SET end_utc = NULL + WHERE animal_id = ? AND attr = ? AND value = ? AND end_utc = ? + """, + (animal_id, attr, old_value, ts_utc), + ) diff --git a/src/animaltrack/services/animal.py b/src/animaltrack/services/animal.py index f40f8dd..7fa7a3d 100644 --- a/src/animaltrack/services/animal.py +++ b/src/animaltrack/services/animal.py @@ -4,10 +4,18 @@ from typing import Any from animaltrack.db import transaction -from animaltrack.events.payloads import AnimalCohortCreatedPayload, AnimalMovedPayload +from animaltrack.events.payloads import ( + AnimalAttributesUpdatedPayload, + AnimalCohortCreatedPayload, + AnimalMovedPayload, +) from animaltrack.events.processor import process_event from animaltrack.events.store import EventStore -from animaltrack.events.types import ANIMAL_COHORT_CREATED, ANIMAL_MOVED +from animaltrack.events.types import ( + ANIMAL_ATTRIBUTES_UPDATED, + ANIMAL_COHORT_CREATED, + ANIMAL_MOVED, +) from animaltrack.id_gen import generate_id from animaltrack.models.events import Event from animaltrack.projections import ProjectionRegistry @@ -246,3 +254,119 @@ class AnimalService: raise ValidationError(msg) return from_location_id + + def update_attributes( + self, + payload: AnimalAttributesUpdatedPayload, + ts_utc: int, + actor: str, + nonce: str | None = None, + route: str | None = None, + ) -> Event: + """Update attributes for animals. + + Creates an AnimalAttributesUpdated event and processes it through + all registered projections. All operations happen atomically + within a transaction. + + Args: + payload: Validated attributes update payload with resolved_ids and set. + ts_utc: Timestamp in milliseconds since epoch. + actor: The user performing the update. + nonce: Optional idempotency nonce. + route: Required if nonce provided. + + Returns: + The created event. + + Raises: + ValidationError: If validation fails. + """ + # Validate at least one attribute is being set + attr_set = payload.set + if attr_set.sex is None and attr_set.life_stage is None and attr_set.repro_status is None: + msg = "Must provide at least one attribute to update" + raise ValidationError(msg) + + # Validate all animals exist and compute changes + changed_attrs = self._compute_attribute_changes(payload.resolved_ids, attr_set) + + # Build entity_refs with animal IDs and changed_attrs + entity_refs = { + "animal_ids": payload.resolved_ids, + "changed_attrs": changed_attrs, + } + + with transaction(self.db): + # Append event to store + event = self.event_store.append_event( + event_type=ANIMAL_ATTRIBUTES_UPDATED, + ts_utc=ts_utc, + actor=actor, + entity_refs=entity_refs, + payload=payload.model_dump(), + nonce=nonce, + route=route, + ) + + # Process event through projections + process_event(event, self.registry) + + return event + + def _compute_attribute_changes( + self, + animal_ids: list[str], + attr_set: Any, + ) -> dict[str, dict[str, dict[str, str]]]: + """Compute which attributes are actually changing for each animal. + + Args: + animal_ids: List of animal IDs to check. + attr_set: AttributeSet with new values. + + Returns: + Dict mapping animal_id -> {attr -> {"old": ..., "new": ...}} + Only includes attributes that are actually changing. + + Raises: + ValidationError: If any animal doesn't exist. + """ + changed_attrs: dict[str, dict[str, dict[str, str]]] = {} + + for animal_id in animal_ids: + row = self.db.execute( + "SELECT sex, life_stage, repro_status FROM animal_registry WHERE animal_id = ?", + (animal_id,), + ).fetchone() + + if row is None: + msg = f"Animal {animal_id} not found" + raise ValidationError(msg) + + current_sex, current_life_stage, current_repro_status = row + animal_changes: dict[str, dict[str, str]] = {} + + # Check each attribute if it's set and different + if attr_set.sex is not None and attr_set.sex.value != current_sex: + animal_changes["sex"] = {"old": current_sex, "new": attr_set.sex.value} + + if attr_set.life_stage is not None and attr_set.life_stage.value != current_life_stage: + animal_changes["life_stage"] = { + "old": current_life_stage, + "new": attr_set.life_stage.value, + } + + if ( + attr_set.repro_status is not None + and attr_set.repro_status.value != current_repro_status + ): + animal_changes["repro_status"] = { + "old": current_repro_status, + "new": attr_set.repro_status.value, + } + + if animal_changes: + changed_attrs[animal_id] = animal_changes + + return changed_attrs diff --git a/tests/test_service_animal.py b/tests/test_service_animal.py index 4ca9236..46a9a1d 100644 --- a/tests/test_service_animal.py +++ b/tests/test_service_animal.py @@ -492,3 +492,251 @@ class TestAnimalServiceMoveValidation: with pytest.raises(ValidationError, match="not found"): animal_service.move_animals(move_payload, int(time.time() * 1000), "test_user") + + +# ============================================================================= +# update_attributes Tests +# ============================================================================= + + +def make_attrs_payload( + resolved_ids: list[str], + sex: str | None = None, + life_stage: str | None = None, + repro_status: str | None = None, +): + """Create an attributes update payload for testing.""" + from animaltrack.events.payloads import AnimalAttributesUpdatedPayload, AttributeSet + + attr_set = AttributeSet(sex=sex, life_stage=life_stage, repro_status=repro_status) + return AnimalAttributesUpdatedPayload( + resolved_ids=resolved_ids, + set=attr_set, + ) + + +class TestAnimalServiceUpdateAttributes: + """Tests for update_attributes().""" + + def test_creates_animal_attributes_updated_event( + self, seeded_db, animal_service, valid_location_id + ): + """update_attributes creates an AnimalAttributesUpdated event.""" + from animaltrack.events.types import ANIMAL_ATTRIBUTES_UPDATED + + # First create a cohort + cohort_payload = make_payload(valid_location_id, count=2) + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_ids = cohort_event.entity_refs["animal_ids"] + + # Update attributes + attrs_payload = make_attrs_payload(animal_ids, sex="female") + attrs_ts = ts_utc + 1000 + attrs_event = animal_service.update_attributes(attrs_payload, attrs_ts, "test_user") + + assert attrs_event.type == ANIMAL_ATTRIBUTES_UPDATED + assert attrs_event.actor == "test_user" + assert attrs_event.ts_utc == attrs_ts + + def test_event_has_animal_ids_in_entity_refs( + self, seeded_db, animal_service, valid_location_id + ): + """Event entity_refs contains animal_ids list.""" + cohort_payload = make_payload(valid_location_id, count=3) + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_ids = cohort_event.entity_refs["animal_ids"] + + attrs_payload = make_attrs_payload(animal_ids, life_stage="juvenile") + attrs_event = animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + assert "animal_ids" in attrs_event.entity_refs + assert set(attrs_event.entity_refs["animal_ids"]) == set(animal_ids) + + def test_updates_sex_in_registry(self, seeded_db, animal_service, valid_location_id): + """Animals have updated sex in animal_registry table.""" + cohort_payload = make_payload(valid_location_id, count=2, sex="unknown") + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_ids = cohort_event.entity_refs["animal_ids"] + + attrs_payload = make_attrs_payload(animal_ids, sex="male") + animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + # Check each animal has updated sex + for animal_id in animal_ids: + row = seeded_db.execute( + "SELECT sex FROM animal_registry WHERE animal_id = ?", + (animal_id,), + ).fetchone() + assert row[0] == "male" + + def test_updates_life_stage_in_registry(self, seeded_db, animal_service, valid_location_id): + """Animals have updated life_stage in animal_registry table.""" + cohort_payload = make_payload(valid_location_id, count=1, life_stage="juvenile") + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_ids = cohort_event.entity_refs["animal_ids"] + + attrs_payload = make_attrs_payload(animal_ids, life_stage="adult") + animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + row = seeded_db.execute( + "SELECT life_stage FROM animal_registry WHERE animal_id = ?", + (animal_ids[0],), + ).fetchone() + assert row[0] == "adult" + + def test_updates_repro_status_in_registry(self, seeded_db, animal_service, valid_location_id): + """Animals have updated repro_status in animal_registry table.""" + cohort_payload = make_payload(valid_location_id, count=1) + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_ids = cohort_event.entity_refs["animal_ids"] + + attrs_payload = make_attrs_payload(animal_ids, repro_status="intact") + animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + row = seeded_db.execute( + "SELECT repro_status FROM animal_registry WHERE animal_id = ?", + (animal_ids[0],), + ).fetchone() + assert row[0] == "intact" + + def test_creates_attr_intervals_for_changed_attrs_only( + self, seeded_db, animal_service, valid_location_id + ): + """Only changed attrs create new intervals.""" + # Create cohort with sex=unknown, life_stage=adult + cohort_payload = make_payload(valid_location_id, count=1, sex="unknown", life_stage="adult") + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_id = cohort_event.entity_refs["animal_ids"][0] + + # Initial intervals: sex, life_stage, repro_status, status = 4 + initial_count = seeded_db.execute( + "SELECT COUNT(*) FROM animal_attr_intervals WHERE animal_id = ?", + (animal_id,), + ).fetchone()[0] + assert initial_count == 4 + + # Update only sex (life_stage stays the same) + attrs_payload = make_attrs_payload([animal_id], sex="female") + animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + # Should have 5 intervals: 4 initial + 1 new for sex (old one closed) + new_count = seeded_db.execute( + "SELECT COUNT(*) FROM animal_attr_intervals WHERE animal_id = ?", + (animal_id,), + ).fetchone()[0] + assert new_count == 5 + + # Verify old sex interval was closed + closed_sex = seeded_db.execute( + """SELECT end_utc FROM animal_attr_intervals + WHERE animal_id = ? AND attr = 'sex' AND value = 'unknown'""", + (animal_id,), + ).fetchone() + assert closed_sex[0] == ts_utc + 1000 + + # Verify new sex interval is open + open_sex = seeded_db.execute( + """SELECT end_utc FROM animal_attr_intervals + WHERE animal_id = ? AND attr = 'sex' AND value = 'female'""", + (animal_id,), + ).fetchone() + assert open_sex[0] is None + + def test_event_animal_links_created(self, seeded_db, animal_service, valid_location_id): + """Event-animal links are created for attrs event.""" + cohort_payload = make_payload(valid_location_id, count=4) + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_ids = cohort_event.entity_refs["animal_ids"] + + attrs_payload = make_attrs_payload(animal_ids, sex="female") + attrs_event = animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + # Check event_animals has 4 rows for the attrs event + count = seeded_db.execute( + "SELECT COUNT(*) FROM event_animals WHERE event_id = ?", + (attrs_event.id,), + ).fetchone()[0] + assert count == 4 + + def test_updates_multiple_attrs_at_once(self, seeded_db, animal_service, valid_location_id): + """Multiple attributes can be updated at once.""" + cohort_payload = make_payload( + valid_location_id, count=1, sex="unknown", life_stage="hatchling" + ) + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_id = cohort_event.entity_refs["animal_ids"][0] + + # Update both sex and life_stage + attrs_payload = make_attrs_payload([animal_id], sex="female", life_stage="juvenile") + animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + # Check both were updated in registry + row = seeded_db.execute( + "SELECT sex, life_stage FROM animal_registry WHERE animal_id = ?", + (animal_id,), + ).fetchone() + assert row[0] == "female" + assert row[1] == "juvenile" + + # Should have 6 intervals: 4 initial + 2 new (sex + life_stage) + count = seeded_db.execute( + "SELECT COUNT(*) FROM animal_attr_intervals WHERE animal_id = ?", + (animal_id,), + ).fetchone()[0] + assert count == 6 + + def test_noop_when_value_unchanged(self, seeded_db, animal_service, valid_location_id): + """No new intervals created when value is already the same.""" + cohort_payload = make_payload(valid_location_id, count=1, sex="female") + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_id = cohort_event.entity_refs["animal_ids"][0] + + initial_count = seeded_db.execute( + "SELECT COUNT(*) FROM animal_attr_intervals WHERE animal_id = ?", + (animal_id,), + ).fetchone()[0] + + # Update sex to same value + attrs_payload = make_attrs_payload([animal_id], sex="female") + animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user") + + # Should have same number of intervals + new_count = seeded_db.execute( + "SELECT COUNT(*) FROM animal_attr_intervals WHERE animal_id = ?", + (animal_id,), + ).fetchone()[0] + assert new_count == initial_count + + +class TestAnimalServiceUpdateAttributesValidation: + """Tests for update_attributes() validation.""" + + def test_rejects_nonexistent_animal(self, seeded_db, animal_service): + """Raises ValidationError for non-existent animal_id.""" + fake_animal_id = "01ARZ3NDEKTSV4RRFFQ69G5XXX" + attrs_payload = make_attrs_payload([fake_animal_id], sex="female") + + with pytest.raises(ValidationError, match="not found"): + animal_service.update_attributes(attrs_payload, int(time.time() * 1000), "test_user") + + def test_rejects_empty_attribute_set(self, seeded_db, animal_service, valid_location_id): + """Raises ValidationError when no attributes are being updated.""" + cohort_payload = make_payload(valid_location_id, count=1) + ts_utc = int(time.time() * 1000) + cohort_event = animal_service.create_cohort(cohort_payload, ts_utc, "test_user") + animal_ids = cohort_event.entity_refs["animal_ids"] + + # Create payload with no attrs set + attrs_payload = make_attrs_payload(animal_ids) + + with pytest.raises(ValidationError, match="at least one attribute"): + animal_service.update_attributes(attrs_payload, ts_utc + 1000, "test_user")