138 lines
4.1 KiB
TypeScript
Raw Normal View History

import { Injectable } from "@nestjs/common";
import { ConfigService } from "@nestjs/config";
import { SignJWT, decodeJwt, jwtVerify, errors, type JWTPayload } from "jose";
import { parseJwtExpiry } from "../../utils/jwt-expiry.util.js";
@Injectable()
export class JoseJwtService {
private readonly signingKey: Uint8Array;
private readonly verificationKeys: Uint8Array[];
private readonly issuer?: string;
private readonly audience?: string | string[];
constructor(private readonly configService: ConfigService) {
const secret = configService.get<string>("JWT_SECRET");
if (!secret) {
throw new Error("JWT_SECRET is required in environment variables");
}
this.signingKey = new TextEncoder().encode(secret);
const previousRaw = configService.get<string | undefined>("JWT_SECRET_PREVIOUS");
const previousSecrets = this.parsePreviousSecrets(previousRaw).filter(s => s !== secret);
this.verificationKeys = [
this.signingKey,
...previousSecrets.map(s => new TextEncoder().encode(s)),
];
const issuer = configService.get<string | undefined>("JWT_ISSUER");
this.issuer = issuer && issuer.trim().length > 0 ? issuer.trim() : undefined;
const audienceRaw = configService.get<string | undefined>("JWT_AUDIENCE");
const parsedAudience = this.parseAudience(audienceRaw);
this.audience = parsedAudience;
}
private parsePreviousSecrets(raw: string | undefined): string[] {
if (!raw) return [];
const trimmed = raw.trim();
if (!trimmed) return [];
return trimmed
.split(",")
.map(s => s.trim())
.filter(Boolean);
}
private parseAudience(raw: string | undefined): string | string[] | undefined {
if (!raw) return undefined;
const trimmed = raw.trim();
if (!trimmed) return undefined;
const parts = trimmed
.split(",")
.map(p => p.trim())
.filter(Boolean);
if (parts.length === 0) return undefined;
return parts.length === 1 ? parts[0] : parts;
}
async sign(payload: JWTPayload, expiresIn: string | number): Promise<string> {
const expiresInSeconds = typeof expiresIn === "number" ? expiresIn : parseJwtExpiry(expiresIn);
const nowSeconds = Math.floor(Date.now() / 1000);
const tokenId = (payload as { tokenId?: unknown }).tokenId;
let builder = new SignJWT(payload)
.setProtectedHeader({ alg: "HS256" })
.setIssuedAt(nowSeconds)
.setExpirationTime(nowSeconds + expiresInSeconds);
if (this.issuer) {
builder = builder.setIssuer(this.issuer);
}
if (this.audience) {
builder = builder.setAudience(this.audience);
}
// Optional: set standard JWT ID when a tokenId is present in the payload
if (typeof tokenId === "string" && tokenId.length > 0) {
builder = builder.setJti(tokenId);
}
return builder.sign(this.signingKey);
}
async verify<T extends JWTPayload>(token: string): Promise<T> {
const options = {
algorithms: ["HS256"],
issuer: this.issuer,
audience: this.audience,
};
let lastError: unknown;
for (let i = 0; i < this.verificationKeys.length; i++) {
const key = this.verificationKeys[i];
try {
const { payload } = await jwtVerify(token, key, options);
return payload as T;
} catch (err) {
lastError = err;
const isLast = i === this.verificationKeys.length - 1;
if (isLast) {
break;
}
// Only try the next key on signature-related failures.
if (err instanceof errors.JWSSignatureVerificationFailed) {
continue;
}
throw err;
}
}
if (lastError instanceof Error) {
throw lastError;
}
throw new Error("JWT verification failed");
}
async verifyAllowExpired<T extends JWTPayload>(token: string): Promise<T | null> {
try {
return await this.verify<T>(token);
} catch (err) {
if (err instanceof errors.JWTExpired) {
return this.decode<T>(token);
}
throw err;
}
}
decode<T extends JWTPayload>(token: string): T | null {
try {
return decodeJwt(token) as T;
} catch {
return null;
}
}
}