# ABOUTME: Tests for the database module. # ABOUTME: Validates connection factory, pragmas, and transaction handling. import threading import time import pytest class TestGetDb: """Test the get_db connection factory.""" def test_returns_database_connection(self, temp_db_path): """get_db should return a usable database connection.""" from animaltrack.db import get_db db = get_db(temp_db_path) assert db is not None # Should be able to execute a simple query result = db.execute("SELECT 1").fetchone() assert result[0] == 1 def test_pragma_journal_mode_wal(self, temp_db_path): """Journal mode should be set to WAL.""" from animaltrack.db import get_db db = get_db(temp_db_path) result = db.execute("PRAGMA journal_mode").fetchone() assert result[0].lower() == "wal" def test_pragma_synchronous_full(self, temp_db_path): """Synchronous mode should be set to FULL (2).""" from animaltrack.db import get_db db = get_db(temp_db_path) result = db.execute("PRAGMA synchronous").fetchone() # FULL = 2 assert result[0] == 2 def test_pragma_foreign_keys_on(self, temp_db_path): """Foreign keys should be enabled.""" from animaltrack.db import get_db db = get_db(temp_db_path) result = db.execute("PRAGMA foreign_keys").fetchone() assert result[0] == 1 def test_pragma_busy_timeout(self, temp_db_path): """Busy timeout should be set to 5000ms.""" from animaltrack.db import get_db db = get_db(temp_db_path) result = db.execute("PRAGMA busy_timeout").fetchone() assert result[0] == 5000 class TestTransaction: """Test the transaction context manager.""" def test_commits_on_success(self, temp_db_path): """Transaction should commit changes on successful completion.""" from animaltrack.db import get_db, transaction db = get_db(temp_db_path) db.execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY, value TEXT)") with transaction(db): db.execute("INSERT INTO test_table (value) VALUES ('test')") # Should be committed - query outside transaction should see it result = db.execute("SELECT value FROM test_table").fetchone() assert result[0] == "test" def test_rollback_on_exception(self, temp_db_path): """Transaction should rollback changes on exception.""" from animaltrack.db import get_db, transaction db = get_db(temp_db_path) db.execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY, value TEXT)") with pytest.raises(ValueError): with transaction(db): db.execute("INSERT INTO test_table (value) VALUES ('test')") raise ValueError("Simulated error") # Should be rolled back - query should return nothing result = db.execute("SELECT value FROM test_table").fetchone() assert result is None def test_begin_immediate_blocks_concurrent_writes(self, temp_db_path): """BEGIN IMMEDIATE should block concurrent write transactions.""" from animaltrack.db import get_db, transaction db1 = get_db(temp_db_path) db2 = get_db(temp_db_path) db1.execute("CREATE TABLE test_table (id INTEGER PRIMARY KEY, value TEXT)") results = {"blocked": False, "error": None} def try_concurrent_write(): try: # This should block or fail because db1 holds IMMEDIATE lock with transaction(db2): db2.execute("INSERT INTO test_table (value) VALUES ('from_thread')") except Exception as e: results["blocked"] = True results["error"] = str(e) # Start transaction on db1 but don't commit yet with transaction(db1): db1.execute("INSERT INTO test_table (value) VALUES ('from_main')") # Try to start another write transaction from another connection thread = threading.Thread(target=try_concurrent_write) thread.start() # Give it a moment to try to acquire lock time.sleep(0.1) # Thread should either be blocked waiting or have gotten an error # Since busy_timeout is 5000ms, it will wait thread.join(timeout=1) # After db1 commits, db2 should have been able to proceed # Both rows should exist result = db1.execute("SELECT COUNT(*) FROM test_table").fetchone() assert result[0] == 2 # Both inserts succeeded def test_nested_transaction_raises_error(self, temp_db_path): """Nested transactions should raise an error.""" from animaltrack.db import get_db, transaction db = get_db(temp_db_path) with pytest.raises(RuntimeError, match="[Nn]ested"): with transaction(db): with transaction(db): pass