429 lines
13 KiB
TypeScript
429 lines
13 KiB
TypeScript
import { Injectable, Inject, UnauthorizedException } from "@nestjs/common";
|
|
import { JwtService } from "@nestjs/jwt";
|
|
import { ConfigService } from "@nestjs/config";
|
|
import { Redis } from "ioredis";
|
|
import { Logger } from "nestjs-pino";
|
|
import { randomBytes, createHash } from "crypto";
|
|
import type { AuthTokens } from "@customer-portal/domain";
|
|
import { UsersService } from "@bff/modules/users/users.service";
|
|
|
|
export interface RefreshTokenPayload {
|
|
userId: string;
|
|
tokenId: string;
|
|
deviceId?: string;
|
|
userAgent?: string;
|
|
type: "refresh";
|
|
}
|
|
|
|
interface StoredRefreshToken {
|
|
familyId: string;
|
|
userId: string;
|
|
valid: boolean;
|
|
}
|
|
|
|
interface StoredRefreshTokenFamily {
|
|
userId: string;
|
|
tokenHash: string;
|
|
deviceId?: string;
|
|
userAgent?: string;
|
|
createdAt?: string;
|
|
}
|
|
|
|
@Injectable()
|
|
export class AuthTokenService {
|
|
private readonly ACCESS_TOKEN_EXPIRY = "15m"; // Short-lived access tokens
|
|
private readonly REFRESH_TOKEN_EXPIRY = "7d"; // Longer-lived refresh tokens
|
|
private readonly REFRESH_TOKEN_FAMILY_PREFIX = "refresh_family:";
|
|
private readonly REFRESH_TOKEN_PREFIX = "refresh_token:";
|
|
|
|
constructor(
|
|
private readonly jwtService: JwtService,
|
|
private readonly configService: ConfigService,
|
|
@Inject("REDIS_CLIENT") private readonly redis: Redis,
|
|
@Inject(Logger) private readonly logger: Logger,
|
|
private readonly usersService: UsersService
|
|
) {}
|
|
|
|
/**
|
|
* Generate a new token pair with refresh token rotation
|
|
*/
|
|
async generateTokenPair(
|
|
user: {
|
|
id: string;
|
|
email: string;
|
|
role?: string;
|
|
},
|
|
deviceInfo?: {
|
|
deviceId?: string;
|
|
userAgent?: string;
|
|
}
|
|
): Promise<AuthTokens> {
|
|
const tokenId = this.generateTokenId();
|
|
const familyId = this.generateTokenId();
|
|
|
|
// Create access token payload
|
|
const accessPayload = {
|
|
sub: user.id,
|
|
email: user.email,
|
|
role: user.role || "user",
|
|
tokenId,
|
|
type: "access",
|
|
};
|
|
|
|
// Create refresh token payload
|
|
const refreshPayload: RefreshTokenPayload = {
|
|
userId: user.id,
|
|
tokenId: familyId,
|
|
deviceId: deviceInfo?.deviceId,
|
|
userAgent: deviceInfo?.userAgent,
|
|
type: "refresh",
|
|
};
|
|
|
|
// Generate tokens
|
|
const accessToken = this.jwtService.sign(accessPayload, {
|
|
expiresIn: this.ACCESS_TOKEN_EXPIRY,
|
|
});
|
|
|
|
const refreshToken = this.jwtService.sign(refreshPayload, {
|
|
expiresIn: this.REFRESH_TOKEN_EXPIRY,
|
|
});
|
|
|
|
// Store refresh token family in Redis
|
|
const refreshTokenHash = this.hashToken(refreshToken);
|
|
const refreshExpirySeconds = this.parseExpiryToSeconds(this.REFRESH_TOKEN_EXPIRY);
|
|
|
|
if (this.redis.status === "ready") {
|
|
try {
|
|
await this.redis.ping();
|
|
await this.redis.setex(
|
|
`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`,
|
|
refreshExpirySeconds,
|
|
JSON.stringify({
|
|
userId: user.id,
|
|
tokenHash: refreshTokenHash,
|
|
deviceId: deviceInfo?.deviceId,
|
|
userAgent: deviceInfo?.userAgent,
|
|
createdAt: new Date().toISOString(),
|
|
})
|
|
);
|
|
|
|
// Store individual refresh token
|
|
await this.redis.setex(
|
|
`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`,
|
|
refreshExpirySeconds,
|
|
JSON.stringify({
|
|
familyId,
|
|
userId: user.id,
|
|
valid: true,
|
|
})
|
|
);
|
|
} catch (error) {
|
|
this.logger.error("Failed to store refresh token in Redis", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
userId: user.id,
|
|
});
|
|
}
|
|
} else {
|
|
this.logger.warn("Redis not ready for token issuance; issuing non-rotating tokens", {
|
|
status: this.redis.status,
|
|
});
|
|
}
|
|
|
|
const accessExpiresAt = new Date(
|
|
Date.now() + this.parseExpiryToMs(this.ACCESS_TOKEN_EXPIRY)
|
|
).toISOString();
|
|
const refreshExpiresAt = new Date(
|
|
Date.now() + this.parseExpiryToMs(this.REFRESH_TOKEN_EXPIRY)
|
|
).toISOString();
|
|
|
|
this.logger.debug("Generated new token pair", { userId: user.id, tokenId, familyId });
|
|
|
|
return {
|
|
accessToken,
|
|
refreshToken,
|
|
expiresAt: accessExpiresAt,
|
|
refreshExpiresAt,
|
|
tokenType: "Bearer",
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Refresh access token using refresh token rotation
|
|
*/
|
|
async refreshTokens(
|
|
refreshToken: string,
|
|
deviceInfo?: {
|
|
deviceId?: string;
|
|
userAgent?: string;
|
|
}
|
|
): Promise<AuthTokens> {
|
|
try {
|
|
// Verify refresh token
|
|
const payload = this.jwtService.verify<RefreshTokenPayload>(refreshToken);
|
|
|
|
if (payload.type !== "refresh") {
|
|
this.logger.warn("Token presented to refresh endpoint is not a refresh token", {
|
|
tokenId: payload.tokenId,
|
|
});
|
|
throw new UnauthorizedException("Invalid refresh token");
|
|
}
|
|
|
|
const refreshTokenHash = this.hashToken(refreshToken);
|
|
|
|
// Check if refresh token exists and is valid
|
|
let storedToken: string | null;
|
|
try {
|
|
storedToken = await this.redis.get(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
|
|
} catch (error) {
|
|
this.logger.error("Redis error during token refresh", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
});
|
|
throw new UnauthorizedException("Token validation temporarily unavailable");
|
|
}
|
|
|
|
if (!storedToken) {
|
|
this.logger.warn("Refresh token not found or expired", {
|
|
tokenHash: refreshTokenHash.slice(0, 8),
|
|
});
|
|
throw new UnauthorizedException("Invalid refresh token");
|
|
}
|
|
|
|
const tokenRecord = this.parseRefreshTokenRecord(storedToken);
|
|
if (!tokenRecord) {
|
|
this.logger.warn("Stored refresh token payload was invalid JSON", {
|
|
tokenHash: refreshTokenHash.slice(0, 8),
|
|
});
|
|
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
|
|
throw new UnauthorizedException("Invalid refresh token");
|
|
}
|
|
|
|
if (!tokenRecord.valid) {
|
|
this.logger.warn("Refresh token marked as invalid", {
|
|
tokenHash: refreshTokenHash.slice(0, 8),
|
|
});
|
|
// Invalidate entire token family on reuse attempt
|
|
await this.invalidateTokenFamily(tokenRecord.familyId);
|
|
throw new UnauthorizedException("Invalid refresh token");
|
|
}
|
|
|
|
// Get user info from database (using internal method to get role)
|
|
const prismaUser = await this.usersService.findByIdInternal(payload.userId);
|
|
if (!prismaUser) {
|
|
this.logger.warn("User not found during token refresh", { userId: payload.userId });
|
|
throw new UnauthorizedException("User not found");
|
|
}
|
|
|
|
// Convert to the format expected by generateTokenPair
|
|
const user = {
|
|
id: prismaUser.id,
|
|
email: prismaUser.email,
|
|
role: prismaUser.role,
|
|
};
|
|
|
|
// Invalidate current refresh token
|
|
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
|
|
|
|
// Generate new token pair
|
|
const newTokenPair = await this.generateTokenPair(user, deviceInfo);
|
|
|
|
this.logger.debug("Refreshed token pair", { userId: payload.userId });
|
|
|
|
return newTokenPair;
|
|
} catch (error) {
|
|
this.logger.error("Token refresh failed", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
});
|
|
|
|
if (this.redis.status !== "ready") {
|
|
this.logger.warn("Redis unavailable during token refresh; issuing fallback token pair");
|
|
const fallbackDecoded: unknown = this.jwtService.decode(refreshToken);
|
|
const fallbackUserId =
|
|
fallbackDecoded && typeof fallbackDecoded === "object" && !Array.isArray(fallbackDecoded)
|
|
? (fallbackDecoded as { userId?: unknown }).userId
|
|
: undefined;
|
|
|
|
if (typeof fallbackUserId === "string") {
|
|
const fallbackUser = await this.usersService
|
|
.findByIdInternal(fallbackUserId)
|
|
.catch(() => null);
|
|
|
|
if (fallbackUser) {
|
|
return this.generateTokenPair(
|
|
{
|
|
id: fallbackUser.id,
|
|
email: fallbackUser.email,
|
|
role: fallbackUser.role,
|
|
},
|
|
deviceInfo
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
throw new UnauthorizedException("Invalid refresh token");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Revoke a specific refresh token
|
|
*/
|
|
async revokeRefreshToken(refreshToken: string): Promise<void> {
|
|
try {
|
|
const refreshTokenHash = this.hashToken(refreshToken);
|
|
const storedToken = await this.redis.get(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
|
|
|
|
if (storedToken) {
|
|
const tokenRecord = this.parseRefreshTokenRecord(storedToken);
|
|
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${refreshTokenHash}`);
|
|
|
|
if (tokenRecord) {
|
|
await this.redis.del(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${tokenRecord.familyId}`);
|
|
}
|
|
|
|
this.logger.debug("Revoked refresh token", { tokenHash: refreshTokenHash.slice(0, 8) });
|
|
}
|
|
} catch (error) {
|
|
this.logger.error("Failed to revoke refresh token", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
});
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Revoke all refresh tokens for a user
|
|
*/
|
|
async revokeAllUserTokens(userId: string): Promise<void> {
|
|
try {
|
|
const pattern = `${this.REFRESH_TOKEN_FAMILY_PREFIX}*`;
|
|
const keys = await this.redis.keys(pattern);
|
|
|
|
for (const key of keys) {
|
|
const data = await this.redis.get(key);
|
|
if (data) {
|
|
const family = this.parseRefreshTokenFamilyRecord(data);
|
|
if (family && family.userId === userId) {
|
|
await this.redis.del(key);
|
|
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${family.tokenHash}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
this.logger.debug("Revoked all tokens for user", { userId });
|
|
} catch (error) {
|
|
this.logger.error("Failed to revoke all user tokens", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
});
|
|
}
|
|
}
|
|
|
|
private async invalidateTokenFamily(familyId: string): Promise<void> {
|
|
try {
|
|
const familyData = await this.redis.get(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`);
|
|
if (familyData) {
|
|
const family = this.parseRefreshTokenFamilyRecord(familyData);
|
|
await this.redis.del(`${this.REFRESH_TOKEN_FAMILY_PREFIX}${familyId}`);
|
|
|
|
if (family) {
|
|
await this.redis.del(`${this.REFRESH_TOKEN_PREFIX}${family.tokenHash}`);
|
|
|
|
this.logger.warn("Invalidated token family due to security concern", {
|
|
familyId: familyId.slice(0, 8),
|
|
userId: family.userId,
|
|
});
|
|
}
|
|
}
|
|
} catch (error) {
|
|
this.logger.error("Failed to invalidate token family", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
});
|
|
}
|
|
}
|
|
|
|
private generateTokenId(): string {
|
|
return randomBytes(32).toString("hex");
|
|
}
|
|
|
|
private hashToken(token: string): string {
|
|
return createHash("sha256").update(token).digest("hex");
|
|
}
|
|
|
|
private parseRefreshTokenRecord(value: string): StoredRefreshToken | null {
|
|
try {
|
|
const parsed = JSON.parse(value) as Partial<StoredRefreshToken>;
|
|
if (
|
|
parsed &&
|
|
typeof parsed === "object" &&
|
|
typeof parsed.familyId === "string" &&
|
|
typeof parsed.userId === "string" &&
|
|
typeof parsed.valid === "boolean"
|
|
) {
|
|
return {
|
|
familyId: parsed.familyId,
|
|
userId: parsed.userId,
|
|
valid: parsed.valid,
|
|
};
|
|
}
|
|
} catch (error) {
|
|
this.logger.warn("Failed to parse refresh token record", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
});
|
|
}
|
|
return null;
|
|
}
|
|
|
|
private parseRefreshTokenFamilyRecord(value: string): StoredRefreshTokenFamily | null {
|
|
try {
|
|
const parsed = JSON.parse(value) as Partial<StoredRefreshTokenFamily>;
|
|
if (
|
|
parsed &&
|
|
typeof parsed === "object" &&
|
|
typeof parsed.userId === "string" &&
|
|
typeof parsed.tokenHash === "string"
|
|
) {
|
|
return {
|
|
userId: parsed.userId,
|
|
tokenHash: parsed.tokenHash,
|
|
deviceId: typeof parsed.deviceId === "string" ? parsed.deviceId : undefined,
|
|
userAgent: typeof parsed.userAgent === "string" ? parsed.userAgent : undefined,
|
|
createdAt: typeof parsed.createdAt === "string" ? parsed.createdAt : undefined,
|
|
};
|
|
}
|
|
} catch (error) {
|
|
this.logger.warn("Failed to parse refresh token family record", {
|
|
error: error instanceof Error ? error.message : String(error),
|
|
});
|
|
}
|
|
return null;
|
|
}
|
|
|
|
private parseExpiryToMs(expiry: string): number {
|
|
const unit = expiry.slice(-1);
|
|
const value = parseInt(expiry.slice(0, -1));
|
|
|
|
switch (unit) {
|
|
case "s":
|
|
return value * 1000;
|
|
case "m":
|
|
return value * 60 * 1000;
|
|
case "h":
|
|
return value * 60 * 60 * 1000;
|
|
case "d":
|
|
return value * 24 * 60 * 60 * 1000;
|
|
default:
|
|
return 15 * 60 * 1000; // Default 15 minutes
|
|
}
|
|
}
|
|
|
|
private parseExpiryToSeconds(expiry: string): number {
|
|
return Math.floor(this.parseExpiryToMs(expiry) / 1000);
|
|
}
|
|
|
|
private calculateExpiryDate(expiresIn: string | number): string {
|
|
const now = new Date();
|
|
if (typeof expiresIn === "number") {
|
|
return new Date(now.getTime() + expiresIn * 1000).toISOString();
|
|
}
|
|
return new Date(now.getTime() + this.parseExpiryToMs(expiresIn)).toISOString();
|
|
}
|
|
}
|