429 lines
13 KiB
TypeScript

import { Injectable, Inject, UnauthorizedException } from "@nestjs/common";
import { JwtService } from "@nestjs/jwt";
import { ConfigService } from "@nestjs/config";
import { Redis } from "ioredis";
import { Logger } from "nestjs-pino";
import { randomBytes, createHash } from "crypto";
import type { AuthTokens } from "@customer-portal/domain";
import { UsersService } from "@bff/modules/users/users.service";
export interface RefreshTokenPayload {
userId: string;
tokenId: string;
deviceId?: string;
userAgent?: string;
type: "refresh";
}
interface StoredRefreshToken {
familyId: string;
userId: string;
valid: boolean;
}
interface StoredRefreshTokenFamily {
userId: string;
tokenHash: string;
deviceId?: string;
userAgent?: string;
createdAt?: string;
}
@Injectable()
export class AuthTokenService {
private readonly ACCESS_TOKEN_EXPIRY = "15m"; // Short-lived access tokens
private readonly REFRESH_TOKEN_EXPIRY = "7d"; // Longer-lived refresh tokens
private readonly REFRESH_TOKEN_FAMILY_PREFIX = "refresh_family:";
private readonly REFRESH_TOKEN_PREFIX = "refresh_token:";
constructor(
private readonly jwtService: JwtService,
private readonly configService: ConfigService,
@Inject("REDIS_CLIENT") private readonly redis: Redis,
@Inject(Logger) private readonly logger: Logger,
private readonly usersService: UsersService
) {}
/**
* Generate a new token pair with refresh token rotation
*/
async generateTokenPair(
user: {
id: string;
email: string;
role?: string;
},
deviceInfo?: {
deviceId?: string;
userAgent?: string;
}
): Promise<AuthTokens> {
const tokenId = this.generateTokenId();
const familyId = this.generateTokenId();
// Create access token payload
const accessPayload = {
sub: user.id,
email: user.email,
role: user.role || "user",
tokenId,
type: "access",
};
// Create refresh token payload
const refreshPayload: RefreshTokenPayload = {
userId: user.id,
tokenId: familyId,
deviceId: deviceInfo?.deviceId,
userAgent: deviceInfo?.userAgent,
type: "refresh",
};
// Generate tokens
const accessToken = this.jwtService.sign(accessPayload, {
expiresIn: this.ACCESS_TOKEN_EXPIRY,
});
const refreshToken = this.jwtService.sign(refreshPayload, {
expiresIn: this.REFRESH_TOKEN_EXPIRY,
});
// Store refresh token family in Redis
const refreshTokenHash = this.hashToken(refreshToken);
const refreshExpirySeconds = this.parseExpiryToSeconds(this.REFRESH_TOKEN_EXPIRY);
if (this.redis.status === "ready") {
try {
await this.redis.ping();
await this.redis.setex(
`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`,
refreshExpirySeconds,
JSON.stringify({
userId: user.id,
tokenHash: refreshTokenHash,
deviceId: deviceInfo?.deviceId,
userAgent: deviceInfo?.userAgent,
createdAt: new Date().toISOString(),
})
);
// Store individual refresh token
await this.redis.setex(
`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`,
refreshExpirySeconds,
JSON.stringify({
familyId,
userId: user.id,
valid: true,
})
);
} catch (error) {
this.logger.error("Failed to store refresh token in Redis", {
error: error instanceof Error ? error.message : String(error),
userId: user.id,
});
}
} else {
this.logger.warn("Redis not ready for token issuance; issuing non-rotating tokens", {
status: this.redis.status,
});
}
const accessExpiresAt = new Date(
Date.now() + this.parseExpiryToMs(this.ACCESS_TOKEN_EXPIRY)
).toISOString();
const refreshExpiresAt = new Date(
Date.now() + this.parseExpiryToMs(this.REFRESH_TOKEN_EXPIRY)
).toISOString();
this.logger.debug("Generated new token pair", { userId: user.id, tokenId, familyId });
return {
accessToken,
refreshToken,
expiresAt: accessExpiresAt,
refreshExpiresAt,
tokenType: "Bearer",
};
}
/**
* Refresh access token using refresh token rotation
*/
async refreshTokens(
refreshToken: string,
deviceInfo?: {
deviceId?: string;
userAgent?: string;
}
): Promise<AuthTokens> {
try {
// Verify refresh token
const payload = this.jwtService.verify<RefreshTokenPayload>(refreshToken);
if (payload.type !== "refresh") {
this.logger.warn("Token presented to refresh endpoint is not a refresh token", {
tokenId: payload.tokenId,
});
throw new UnauthorizedException("Invalid refresh token");
}
const refreshTokenHash = this.hashToken(refreshToken);
// Check if refresh token exists and is valid
let storedToken: string | null;
try {
storedToken = await this.redis.get(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
} catch (error) {
this.logger.error("Redis error during token refresh", {
error: error instanceof Error ? error.message : String(error),
});
throw new UnauthorizedException("Token validation temporarily unavailable");
}
if (!storedToken) {
this.logger.warn("Refresh token not found or expired", {
tokenHash: refreshTokenHash.slice(0, 8),
});
throw new UnauthorizedException("Invalid refresh token");
}
const tokenRecord = this.parseRefreshTokenRecord(storedToken);
if (!tokenRecord) {
this.logger.warn("Stored refresh token payload was invalid JSON", {
tokenHash: refreshTokenHash.slice(0, 8),
});
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
throw new UnauthorizedException("Invalid refresh token");
}
if (!tokenRecord.valid) {
this.logger.warn("Refresh token marked as invalid", {
tokenHash: refreshTokenHash.slice(0, 8),
});
// Invalidate entire token family on reuse attempt
await this.invalidateTokenFamily(tokenRecord.familyId);
throw new UnauthorizedException("Invalid refresh token");
}
// Get user info from database (using internal method to get role)
const prismaUser = await this.usersService.findByIdInternal(payload.userId);
if (!prismaUser) {
this.logger.warn("User not found during token refresh", { userId: payload.userId });
throw new UnauthorizedException("User not found");
}
// Convert to the format expected by generateTokenPair
const user = {
id: prismaUser.id,
email: prismaUser.email,
role: prismaUser.role,
};
// Invalidate current refresh token
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
// Generate new token pair
const newTokenPair = await this.generateTokenPair(user, deviceInfo);
this.logger.debug("Refreshed token pair", { userId: payload.userId });
return newTokenPair;
} catch (error) {
this.logger.error("Token refresh failed", {
error: error instanceof Error ? error.message : String(error),
});
if (this.redis.status !== "ready") {
this.logger.warn("Redis unavailable during token refresh; issuing fallback token pair");
const fallbackDecoded: unknown = this.jwtService.decode(refreshToken);
const fallbackUserId =
fallbackDecoded && typeof fallbackDecoded === "object" && !Array.isArray(fallbackDecoded)
? (fallbackDecoded as { userId?: unknown }).userId
: undefined;
if (typeof fallbackUserId === "string") {
const fallbackUser = await this.usersService
.findByIdInternal(fallbackUserId)
.catch(() => null);
if (fallbackUser) {
return this.generateTokenPair(
{
id: fallbackUser.id,
email: fallbackUser.email,
role: fallbackUser.role,
},
deviceInfo
);
}
}
}
throw new UnauthorizedException("Invalid refresh token");
}
}
/**
* Revoke a specific refresh token
*/
async revokeRefreshToken(refreshToken: string): Promise<void> {
try {
const refreshTokenHash = this.hashToken(refreshToken);
const storedToken = await this.redis.get(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
if (storedToken) {
const tokenRecord = this.parseRefreshTokenRecord(storedToken);
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
if (tokenRecord) {
await this.redis.del(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${tokenRecord.familyId}`);
}
this.logger.debug("Revoked refresh token", { tokenHash: refreshTokenHash.slice(0, 8) });
}
} catch (error) {
this.logger.error("Failed to revoke refresh token", {
error: error instanceof Error ? error.message : String(error),
});
}
}
/**
* Revoke all refresh tokens for a user
*/
async revokeAllUserTokens(userId: string): Promise<void> {
try {
const pattern = `${this.REFRESH_TOKEN_FAMILY_PREFIX}*`;
const keys = await this.redis.keys(pattern);
for (const key of keys) {
const data = await this.redis.get(key);
if (data) {
const family = this.parseRefreshTokenFamilyRecord(data);
if (family && family.userId === userId) {
await this.redis.del(key);
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${family.tokenHash}`);
}
}
}
this.logger.debug("Revoked all tokens for user", { userId });
} catch (error) {
this.logger.error("Failed to revoke all user tokens", {
error: error instanceof Error ? error.message : String(error),
});
}
}
private async invalidateTokenFamily(familyId: string): Promise<void> {
try {
const familyData = await this.redis.get(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`);
if (familyData) {
const family = this.parseRefreshTokenFamilyRecord(familyData);
await this.redis.del(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`);
if (family) {
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${family.tokenHash}`);
this.logger.warn("Invalidated token family due to security concern", {
familyId: familyId.slice(0, 8),
userId: family.userId,
});
}
}
} catch (error) {
this.logger.error("Failed to invalidate token family", {
error: error instanceof Error ? error.message : String(error),
});
}
}
private generateTokenId(): string {
return randomBytes(32).toString("hex");
}
private hashToken(token: string): string {
return createHash("sha256").update(token).digest("hex");
}
private parseRefreshTokenRecord(value: string): StoredRefreshToken | null {
try {
const parsed = JSON.parse(value) as Partial<StoredRefreshToken>;
if (
parsed &&
typeof parsed === "object" &&
typeof parsed.familyId === "string" &&
typeof parsed.userId === "string" &&
typeof parsed.valid === "boolean"
) {
return {
familyId: parsed.familyId,
userId: parsed.userId,
valid: parsed.valid,
};
}
} catch (error) {
this.logger.warn("Failed to parse refresh token record", {
error: error instanceof Error ? error.message : String(error),
});
}
return null;
}
private parseRefreshTokenFamilyRecord(value: string): StoredRefreshTokenFamily | null {
try {
const parsed = JSON.parse(value) as Partial<StoredRefreshTokenFamily>;
if (
parsed &&
typeof parsed === "object" &&
typeof parsed.userId === "string" &&
typeof parsed.tokenHash === "string"
) {
return {
userId: parsed.userId,
tokenHash: parsed.tokenHash,
deviceId: typeof parsed.deviceId === "string" ? parsed.deviceId : undefined,
userAgent: typeof parsed.userAgent === "string" ? parsed.userAgent : undefined,
createdAt: typeof parsed.createdAt === "string" ? parsed.createdAt : undefined,
};
}
} catch (error) {
this.logger.warn("Failed to parse refresh token family record", {
error: error instanceof Error ? error.message : String(error),
});
}
return null;
}
private parseExpiryToMs(expiry: string): number {
const unit = expiry.slice(-1);
const value = parseInt(expiry.slice(0, -1));
switch (unit) {
case "s":
return value * 1000;
case "m":
return value * 60 * 1000;
case "h":
return value * 60 * 60 * 1000;
case "d":
return value * 24 * 60 * 60 * 1000;
default:
return 15 * 60 * 1000; // Default 15 minutes
}
}
private parseExpiryToSeconds(expiry: string): number {
return Math.floor(this.parseExpiryToMs(expiry) / 1000);
}
private calculateExpiryDate(expiresIn: string | number): string {
const now = new Date();
if (typeof expiresIn === "number") {
return new Date(now.getTime() + expiresIn * 1000).toISOString();
}
return new Date(now.getTime() + this.parseExpiryToMs(expiresIn)).toISOString();
}
}