Files
animaltrack/tests/test_selection_validation.py
Petru Paler e9d3f34994 feat: add selection validation with optimistic locking
Implements Step 5.3 - selection validation for optimistic locking:

- SelectionContext: holds client's filter, resolved_ids, roster_hash, ts_utc
- SelectionDiff: shows added/removed animals on mismatch
- SelectionValidationResult: validation result with diff if applicable
- validate_selection(): re-resolves at ts_utc, compares hashes, returns diff
- SelectionMismatchError: exception for unconfirmed mismatches

Tests cover: hash match, mismatch detection, diff correctness, confirmed bypass,
from_location_id in hash comparison.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-29 15:46:19 +00:00

459 lines
16 KiB
Python

# ABOUTME: Tests for selection validation with optimistic locking.
# ABOUTME: Tests hash comparison, diff computation, and confirmed bypass.
import time
import pytest
from animaltrack.events.payloads import AnimalCohortCreatedPayload, AnimalMovedPayload
from animaltrack.events.store import EventStore
from animaltrack.projections import ProjectionRegistry
from animaltrack.projections.animal_registry import AnimalRegistryProjection
from animaltrack.projections.event_animals import EventAnimalsProjection
from animaltrack.projections.intervals import IntervalProjection
from animaltrack.selection import compute_roster_hash, parse_filter, resolve_filter
from animaltrack.selection.validation import (
SelectionContext,
SelectionDiff,
SelectionMismatchError,
SelectionValidationResult,
validate_selection,
)
from animaltrack.services.animal import AnimalService
@pytest.fixture
def event_store(seeded_db):
"""Create an EventStore for testing."""
return EventStore(seeded_db)
@pytest.fixture
def projection_registry(seeded_db):
"""Create a ProjectionRegistry with animal projections registered."""
registry = ProjectionRegistry()
registry.register(AnimalRegistryProjection(seeded_db))
registry.register(EventAnimalsProjection(seeded_db))
registry.register(IntervalProjection(seeded_db))
return registry
@pytest.fixture
def animal_service(seeded_db, event_store, projection_registry):
"""Create an AnimalService for testing."""
return AnimalService(seeded_db, event_store, projection_registry)
@pytest.fixture
def strip1_location_id(seeded_db):
"""Get Strip 1 location ID from seeds."""
row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone()
return row[0]
@pytest.fixture
def strip2_location_id(seeded_db):
"""Get Strip 2 location ID from seeds."""
row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 2'").fetchone()
return row[0]
def make_cohort_payload(
location_id: str,
count: int = 3,
species: str = "duck",
sex: str = "unknown",
life_stage: str = "adult",
) -> AnimalCohortCreatedPayload:
"""Create a cohort payload for testing."""
return AnimalCohortCreatedPayload(
species=species,
count=count,
life_stage=life_stage,
sex=sex,
location_id=location_id,
origin="purchased",
)
# ============================================================================
# Tests for SelectionContext and SelectionDiff models
# ============================================================================
class TestSelectionContextModel:
"""Tests for SelectionContext dataclass."""
def test_creates_with_required_fields(self):
"""SelectionContext can be created with required fields."""
ctx = SelectionContext(
filter="species:duck",
resolved_ids=["id1", "id2"],
roster_hash="abc123",
ts_utc=1000000,
from_location_id=None,
)
assert ctx.filter == "species:duck"
assert ctx.resolved_ids == ["id1", "id2"]
assert ctx.roster_hash == "abc123"
assert ctx.ts_utc == 1000000
assert ctx.from_location_id is None
assert ctx.confirmed is False
assert ctx.resolver_version == "v1"
def test_creates_with_optional_fields(self):
"""SelectionContext can be created with optional fields."""
ctx = SelectionContext(
filter="",
resolved_ids=["id1"],
roster_hash="abc",
ts_utc=1000,
from_location_id="loc123",
confirmed=True,
resolver_version="v1",
)
assert ctx.from_location_id == "loc123"
assert ctx.confirmed is True
class TestSelectionDiffModel:
"""Tests for SelectionDiff dataclass."""
def test_creates_with_all_fields(self):
"""SelectionDiff can be created with all fields."""
diff = SelectionDiff(
added=["id3", "id4"],
removed=["id1"],
server_count=3,
client_count=2,
)
assert diff.added == ["id3", "id4"]
assert diff.removed == ["id1"]
assert diff.server_count == 3
assert diff.client_count == 2
# ============================================================================
# Tests for validate_selection - hash match
# ============================================================================
class TestValidateSelectionHashMatch:
"""Tests for validate_selection when hashes match."""
def test_returns_valid_when_hashes_match(self, seeded_db, animal_service, strip1_location_id):
"""validate_selection returns valid=True when hashes match."""
# Create a cohort
payload = make_cohort_payload(strip1_location_id, count=3)
ts_utc = int(time.time() * 1000)
animal_service.create_cohort(payload, ts_utc, "test_user")
# Resolve at same timestamp to get correct hash
filter_ast = parse_filter("species:duck")
resolution = resolve_filter(seeded_db, filter_ast, ts_utc)
# Create context with matching hash
ctx = SelectionContext(
filter="species:duck",
resolved_ids=resolution.animal_ids,
roster_hash=resolution.roster_hash,
ts_utc=ts_utc,
from_location_id=None,
)
result = validate_selection(seeded_db, ctx)
assert result.valid is True
assert result.resolved_ids == resolution.animal_ids
assert result.roster_hash == resolution.roster_hash
assert result.diff is None
def test_returns_valid_with_empty_filter(self, seeded_db, animal_service, strip1_location_id):
"""validate_selection works with empty filter (match all)."""
# Create a cohort
payload = make_cohort_payload(strip1_location_id, count=2)
ts_utc = int(time.time() * 1000)
animal_service.create_cohort(payload, ts_utc, "test_user")
# Resolve with empty filter
filter_ast = parse_filter("")
resolution = resolve_filter(seeded_db, filter_ast, ts_utc)
ctx = SelectionContext(
filter="",
resolved_ids=resolution.animal_ids,
roster_hash=resolution.roster_hash,
ts_utc=ts_utc,
from_location_id=None,
)
result = validate_selection(seeded_db, ctx)
assert result.valid is True
assert result.diff is None
# ============================================================================
# Tests for validate_selection - hash mismatch
# ============================================================================
class TestValidateSelectionHashMismatch:
"""Tests for validate_selection when hashes don't match."""
def test_returns_invalid_when_animal_added(self, seeded_db, animal_service, strip1_location_id):
"""validate_selection returns valid=False when new animal was added."""
# Create initial cohort
ts_before = int(time.time() * 1000)
payload1 = make_cohort_payload(strip1_location_id, count=2)
animal_service.create_cohort(payload1, ts_before, "test_user")
# Client resolves at ts_before
filter_ast = parse_filter("species:duck")
client_resolution = resolve_filter(seeded_db, filter_ast, ts_before)
# Add another animal after client resolution
ts_after = ts_before + 1000
payload2 = make_cohort_payload(strip1_location_id, count=1)
event2 = animal_service.create_cohort(payload2, ts_after, "test_user")
new_animal_id = event2.entity_refs["animal_ids"][0]
# Create context with old hash but at ts_after
ctx = SelectionContext(
filter="species:duck",
resolved_ids=client_resolution.animal_ids,
roster_hash=client_resolution.roster_hash,
ts_utc=ts_after,
from_location_id=None,
)
result = validate_selection(seeded_db, ctx)
assert result.valid is False
assert result.diff is not None
assert new_animal_id in result.diff.added
assert result.diff.removed == []
assert result.diff.server_count == 3
assert result.diff.client_count == 2
def test_returns_invalid_when_animal_moved_away(
self, seeded_db, animal_service, strip1_location_id, strip2_location_id
):
"""validate_selection returns valid=False when animal moved to different location."""
# Create cohort at Strip 1
ts_create = int(time.time() * 1000)
payload = make_cohort_payload(strip1_location_id, count=3)
event = animal_service.create_cohort(payload, ts_create, "test_user")
animal_ids = event.entity_refs["animal_ids"]
# Client resolves Strip 1 filter
filter_ast = parse_filter("location:'Strip 1'")
client_resolution = resolve_filter(seeded_db, filter_ast, ts_create)
# Move one animal to Strip 2
ts_move = ts_create + 1000
move_payload = AnimalMovedPayload(
resolved_ids=[animal_ids[0]],
from_location_id=strip1_location_id,
to_location_id=strip2_location_id,
)
animal_service.move_animals(move_payload, ts_move, "test_user")
# Create context with old hash but at ts_move
ctx = SelectionContext(
filter="location:'Strip 1'",
resolved_ids=client_resolution.animal_ids,
roster_hash=client_resolution.roster_hash,
ts_utc=ts_move,
from_location_id=None,
)
result = validate_selection(seeded_db, ctx)
assert result.valid is False
assert result.diff is not None
assert result.diff.added == []
assert animal_ids[0] in result.diff.removed
assert result.diff.server_count == 2
assert result.diff.client_count == 3
class TestValidateSelectionDiffCorrectness:
"""Tests for accurate diff computation."""
def test_diff_includes_both_added_and_removed(
self, seeded_db, animal_service, strip1_location_id, strip2_location_id
):
"""validate_selection diff correctly shows both added and removed."""
# Create cohorts at both locations
ts_create = int(time.time() * 1000)
# 2 ducks at Strip 1
payload1 = make_cohort_payload(strip1_location_id, count=2, species="duck")
event1 = animal_service.create_cohort(payload1, ts_create, "test_user")
strip1_ducks = event1.entity_refs["animal_ids"]
# 1 goose at Strip 1 (won't match species:duck filter)
payload2 = make_cohort_payload(strip1_location_id, count=1, species="goose")
event2 = animal_service.create_cohort(payload2, ts_create, "test_user")
strip1_goose = event2.entity_refs["animal_ids"][0]
# Client thinks they resolved "location:'Strip 1'" but used wrong IDs
# (simulating a bug where client included goose but excluded one duck)
wrong_client_ids = [strip1_ducks[0], strip1_goose] # missing strip1_ducks[1]
wrong_hash = compute_roster_hash(wrong_client_ids)
ctx = SelectionContext(
filter="species:duck",
resolved_ids=wrong_client_ids,
roster_hash=wrong_hash,
ts_utc=ts_create,
from_location_id=None,
)
result = validate_selection(seeded_db, ctx)
assert result.valid is False
assert result.diff is not None
# Server resolved both ducks, but client had duck[0] + goose
# So: added = duck[1], removed = goose
assert strip1_ducks[1] in result.diff.added
assert strip1_goose in result.diff.removed
# ============================================================================
# Tests for validate_selection - confirmed bypass
# ============================================================================
class TestValidateSelectionConfirmedBypass:
"""Tests for confirmed=True bypassing mismatch."""
def test_confirmed_true_returns_valid_despite_mismatch(
self, seeded_db, animal_service, strip1_location_id
):
"""validate_selection returns valid=True when confirmed=True despite mismatch."""
# Create cohort
ts_before = int(time.time() * 1000)
payload1 = make_cohort_payload(strip1_location_id, count=2)
animal_service.create_cohort(payload1, ts_before, "test_user")
# Client resolves at ts_before
filter_ast = parse_filter("species:duck")
client_resolution = resolve_filter(seeded_db, filter_ast, ts_before)
# Add another animal
ts_after = ts_before + 1000
payload2 = make_cohort_payload(strip1_location_id, count=1)
animal_service.create_cohort(payload2, ts_after, "test_user")
# Create context with old hash, new timestamp, but confirmed=True
ctx = SelectionContext(
filter="species:duck",
resolved_ids=client_resolution.animal_ids,
roster_hash=client_resolution.roster_hash,
ts_utc=ts_after,
from_location_id=None,
confirmed=True,
)
result = validate_selection(seeded_db, ctx)
assert result.valid is True
# Uses client's IDs when confirmed
assert result.resolved_ids == client_resolution.animal_ids
# Still includes diff for transparency
assert result.diff is not None
assert result.diff.server_count == 3
assert result.diff.client_count == 2
# ============================================================================
# Tests for from_location_id in hash
# ============================================================================
class TestValidateSelectionFromLocationId:
"""Tests for from_location_id inclusion in hash comparison."""
def test_from_location_id_included_in_hash_comparison(
self, seeded_db, animal_service, strip1_location_id, strip2_location_id
):
"""validate_selection includes from_location_id in hash when provided."""
# Create cohort at Strip 1
ts_utc = int(time.time() * 1000)
payload = make_cohort_payload(strip1_location_id, count=2)
event = animal_service.create_cohort(payload, ts_utc, "test_user")
animal_ids = event.entity_refs["animal_ids"]
# Compute hash with from_location_id
hash_with_from = compute_roster_hash(sorted(animal_ids), strip1_location_id)
# Context with from_location_id should match hash_with_from
ctx_with_from = SelectionContext(
filter="",
resolved_ids=sorted(animal_ids),
roster_hash=hash_with_from,
ts_utc=ts_utc,
from_location_id=strip1_location_id,
)
result = validate_selection(seeded_db, ctx_with_from)
assert result.valid is True
assert result.roster_hash == hash_with_from
def test_mismatched_from_location_id_causes_invalid(
self, seeded_db, animal_service, strip1_location_id, strip2_location_id
):
"""validate_selection returns invalid if from_location_id differs."""
# Create cohort at Strip 1
ts_utc = int(time.time() * 1000)
payload = make_cohort_payload(strip1_location_id, count=2)
event = animal_service.create_cohort(payload, ts_utc, "test_user")
animal_ids = event.entity_refs["animal_ids"]
# Client computed hash without from_location_id
hash_without_from = compute_roster_hash(sorted(animal_ids))
# But context says from_location_id is set (hash should include it)
ctx = SelectionContext(
filter="",
resolved_ids=sorted(animal_ids),
roster_hash=hash_without_from, # Wrong hash
ts_utc=ts_utc,
from_location_id=strip1_location_id, # Server will include this
)
result = validate_selection(seeded_db, ctx)
# Hash mismatch because server includes from_location_id
assert result.valid is False
# ============================================================================
# Tests for SelectionMismatchError
# ============================================================================
class TestSelectionMismatchError:
"""Tests for SelectionMismatchError exception."""
def test_stores_validation_result(self):
"""SelectionMismatchError stores the validation result."""
diff = SelectionDiff(added=["id1"], removed=[], server_count=2, client_count=1)
result = SelectionValidationResult(
valid=False,
resolved_ids=["id1", "id2"],
roster_hash="abc",
diff=diff,
)
error = SelectionMismatchError(result)
assert error.result is result
assert error.result.diff is diff