From 719d1e6ce77f171468eff0ca825d6ac0c166cbde Mon Sep 17 00:00:00 2001 From: Petru Paler Date: Wed, 31 Dec 2025 14:35:27 +0000 Subject: [PATCH] feat: implement user defaults persistence (Step 9.3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add user_defaults table and repository for persisting form defaults across sessions. Feed and egg forms now load/save user preferences. Changes: - Add migration 0009-user-defaults.sql with table schema - Add UserDefault model and UserDefaultsRepository - Integrate defaults into feed route (location, feed_type, amount) - Integrate defaults into egg route (location) - Add repository unit tests and route integration tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- PLAN.md | 12 +- migrations/0009-user-defaults.sql | 17 ++ src/animaltrack/models/reference.py | 27 +++ src/animaltrack/repositories/__init__.py | 2 + src/animaltrack/repositories/user_defaults.py | 77 ++++++ src/animaltrack/web/routes/eggs.py | 23 ++ src/animaltrack/web/routes/feed.py | 33 +++ tests/test_repositories_user_defaults.py | 158 +++++++++++++ tests/test_user_defaults_integration.py | 223 ++++++++++++++++++ 9 files changed, 566 insertions(+), 6 deletions(-) create mode 100644 migrations/0009-user-defaults.sql create mode 100644 src/animaltrack/repositories/user_defaults.py create mode 100644 tests/test_repositories_user_defaults.py create mode 100644 tests/test_user_defaults_integration.py diff --git a/PLAN.md b/PLAN.md index 659160a..cfa52a4 100644 --- a/PLAN.md +++ b/PLAN.md @@ -359,12 +359,12 @@ Check off items as completed. Each phase builds on the previous. - [x] Write tests: sale creates event, unit price calculated - [x] **Commit checkpoint**: 0eef3ed -### Step 9.3: User Defaults -- [ ] Create migration for user_defaults table -- [ ] Create `repositories/user_defaults.py` -- [ ] Integrate defaults into form rendering -- [ ] Write tests: defaults saved and loaded -- [ ] **Commit checkpoint** +### Step 9.3: User Defaults ✅ +- [x] Create migration for user_defaults table +- [x] Create `repositories/user_defaults.py` +- [x] Integrate defaults into form rendering +- [x] Write tests: defaults saved and loaded +- [x] **Commit checkpoint** --- diff --git a/migrations/0009-user-defaults.sql b/migrations/0009-user-defaults.sql new file mode 100644 index 0000000..39e379f --- /dev/null +++ b/migrations/0009-user-defaults.sql @@ -0,0 +1,17 @@ +-- ABOUTME: Migration for user_defaults table +-- ABOUTME: Stores per-user form defaults (location, feed type, etc.) that persist across sessions + +CREATE TABLE user_defaults ( + username TEXT NOT NULL REFERENCES users(username), + action TEXT NOT NULL CHECK(action IN ('collect_egg','feed_given')), + location_id TEXT, + species TEXT, + animal_filter TEXT, + feed_type_code TEXT, + amount_kg INTEGER, + bag_size_kg INTEGER, + updated_at_utc INTEGER NOT NULL, + PRIMARY KEY (username, action) +); + +CREATE INDEX idx_user_defaults_username ON user_defaults(username); diff --git a/src/animaltrack/models/reference.py b/src/animaltrack/models/reference.py index b3421b1..36970f2 100644 --- a/src/animaltrack/models/reference.py +++ b/src/animaltrack/models/reference.py @@ -2,6 +2,7 @@ # ABOUTME: These models validate data before database insertion and provide type safety. from enum import Enum +from typing import Literal from pydantic import BaseModel, Field, field_validator @@ -127,3 +128,29 @@ class User(BaseModel): msg = "Timestamp must be non-negative" raise ValueError(msg) return v + + +UserDefaultAction = Literal["collect_egg", "feed_given"] + + +class UserDefault(BaseModel): + """User default form values that persist across sessions.""" + + username: str + action: UserDefaultAction + location_id: str | None = None + species: str | None = None + animal_filter: str | None = None + feed_type_code: str | None = None + amount_kg: int | None = None + bag_size_kg: int | None = None + updated_at_utc: int + + @field_validator("updated_at_utc") + @classmethod + def timestamp_must_be_non_negative(cls, v: int) -> int: + """Timestamps must be >= 0 (milliseconds since Unix epoch).""" + if v < 0: + msg = "Timestamp must be non-negative" + raise ValueError(msg) + return v diff --git a/src/animaltrack/repositories/__init__.py b/src/animaltrack/repositories/__init__.py index b34e9c3..1ed5fb1 100644 --- a/src/animaltrack/repositories/__init__.py +++ b/src/animaltrack/repositories/__init__.py @@ -5,6 +5,7 @@ from animaltrack.repositories.feed_types import FeedTypeRepository from animaltrack.repositories.locations import LocationRepository from animaltrack.repositories.products import ProductRepository from animaltrack.repositories.species import SpeciesRepository +from animaltrack.repositories.user_defaults import UserDefaultsRepository from animaltrack.repositories.users import UserRepository __all__ = [ @@ -12,5 +13,6 @@ __all__ = [ "LocationRepository", "ProductRepository", "SpeciesRepository", + "UserDefaultsRepository", "UserRepository", ] diff --git a/src/animaltrack/repositories/user_defaults.py b/src/animaltrack/repositories/user_defaults.py new file mode 100644 index 0000000..0606469 --- /dev/null +++ b/src/animaltrack/repositories/user_defaults.py @@ -0,0 +1,77 @@ +# ABOUTME: Repository for user form defaults. +# ABOUTME: Provides get and upsert operations for the user_defaults table. + +from typing import Any + +from animaltrack.models.reference import UserDefault, UserDefaultAction + + +class UserDefaultsRepository: + """Repository for managing user default form values.""" + + def __init__(self, db: Any) -> None: + """Initialize repository with database connection. + + Args: + db: A fastlite database connection. + """ + self.db = db + + def get(self, username: str, action: UserDefaultAction) -> UserDefault | None: + """Get user defaults for a specific action. + + Args: + username: The username. + action: The action type ('collect_egg' or 'feed_given'). + + Returns: + The UserDefault if found, None otherwise. + """ + row = self.db.execute( + """ + SELECT username, action, location_id, species, animal_filter, + feed_type_code, amount_kg, bag_size_kg, updated_at_utc + FROM user_defaults + WHERE username = ? AND action = ? + """, + (username, action), + ).fetchone() + if row is None: + return None + return UserDefault( + username=row[0], + action=row[1], + location_id=row[2], + species=row[3], + animal_filter=row[4], + feed_type_code=row[5], + amount_kg=row[6], + bag_size_kg=row[7], + updated_at_utc=row[8], + ) + + def upsert(self, defaults: UserDefault) -> None: + """Insert or update user defaults. + + Args: + defaults: The user defaults to upsert. + """ + self.db.execute( + """ + INSERT OR REPLACE INTO user_defaults + (username, action, location_id, species, animal_filter, + feed_type_code, amount_kg, bag_size_kg, updated_at_utc) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + defaults.username, + defaults.action, + defaults.location_id, + defaults.species, + defaults.animal_filter, + defaults.feed_type_code, + defaults.amount_kg, + defaults.bag_size_kg, + defaults.updated_at_utc, + ), + ) diff --git a/src/animaltrack/web/routes/eggs.py b/src/animaltrack/web/routes/eggs.py index 30bdd27..0bf7ce2 100644 --- a/src/animaltrack/web/routes/eggs.py +++ b/src/animaltrack/web/routes/eggs.py @@ -13,12 +13,15 @@ from starlette.responses import HTMLResponse from animaltrack.events.payloads import ProductCollectedPayload from animaltrack.events.store import EventStore +from animaltrack.models.reference import UserDefault from animaltrack.projections import EventLogProjection, ProjectionRegistry from animaltrack.projections.animal_registry import AnimalRegistryProjection from animaltrack.projections.event_animals import EventAnimalsProjection from animaltrack.projections.intervals import IntervalProjection from animaltrack.projections.products import ProductsProjection from animaltrack.repositories.locations import LocationRepository +from animaltrack.repositories.user_defaults import UserDefaultsRepository +from animaltrack.repositories.users import UserRepository from animaltrack.services.products import ProductService, ValidationError from animaltrack.web.templates import page from animaltrack.web.templates.eggs import egg_form @@ -58,6 +61,15 @@ def egg_index(request: Request): # Check for pre-selected location from query params selected_location_id = request.query_params.get("location_id") + # If no query param, load from user defaults + if not selected_location_id: + auth = request.scope.get("auth") + username = auth.username if auth else None + if username: + defaults = UserDefaultsRepository(db).get(username, "collect_egg") + if defaults: + selected_location_id = defaults.location_id + return page( egg_form(locations, selected_location_id=selected_location_id, action=product_collected), title="Egg - AnimalTrack", @@ -137,6 +149,17 @@ async def product_collected(request: Request): except ValidationError as e: return _render_error_form(locations, location_id, str(e)) + # Save user defaults (only if user exists in database) + if UserRepository(db).get(actor): + UserDefaultsRepository(db).upsert( + UserDefault( + username=actor, + action="collect_egg", + location_id=location_id, + updated_at_utc=ts_utc, + ) + ) + # Success: re-render form with location sticking, qty cleared response = HTMLResponse( content=to_xml( diff --git a/src/animaltrack/web/routes/feed.py b/src/animaltrack/web/routes/feed.py index bf8cff3..d1f7fa3 100644 --- a/src/animaltrack/web/routes/feed.py +++ b/src/animaltrack/web/routes/feed.py @@ -12,10 +12,13 @@ from starlette.responses import HTMLResponse from animaltrack.events.payloads import FeedGivenPayload, FeedPurchasedPayload from animaltrack.events.store import EventStore +from animaltrack.models.reference import UserDefault from animaltrack.projections import EventLogProjection, ProjectionRegistry from animaltrack.projections.feed import FeedInventoryProjection from animaltrack.repositories.feed_types import FeedTypeRepository from animaltrack.repositories.locations import LocationRepository +from animaltrack.repositories.user_defaults import UserDefaultsRepository +from animaltrack.repositories.users import UserRepository from animaltrack.services.feed import FeedService, ValidationError from animaltrack.web.templates import page from animaltrack.web.templates.feed import feed_page @@ -49,11 +52,28 @@ def feed_index(request: Request): if active_tab not in ("give", "purchase"): active_tab = "give" + # Load user defaults + auth = request.scope.get("auth") + username = auth.username if auth else None + selected_location_id = None + selected_feed_type_code = None + default_amount_kg = None + + if username: + defaults = UserDefaultsRepository(db).get(username, "feed_given") + if defaults: + selected_location_id = defaults.location_id + selected_feed_type_code = defaults.feed_type_code + default_amount_kg = defaults.amount_kg + return page( feed_page( locations, feed_types, active_tab=active_tab, + selected_location_id=selected_location_id, + selected_feed_type_code=selected_feed_type_code, + default_amount_kg=default_amount_kg, give_action=feed_given, purchase_action=feed_purchased, ), @@ -173,6 +193,19 @@ async def feed_given(request: Request): if balance is not None and balance < 0: balance_warning = f"Warning: Inventory balance is now negative ({balance} kg)" + # Save user defaults (only if user exists in database) + if UserRepository(db).get(actor): + UserDefaultsRepository(db).upsert( + UserDefault( + username=actor, + action="feed_given", + location_id=location_id, + feed_type_code=feed_type_code, + amount_kg=amount_kg, + updated_at_utc=ts_utc, + ) + ) + # Success: re-render form with location/type sticking, amount reset response = HTMLResponse( content=str( diff --git a/tests/test_repositories_user_defaults.py b/tests/test_repositories_user_defaults.py new file mode 100644 index 0000000..2e241ac --- /dev/null +++ b/tests/test_repositories_user_defaults.py @@ -0,0 +1,158 @@ +# ABOUTME: Tests for UserDefaultsRepository. +# ABOUTME: Validates CRUD operations for user form defaults. + +import time + +import pytest + +from animaltrack.models.reference import UserDefault +from animaltrack.repositories.user_defaults import UserDefaultsRepository + + +@pytest.fixture +def now_utc(): + """Current time in milliseconds since epoch.""" + return int(time.time() * 1000) + + +class TestUserDefaultsRepository: + """Tests for UserDefaultsRepository.""" + + def test_get_returns_none_for_missing(self, seeded_db): + """get returns None when no defaults exist.""" + repo = UserDefaultsRepository(seeded_db) + result = repo.get("ppetru", "collect_egg") + assert result is None + + def test_upsert_creates_new_record(self, seeded_db, now_utc): + """upsert creates a new defaults record.""" + repo = UserDefaultsRepository(seeded_db) + defaults = UserDefault( + username="ppetru", + action="collect_egg", + location_id="01ARZ3NDEKTSV4RRFFQ69G5FAV", + updated_at_utc=now_utc, + ) + repo.upsert(defaults) + + result = repo.get("ppetru", "collect_egg") + assert result is not None + assert result.username == "ppetru" + assert result.action == "collect_egg" + assert result.location_id == "01ARZ3NDEKTSV4RRFFQ69G5FAV" + + def test_upsert_updates_existing_record(self, seeded_db, now_utc): + """upsert updates an existing defaults record.""" + repo = UserDefaultsRepository(seeded_db) + defaults = UserDefault( + username="ppetru", + action="feed_given", + location_id="01ARZ3NDEKTSV4RRFFQ69G5FAV", + feed_type_code="layer", + amount_kg=20, + updated_at_utc=now_utc, + ) + repo.upsert(defaults) + + updated = UserDefault( + username="ppetru", + action="feed_given", + location_id="01ARZ3NDEKTSV4RRFFQ69G5ABC", + feed_type_code="grower", + amount_kg=25, + updated_at_utc=now_utc + 1000, + ) + repo.upsert(updated) + + result = repo.get("ppetru", "feed_given") + assert result.location_id == "01ARZ3NDEKTSV4RRFFQ69G5ABC" + assert result.feed_type_code == "grower" + assert result.amount_kg == 25 + + def test_get_returns_all_fields(self, seeded_db, now_utc): + """get returns all stored fields correctly.""" + repo = UserDefaultsRepository(seeded_db) + defaults = UserDefault( + username="ppetru", + action="feed_given", + location_id="01ARZ3NDEKTSV4RRFFQ69G5FAV", + species="duck", + animal_filter="location:strip1", + feed_type_code="layer", + amount_kg=20, + bag_size_kg=25, + updated_at_utc=now_utc, + ) + repo.upsert(defaults) + + result = repo.get("ppetru", "feed_given") + assert result.location_id == "01ARZ3NDEKTSV4RRFFQ69G5FAV" + assert result.species == "duck" + assert result.animal_filter == "location:strip1" + assert result.feed_type_code == "layer" + assert result.amount_kg == 20 + assert result.bag_size_kg == 25 + + def test_different_actions_are_independent(self, seeded_db, now_utc): + """Different actions for same user are stored independently.""" + repo = UserDefaultsRepository(seeded_db) + egg_defaults = UserDefault( + username="ppetru", + action="collect_egg", + location_id="01ARZ3NDEKTSV4RRFFQ69G5EGG", + updated_at_utc=now_utc, + ) + feed_defaults = UserDefault( + username="ppetru", + action="feed_given", + location_id="01ARZ3NDEKTSV4RRFFQ69G5FED", + updated_at_utc=now_utc, + ) + repo.upsert(egg_defaults) + repo.upsert(feed_defaults) + + egg_result = repo.get("ppetru", "collect_egg") + feed_result = repo.get("ppetru", "feed_given") + + assert egg_result.location_id == "01ARZ3NDEKTSV4RRFFQ69G5EGG" + assert feed_result.location_id == "01ARZ3NDEKTSV4RRFFQ69G5FED" + + def test_different_users_are_independent(self, seeded_db, now_utc): + """Different users have independent defaults.""" + repo = UserDefaultsRepository(seeded_db) + ppetru_defaults = UserDefault( + username="ppetru", + action="collect_egg", + location_id="01ARZ3NDEKTSV4RRFFQ69G5PPP", + updated_at_utc=now_utc, + ) + ines_defaults = UserDefault( + username="ines", + action="collect_egg", + location_id="01ARZ3NDEKTSV4RRFFQ69G5III", + updated_at_utc=now_utc, + ) + repo.upsert(ppetru_defaults) + repo.upsert(ines_defaults) + + ppetru_result = repo.get("ppetru", "collect_egg") + ines_result = repo.get("ines", "collect_egg") + + assert ppetru_result.location_id == "01ARZ3NDEKTSV4RRFFQ69G5PPP" + assert ines_result.location_id == "01ARZ3NDEKTSV4RRFFQ69G5III" + + def test_null_fields_preserved(self, seeded_db, now_utc): + """Null fields are stored and retrieved correctly.""" + repo = UserDefaultsRepository(seeded_db) + defaults = UserDefault( + username="ppetru", + action="collect_egg", + location_id=None, + species=None, + updated_at_utc=now_utc, + ) + repo.upsert(defaults) + + result = repo.get("ppetru", "collect_egg") + assert result.location_id is None + assert result.species is None diff --git a/tests/test_user_defaults_integration.py b/tests/test_user_defaults_integration.py new file mode 100644 index 0000000..57019a7 --- /dev/null +++ b/tests/test_user_defaults_integration.py @@ -0,0 +1,223 @@ +# ABOUTME: Integration tests for user defaults feature. +# ABOUTME: Verifies that form defaults are saved and loaded correctly. + +import os +import time + +import pytest +from starlette.testclient import TestClient + +from animaltrack.events.payloads import FeedPurchasedPayload +from animaltrack.events.store import EventStore +from animaltrack.models.reference import UserDefault +from animaltrack.projections import ProjectionRegistry +from animaltrack.projections.feed import FeedInventoryProjection +from animaltrack.repositories.user_defaults import UserDefaultsRepository +from animaltrack.services.feed import FeedService + + +def make_test_settings( + csrf_secret: str = "test-secret", + trusted_proxy_ips: str = "127.0.0.1", + dev_mode: bool = False, # Disable dev_mode to test real auth +): + """Create Settings for testing by setting env vars temporarily.""" + from animaltrack.config import Settings + + old_env = os.environ.copy() + try: + os.environ["CSRF_SECRET"] = csrf_secret + os.environ["TRUSTED_PROXY_IPS"] = trusted_proxy_ips + os.environ["DEV_MODE"] = str(dev_mode).lower() + return Settings() + finally: + os.environ.clear() + os.environ.update(old_env) + + +@pytest.fixture +def client(seeded_db): + """Create a test client for the app with real auth enabled.""" + from animaltrack.web.app import create_app + + settings = make_test_settings(trusted_proxy_ips="testclient", dev_mode=False) + app, rt = create_app(settings=settings, db=seeded_db) + return TestClient(app, raise_server_exceptions=True) + + +@pytest.fixture +def location_strip1_id(seeded_db): + """Get Strip 1 location ID from seeded data.""" + row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 1'").fetchone() + return row[0] + + +@pytest.fixture +def location_strip2_id(seeded_db): + """Get Strip 2 location ID from seeded data.""" + row = seeded_db.execute("SELECT id FROM locations WHERE name = 'Strip 2'").fetchone() + return row[0] + + +@pytest.fixture +def feed_purchase_in_db(seeded_db): + """Create a feed purchase so give_feed can work.""" + event_store = EventStore(seeded_db) + registry = ProjectionRegistry() + registry.register(FeedInventoryProjection(seeded_db)) + feed_service = FeedService(seeded_db, event_store, registry) + + payload = FeedPurchasedPayload( + feed_type_code="layer", + bag_size_kg=20, + bags_count=5, + bag_price_cents=2400, + ) + ts_utc = int(time.time() * 1000) - 86400000 + feed_service.purchase_feed(payload, ts_utc, "ppetru") + return payload + + +def make_csrf_headers(csrf_token: str = "test-csrf-token"): + """Make headers with CSRF token for POST requests.""" + return { + "X-CSRF-Token": csrf_token, + "Origin": "http://testserver", # Match TestClient's default host + } + + +class TestFeedUserDefaults: + """Tests for feed form user defaults.""" + + def test_defaults_saved_on_successful_give( + self, client, seeded_db, location_strip1_id, feed_purchase_in_db + ): + """Successful feed-given saves user defaults.""" + csrf_token = "test-csrf-token" + response = client.post( + "/actions/feed-given", + data={ + "location_id": location_strip1_id, + "feed_type_code": "layer", + "amount_kg": "15", + }, + headers={ + "X-Oidc-Username": "ppetru", + **make_csrf_headers(csrf_token), + }, + cookies={"csrf_token": csrf_token}, + ) + assert response.status_code == 200 + + # Verify defaults were saved + defaults = UserDefaultsRepository(seeded_db).get("ppetru", "feed_given") + assert defaults is not None + assert defaults.location_id == location_strip1_id + assert defaults.feed_type_code == "layer" + assert defaults.amount_kg == 15 + + def test_defaults_loaded_on_feed_page(self, client, seeded_db, location_strip1_id): + """GET /feed loads saved user defaults.""" + # First set some defaults + now_utc = int(time.time() * 1000) + UserDefaultsRepository(seeded_db).upsert( + UserDefault( + username="ppetru", + action="feed_given", + location_id=location_strip1_id, + feed_type_code="grower", + amount_kg=25, + updated_at_utc=now_utc, + ) + ) + + # Load the feed page + response = client.get( + "/feed", + headers={"X-Oidc-Username": "ppetru"}, + ) + assert response.status_code == 200 + + # Check that the form has pre-selected values + content = response.text + assert f'value="{location_strip1_id}"' in content or "selected" in content + assert "grower" in content + + def test_no_defaults_for_unknown_user( + self, client, seeded_db, location_strip1_id, feed_purchase_in_db + ): + """Unknown users are rejected by auth middleware.""" + csrf_token = "test-csrf-token" + response = client.post( + "/actions/feed-given", + data={ + "location_id": location_strip1_id, + "feed_type_code": "layer", + "amount_kg": "10", + }, + headers={ + "X-Oidc-Username": "unknown_user", + **make_csrf_headers(csrf_token), + }, + cookies={"csrf_token": csrf_token}, + ) + # Unknown user is rejected by auth middleware + assert response.status_code == 401 + + # Verify no defaults were saved + defaults = UserDefaultsRepository(seeded_db).get("unknown_user", "feed_given") + assert defaults is None + + +class TestEggUserDefaults: + """Tests for egg form user defaults.""" + + def test_defaults_loaded_on_egg_page(self, client, seeded_db, location_strip1_id): + """GET / loads saved user defaults for egg collection.""" + # First set some defaults + now_utc = int(time.time() * 1000) + UserDefaultsRepository(seeded_db).upsert( + UserDefault( + username="ppetru", + action="collect_egg", + location_id=location_strip1_id, + updated_at_utc=now_utc, + ) + ) + + # Load the egg page + response = client.get( + "/", + headers={"X-Oidc-Username": "ppetru"}, + ) + assert response.status_code == 200 + + # Check that the form has pre-selected location + content = response.text + assert location_strip1_id in content + + def test_query_param_overrides_defaults( + self, client, seeded_db, location_strip1_id, location_strip2_id + ): + """Query param location_id overrides saved defaults.""" + # Set defaults to Strip 1 + now_utc = int(time.time() * 1000) + UserDefaultsRepository(seeded_db).upsert( + UserDefault( + username="ppetru", + action="collect_egg", + location_id=location_strip1_id, + updated_at_utc=now_utc, + ) + ) + + # Load the egg page with Strip 2 in query params + response = client.get( + f"/?location_id={location_strip2_id}", + headers={"X-Oidc-Username": "ppetru"}, + ) + assert response.status_code == 200 + + # Query param should take precedence - Strip 2 should be selected + content = response.text + assert location_strip2_id in content