diff --git a/packages/server/src/lib/oauth-providers.ts b/packages/server/src/lib/oauth-providers.ts index 712daad..d409506 100644 --- a/packages/server/src/lib/oauth-providers.ts +++ b/packages/server/src/lib/oauth-providers.ts @@ -14,7 +14,9 @@ type ProviderUser = { avatarUrl: string | null; }; -const providers: Record = { +export type Provider = 'google' | 'github' | 'apple'; + +const providers: Record = { google: { authUrl: 'https://accounts.google.com/o/oauth2/v2/auth', tokenUrl: 'https://oauth2.googleapis.com/token', @@ -35,14 +37,14 @@ const providers: Record = { }, }; -function getClientId(provider: string): string { +function getClientId(provider: Provider): string { const envKey = provider === 'apple' ? 'APPLE_CLIENT_ID' : `${provider.toUpperCase()}_CLIENT_ID`; const value = process.env[envKey]; if (!value) throw new Error(`Missing env: ${envKey}`); return value; } -function getClientSecret(provider: string): string { +function getClientSecret(provider: Provider): string { if (provider === 'apple') return buildAppleClientSecret(); const envKey = `${provider.toUpperCase()}_CLIENT_SECRET`; const value = process.env[envKey]; @@ -50,28 +52,28 @@ function getClientSecret(provider: string): string { return value; } -function getCallbackUrl(provider: string): string { +function getCallbackUrl(provider: Provider): string { const base = process.env.OAUTH_CALLBACK_BASE_URL || 'http://localhost:3000'; return `${base}/api/auth/oauth/${provider}/callback`; } -// --- State management (CSRF protection) --- const stateStore = new Map(); -setInterval(() => { +const cleanupTimer = setInterval(() => { const now = Date.now(); for (const [key, value] of stateStore) { if (now - value.createdAt > 10 * 60 * 1000) stateStore.delete(key); } }, 5 * 60 * 1000); +cleanupTimer.unref(); -function generateState(provider: string): string { +function generateState(provider: Provider): string { const state = crypto.randomBytes(32).toString('hex'); stateStore.set(state, { provider, createdAt: Date.now() }); return state; } -function validateState(state: string, provider: string): boolean { +function validateState(state: string, provider: Provider): boolean { const entry = stateStore.get(state); if (!entry) return false; if (entry.provider !== provider) return false; @@ -83,7 +85,6 @@ function validateState(state: string, provider: string): boolean { return true; } -// --- Apple client_secret JWT --- function buildAppleClientSecret(): string { const teamId = process.env.APPLE_TEAM_ID; const keyId = process.env.APPLE_KEY_ID; @@ -105,9 +106,7 @@ function buildAppleClientSecret(): string { return `${signingInput}.${sig.toString('base64url')}`; } -// --- Public API --- - -export function buildAuthUrl(provider: string): string { +export function buildAuthUrl(provider: Provider): string { const config = providers[provider]; if (!config) throw new Error(`Unknown provider: ${provider}`); @@ -127,7 +126,7 @@ export function buildAuthUrl(provider: string): string { return `${config.authUrl}?${params.toString()}`; } -export async function exchangeCodeForToken(provider: string, code: string): Promise { +export async function exchangeCodeForToken(provider: Provider, code: string): Promise { const config = providers[provider]; if (!config) throw new Error(`Unknown provider: ${provider}`); @@ -163,7 +162,7 @@ export async function exchangeCodeForToken(provider: string, code: string): Prom return data.access_token as string; } -export async function fetchProviderUser(provider: string, token: string): Promise { +export async function fetchProviderUser(provider: Provider, token: string): Promise { if (provider === 'apple') { return parseAppleIdToken(token); } diff --git a/packages/server/src/routes/oauth.ts b/packages/server/src/routes/oauth.ts index 0d8326b..1d7f2c3 100644 --- a/packages/server/src/routes/oauth.ts +++ b/packages/server/src/routes/oauth.ts @@ -1,17 +1,20 @@ -import { Router, type Router as RouterType } from 'express'; +import { Router, type Router as RouterType, type Response } from 'express'; import { prisma } from '@agent-fox/shared'; import { generateTokenPair } from '../lib/jwt.js'; -import { buildAuthUrl, exchangeCodeForToken, fetchProviderUser, validateState } from '../lib/oauth-providers.js'; +import { buildAuthUrl, exchangeCodeForToken, fetchProviderUser, validateState, type Provider } from '../lib/oauth-providers.js'; const router: RouterType = Router(); -const VALID_PROVIDERS = ['google', 'github', 'apple']; +const VALID_PROVIDERS: Provider[] = ['google', 'github', 'apple']; const FRONTEND_URL = process.env.FRONTEND_URL || 'http://localhost:5173'; -// GET /auth/oauth/:provider — redirect to provider's authorization page +function isValidProvider(value: string): value is Provider { + return (VALID_PROVIDERS as string[]).includes(value); +} + router.get('/:provider', (req, res) => { const { provider } = req.params; - if (!VALID_PROVIDERS.includes(provider)) { + if (!isValidProvider(provider)) { res.status(400).json({ success: false, error: { code: 'INVALID_PROVIDER', message: `Unknown provider: ${provider}` } }); return; } @@ -24,17 +27,31 @@ router.get('/:provider', (req, res) => { } }); -// GET /auth/oauth/:provider/callback — handle provider callback router.get('/:provider/callback', async (req, res) => { const { provider } = req.params; - const { code, state, error: oauthError } = req.query as Record; + const params = req.query as Record; + await handleOAuthCallback(provider, params.code, params.state, params.error, res); +}); +// Apple sends callback as POST (form_post response mode) +router.post('/:provider/callback', async (req, res) => { + const { provider } = req.params; + await handleOAuthCallback(provider, req.body.code, req.body.state, req.body.error, res); +}); + +async function handleOAuthCallback( + provider: string, + code: string | undefined, + state: string | undefined, + oauthError: string | undefined, + res: Response, +) { if (oauthError) { res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent(oauthError)}`); return; } - if (!code || !state) { + if (!code || !state || !isValidProvider(provider)) { res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Missing code or state')}`); return; } @@ -60,51 +77,12 @@ router.get('/:provider/callback', async (req, res) => { console.error(`OAuth callback error (${provider}):`, err); res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Authentication failed')}`); } -}); - -// Apple sends callback as POST (form_post response mode) -router.post('/:provider/callback', async (req, res) => { - const { provider } = req.params; - const { code, state, error: oauthError } = req.body; - - if (oauthError) { - res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent(oauthError)}`); - return; - } - - if (!code || !state) { - res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Missing code or state')}`); - return; - } - - if (!validateState(state, provider)) { - res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Invalid or expired state')}`); - return; - } - - try { - const token = await exchangeCodeForToken(provider, code); - const providerUser = await fetchProviderUser(provider, token); - if (!providerUser.email) { - res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('No email returned from provider')}`); - return; - } - - const user = await findOrCreateUser(provider, providerUser); - const tokens = generateTokenPair({ userId: user.id, email: user.email }); - - res.redirect(`${FRONTEND_URL}/login/callback?accessToken=${tokens.accessToken}&refreshToken=${tokens.refreshToken}`); - } catch (err) { - console.error(`OAuth POST callback error (${provider}):`, err); - res.redirect(`${FRONTEND_URL}/login/callback?error=${encodeURIComponent('Authentication failed')}`); - } -}); +} async function findOrCreateUser( provider: string, providerUser: { id: string; email: string; name: string; avatarUrl: string | null }, ) { - // 1. Check existing OAuthAccount const existingOAuth = await prisma.oAuthAccount.findUnique({ where: { provider_providerAccountId: { provider, providerAccountId: providerUser.id } }, include: { user: true }, @@ -119,7 +97,6 @@ async function findOrCreateUser( return existingOAuth.user; } - // 2. Check existing user by email — link OAuth account const existingUser = await prisma.user.findUnique({ where: { email: providerUser.email } }); if (existingUser) { await prisma.oAuthAccount.create({ @@ -134,19 +111,32 @@ async function findOrCreateUser( return existingUser; } - // 3. Create new user + OAuth account - const newUser = await prisma.user.create({ - data: { - email: providerUser.email, - name: providerUser.name, - avatarUrl: providerUser.avatarUrl, - passwordHash: null, - oauthAccounts: { - create: { provider, providerAccountId: providerUser.id }, + try { + const newUser = await prisma.user.create({ + data: { + email: providerUser.email, + name: providerUser.name, + avatarUrl: providerUser.avatarUrl, + passwordHash: null, + oauthAccounts: { + create: { provider, providerAccountId: providerUser.id }, + }, }, - }, - }); - return newUser; + }); + return newUser; + } catch (err: any) { + // Handle race condition: concurrent OAuth with same email + if (err?.code === 'P2002') { + const user = await prisma.user.findUnique({ where: { email: providerUser.email } }); + if (user) { + await prisma.oAuthAccount.create({ + data: { userId: user.id, provider, providerAccountId: providerUser.id }, + }).catch(() => {}); // Ignore if OAuthAccount also raced + return user; + } + } + throw err; + } } export default router; diff --git a/packages/web/src/components/AuthBranding.tsx b/packages/web/src/components/AuthBranding.tsx index de9679c..288e29a 100644 --- a/packages/web/src/components/AuthBranding.tsx +++ b/packages/web/src/components/AuthBranding.tsx @@ -1,36 +1,51 @@ import { useI18n } from '../lib/i18n'; +function Logo({ className }: { className: string }) { + return ( + + + + + + ); +} + +export function MobileBranding() { + const { t } = useI18n(); + + return ( +
+
+ +
+

{t('auth.productName')}

+

{t('auth.slogan')}

+
+ ); +} + export default function AuthBranding() { const { t } = useI18n(); return (
- {/* Decorative circles */}
- {/* Logo */}
- - - - - +
- {/* Product name */}

{t('auth.productName')}

- {/* Slogan */}

{t('auth.slogan')}

- {/* Feature highlights */}
{['auth.feature1', 'auth.feature2', 'auth.feature3'].map((key) => (
diff --git a/packages/web/src/components/OAuthButtons.tsx b/packages/web/src/components/OAuthButtons.tsx index 0de29ee..9a2014a 100644 --- a/packages/web/src/components/OAuthButtons.tsx +++ b/packages/web/src/components/OAuthButtons.tsx @@ -1,6 +1,5 @@ import { useI18n } from '../lib/i18n'; - -const API_BASE = '/api'; +import { API_BASE } from '../lib/api'; function GoogleIcon() { return ( diff --git a/packages/web/src/lib/api.ts b/packages/web/src/lib/api.ts index f94c351..18c9f3b 100644 --- a/packages/web/src/lib/api.ts +++ b/packages/web/src/lib/api.ts @@ -1,4 +1,4 @@ -const API_BASE = '/api'; +export const API_BASE = '/api'; type ApiResponse = { success: boolean; diff --git a/packages/web/src/pages/Login.tsx b/packages/web/src/pages/Login.tsx index e8e23ef..d76f608 100644 --- a/packages/web/src/pages/Login.tsx +++ b/packages/web/src/pages/Login.tsx @@ -2,7 +2,7 @@ import { useState } from 'react'; import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { useAuth } from '../lib/auth'; import { useI18n } from '../lib/i18n'; -import AuthBranding from '../components/AuthBranding'; +import AuthBranding, { MobileBranding } from '../components/AuthBranding'; import OAuthButtons from '../components/OAuthButtons'; export default function Login() { @@ -63,18 +63,7 @@ export default function Login() { }} />
- {/* Mobile-only brand (visible when left panel is hidden) */} -
-
- - - - - -
-

{t('auth.productName')}

-

{t('auth.slogan')}

-
+ {/* Title (desktop) */}
diff --git a/packages/web/src/pages/Register.tsx b/packages/web/src/pages/Register.tsx index b8ffad1..a4f74a3 100644 --- a/packages/web/src/pages/Register.tsx +++ b/packages/web/src/pages/Register.tsx @@ -2,7 +2,7 @@ import { useState } from 'react'; import { Link, useNavigate } from 'react-router-dom'; import { useAuth } from '../lib/auth'; import { useI18n } from '../lib/i18n'; -import AuthBranding from '../components/AuthBranding'; +import AuthBranding, { MobileBranding } from '../components/AuthBranding'; import OAuthButtons from '../components/OAuthButtons'; export default function Register() { @@ -72,18 +72,7 @@ export default function Register() { }} />
- {/* Mobile-only brand */} -
-
- - - - - -
-

{t('auth.productName')}

-

{t('auth.slogan')}

-
+ {/* Title (desktop) */}