refactor: simplify OAuth routes, add type safety, deduplicate UI components

- Extract handleOAuthCallback to eliminate GET/POST duplication in oauth.ts
- Add P2002 race condition handling in findOrCreateUser
- Add .unref() to stateStore cleanup timer to not block process exit
- Use Provider union type instead of bare strings throughout OAuth code
- Export API_BASE from api.ts, reuse in OAuthButtons
- Extract MobileBranding component to deduplicate Login/Register mobile brand
- Extract shared Logo component in AuthBranding
- Remove unnecessary WHAT comments

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-03 13:25:50 +08:00
parent 0bab0ecb93
commit eacaa5be05
7 changed files with 95 additions and 114 deletions

View File

@@ -14,7 +14,9 @@ type ProviderUser = {
avatarUrl: string | null;
};
const providers: Record<string, ProviderConfig> = {
export type Provider = 'google' | 'github' | 'apple';
const providers: Record<Provider, ProviderConfig> = {
google: {
authUrl: 'https://accounts.google.com/o/oauth2/v2/auth',
tokenUrl: 'https://oauth2.googleapis.com/token',
@@ -35,14 +37,14 @@ const providers: Record<string, ProviderConfig> = {
},
};
function getClientId(provider: string): string {
function getClientId(provider: Provider): string {
const envKey = provider === 'apple' ? 'APPLE_CLIENT_ID' : `${provider.toUpperCase()}_CLIENT_ID`;
const value = process.env[envKey];
if (!value) throw new Error(`Missing env: ${envKey}`);
return value;
}
function getClientSecret(provider: string): string {
function getClientSecret(provider: Provider): string {
if (provider === 'apple') return buildAppleClientSecret();
const envKey = `${provider.toUpperCase()}_CLIENT_SECRET`;
const value = process.env[envKey];
@@ -50,28 +52,28 @@ function getClientSecret(provider: string): string {
return value;
}
function getCallbackUrl(provider: string): string {
function getCallbackUrl(provider: Provider): string {
const base = process.env.OAUTH_CALLBACK_BASE_URL || 'http://localhost:3000';
return `${base}/api/auth/oauth/${provider}/callback`;
}
// --- State management (CSRF protection) ---
const stateStore = new Map<string, { provider: string; createdAt: number }>();
setInterval(() => {
const cleanupTimer = setInterval(() => {
const now = Date.now();
for (const [key, value] of stateStore) {
if (now - value.createdAt > 10 * 60 * 1000) stateStore.delete(key);
}
}, 5 * 60 * 1000);
cleanupTimer.unref();
function generateState(provider: string): string {
function generateState(provider: Provider): string {
const state = crypto.randomBytes(32).toString('hex');
stateStore.set(state, { provider, createdAt: Date.now() });
return state;
}
function validateState(state: string, provider: string): boolean {
function validateState(state: string, provider: Provider): boolean {
const entry = stateStore.get(state);
if (!entry) return false;
if (entry.provider !== provider) return false;
@@ -83,7 +85,6 @@ function validateState(state: string, provider: string): boolean {
return true;
}
// --- Apple client_secret JWT ---
function buildAppleClientSecret(): string {
const teamId = process.env.APPLE_TEAM_ID;
const keyId = process.env.APPLE_KEY_ID;
@@ -105,9 +106,7 @@ function buildAppleClientSecret(): string {
return `${signingInput}.${sig.toString('base64url')}`;
}
// --- Public API ---
export function buildAuthUrl(provider: string): string {
export function buildAuthUrl(provider: Provider): string {
const config = providers[provider];
if (!config) throw new Error(`Unknown provider: ${provider}`);
@@ -127,7 +126,7 @@ export function buildAuthUrl(provider: string): string {
return `${config.authUrl}?${params.toString()}`;
}
export async function exchangeCodeForToken(provider: string, code: string): Promise<string> {
export async function exchangeCodeForToken(provider: Provider, code: string): Promise<string> {
const config = providers[provider];
if (!config) throw new Error(`Unknown provider: ${provider}`);
@@ -163,7 +162,7 @@ export async function exchangeCodeForToken(provider: string, code: string): Prom
return data.access_token as string;
}
export async function fetchProviderUser(provider: string, token: string): Promise<ProviderUser> {
export async function fetchProviderUser(provider: Provider, token: string): Promise<ProviderUser> {
if (provider === 'apple') {
return parseAppleIdToken(token);
}

View File

@@ -1,17 +1,20 @@
import { Router, type Router as RouterType } from 'express';
import { Router, type Router as RouterType, type Response } from 'express';
import { prisma } from '@agent-fox/shared';
import { generateTokenPair } from '../lib/jwt.js';
import { buildAuthUrl, exchangeCodeForToken, fetchProviderUser, validateState } from '../lib/oauth-providers.js';
import { buildAuthUrl, exchangeCodeForToken, fetchProviderUser, validateState, type Provider } from '../lib/oauth-providers.js';
const router: RouterType = Router();
const VALID_PROVIDERS = ['google', 'github', 'apple'];
const VALID_PROVIDERS: Provider[] = ['google', 'github', 'apple'];
const FRONTEND_URL = process.env.FRONTEND_URL || 'http://localhost:5173';
// GET /auth/oauth/:provider — redirect to provider's authorization page
function isValidProvider(value: string): value is Provider {
return (VALID_PROVIDERS as string[]).includes(value);
}
router.get('/:provider', (req, res) => {
const { provider } = req.params;
if (!VALID_PROVIDERS.includes(provider)) {
if (!isValidProvider(provider)) {
res.status(400).json({ success: false, error: { code: 'INVALID_PROVIDER', message: `Unknown provider: ${provider}` } });
return;
}
@@ -24,17 +27,31 @@ router.get('/:provider', (req, res) => {
}
});
// GET /auth/oauth/:provider/callback — handle provider callback
router.get('/:provider/callback', async (req, res) => {
const { provider } = req.params;
const { code, state, error: oauthError } = req.query as Record<string, string>;
const params = req.query as Record<string, string>;
await handleOAuthCallback(provider, params.code, params.state, params.error, res);
});
// Apple sends callback as POST (form_post response mode)
router.post('/:provider/callback', async (req, res) => {
const { provider } = req.params;
await handleOAuthCallback(provider, req.body.code, req.body.state, req.body.error, res);
});
async function handleOAuthCallback(
provider: string,
code: string | undefined,
state: string | undefined,
oauthError: string | undefined,
res: Response,
) {
if (oauthError) {
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent(oauthError)}`);
return;
}
if (!code || !state) {
if (!code || !state || !isValidProvider(provider)) {
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Missing code or state')}`);
return;
}
@@ -60,51 +77,12 @@ router.get('/:provider/callback', async (req, res) => {
console.error(`OAuth callback error (${provider}):`, err);
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Authentication failed')}`);
}
});
// Apple sends callback as POST (form_post response mode)
router.post('/:provider/callback', async (req, res) => {
const { provider } = req.params;
const { code, state, error: oauthError } = req.body;
if (oauthError) {
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent(oauthError)}`);
return;
}
if (!code || !state) {
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Missing code or state')}`);
return;
}
if (!validateState(state, provider)) {
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Invalid or expired state')}`);
return;
}
try {
const token = await exchangeCodeForToken(provider, code);
const providerUser = await fetchProviderUser(provider, token);
if (!providerUser.email) {
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('No email returned from provider')}`);
return;
}
const user = await findOrCreateUser(provider, providerUser);
const tokens = generateTokenPair({ userId: user.id, email: user.email });
res.redirect(`${FRONTEND_URL}/login/callback?accessToken=${tokens.accessToken}&refreshToken=${tokens.refreshToken}`);
} catch (err) {
console.error(`OAuth POST callback error (${provider}):`, err);
res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Authentication failed')}`);
}
});
}
async function findOrCreateUser(
provider: string,
providerUser: { id: string; email: string; name: string; avatarUrl: string | null },
) {
// 1. Check existing OAuthAccount
const existingOAuth = await prisma.oAuthAccount.findUnique({
where: { provider_providerAccountId: { provider, providerAccountId: providerUser.id } },
include: { user: true },
@@ -119,7 +97,6 @@ async function findOrCreateUser(
return existingOAuth.user;
}
// 2. Check existing user by email — link OAuth account
const existingUser = await prisma.user.findUnique({ where: { email: providerUser.email } });
if (existingUser) {
await prisma.oAuthAccount.create({
@@ -134,19 +111,32 @@ async function findOrCreateUser(
return existingUser;
}
// 3. Create new user + OAuth account
const newUser = await prisma.user.create({
data: {
email: providerUser.email,
name: providerUser.name,
avatarUrl: providerUser.avatarUrl,
passwordHash: null,
oauthAccounts: {
create: { provider, providerAccountId: providerUser.id },
try {
const newUser = await prisma.user.create({
data: {
email: providerUser.email,
name: providerUser.name,
avatarUrl: providerUser.avatarUrl,
passwordHash: null,
oauthAccounts: {
create: { provider, providerAccountId: providerUser.id },
},
},
},
});
return newUser;
});
return newUser;
} catch (err: any) {
// Handle race condition: concurrent OAuth with same email
if (err?.code === 'P2002') {
const user = await prisma.user.findUnique({ where: { email: providerUser.email } });
if (user) {
await prisma.oAuthAccount.create({
data: { userId: user.id, provider, providerAccountId: providerUser.id },
}).catch(() => {}); // Ignore if OAuthAccount also raced
return user;
}
}
throw err;
}
}
export default router;