Compare commits

...

3 Commits

Author SHA1 Message Date
c8a738d0c4 Fix garth token serialization using Pydantic v2 API
All checks were successful
Deploy / deploy (push) Successful in 1m36s
The garth library uses Pydantic dataclasses for OAuth tokens which don't
have a serialize() method. Use model_dump() instead, and fix expires_at
handling since it's an integer timestamp not a datetime object.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 16:49:16 +00:00
2408839b8b Fix 404 error when saving user preferences
Routes using withAuth were creating new unauthenticated PocketBase
clients, causing 404 errors when trying to update records. Modified
withAuth to pass the authenticated pb client to handlers so they can
use it for database operations.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-12 16:45:55 +00:00
df2f52ad50 Install deps for python script. 2026-01-12 15:00:31 +00:00
19 changed files with 125 additions and 128 deletions

View File

@@ -8,12 +8,35 @@
system = "x86_64-linux"; system = "x86_64-linux";
pkgs = nixpkgs.legacyPackages.${system}; pkgs = nixpkgs.legacyPackages.${system};
# Custom Python package: garth (not in nixpkgs)
garth = pkgs.python3Packages.buildPythonPackage {
pname = "garth";
version = "0.5.21";
src = pkgs.fetchPypi {
pname = "garth";
version = "0.5.21";
sha256 = "sha256-jZeVldHU6iOhtGarSmCVXRObcfiG9GSQvhQPzuWE2rQ=";
};
format = "pyproject";
nativeBuildInputs = [ pkgs.python3Packages.hatchling ];
propagatedBuildInputs = with pkgs.python3Packages; [
pydantic
requests-oauthlib
requests
];
doCheck = false;
};
# Python with garth for Garmin auth scripts
pythonWithGarth = pkgs.python3.withPackages (ps: [ garth ]);
# Common packages for development # Common packages for development
commonPackages = with pkgs; [ commonPackages = [
nodejs_24 pkgs.nodejs_24
pnpm pkgs.pnpm
git pkgs.git
pocketbase pkgs.pocketbase
pythonWithGarth
]; ];
in { in {
# Docker image for production deployment # Docker image for production deployment

View File

@@ -10,6 +10,7 @@ Usage:
python3 garmin_auth.py python3 garmin_auth.py
""" """
import json import json
from datetime import datetime
from getpass import getpass from getpass import getpass
try: try:
@@ -26,12 +27,13 @@ password = getpass("Garmin password: ")
garth.login(email, password) garth.login(email, password)
tokens = { tokens = {
"oauth1": garth.client.oauth1_token.serialize(), "oauth1": garth.client.oauth1_token.model_dump(),
"oauth2": garth.client.oauth2_token.serialize(), "oauth2": garth.client.oauth2_token.model_dump(),
"expires_at": garth.client.oauth2_token.expires_at.isoformat() "expires_at": garth.client.oauth2_token.expires_at
} }
print("\n--- Copy everything below this line ---") print("\n--- Copy everything below this line ---")
print(json.dumps(tokens, indent=2)) print(json.dumps(tokens, indent=2))
print("--- Copy everything above this line ---") print("--- Copy everything above this line ---")
print(f"\nTokens expire: {tokens['expires_at']}") expires_dt = datetime.fromtimestamp(tokens['expires_at'])
print(f"\nTokens expire: {expires_dt.isoformat()}")

View File

@@ -12,7 +12,7 @@ import { logger } from "@/lib/logger";
* Clears the user's authentication session by deleting the pb_auth cookie. * Clears the user's authentication session by deleting the pb_auth cookie.
* Returns a success response with redirect URL. * Returns a success response with redirect URL.
*/ */
export async function POST(): Promise<NextResponse> { export async function POST(_request: Request): Promise<NextResponse> {
try { try {
const cookieStore = await cookies(); const cookieStore = await cookies();

View File

@@ -12,14 +12,12 @@ let currentMockUser: User | null = null;
// Track PocketBase update calls // Track PocketBase update calls
const mockPbUpdate = vi.fn().mockResolvedValue({}); const mockPbUpdate = vi.fn().mockResolvedValue({});
// Mock PocketBase // Create mock PocketBase client
vi.mock("@/lib/pocketbase", () => ({ const mockPb = {
createPocketBaseClient: vi.fn(() => ({
collection: vi.fn(() => ({ collection: vi.fn(() => ({
update: mockPbUpdate, update: mockPbUpdate,
})), })),
})), };
}));
// Mock the auth-middleware module // Mock the auth-middleware module
vi.mock("@/lib/auth-middleware", () => ({ vi.mock("@/lib/auth-middleware", () => ({
@@ -28,7 +26,7 @@ vi.mock("@/lib/auth-middleware", () => ({
if (!currentMockUser) { if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
} }
return handler(request, currentMockUser); return handler(request, currentMockUser, mockPb);
}; };
}), }),
})); }));

View File

@@ -5,7 +5,6 @@ import { randomBytes } from "node:crypto";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { withAuth } from "@/lib/auth-middleware"; import { withAuth } from "@/lib/auth-middleware";
import { createPocketBaseClient } from "@/lib/pocketbase";
/** /**
* Generates a cryptographically secure random 32-character alphanumeric token. * Generates a cryptographically secure random 32-character alphanumeric token.
@@ -17,12 +16,11 @@ function generateToken(): string {
return randomBytes(32).toString("hex").slice(0, 32); return randomBytes(32).toString("hex").slice(0, 32);
} }
export const POST = withAuth(async (_request, user) => { export const POST = withAuth(async (_request, user, pb) => {
// Generate new random token // Generate new random token
const newToken = generateToken(); const newToken = generateToken();
// Update user record with new token // Update user record with new token
const pb = createPocketBaseClient();
await pb.collection("users").update(user.id, { await pb.collection("users").update(user.id, {
calendarToken: newToken, calendarToken: newToken,
}); });

View File

@@ -13,17 +13,13 @@ let currentMockUser: User | null = null;
const mockPbUpdate = vi.fn(); const mockPbUpdate = vi.fn();
const mockPbCreate = vi.fn(); const mockPbCreate = vi.fn();
vi.mock("@/lib/pocketbase", () => ({ // Create mock PocketBase client
createPocketBaseClient: vi.fn(() => ({ const mockPb = {
collection: vi.fn((_name: string) => ({ collection: vi.fn((_name: string) => ({
update: mockPbUpdate, update: mockPbUpdate,
create: mockPbCreate, create: mockPbCreate,
})), })),
})), };
loadAuthFromCookies: vi.fn(),
isAuthenticated: vi.fn(() => currentMockUser !== null),
getCurrentUser: vi.fn(() => currentMockUser),
}));
// Mock the auth-middleware module // Mock the auth-middleware module
vi.mock("@/lib/auth-middleware", () => ({ vi.mock("@/lib/auth-middleware", () => ({
@@ -32,7 +28,7 @@ vi.mock("@/lib/auth-middleware", () => ({
if (!currentMockUser) { if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
} }
return handler(request, currentMockUser); return handler(request, currentMockUser, mockPb);
}; };
}), }),
})); }));

View File

@@ -6,7 +6,6 @@ import { NextResponse } from "next/server";
import { withAuth } from "@/lib/auth-middleware"; import { withAuth } from "@/lib/auth-middleware";
import { getCycleDay, getPhase } from "@/lib/cycle"; import { getCycleDay, getPhase } from "@/lib/cycle";
import { logger } from "@/lib/logger"; import { logger } from "@/lib/logger";
import { createPocketBaseClient } from "@/lib/pocketbase";
interface PeriodLogRequest { interface PeriodLogRequest {
startDate?: string; startDate?: string;
@@ -35,7 +34,7 @@ function isFutureDate(dateStr: string): boolean {
return inputDate > today; return inputDate > today;
} }
export const POST = withAuth(async (request: NextRequest, user) => { export const POST = withAuth(async (request: NextRequest, user, pb) => {
try { try {
const body = (await request.json()) as PeriodLogRequest; const body = (await request.json()) as PeriodLogRequest;
@@ -63,8 +62,6 @@ export const POST = withAuth(async (request: NextRequest, user) => {
); );
} }
const pb = createPocketBaseClient();
// Calculate predicted date based on previous cycle (if exists) // Calculate predicted date based on previous cycle (if exists)
let predictedDateStr: string | null = null; let predictedDateStr: string | null = null;
if (user.lastPeriodDate) { if (user.lastPeriodDate) {

View File

@@ -12,14 +12,12 @@ let currentMockUser: User | null = null;
// Track PocketBase update calls // Track PocketBase update calls
const mockPbUpdate = vi.fn().mockResolvedValue({}); const mockPbUpdate = vi.fn().mockResolvedValue({});
// Mock PocketBase // Create mock PocketBase client
vi.mock("@/lib/pocketbase", () => ({ const mockPb = {
createPocketBaseClient: vi.fn(() => ({
collection: vi.fn(() => ({ collection: vi.fn(() => ({
update: mockPbUpdate, update: mockPbUpdate,
})), })),
})), };
}));
// Track encryption calls // Track encryption calls
const mockEncrypt = vi.fn((plaintext: string) => `encrypted:${plaintext}`); const mockEncrypt = vi.fn((plaintext: string) => `encrypted:${plaintext}`);
@@ -36,7 +34,7 @@ vi.mock("@/lib/auth-middleware", () => ({
if (!currentMockUser) { if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
} }
return handler(request, currentMockUser); return handler(request, currentMockUser, mockPb);
}; };
}), }),
})); }));

View File

@@ -5,9 +5,8 @@ import { NextResponse } from "next/server";
import { withAuth } from "@/lib/auth-middleware"; import { withAuth } from "@/lib/auth-middleware";
import { encrypt } from "@/lib/encryption"; import { encrypt } from "@/lib/encryption";
import { daysUntilExpiry } from "@/lib/garmin"; import { daysUntilExpiry } from "@/lib/garmin";
import { createPocketBaseClient } from "@/lib/pocketbase";
export const POST = withAuth(async (request, user) => { export const POST = withAuth(async (request, user, pb) => {
const body = await request.json(); const body = await request.json();
const { oauth1, oauth2, expires_at } = body; const { oauth1, oauth2, expires_at } = body;
@@ -57,7 +56,6 @@ export const POST = withAuth(async (request, user) => {
const encryptedOauth2 = encrypt(JSON.stringify(oauth2)); const encryptedOauth2 = encrypt(JSON.stringify(oauth2));
// Update user record // Update user record
const pb = createPocketBaseClient();
await pb.collection("users").update(user.id, { await pb.collection("users").update(user.id, {
garminOauth1Token: encryptedOauth1, garminOauth1Token: encryptedOauth1,
garminOauth2Token: encryptedOauth2, garminOauth2Token: encryptedOauth2,
@@ -79,9 +77,7 @@ export const POST = withAuth(async (request, user) => {
}); });
}); });
export const DELETE = withAuth(async (_request, user) => { export const DELETE = withAuth(async (_request, user, pb) => {
const pb = createPocketBaseClient();
await pb.collection("users").update(user.id, { await pb.collection("users").update(user.id, {
garminOauth1Token: "", garminOauth1Token: "",
garminOauth2Token: "", garminOauth2Token: "",

View File

@@ -12,14 +12,12 @@ let currentMockUser: User | null = null;
// Track PocketBase collection calls // Track PocketBase collection calls
const mockGetList = vi.fn(); const mockGetList = vi.fn();
// Mock PocketBase // Create mock PocketBase client
vi.mock("@/lib/pocketbase", () => ({ const mockPb = {
createPocketBaseClient: vi.fn(() => ({
collection: vi.fn(() => ({ collection: vi.fn(() => ({
getList: mockGetList, getList: mockGetList,
})), })),
})), };
}));
// Mock the auth-middleware module // Mock the auth-middleware module
vi.mock("@/lib/auth-middleware", () => ({ vi.mock("@/lib/auth-middleware", () => ({
@@ -28,7 +26,7 @@ vi.mock("@/lib/auth-middleware", () => ({
if (!currentMockUser) { if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
} }
return handler(request, currentMockUser); return handler(request, currentMockUser, mockPb);
}; };
}), }),
})); }));

View File

@@ -3,7 +3,6 @@
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { withAuth } from "@/lib/auth-middleware"; import { withAuth } from "@/lib/auth-middleware";
import { createPocketBaseClient } from "@/lib/pocketbase";
import type { DailyLog } from "@/types"; import type { DailyLog } from "@/types";
// Validation constants // Validation constants
@@ -24,7 +23,7 @@ function isValidDateFormat(dateStr: string): boolean {
return !Number.isNaN(date.getTime()); return !Number.isNaN(date.getTime());
} }
export const GET = withAuth(async (request, user) => { export const GET = withAuth(async (request, user, pb) => {
const { searchParams } = request.nextUrl; const { searchParams } = request.nextUrl;
// Parse and validate page parameter // Parse and validate page parameter
@@ -77,7 +76,6 @@ export const GET = withAuth(async (request, user) => {
const filter = filters.join(" && "); const filter = filters.join(" && ");
// Query PocketBase // Query PocketBase
const pb = createPocketBaseClient();
const result = await pb const result = await pb
.collection("dailyLogs") .collection("dailyLogs")
.getList<DailyLog>(page, limit, { .getList<DailyLog>(page, limit, {

View File

@@ -14,21 +14,8 @@ let lastUpdateCall: {
data: { activeOverrides: OverrideType[] }; data: { activeOverrides: OverrideType[] };
} | null = null; } | null = null;
// Mock the auth-middleware module // Create mock PocketBase client
vi.mock("@/lib/auth-middleware", () => ({ const mockPb = {
withAuth: vi.fn((handler) => {
return async (request: NextRequest) => {
if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
return handler(request, currentMockUser);
};
}),
}));
// Mock the pocketbase module
vi.mock("@/lib/pocketbase", () => ({
createPocketBaseClient: vi.fn(() => ({
collection: vi.fn(() => ({ collection: vi.fn(() => ({
update: vi.fn( update: vi.fn(
async (id: string, data: { activeOverrides: OverrideType[] }) => { async (id: string, data: { activeOverrides: OverrideType[] }) => {
@@ -44,7 +31,18 @@ vi.mock("@/lib/pocketbase", () => ({
}, },
), ),
})), })),
})), };
// Mock the auth-middleware module
vi.mock("@/lib/auth-middleware", () => ({
withAuth: vi.fn((handler) => {
return async (request: NextRequest) => {
if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
}
return handler(request, currentMockUser, mockPb);
};
}),
})); }));
import { DELETE, POST } from "./route"; import { DELETE, POST } from "./route";

View File

@@ -5,7 +5,6 @@ import { NextResponse } from "next/server";
import { withAuth } from "@/lib/auth-middleware"; import { withAuth } from "@/lib/auth-middleware";
import { logger } from "@/lib/logger"; import { logger } from "@/lib/logger";
import { createPocketBaseClient } from "@/lib/pocketbase";
import type { OverrideType } from "@/types"; import type { OverrideType } from "@/types";
const VALID_OVERRIDE_TYPES: OverrideType[] = [ const VALID_OVERRIDE_TYPES: OverrideType[] = [
@@ -27,7 +26,7 @@ function isValidOverrideType(value: unknown): value is OverrideType {
* Request body: { override: OverrideType } * Request body: { override: OverrideType }
* Response: { activeOverrides: OverrideType[] } * Response: { activeOverrides: OverrideType[] }
*/ */
export const POST = withAuth(async (request: NextRequest, user) => { export const POST = withAuth(async (request: NextRequest, user, pb) => {
const body = await request.json(); const body = await request.json();
if (!body.override) { if (!body.override) {
@@ -55,7 +54,6 @@ export const POST = withAuth(async (request: NextRequest, user) => {
: [...currentOverrides, overrideToAdd]; : [...currentOverrides, overrideToAdd];
// Update the user record in PocketBase // Update the user record in PocketBase
const pb = createPocketBaseClient();
await pb await pb
.collection("users") .collection("users")
.update(user.id, { activeOverrides: newOverrides }); .update(user.id, { activeOverrides: newOverrides });
@@ -74,7 +72,7 @@ export const POST = withAuth(async (request: NextRequest, user) => {
* Request body: { override: OverrideType } * Request body: { override: OverrideType }
* Response: { activeOverrides: OverrideType[] } * Response: { activeOverrides: OverrideType[] }
*/ */
export const DELETE = withAuth(async (request: NextRequest, user) => { export const DELETE = withAuth(async (request: NextRequest, user, pb) => {
const body = await request.json(); const body = await request.json();
if (!body.override) { if (!body.override) {
@@ -100,7 +98,6 @@ export const DELETE = withAuth(async (request: NextRequest, user) => {
const newOverrides = currentOverrides.filter((o) => o !== overrideToRemove); const newOverrides = currentOverrides.filter((o) => o !== overrideToRemove);
// Update the user record in PocketBase // Update the user record in PocketBase
const pb = createPocketBaseClient();
await pb await pb
.collection("users") .collection("users")
.update(user.id, { activeOverrides: newOverrides }); .update(user.id, { activeOverrides: newOverrides });

View File

@@ -12,9 +12,8 @@ let currentMockUser: User | null = null;
// Module-level variable to control mock daily log in tests // Module-level variable to control mock daily log in tests
let currentMockDailyLog: DailyLog | null = null; let currentMockDailyLog: DailyLog | null = null;
// Mock PocketBase client for database operations // Create mock PocketBase client
vi.mock("@/lib/pocketbase", () => ({ const mockPb = {
createPocketBaseClient: vi.fn(() => ({
collection: vi.fn(() => ({ collection: vi.fn(() => ({
getFirstListItem: vi.fn(async () => { getFirstListItem: vi.fn(async () => {
if (!currentMockDailyLog) { if (!currentMockDailyLog) {
@@ -25,11 +24,7 @@ vi.mock("@/lib/pocketbase", () => ({
return currentMockDailyLog; return currentMockDailyLog;
}), }),
})), })),
})), };
loadAuthFromCookies: vi.fn(),
isAuthenticated: vi.fn(() => currentMockUser !== null),
getCurrentUser: vi.fn(() => currentMockUser),
}));
// Mock the auth-middleware module // Mock the auth-middleware module
vi.mock("@/lib/auth-middleware", () => ({ vi.mock("@/lib/auth-middleware", () => ({
@@ -38,7 +33,7 @@ vi.mock("@/lib/auth-middleware", () => ({
if (!currentMockUser) { if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
} }
return handler(request, currentMockUser); return handler(request, currentMockUser, mockPb);
}; };
}), }),
})); }));

View File

@@ -12,7 +12,6 @@ import {
import { getDecisionWithOverrides } from "@/lib/decision-engine"; import { getDecisionWithOverrides } from "@/lib/decision-engine";
import { logger } from "@/lib/logger"; import { logger } from "@/lib/logger";
import { getNutritionGuidance } from "@/lib/nutrition"; import { getNutritionGuidance } from "@/lib/nutrition";
import { createPocketBaseClient } from "@/lib/pocketbase";
import type { DailyData, DailyLog, HrvStatus } from "@/types"; import type { DailyData, DailyLog, HrvStatus } from "@/types";
// Default biometrics when no Garmin data is available // Default biometrics when no Garmin data is available
@@ -28,7 +27,7 @@ const DEFAULT_BIOMETRICS: {
weekIntensityMinutes: 0, weekIntensityMinutes: 0,
}; };
export const GET = withAuth(async (_request, user) => { export const GET = withAuth(async (_request, user, pb) => {
// Validate required user data // Validate required user data
if (!user.lastPeriodDate) { if (!user.lastPeriodDate) {
return NextResponse.json( return NextResponse.json(
@@ -70,7 +69,6 @@ export const GET = withAuth(async (_request, user) => {
// Try to fetch today's DailyLog for biometrics // Try to fetch today's DailyLog for biometrics
let biometrics = { ...DEFAULT_BIOMETRICS, phaseLimit }; let biometrics = { ...DEFAULT_BIOMETRICS, phaseLimit };
try { try {
const pb = createPocketBaseClient();
const today = new Date().toISOString().split("T")[0]; const today = new Date().toISOString().split("T")[0];
const dailyLog = await pb const dailyLog = await pb
.collection("dailyLogs") .collection("dailyLogs")

View File

@@ -12,14 +12,12 @@ let currentMockUser: User | null = null;
// Track PocketBase update calls // Track PocketBase update calls
const mockPbUpdate = vi.fn().mockResolvedValue({}); const mockPbUpdate = vi.fn().mockResolvedValue({});
// Mock PocketBase // Create mock PocketBase client
vi.mock("@/lib/pocketbase", () => ({ const mockPb = {
createPocketBaseClient: vi.fn(() => ({
collection: vi.fn(() => ({ collection: vi.fn(() => ({
update: mockPbUpdate, update: mockPbUpdate,
})), })),
})), };
}));
// Mock the auth-middleware module // Mock the auth-middleware module
vi.mock("@/lib/auth-middleware", () => ({ vi.mock("@/lib/auth-middleware", () => ({
@@ -28,7 +26,7 @@ vi.mock("@/lib/auth-middleware", () => ({
if (!currentMockUser) { if (!currentMockUser) {
return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
} }
return handler(request, currentMockUser); return handler(request, currentMockUser, mockPb);
}; };
}), }),
})); }));

View File

@@ -4,7 +4,6 @@ import type { NextRequest } from "next/server";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import { withAuth } from "@/lib/auth-middleware"; import { withAuth } from "@/lib/auth-middleware";
import { createPocketBaseClient } from "@/lib/pocketbase";
// Validation constants // Validation constants
const CYCLE_LENGTH_MIN = 21; const CYCLE_LENGTH_MIN = 21;
@@ -16,7 +15,7 @@ const TIME_FORMAT_REGEX = /^([01]\d|2[0-3]):([0-5]\d)$/;
* Returns the authenticated user's profile. * Returns the authenticated user's profile.
* Excludes sensitive fields like encrypted tokens. * Excludes sensitive fields like encrypted tokens.
*/ */
export const GET = withAuth(async (_request, user) => { export const GET = withAuth(async (_request, user, _pb) => {
// Format date for consistent API response // Format date for consistent API response
const lastPeriodDate = user.lastPeriodDate const lastPeriodDate = user.lastPeriodDate
? user.lastPeriodDate.toISOString().split("T")[0] ? user.lastPeriodDate.toISOString().split("T")[0]
@@ -81,7 +80,7 @@ function validateTimezone(value: unknown): string | null {
* Updates the authenticated user's profile. * Updates the authenticated user's profile.
* Allowed fields: cycleLength, notificationTime, timezone * Allowed fields: cycleLength, notificationTime, timezone
*/ */
export const PATCH = withAuth(async (request: NextRequest, user) => { export const PATCH = withAuth(async (request: NextRequest, user, pb) => {
const body = await request.json(); const body = await request.json();
// Build update object with only valid, updatable fields // Build update object with only valid, updatable fields
@@ -132,7 +131,6 @@ export const PATCH = withAuth(async (request: NextRequest, user) => {
} }
// Update the user record in PocketBase // Update the user record in PocketBase
const pb = createPocketBaseClient();
await pb.collection("users").update(user.id, updates); await pb.collection("users").update(user.id, updates);
// Build updated user profile for response // Build updated user profile for response

View File

@@ -113,7 +113,12 @@ describe("withAuth", () => {
const response = await wrappedHandler(mockRequest); const response = await wrappedHandler(mockRequest);
expect(response.status).toBe(200); expect(response.status).toBe(200);
expect(handler).toHaveBeenCalledWith(mockRequest, mockUser, undefined); expect(handler).toHaveBeenCalledWith(
mockRequest,
mockUser,
mockPbClient,
undefined,
);
}); });
it("loads auth from cookies before checking authentication", async () => { it("loads auth from cookies before checking authentication", async () => {
@@ -159,7 +164,7 @@ describe("withAuth", () => {
await wrappedHandler(mockRequest, { params: mockParams }); await wrappedHandler(mockRequest, { params: mockParams });
expect(handler).toHaveBeenCalledWith(mockRequest, mockUser, { expect(handler).toHaveBeenCalledWith(mockRequest, mockUser, mockPbClient, {
params: mockParams, params: mockParams,
}); });
}); });

View File

@@ -4,6 +4,7 @@
import { cookies } from "next/headers"; import { cookies } from "next/headers";
import type { NextRequest } from "next/server"; import type { NextRequest } from "next/server";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import type PocketBase from "pocketbase";
import type { User } from "@/types"; import type { User } from "@/types";
@@ -16,24 +17,27 @@ import {
} from "./pocketbase"; } from "./pocketbase";
/** /**
* Route handler function type that receives the authenticated user. * Route handler function type that receives the authenticated user and PocketBase client.
*/ */
export type AuthenticatedHandler<T = unknown> = ( export type AuthenticatedHandler<T = unknown> = (
request: NextRequest, request: NextRequest,
user: User, user: User,
pb: PocketBase,
context?: { params?: T }, context?: { params?: T },
) => Promise<NextResponse>; ) => Promise<NextResponse>;
/** /**
* Higher-order function that wraps an API route handler with authentication. * Higher-order function that wraps an API route handler with authentication.
* Loads auth from cookies, validates the session, and passes the user to the handler. * Loads auth from cookies, validates the session, and passes the user and
* authenticated PocketBase client to the handler.
* *
* @param handler - The route handler that requires authentication * @param handler - The route handler that requires authentication
* @returns A wrapped handler that checks auth before calling the original handler * @returns A wrapped handler that checks auth before calling the original handler
* *
* @example * @example
* ```ts * ```ts
* export const GET = withAuth(async (request, user) => { * export const GET = withAuth(async (request, user, pb) => {
* const data = await pb.collection("users").getOne(user.id);
* return NextResponse.json({ email: user.email }); * return NextResponse.json({ email: user.email });
* }); * });
* ``` * ```
@@ -66,8 +70,8 @@ export function withAuth<T = unknown>(
return NextResponse.json({ error: "Unauthorized" }, { status: 401 }); return NextResponse.json({ error: "Unauthorized" }, { status: 401 });
} }
// Call the original handler with the user context // Call the original handler with the user context and authenticated pb client
return await handler(request, user, context); return await handler(request, user, pb, context);
} catch (error) { } catch (error) {
logger.error({ err: error }, "Auth middleware error"); logger.error({ err: error }, "Auth middleware error");
return NextResponse.json( return NextResponse.json(