import { Injectable } from "@nestjs/common"; import { ConfigService } from "@nestjs/config"; import { SignJWT, decodeJwt, jwtVerify, errors, type JWTPayload } from "jose"; import { parseJwtExpiry } from "../../utils/jwt-expiry.util.js"; @Injectable() export class JoseJwtService { private readonly signingKey: Uint8Array; private readonly verificationKeys: Uint8Array[]; private readonly issuer: string | undefined; private readonly audience: string | string[] | undefined; constructor(configService: ConfigService) { const secret = configService.get("JWT_SECRET"); if (!secret) { throw new Error("JWT_SECRET is required in environment variables"); } this.signingKey = new TextEncoder().encode(secret); const previousRaw = configService.get("JWT_SECRET_PREVIOUS"); const previousSecrets = this.parsePreviousSecrets(previousRaw).filter(s => s !== secret); this.verificationKeys = [ this.signingKey, ...previousSecrets.map(s => new TextEncoder().encode(s)), ]; const issuer = configService.get("JWT_ISSUER"); this.issuer = issuer && issuer.trim().length > 0 ? issuer.trim() : undefined; const audienceRaw = configService.get("JWT_AUDIENCE"); const parsedAudience = this.parseAudience(audienceRaw); this.audience = parsedAudience; } private parsePreviousSecrets(raw: string | undefined): string[] { if (!raw) return []; const trimmed = raw.trim(); if (!trimmed) return []; return trimmed .split(",") .map(s => s.trim()) .filter(Boolean); } private parseAudience(raw: string | undefined): string | string[] | undefined { if (!raw) return undefined; const trimmed = raw.trim(); if (!trimmed) return undefined; const parts = trimmed .split(",") .map(p => p.trim()) .filter(Boolean); if (parts.length === 0) return undefined; return parts.length === 1 ? parts[0] : parts; } async sign(payload: JWTPayload, expiresIn: string | number): Promise { const expiresInSeconds = typeof expiresIn === "number" ? expiresIn : parseJwtExpiry(expiresIn); const nowSeconds = Math.floor(Date.now() / 1000); const tokenId = (payload as { tokenId?: unknown }).tokenId; let builder = new SignJWT(payload) .setProtectedHeader({ alg: "HS256" }) .setIssuedAt(nowSeconds) .setExpirationTime(nowSeconds + expiresInSeconds); if (this.issuer) { builder = builder.setIssuer(this.issuer); } if (this.audience) { builder = builder.setAudience(this.audience); } // Optional: set standard JWT ID when a tokenId is present in the payload if (typeof tokenId === "string" && tokenId.length > 0) { builder = builder.setJti(tokenId); } return builder.sign(this.signingKey); } async verify(token: string): Promise { const options = { algorithms: ["HS256"] as string[], ...(this.issuer === undefined ? {} : { issuer: this.issuer }), ...(this.audience === undefined ? {} : { audience: this.audience }), }; let lastError: unknown; for (let i = 0; i < this.verificationKeys.length; i++) { const key = this.verificationKeys[i]; if (!key) continue; try { const { payload } = await jwtVerify(token, key, options); return payload as T; } catch (err) { lastError = err; const isLast = i === this.verificationKeys.length - 1; if (isLast) { break; } // Only try the next key on signature-related failures. if (err instanceof errors.JWSSignatureVerificationFailed) { continue; } throw err; } } if (lastError instanceof Error) { throw lastError; } throw new Error("JWT verification failed"); } async verifyAllowExpired(token: string): Promise { try { return await this.verify(token); } catch (err) { if (err instanceof errors.JWTExpired) { return this.decode(token); } throw err; } } decode(token: string): T | null { try { return decodeJwt(token) as T; } catch { return null; } } }