401 lines
9.8 KiB
TypeScript
Raw Normal View History

import type { ApiResponse } from "../response-helpers";
export class ApiError extends Error {
constructor(
message: string,
public readonly response: Response,
public readonly body?: unknown
) {
super(message);
this.name = "ApiError";
}
}
export const isApiError = (error: unknown): error is ApiError => error instanceof ApiError;
export type HttpMethod = "GET" | "POST" | "PUT" | "PATCH" | "DELETE" | "HEAD" | "OPTIONS";
export type PathParams = Record<string, string | number>;
export type QueryPrimitive = string | number | boolean;
export type QueryParams = Record<
string,
QueryPrimitive | QueryPrimitive[] | readonly QueryPrimitive[] | undefined
>;
export interface RequestOptions {
params?: {
path?: PathParams;
query?: QueryParams;
};
body?: unknown;
headers?: Record<string, string>;
signal?: AbortSignal;
credentials?: RequestCredentials;
disableCsrf?: boolean;
}
export type AuthHeaderResolver = () => string | undefined;
export interface CreateClientOptions {
baseUrl?: string;
getAuthHeader?: AuthHeaderResolver;
handleError?: (response: Response) => void | Promise<void>;
enableCsrf?: boolean;
}
type ApiMethod = <T = unknown>(path: string, options?: RequestOptions) => Promise<ApiResponse<T>>;
export interface ApiClient {
GET: ApiMethod;
POST: ApiMethod;
PUT: ApiMethod;
PATCH: ApiMethod;
DELETE: ApiMethod;
}
type EnvKey =
| "NEXT_PUBLIC_API_BASE"
| "NEXT_PUBLIC_API_URL"
| "API_BASE_URL"
| "API_BASE"
| "API_URL";
const BASE_URL_ENV_KEYS: readonly EnvKey[] = [
"NEXT_PUBLIC_API_BASE",
"NEXT_PUBLIC_API_URL",
"API_BASE_URL",
"API_BASE",
"API_URL",
];
const DEFAULT_BASE_URL = "http://localhost:4000";
const normalizeBaseUrl = (value: string) => {
const trimmed = value.trim();
if (!trimmed) {
return DEFAULT_BASE_URL;
}
if (trimmed === "/") {
return trimmed;
}
return trimmed.replace(/\/+$/, "");
};
const resolveBaseUrlFromEnv = () => {
if (typeof process !== "undefined" && process.env) {
for (const key of BASE_URL_ENV_KEYS) {
const envValue = process.env[key];
if (typeof envValue === "string" && envValue.trim()) {
return normalizeBaseUrl(envValue);
}
}
}
return DEFAULT_BASE_URL;
};
export const resolveBaseUrl = (baseUrl?: string) => {
if (typeof baseUrl === "string" && baseUrl.trim()) {
return normalizeBaseUrl(baseUrl);
}
return resolveBaseUrlFromEnv();
};
const applyPathParams = (path: string, params?: PathParams): string => {
if (!params) {
return path;
}
return path.replace(/\{([^}]+)\}/g, (_match, rawKey) => {
const key = rawKey as keyof typeof params;
if (!(key in params)) {
throw new Error(`Missing path parameter: ${String(rawKey)}`);
}
const value = params[key];
return encodeURIComponent(String(value));
});
};
const buildQueryString = (query?: QueryParams): string => {
if (!query) {
return "";
}
const searchParams = new URLSearchParams();
const appendPrimitive = (key: string, value: QueryPrimitive) => {
searchParams.append(key, String(value));
};
for (const [key, value] of Object.entries(query)) {
if (value === undefined || value === null) {
continue;
}
if (Array.isArray(value)) {
(value as readonly QueryPrimitive[]).forEach(entry => appendPrimitive(key, entry));
continue;
}
appendPrimitive(key, value as QueryPrimitive);
}
return searchParams.toString();
};
const getBodyMessage = (body: unknown): string | null => {
if (typeof body === "string") {
return body;
}
if (body && typeof body === "object" && "message" in body) {
const maybeMessage = (body as { message?: unknown }).message;
if (typeof maybeMessage === "string") {
return maybeMessage;
}
}
return null;
};
async function defaultHandleError(response: Response) {
if (response.ok) return;
let body: unknown;
let message = response.statusText || `Request failed with status ${response.status}`;
try {
const cloned = response.clone();
const contentType = cloned.headers.get("content-type");
if (contentType?.includes("application/json")) {
body = await cloned.json();
const jsonMessage = getBodyMessage(body);
if (jsonMessage) {
message = jsonMessage;
}
} else {
const text = await cloned.text();
if (text) {
body = text;
message = text;
}
}
} catch {
// Ignore body parse errors; fall back to status text
}
throw new ApiError(message, response, body);
}
const parseResponseBody = async (response: Response): Promise<unknown> => {
if (response.status === 204) {
return null;
}
const contentLength = response.headers.get("content-length");
if (contentLength === "0") {
return null;
}
const contentType = response.headers.get("content-type") ?? "";
if (contentType.includes("application/json")) {
try {
return await response.json();
} catch {
return null;
}
}
if (contentType.includes("text/")) {
try {
return await response.text();
} catch {
return null;
}
}
return null;
};
interface CsrfTokenPayload {
success: boolean;
token: string;
}
const isCsrfTokenPayload = (value: unknown): value is CsrfTokenPayload => {
return (
typeof value === "object" &&
value !== null &&
"success" in value &&
"token" in value &&
typeof (value as { success: unknown }).success === "boolean" &&
typeof (value as { token: unknown }).token === "string"
);
};
class CsrfTokenManager {
private token: string | null = null;
private tokenPromise: Promise<string> | null = null;
constructor(private readonly baseUrl: string) {}
async getToken(): Promise<string> {
if (this.token) {
return this.token;
}
if (this.tokenPromise) {
return this.tokenPromise;
}
this.tokenPromise = this.fetchToken();
try {
this.token = await this.tokenPromise;
return this.token;
} finally {
this.tokenPromise = null;
}
}
clearToken(): void {
this.token = null;
this.tokenPromise = null;
}
private async fetchToken(): Promise<string> {
const response = await fetch(`${this.baseUrl}/api/security/csrf/token`, {
method: "GET",
credentials: "include",
headers: {
Accept: "application/json",
},
});
if (!response.ok) {
throw new Error(`Failed to fetch CSRF token: ${response.status}`);
}
const data: unknown = await response.json();
if (!isCsrfTokenPayload(data)) {
throw new Error("Invalid CSRF token response");
}
return data.token;
}
}
const SAFE_METHODS = new Set<HttpMethod>(["GET", "HEAD", "OPTIONS"]);
export function createClient(options: CreateClientOptions = {}): ApiClient {
const baseUrl = resolveBaseUrl(options.baseUrl);
const resolveAuthHeader = options.getAuthHeader;
const handleError = options.handleError ?? defaultHandleError;
const enableCsrf = options.enableCsrf ?? true;
const csrfManager = enableCsrf ? new CsrfTokenManager(baseUrl) : null;
const request = async <T>(
method: HttpMethod,
path: string,
opts: RequestOptions = {}
): Promise<ApiResponse<T>> => {
const resolvedPath = applyPathParams(path, opts.params?.path);
const url = new URL(resolvedPath, baseUrl);
const queryString = buildQueryString(opts.params?.query);
if (queryString) {
url.search = queryString;
}
const headers = new Headers(opts.headers);
const credentials = opts.credentials ?? "include";
const init: RequestInit = {
method,
headers,
credentials,
signal: opts.signal,
};
const body = opts.body;
if (body !== undefined && body !== null) {
if (body instanceof FormData || body instanceof Blob) {
init.body = body as BodyInit;
} else {
if (!headers.has("Content-Type")) {
headers.set("Content-Type", "application/json");
}
init.body = JSON.stringify(body);
}
}
if (resolveAuthHeader && !headers.has("Authorization")) {
const headerValue = resolveAuthHeader();
if (headerValue) {
headers.set("Authorization", headerValue);
}
}
if (
csrfManager &&
!opts.disableCsrf &&
!SAFE_METHODS.has(method) &&
!headers.has("X-CSRF-Token")
) {
try {
const csrfToken = await csrfManager.getToken();
headers.set("X-CSRF-Token", csrfToken);
} catch (error) {
// Don't proceed without CSRF protection for mutation endpoints
console.error("Failed to obtain CSRF token - blocking request", error);
throw new ApiError(
"CSRF protection unavailable. Please refresh the page and try again.",
new Response(null, { status: 403, statusText: "CSRF Token Required" })
);
}
}
const response = await fetch(url.toString(), init);
if (!response.ok) {
if (response.status === 403 && csrfManager) {
try {
const bodyText = await response.clone().text();
if (bodyText.toLowerCase().includes("csrf")) {
csrfManager.clearToken();
}
} catch {
csrfManager.clearToken();
}
}
await handleError(response);
// If handleError does not throw, throw a default error to ensure rejection
throw new ApiError(`Request failed with status ${response.status}`, response);
}
const parsedBody = await parseResponseBody(response);
if (parsedBody === undefined || parsedBody === null) {
return {};
}
return {
data: parsedBody as T,
};
};
return {
GET: (path, opts) => request("GET", path, opts),
POST: (path, opts) => request("POST", path, opts),
PUT: (path, opts) => request("PUT", path, opts),
PATCH: (path, opts) => request("PATCH", path, opts),
DELETE: (path, opts) => request("DELETE", path, opts),
} satisfies ApiClient;
}