292 lines
9.6 KiB
TypeScript
Raw Normal View History

import { Injectable, Inject, UnauthorizedException } from "@nestjs/common";
import { JwtService } from "@nestjs/jwt";
import { ConfigService } from "@nestjs/config";
import { Redis } from "ioredis";
import { Logger } from "nestjs-pino";
import { randomBytes, createHash } from "crypto";
import type { AuthTokens } from "@customer-portal/domain";
import { UsersService } from "@bff/modules/users/users.service";
export interface RefreshTokenPayload {
userId: string;
tokenId: string;
deviceId?: string;
userAgent?: string;
}
@Injectable()
export class AuthTokenService {
private readonly ACCESS_TOKEN_EXPIRY = "15m"; // Short-lived access tokens
private readonly REFRESH_TOKEN_EXPIRY = "7d"; // Longer-lived refresh tokens
private readonly REFRESH_TOKEN_FAMILY_PREFIX = "refresh_family:";
private readonly REFRESH_TOKEN_PREFIX = "refresh_token:";
constructor(
private readonly jwtService: JwtService,
private readonly configService: ConfigService,
@Inject("REDIS_CLIENT") private readonly redis: Redis,
@Inject(Logger) private readonly logger: Logger,
private readonly usersService: UsersService
) {}
/**
* Generate a new token pair with refresh token rotation
*/
async generateTokenPair(user: {
id: string;
email: string;
role?: string;
}, deviceInfo?: {
deviceId?: string;
userAgent?: string;
}): Promise<AuthTokens> {
const tokenId = this.generateTokenId();
const familyId = this.generateTokenId();
// Create access token payload
const accessPayload = {
sub: user.id,
email: user.email,
role: user.role || "user",
tokenId,
type: "access",
};
// Create refresh token payload
const refreshPayload: RefreshTokenPayload = {
userId: user.id,
tokenId: familyId,
deviceId: deviceInfo?.deviceId,
userAgent: deviceInfo?.userAgent,
};
// Generate tokens
const accessToken = this.jwtService.sign(accessPayload, {
expiresIn: this.ACCESS_TOKEN_EXPIRY,
});
const refreshToken = this.jwtService.sign(refreshPayload, {
expiresIn: this.REFRESH_TOKEN_EXPIRY,
});
// Store refresh token family in Redis
const refreshTokenHash = this.hashToken(refreshToken);
const refreshExpirySeconds = this.parseExpiryToSeconds(this.REFRESH_TOKEN_EXPIRY);
try {
await this.redis.setex(
`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`,
refreshExpirySeconds,
JSON.stringify({
userId: user.id,
tokenHash: refreshTokenHash,
deviceId: deviceInfo?.deviceId,
userAgent: deviceInfo?.userAgent,
createdAt: new Date().toISOString(),
})
);
// Store individual refresh token
await this.redis.setex(
`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`,
refreshExpirySeconds,
JSON.stringify({
familyId,
userId: user.id,
valid: true,
})
);
} catch (error) {
this.logger.error("Failed to store refresh token in Redis", {
error: error instanceof Error ? error.message : String(error),
userId: user.id
});
// Continue without Redis storage - tokens will still work but won't have rotation protection
}
const accessExpiresAt = new Date(Date.now() + this.parseExpiryToMs(this.ACCESS_TOKEN_EXPIRY)).toISOString();
const refreshExpiresAt = new Date(Date.now() + this.parseExpiryToMs(this.REFRESH_TOKEN_EXPIRY)).toISOString();
this.logger.debug("Generated new token pair", { userId: user.id, tokenId, familyId });
return {
accessToken,
refreshToken,
expiresAt: accessExpiresAt,
refreshExpiresAt,
tokenType: "Bearer",
};
}
/**
* Refresh access token using refresh token rotation
*/
async refreshTokens(refreshToken: string, deviceInfo?: {
deviceId?: string;
userAgent?: string;
}): Promise<AuthTokens> {
try {
// Verify refresh token
const payload = this.jwtService.verify(refreshToken) as RefreshTokenPayload;
const refreshTokenHash = this.hashToken(refreshToken);
// Check if refresh token exists and is valid
let storedToken: string | null;
try {
storedToken = await this.redis.get(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
} catch (error) {
this.logger.error("Redis error during token refresh", {
error: error instanceof Error ? error.message : String(error)
});
throw new UnauthorizedException("Token validation temporarily unavailable");
}
if (!storedToken) {
this.logger.warn("Refresh token not found or expired", { tokenHash: refreshTokenHash.slice(0, 8) });
throw new UnauthorizedException("Invalid refresh token");
}
const tokenData = JSON.parse(storedToken);
if (!tokenData.valid) {
this.logger.warn("Refresh token marked as invalid", { tokenHash: refreshTokenHash.slice(0, 8) });
// Invalidate entire token family on reuse attempt
await this.invalidateTokenFamily(tokenData.familyId);
throw new UnauthorizedException("Invalid refresh token");
}
// Get user info from database (using internal method to get role)
const prismaUser = await this.usersService.findByIdInternal(payload.userId);
if (!prismaUser) {
this.logger.warn("User not found during token refresh", { userId: payload.userId });
throw new UnauthorizedException("User not found");
}
// Convert to the format expected by generateTokenPair
const user = {
id: prismaUser.id,
email: prismaUser.email,
role: prismaUser.role.toLowerCase(), // Convert UserRole enum to lowercase string
};
// Invalidate current refresh token
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
// Generate new token pair
const newTokenPair = await this.generateTokenPair(user, deviceInfo);
this.logger.debug("Refreshed token pair", { userId: payload.userId });
return newTokenPair;
} catch (error) {
this.logger.error("Token refresh failed", {
error: error instanceof Error ? error.message : String(error)
});
throw new UnauthorizedException("Invalid refresh token");
}
}
/**
* Revoke a specific refresh token
*/
async revokeRefreshToken(refreshToken: string): Promise<void> {
try {
const refreshTokenHash = this.hashToken(refreshToken);
const storedToken = await this.redis.get(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
if (storedToken) {
const tokenData = JSON.parse(storedToken);
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
await this.redis.del(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${tokenData.familyId}`);
this.logger.debug("Revoked refresh token", { tokenHash: refreshTokenHash.slice(0, 8) });
}
} catch (error) {
this.logger.error("Failed to revoke refresh token", {
error: error instanceof Error ? error.message : String(error)
});
}
}
/**
* Revoke all refresh tokens for a user
*/
async revokeAllUserTokens(userId: string): Promise<void> {
try {
const pattern = `${this.REFRESH_TOKEN_FAMILY_PREFIX}*`;
const keys = await this.redis.keys(pattern);
for (const key of keys) {
const data = await this.redis.get(key);
if (data) {
const family = JSON.parse(data);
if (family.userId === userId) {
await this.redis.del(key);
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${family.tokenHash}`);
}
}
}
this.logger.debug("Revoked all tokens for user", { userId });
} catch (error) {
this.logger.error("Failed to revoke all user tokens", {
error: error instanceof Error ? error.message : String(error)
});
}
}
private async invalidateTokenFamily(familyId: string): Promise<void> {
try {
const familyData = await this.redis.get(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`);
if (familyData) {
const family = JSON.parse(familyData);
await this.redis.del(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`);
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${family.tokenHash}`);
this.logger.warn("Invalidated token family due to security concern", {
familyId: familyId.slice(0, 8),
userId: family.userId
});
}
} catch (error) {
this.logger.error("Failed to invalidate token family", {
error: error instanceof Error ? error.message : String(error)
});
}
}
private generateTokenId(): string {
return randomBytes(32).toString("hex");
}
private hashToken(token: string): string {
return createHash("sha256").update(token).digest("hex");
}
private parseExpiryToMs(expiry: string): number {
const unit = expiry.slice(-1);
const value = parseInt(expiry.slice(0, -1));
switch (unit) {
case "s": return value * 1000;
case "m": return value * 60 * 1000;
case "h": return value * 60 * 60 * 1000;
case "d": return value * 24 * 60 * 60 * 1000;
default: return 15 * 60 * 1000; // Default 15 minutes
}
}
private parseExpiryToSeconds(expiry: string): number {
return Math.floor(this.parseExpiryToMs(expiry) / 1000);
}
private calculateExpiryDate(expiresIn: string | number): string {
const now = new Date();
if (typeof expiresIn === 'number') {
return new Date(now.getTime() + expiresIn * 1000).toISOString();
}
return new Date(now.getTime() + this.parseExpiryToMs(expiresIn)).toISOString();
}
}