กลับไปที่บทความ
Authentication Security Backend Node.js

การรับรองความถูกต้องแบบหลายขั้นตอน (MFA) กลยุทธ์การยืนยันตัวตนสำหรับแอปเว็บสมัยใหม่

พลากร วรมงคล
13 เมษายน 2568 13 นาที

“คู่มือการใช้งานที่ครอบคลุมสำหรับการรับรองความถูกต้องแบบหลายขั้นตอน — ครอบคลุม TOTP setup, backup codes, SMS verification, hardware security keys, adaptive MFA และกลยุทธ์การกู้คืน”

บทวิเคราะห์เชิงลึก: การรับรองความถูกต้องแบบหลายขั้นตอน (MFA)

การรับรองความถูกต้องแบบหลายขั้นตอนคืออะไร

การรับรองความถูกต้องแบบหลายขั้นตอน (MFA) คือกลไกความปลอดภัยที่ต้องให้ผู้ใช้ยืนยันตัวตนผ่าน สองตัวประกอบอิสระหรือมากกว่า ก่อนที่จะอนุญาตการเข้าถึง หลักการพื้นฐาน: แม้ว่าตัวประกอบหนึ่งถูกบุกรุก (เช่น รหัสผ่านรั่วไหล) ผู้โจมตีก็ยังไม่สามารถเข้าถึงบัญชีได้หากไม่มีตัวประกอบเพิ่มเติม

MFA ไม่ใช่กลยุทธ์การรับรองความถูกต้องแบบสแตนด์อโลน — เป็น ชั้นความปลอดภัย ที่เพิ่มเติมบนวิธีการหลัก (โดยทั่วไปคือการเข้าสู่ระบบตามรหัสผ่าน) พลังของ MFA มาจากการรวมตัวประกอบจากหมวดหมู่ต่างๆ: บางสิ่งที่คุณ รู้ (รหัสผ่าน, PIN), บางสิ่งที่คุณ มี (โทรศัพท์, hardware key), หรือบางสิ่งที่คุณ เป็น (ลายนิ้วมือ, ใบหน้า) การใช้สองตัวประกอบจากหมวดหมู่เดียวกัน (เช่น รหัสผ่านบวกคำถามความปลอดภัย) ไม่ถือว่า MFA ที่แท้จริง

การใช้งาน MFA ในการผลิตที่พบบ่อยที่สุดในปัจจุบันคือ TOTP (Time-based One-Time Password) ผ่านแอปผู้รับรอง เช่น Google Authenticator หรือ Authy มักจะจับคู่กับรหัสกู้คืนสำรอง Hardware security keys (FIDO2/U2F) ให้การป้องกันที่แข็งแกร่งที่สุดและมีการนำมาใช้เพิ่มมากขึ้นสำหรับบัญชีที่มีมูลค่าสูง

หลักการหลัก

  • Defense in depth: ตัวประกอบเดียวที่ถูกบุกรุกไม่ได้อนุญาตการเข้าถึง ผู้โจมตีต้องทำให้เสียหายหลายช่องทางอิสระ
  • Factor independence: แต่ละตัวประกอบต้องมาจากหมวดหมู่ที่แตกต่างกัน (ความรู้, การครอบครอง, ลักษณะเฉพาะตัว) เพื่อให้มีประสิทธิภาพ
  • Risk-proportional: MFA ควรบังคับใช้ตามความไวความเสี่ยง — เสมอสำหรับบัญชีผู้ดูแลระบบและการดำเนินการทางการเงิน ไม่บังคับสำหรับผู้ใช้ทั่วไป
  • Recovery planning: ระบบ MFA ทุกระบบต้องรวมเส้นทางการกู้คืนสำหรับเมื่อผู้ใช้สูญหายตัวประกอบที่สอง

TOTP Authentication Flow

sequenceDiagram
    participant U as User
    participant B as Browser
    participant S as Server
    participant DB as Database
    participant A as Authenticator App

    U->>B: Enter email + password
    B->>S: POST /login {email, password}
    S->>DB: Validate credentials
    DB-->>S: Valid — MFA enabled for this user
    S-->>B: 200 {mfaRequired: true, mfaToken: "temp_xyz"}

    Note over U,A: User opens authenticator app
    A->>A: Generate TOTP: HMAC-SHA1(secret, floor(time/30))
    A-->>U: Display 6-digit code: 482951

    U->>B: Enter TOTP code: 482951
    B->>S: POST /login/mfa {mfaToken: "temp_xyz", code: "482951"}
    S->>DB: Retrieve user's TOTP secret
    S->>S: Verify: TOTP(secret, time) === "482951"
    S->>S: Issue session or JWT
    S-->>B: 200 {accessToken, refreshToken} — Fully authenticated

ตอนนี้เรามาสำรวจตัวประกอบการรับรองความถูกต้องแต่ละตัวและวิธีการใช้งานโดยละเอียด

การทำความเข้าใจตัวประกอบสามประการของการรับรองความถูกต้อง

ตัวประกอบการรับรองความถูกต้องแบ่งออกเป็นสามประเภท โดยแต่ละประเภทให้การรับประกันความปลอดภัยที่แตกต่างกัน:

1. Knowledge Factors (บางสิ่งที่คุณรู้)

ข้อมูลที่เฉพาะผู้ใช้ควรรู้—รหัสผ่าน, PIN, คำถามความปลอดภัย เป็นเรื่องง่ายในการใช้งาน แต่เสี่ยงต่อการหลอกลวง social engineering และการโจมตี brute-force

ตัวอย่าง: รหัสผ่านบัญชีธนาคารออนไลน์ของคุณ

2. Possession Factors (บางสิ่งที่คุณมี)

รายการทางกายภาพหรือดิจิทัลที่อยู่ในความควบคุมของผู้ใช้—โทรศัพท์, hardware keys, authenticator apps หรือ sim cards ยากต่อการบุกรุกจากระยะไกล

ตัวอย่าง: โทรศัพท์ของคุณรับรหัส SMS หรือ security key

3. Inherence Factors (บางสิ่งที่คุณเป็น)

การรับรองความถูกต้องแบบ biometric—ลายนิ้วมือ, face recognition, voice patterns ไม่ซ้ำใครสำหรับแต่ละบุคคลและยากต่อการปลอมแปลง

ตัวอย่าง: การปลดล็อกลายนิ้วมือบนโทรศัพท์ของคุณ

MFA ที่แข็งแกร่งโดยปกติรวมตัวประกอบอย่างน้อยสองประการจากหมวดหมู่ที่แตกต่างกัน รูปแบบที่พบได้บ่อยที่สุดในปัจจุบันคือ knowledge (รหัสผ่าน) + possession (authenticator app หรือ hardware key)

TOTP: มาตรฐานสูงสุดของ MFA

Time-based One-Time Passwords (TOTP) เป็นรูปแบบ MFA ที่รองรับโดยทั่วไปที่สุด ซึ่งแตกต่างจากรหัส SMS ผู้โจมตีจะต้องบุกรุกอุปกรณ์ของผู้ใช้เท่านั้น ทำให้ปลอดภัยมากขึ้น

วิธี TOTP ทำงาน

TOTP อาศัย HMAC-SHA1 และการซิงโครไนซ์เวลา:

  1. Shared Secret: ในช่วงการตั้งค่า เซิร์ฟเวอร์สร้าง random secret (โดยปกติ 32 ไบต์) ซึ่งแชร์กับไคลเอนต์
  2. Time Steps: timestamp Unix ปัจจุบันแบ่งออกเป็นช่วง 30 วินาที
  3. HMAC Generation: HMAC-SHA1(secret, time_step) สร้าง hash
  4. Code Extraction: 6 หลักสุดท้ายของ hash ประกอบเป็นรหัส one-time

นี่คือกระบวนการทางคณิตศาสตร์:

T = floor(current_unix_time / 30)  // Time step
H = HMAC-SHA1(secret, T)           // Hash
code = H[-6:] % 1,000,000          // Extract 6 digits

ความสวยงามคือทั้งเซิร์ฟเวอร์และไคลเอนต์สามารถตรวจสอบรหัสโดยอิสระได้ เนื่องจากพวกเขาแชร์ secret และเวลา

การใช้งาน TOTP กับ Speakeasy

มาสร้าง TOTP implementation ที่สมบูรณ์:

import speakeasy from 'speakeasy';
import QRCode from 'qrcode';
import { PrismaClient } from '@prisma/client';
import crypto from 'crypto';

const prisma = new PrismaClient();

// Generate TOTP secret and return QR code
export async function initializeTOTP(userId: string, userEmail: string) {
  // Generate a high-entropy secret
  const secret = speakeasy.generateSecret({
    name: `YourApp (${userEmail})`,
    issuer: 'YourApp',
    length: 32
  });

  if (!secret.otpauth_url) {
    throw new Error('Failed to generate secret');
  }

  // Generate QR code
  const qrCodeDataUrl = await QRCode.toDataURL(secret.otpauth_url);

  // Store temporary secret (not verified yet)
  await prisma.mfaSetup.create({
    data: {
      userId,
      secret: secret.base32,
      verified: false,
      createdAt: new Date(),
      expiresAt: new Date(Date.now() + 15 * 60 * 1000) // 15 minute window
    }
  });

  return {
    secret: secret.base32,
    qrCodeUrl: qrCodeDataUrl,
    manualEntryKey: secret.base32
  };
}

// Verify TOTP code during setup (user must provide a valid code)
export async function verifyTOTPSetup(userId: string, token: string) {
  const setup = await prisma.mfaSetup.findFirst({
    where: {
      userId,
      verified: false,
      expiresAt: { gt: new Date() }
    }
  });

  if (!setup) {
    throw new Error('TOTP setup expired or not found');
  }

  // Verify with 1-code window tolerance (±1 interval)
  const verified = speakeasy.totp.verify({
    secret: setup.secret,
    encoding: 'base32',
    token,
    window: 1 // Allow codes from 30s before/after current time
  });

  if (!verified) {
    throw new Error('Invalid TOTP code');
  }

  // Activate TOTP
  await prisma.mfaSetup.update({
    where: { id: setup.id },
    data: { verified: true }
  });

  // Update user record
  await prisma.user.update({
    where: { id: userId },
    data: {
      totpEnabled: true,
      totpSecret: setup.secret
    }
  });

  return { success: true };
}

// Verify TOTP code during login
export async function verifyTOTPLogin(userId: string, token: string): Promise<boolean> {
  const user = await prisma.user.findUnique({
    where: { id: userId }
  });

  if (!user?.totpSecret) {
    return false;
  }

  const verified = speakeasy.totp.verify({
    secret: user.totpSecret,
    encoding: 'base32',
    token,
    window: 1
  });

  if (verified) {
    // Log MFA verification for audit trail
    await prisma.auditLog.create({
      data: {
        userId,
        action: 'MFA_VERIFIED',
        ipAddress: '',
        userAgent: '',
        timestamp: new Date()
      }
    });
  }

  return verified;
}
// Spring Boot + GoogleAuth (com.warrenstrange:googleauth)
import com.warrenstrange.googleauth.GoogleAuthenticator;
import com.warrenstrange.googleauth.GoogleAuthenticatorKey;
import net.glxn.qrgen.javase.QRCode;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.time.Instant;
import java.util.Base64;
import java.util.Map;

@Service
public class TotpService {

    private final GoogleAuthenticator gAuth = new GoogleAuthenticator();
    private final MfaSetupRepository mfaSetupRepo;
    private final UserRepository userRepo;
    private final AuditLogRepository auditLogRepo;

    public TotpService(MfaSetupRepository mfaSetupRepo,
                       UserRepository userRepo,
                       AuditLogRepository auditLogRepo) {
        this.mfaSetupRepo = mfaSetupRepo;
        this.userRepo = userRepo;
        this.auditLogRepo = auditLogRepo;
    }

    // Generate TOTP secret and return QR code
    @Transactional
    public Map<String, String> initializeTOTP(String userId, String userEmail) {
        GoogleAuthenticatorKey key = gAuth.createCredentials();
        String secret = key.getKey();

        String otpauthUrl = String.format(
            "otpauth://totp/YourApp%%3A%%20%s?secret=%s&issuer=YourApp",
            userEmail, secret
        );

        // Generate QR code as Base64
        byte[] qrBytes = QRCode.from(otpauthUrl).withSize(200, 200).stream().toByteArray();
        String qrCodeDataUrl = "data:image/png;base64," + Base64.getEncoder().encodeToString(qrBytes);

        // Store temporary secret (not verified yet)
        MfaSetup setup = new MfaSetup();
        setup.setUserId(userId);
        setup.setSecret(secret);
        setup.setVerified(false);
        setup.setCreatedAt(Instant.now());
        setup.setExpiresAt(Instant.now().plusSeconds(15 * 60)); // 15-minute window
        mfaSetupRepo.save(setup);

        return Map.of(
            "secret", secret,
            "qrCodeUrl", qrCodeDataUrl,
            "manualEntryKey", secret
        );
    }

    // Verify TOTP code during setup
    @Transactional
    public void verifyTOTPSetup(String userId, int token) {
        MfaSetup setup = mfaSetupRepo
            .findFirstByUserIdAndVerifiedFalseAndExpiresAtAfter(userId, Instant.now())
            .orElseThrow(() -> new RuntimeException("TOTP setup expired or not found"));

        boolean verified = gAuth.authorize(setup.getSecret(), token);
        if (!verified) {
            throw new RuntimeException("Invalid TOTP code");
        }

        setup.setVerified(true);
        mfaSetupRepo.save(setup);

        User user = userRepo.findById(userId)
            .orElseThrow(() -> new RuntimeException("User not found"));
        user.setTotpEnabled(true);
        user.setTotpSecret(setup.getSecret());
        userRepo.save(user);
    }

    // Verify TOTP code during login
    @Transactional
    public boolean verifyTOTPLogin(String userId, int token) {
        User user = userRepo.findById(userId)
            .orElseThrow(() -> new RuntimeException("User not found"));

        if (user.getTotpSecret() == null) return false;

        boolean verified = gAuth.authorize(user.getTotpSecret(), token);
        if (verified) {
            AuditLog log = new AuditLog();
            log.setUserId(userId);
            log.setAction("MFA_VERIFIED");
            log.setIpAddress("");
            log.setUserAgent("");
            log.setTimestamp(Instant.now());
            auditLogRepo.save(log);
        }
        return verified;
    }
}
# FastAPI + pyotp + qrcode
import pyotp
import qrcode
import io
import base64
import hashlib
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from models import MfaSetup, User, AuditLog

def initialize_totp(user_id: str, user_email: str, db: Session) -> dict:
    # Generate a high-entropy secret
    secret = pyotp.random_base32()
    totp = pyotp.TOTP(secret)
    otpauth_url = totp.provisioning_uri(name=user_email, issuer_name="YourApp")

    # Generate QR code as data URL
    img = qrcode.make(otpauth_url)
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    qr_code_data_url = "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()

    # Store temporary secret (not verified yet)
    expires_at = datetime.now(timezone.utc) + timedelta(minutes=15)
    setup = MfaSetup(
        user_id=user_id,
        secret=secret,
        verified=False,
        created_at=datetime.now(timezone.utc),
        expires_at=expires_at,
    )
    db.add(setup)
    db.commit()

    return {"secret": secret, "qr_code_url": qr_code_data_url, "manual_entry_key": secret}


def verify_totp_setup(user_id: str, token: str, db: Session) -> dict:
    setup = (
        db.query(MfaSetup)
        .filter(
            MfaSetup.user_id == user_id,
            MfaSetup.verified == False,
            MfaSetup.expires_at > datetime.now(timezone.utc),
        )
        .first()
    )
    if not setup:
        raise ValueError("TOTP setup expired or not found")

    totp = pyotp.TOTP(setup.secret)
    # valid_window=1 allows ±1 interval tolerance
    if not totp.verify(token, valid_window=1):
        raise ValueError("Invalid TOTP code")

    setup.verified = True
    user = db.query(User).filter(User.id == user_id).first()
    user.totp_enabled = True
    user.totp_secret = setup.secret
    db.commit()

    return {"success": True}


def verify_totp_login(user_id: str, token: str, db: Session) -> bool:
    user = db.query(User).filter(User.id == user_id).first()
    if not user or not user.totp_secret:
        return False

    totp = pyotp.TOTP(user.totp_secret)
    verified = totp.verify(token, valid_window=1)

    if verified:
        log = AuditLog(
            user_id=user_id,
            action="MFA_VERIFIED",
            ip_address="",
            user_agent="",
            timestamp=datetime.now(timezone.utc),
        )
        db.add(log)
        db.commit()

    return verified
// ASP.NET Core + OtpNet + QRCoder
using OtpNet;
using QRCoder;
using System.Security.Cryptography;

public class TotpService
{
    private readonly AppDbContext _db;

    public TotpService(AppDbContext db) => _db = db;

    // Generate TOTP secret and return QR code
    public async Task<TotpSetupResult> InitializeTotpAsync(string userId, string userEmail)
    {
        // Generate a high-entropy secret (20 bytes = 160 bits)
        var secretBytes = RandomNumberGenerator.GetBytes(20);
        var secret = Base32Encoding.ToString(secretBytes);

        var otpauthUrl = $"otpauth://totp/YourApp%3A{Uri.EscapeDataString(userEmail)}" +
                         $"?secret={secret}&issuer=YourApp";

        // Generate QR code as data URL
        using var qrGenerator = new QRCodeGenerator();
        using var qrData = qrGenerator.CreateQrCode(otpauthUrl, QRCodeGenerator.ECCLevel.Q);
        using var qrCode = new PngByteQRCode(qrData);
        var qrBytes = qrCode.GetGraphic(5);
        var qrCodeDataUrl = "data:image/png;base64," + Convert.ToBase64String(qrBytes);

        // Store temporary secret (not verified yet)
        var setup = new MfaSetup
        {
            UserId = userId,
            Secret = secret,
            Verified = false,
            CreatedAt = DateTime.UtcNow,
            ExpiresAt = DateTime.UtcNow.AddMinutes(15)
        };
        _db.MfaSetups.Add(setup);
        await _db.SaveChangesAsync();

        return new TotpSetupResult
        {
            Secret = secret,
            QrCodeUrl = qrCodeDataUrl,
            ManualEntryKey = secret
        };
    }

    // Verify TOTP code during setup
    public async Task VerifyTotpSetupAsync(string userId, string token)
    {
        var setup = await _db.MfaSetups
            .Where(s => s.UserId == userId && !s.Verified && s.ExpiresAt > DateTime.UtcNow)
            .FirstOrDefaultAsync()
            ?? throw new InvalidOperationException("TOTP setup expired or not found");

        var secretBytes = Base32Encoding.ToBytes(setup.Secret);
        var totp = new Totp(secretBytes);
        // VerifyTotp with window of 1 (±1 interval)
        bool verified = totp.VerifyTotp(token, out _, new VerificationWindow(1, 1));
        if (!verified)
            throw new InvalidOperationException("Invalid TOTP code");

        setup.Verified = true;
        var user = await _db.Users.FindAsync(userId)
            ?? throw new InvalidOperationException("User not found");
        user.TotpEnabled = true;
        user.TotpSecret = setup.Secret;
        await _db.SaveChangesAsync();
    }

    // Verify TOTP code during login
    public async Task<bool> VerifyTotpLoginAsync(string userId, string token)
    {
        var user = await _db.Users.FindAsync(userId);
        if (user?.TotpSecret == null) return false;

        var secretBytes = Base32Encoding.ToBytes(user.TotpSecret);
        var totp = new Totp(secretBytes);
        bool verified = totp.VerifyTotp(token, out _, new VerificationWindow(1, 1));

        if (verified)
        {
            _db.AuditLogs.Add(new AuditLog
            {
                UserId = userId,
                Action = "MFA_VERIFIED",
                IpAddress = "",
                UserAgent = "",
                Timestamp = DateTime.UtcNow
            });
            await _db.SaveChangesAsync();
        }
        return verified;
    }
}

การผสานรวมกับ Express.js

import express, { Request, Response } from 'express';
import { authenticateUser } from './auth';

const app = express();

// Step 1: User initiates TOTP setup
app.post('/api/mfa/totp/setup', async (req: Request, res: Response) => {
  try {
    const userId = req.user?.id;
    if (!userId) {
      return res.status(401).json({ error: 'Unauthorized' });
    }

    const { secret, qrCodeUrl, manualEntryKey } = await initializeTOTP(
      userId,
      req.user.email
    );

    res.json({
      secret: manualEntryKey, // For manual entry
      qrCode: qrCodeUrl,
      message: 'Scan the QR code with your authenticator app'
    });
  } catch (error) {
    res.status(500).json({ error: 'Failed to initialize TOTP' });
  }
});

// Step 2: User confirms TOTP setup with a valid code
app.post('/api/mfa/totp/verify-setup', async (req: Request, res: Response) => {
  try {
    const userId = req.user?.id;
    const { token } = req.body;

    if (!userId || !token) {
      return res.status(400).json({ error: 'Missing required fields' });
    }

    await verifyTOTPSetup(userId, token);
    res.json({ success: true, message: 'TOTP enabled successfully' });
  } catch (error) {
    res.status(400).json({ error: 'Invalid code or setup expired' });
  }
});

// Step 3: Verify TOTP during login
app.post('/api/auth/verify-mfa', async (req: Request, res: Response) => {
  try {
    const { userId, token } = req.body;

    const valid = await verifyTOTPLogin(userId, token);
    if (!valid) {
      return res.status(401).json({ error: 'Invalid TOTP code' });
    }

    // Issue session/JWT after MFA verification
    const session = await createSession(userId);
    res.json({ sessionToken: session.token });
  } catch (error) {
    res.status(500).json({ error: 'MFA verification failed' });
  }
});
// Spring Boot REST Controller
@RestController
@RequestMapping("/api")
public class MfaController {

    private final TotpService totpService;
    private final SessionService sessionService;

    public MfaController(TotpService totpService, SessionService sessionService) {
        this.totpService = totpService;
        this.sessionService = sessionService;
    }

    // Step 1: User initiates TOTP setup
    @PostMapping("/mfa/totp/setup")
    public ResponseEntity<?> setupTotp(@AuthenticationPrincipal UserPrincipal principal) {
        if (principal == null) {
            return ResponseEntity.status(401).body(Map.of("error", "Unauthorized"));
        }
        try {
            var result = totpService.initializeTOTP(principal.getId(), principal.getEmail());
            return ResponseEntity.ok(Map.of(
                "secret", result.get("manualEntryKey"),
                "qrCode", result.get("qrCodeUrl"),
                "message", "Scan the QR code with your authenticator app"
            ));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "Failed to initialize TOTP"));
        }
    }

    // Step 2: User confirms TOTP setup with a valid code
    @PostMapping("/mfa/totp/verify-setup")
    public ResponseEntity<?> verifyTotpSetup(
            @AuthenticationPrincipal UserPrincipal principal,
            @RequestBody Map<String, String> body) {
        if (principal == null || body.get("token") == null) {
            return ResponseEntity.status(400).body(Map.of("error", "Missing required fields"));
        }
        try {
            int token = Integer.parseInt(body.get("token"));
            totpService.verifyTOTPSetup(principal.getId(), token);
            return ResponseEntity.ok(Map.of("success", true, "message", "TOTP enabled successfully"));
        } catch (Exception e) {
            return ResponseEntity.status(400).body(Map.of("error", "Invalid code or setup expired"));
        }
    }

    // Step 3: Verify TOTP during login
    @PostMapping("/auth/verify-mfa")
    public ResponseEntity<?> verifyMfa(@RequestBody Map<String, String> body) {
        try {
            String userId = body.get("userId");
            int token = Integer.parseInt(body.get("token"));
            boolean valid = totpService.verifyTOTPLogin(userId, token);
            if (!valid) {
                return ResponseEntity.status(401).body(Map.of("error", "Invalid TOTP code"));
            }
            var session = sessionService.createSession(userId);
            return ResponseEntity.ok(Map.of("sessionToken", session.getToken()));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "MFA verification failed"));
        }
    }
}
# FastAPI router
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from pydantic import BaseModel
from database import get_db
from auth import get_current_user

router = APIRouter()

class TokenBody(BaseModel):
    token: str

class VerifyMfaBody(BaseModel):
    user_id: str
    token: str

# Step 1: User initiates TOTP setup
@router.post("/mfa/totp/setup")
async def setup_totp(
    current_user=Depends(get_current_user),
    db: Session = Depends(get_db),
):
    result = initialize_totp(current_user.id, current_user.email, db)
    return {
        "secret": result["manual_entry_key"],
        "qr_code": result["qr_code_url"],
        "message": "Scan the QR code with your authenticator app",
    }

# Step 2: User confirms TOTP setup with a valid code
@router.post("/mfa/totp/verify-setup")
async def verify_totp_setup_endpoint(
    body: TokenBody,
    current_user=Depends(get_current_user),
    db: Session = Depends(get_db),
):
    try:
        result = verify_totp_setup(current_user.id, body.token, db)
        return {"success": True, "message": "TOTP enabled successfully"}
    except ValueError:
        raise HTTPException(status_code=400, detail="Invalid code or setup expired")

# Step 3: Verify TOTP during login
@router.post("/auth/verify-mfa")
async def verify_mfa(body: VerifyMfaBody, db: Session = Depends(get_db)):
    valid = verify_totp_login(body.user_id, body.token, db)
    if not valid:
        raise HTTPException(status_code=401, detail="Invalid TOTP code")
    session = create_session(body.user_id)
    return {"session_token": session.token}
// ASP.NET Core Controller
[ApiController]
[Route("api")]
public class MfaController : ControllerBase
{
    private readonly TotpService _totpService;
    private readonly SessionService _sessionService;

    public MfaController(TotpService totpService, SessionService sessionService)
    {
        _totpService = totpService;
        _sessionService = sessionService;
    }

    // Step 1: User initiates TOTP setup
    [HttpPost("mfa/totp/setup")]
    [Authorize]
    public async Task<IActionResult> SetupTotp()
    {
        var userId = User.FindFirstValue(ClaimTypes.NameIdentifier);
        var email = User.FindFirstValue(ClaimTypes.Email);
        try
        {
            var result = await _totpService.InitializeTotpAsync(userId!, email!);
            return Ok(new
            {
                secret = result.ManualEntryKey,
                qrCode = result.QrCodeUrl,
                message = "Scan the QR code with your authenticator app"
            });
        }
        catch
        {
            return StatusCode(500, new { error = "Failed to initialize TOTP" });
        }
    }

    // Step 2: User confirms TOTP setup with a valid code
    [HttpPost("mfa/totp/verify-setup")]
    [Authorize]
    public async Task<IActionResult> VerifyTotpSetup([FromBody] TokenRequest request)
    {
        var userId = User.FindFirstValue(ClaimTypes.NameIdentifier);
        if (userId == null || request.Token == null)
            return BadRequest(new { error = "Missing required fields" });
        try
        {
            await _totpService.VerifyTotpSetupAsync(userId, request.Token);
            return Ok(new { success = true, message = "TOTP enabled successfully" });
        }
        catch
        {
            return BadRequest(new { error = "Invalid code or setup expired" });
        }
    }

    // Step 3: Verify TOTP during login
    [HttpPost("auth/verify-mfa")]
    public async Task<IActionResult> VerifyMfa([FromBody] VerifyMfaRequest request)
    {
        try
        {
            var valid = await _totpService.VerifyTotpLoginAsync(request.UserId, request.Token);
            if (!valid)
                return Unauthorized(new { error = "Invalid TOTP code" });
            var session = await _sessionService.CreateSessionAsync(request.UserId);
            return Ok(new { sessionToken = session.Token });
        }
        catch
        {
            return StatusCode(500, new { error = "MFA verification failed" });
        }
    }
}

Backup and Recovery Codes

จะเกิดอะไรขึ้นเมื่อผู้ใช้สูญเสียโทรศัพท์ของพวกเขา? Backup codes ให้เส้นทางการกู้คืนโดยไม่ประนีประนวมความปลอดภัย

การสร้างและจัดเก็บ Backup Codes

export async function generateBackupCodes(userId: string, count: number = 10) {
  const codes = Array.from({ length: count }).map(() => {
    // Generate 8-character alphanumeric codes
    return crypto
      .randomBytes(6)
      .toString('hex')
      .substring(0, 8)
      .toUpperCase();
  });

  // Hash codes before storage (they should be one-way)
  const hashedCodes = codes.map(code => ({
    hash: crypto.createHash('sha256').update(code).digest('hex'),
    used: false,
    createdAt: new Date(),
    userId
  }));

  // Store hashed codes
  await prisma.backupCode.createMany({
    data: hashedCodes
  });

  // Return unhashed codes for display (only time user sees them)
  return codes;
}

// Verify and consume a backup code
export async function verifyBackupCode(userId: string, code: string): Promise<boolean> {
  const codeHash = crypto.createHash('sha256').update(code).digest('hex');

  const backupCode = await prisma.backupCode.findFirst({
    where: {
      userId,
      hash: codeHash,
      used: false
    }
  });

  if (!backupCode) {
    return false;
  }

  // Mark as used (one-time only)
  await prisma.backupCode.update({
    where: { id: backupCode.id },
    data: { used: true, usedAt: new Date() }
  });

  // Log usage for audit trail
  await prisma.auditLog.create({
    data: {
      userId,
      action: 'BACKUP_CODE_USED',
      timestamp: new Date()
    }
  });

  return true;
}

// Check if user has remaining backup codes
export async function getBackupCodeCount(userId: string): Promise<number> {
  return prisma.backupCode.count({
    where: {
      userId,
      used: false
    }
  });
}
@Service
public class BackupCodeService {

    private final BackupCodeRepository backupCodeRepo;
    private final AuditLogRepository auditLogRepo;

    public BackupCodeService(BackupCodeRepository backupCodeRepo,
                             AuditLogRepository auditLogRepo) {
        this.backupCodeRepo = backupCodeRepo;
        this.auditLogRepo = auditLogRepo;
    }

    @Transactional
    public List<String> generateBackupCodes(String userId, int count) {
        List<String> codes = new ArrayList<>();
        List<BackupCode> entities = new ArrayList<>();

        for (int i = 0; i < count; i++) {
            // Generate 8-character hex code
            byte[] bytes = new byte[6];
            new SecureRandom().nextBytes(bytes);
            String code = HexFormat.of().formatHex(bytes).substring(0, 8).toUpperCase();
            codes.add(code);

            String hash = sha256Hex(code);
            BackupCode entity = new BackupCode();
            entity.setUserId(userId);
            entity.setHash(hash);
            entity.setUsed(false);
            entity.setCreatedAt(Instant.now());
            entities.add(entity);
        }

        backupCodeRepo.saveAll(entities);
        return codes; // Return plaintext codes for one-time display
    }

    @Transactional
    public boolean verifyBackupCode(String userId, String code) {
        String codeHash = sha256Hex(code);

        Optional<BackupCode> backupCode = backupCodeRepo
            .findFirstByUserIdAndHashAndUsedFalse(userId, codeHash);

        if (backupCode.isEmpty()) return false;

        BackupCode bc = backupCode.get();
        bc.setUsed(true);
        bc.setUsedAt(Instant.now());
        backupCodeRepo.save(bc);

        AuditLog log = new AuditLog();
        log.setUserId(userId);
        log.setAction("BACKUP_CODE_USED");
        log.setTimestamp(Instant.now());
        auditLogRepo.save(log);

        return true;
    }

    public long getBackupCodeCount(String userId) {
        return backupCodeRepo.countByUserIdAndUsedFalse(userId);
    }

    private String sha256Hex(String input) {
        try {
            MessageDigest digest = MessageDigest.getInstance("SHA-256");
            byte[] hash = digest.digest(input.getBytes(StandardCharsets.UTF_8));
            return HexFormat.of().formatHex(hash);
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }
}
import secrets
import hashlib
from datetime import datetime, timezone
from sqlalchemy.orm import Session
from models import BackupCode, AuditLog

def generate_backup_codes(user_id: str, db: Session, count: int = 10) -> list[str]:
    codes = []
    entities = []

    for _ in range(count):
        # Generate 8-character hex code
        code = secrets.token_hex(6)[:8].upper()
        codes.append(code)

        code_hash = hashlib.sha256(code.encode()).hexdigest()
        entities.append(BackupCode(
            user_id=user_id,
            hash=code_hash,
            used=False,
            created_at=datetime.now(timezone.utc),
        ))

    db.add_all(entities)
    db.commit()
    return codes  # Return plaintext codes for one-time display


def verify_backup_code(user_id: str, code: str, db: Session) -> bool:
    code_hash = hashlib.sha256(code.encode()).hexdigest()

    backup_code = (
        db.query(BackupCode)
        .filter(
            BackupCode.user_id == user_id,
            BackupCode.hash == code_hash,
            BackupCode.used == False,
        )
        .first()
    )
    if not backup_code:
        return False

    backup_code.used = True
    backup_code.used_at = datetime.now(timezone.utc)

    log = AuditLog(
        user_id=user_id,
        action="BACKUP_CODE_USED",
        timestamp=datetime.now(timezone.utc),
    )
    db.add(log)
    db.commit()
    return True


def get_backup_code_count(user_id: str, db: Session) -> int:
    return db.query(BackupCode).filter(
        BackupCode.user_id == user_id,
        BackupCode.used == False,
    ).count()
public class BackupCodeService
{
    private readonly AppDbContext _db;

    public BackupCodeService(AppDbContext db) => _db = db;

    public async Task<List<string>> GenerateBackupCodesAsync(string userId, int count = 10)
    {
        var codes = new List<string>();
        var entities = new List<BackupCode>();

        for (int i = 0; i < count; i++)
        {
            // Generate 8-character hex code
            var bytes = RandomNumberGenerator.GetBytes(6);
            var code = Convert.ToHexString(bytes)[..8].ToUpper();
            codes.Add(code);

            entities.Add(new BackupCode
            {
                UserId = userId,
                Hash = Sha256Hex(code),
                Used = false,
                CreatedAt = DateTime.UtcNow
            });
        }

        _db.BackupCodes.AddRange(entities);
        await _db.SaveChangesAsync();
        return codes; // Return plaintext codes for one-time display
    }

    public async Task<bool> VerifyBackupCodeAsync(string userId, string code)
    {
        var codeHash = Sha256Hex(code);

        var backupCode = await _db.BackupCodes
            .Where(bc => bc.UserId == userId && bc.Hash == codeHash && !bc.Used)
            .FirstOrDefaultAsync();

        if (backupCode == null) return false;

        backupCode.Used = true;
        backupCode.UsedAt = DateTime.UtcNow;

        _db.AuditLogs.Add(new AuditLog
        {
            UserId = userId,
            Action = "BACKUP_CODE_USED",
            Timestamp = DateTime.UtcNow
        });
        await _db.SaveChangesAsync();
        return true;
    }

    public async Task<int> GetBackupCodeCountAsync(string userId) =>
        await _db.BackupCodes.CountAsync(bc => bc.UserId == userId && !bc.Used);

    private static string Sha256Hex(string input)
    {
        var hash = SHA256.HashData(Encoding.UTF8.GetBytes(input));
        return Convert.ToHexString(hash).ToLower();
    }
}

Endpoint สำหรับแสดง backup codes ในระหว่างการตั้งค่า

app.post('/api/mfa/backup-codes/generate', async (req: Request, res: Response) => {
  try {
    const userId = req.user?.id;
    if (!userId) {
      return res.status(401).json({ error: 'Unauthorized' });
    }

    const codes = await generateBackupCodes(userId);

    res.json({
      codes,
      message: 'Save these codes in a secure location. Each code can only be used once.'
    });
  } catch (error) {
    res.status(500).json({ error: 'Failed to generate backup codes' });
  }
});

// Verify backup code during login
app.post('/api/auth/verify-backup-code', async (req: Request, res: Response) => {
  try {
    const { userId, code } = req.body;

    const valid = await verifyBackupCode(userId, code);
    if (!valid) {
      return res.status(401).json({ error: 'Invalid or expired backup code' });
    }

    const session = await createSession(userId);
    res.json({ sessionToken: session.token });
  } catch (error) {
    res.status(500).json({ error: 'Backup code verification failed' });
  }
});
@RestController
@RequestMapping("/api")
public class BackupCodeController {

    private final BackupCodeService backupCodeService;
    private final SessionService sessionService;

    public BackupCodeController(BackupCodeService backupCodeService,
                                SessionService sessionService) {
        this.backupCodeService = backupCodeService;
        this.sessionService = sessionService;
    }

    @PostMapping("/mfa/backup-codes/generate")
    @Authorize
    public ResponseEntity<?> generateBackupCodes(
            @AuthenticationPrincipal UserPrincipal principal) {
        try {
            List<String> codes = backupCodeService.generateBackupCodes(principal.getId(), 10);
            return ResponseEntity.ok(Map.of(
                "codes", codes,
                "message", "Save these codes in a secure location. Each code can only be used once."
            ));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "Failed to generate backup codes"));
        }
    }

    @PostMapping("/auth/verify-backup-code")
    public ResponseEntity<?> verifyBackupCode(@RequestBody Map<String, String> body) {
        try {
            String userId = body.get("userId");
            String code = body.get("code");
            boolean valid = backupCodeService.verifyBackupCode(userId, code);
            if (!valid) {
                return ResponseEntity.status(401).body(Map.of("error", "Invalid or expired backup code"));
            }
            var session = sessionService.createSession(userId);
            return ResponseEntity.ok(Map.of("sessionToken", session.getToken()));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "Backup code verification failed"));
        }
    }
}
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from pydantic import BaseModel
from database import get_db
from auth import get_current_user

router = APIRouter()

class VerifyBackupCodeBody(BaseModel):
    user_id: str
    code: str

@router.post("/mfa/backup-codes/generate")
async def generate_backup_codes_endpoint(
    current_user=Depends(get_current_user),
    db: Session = Depends(get_db),
):
    codes = generate_backup_codes(current_user.id, db)
    return {
        "codes": codes,
        "message": "Save these codes in a secure location. Each code can only be used once.",
    }

@router.post("/auth/verify-backup-code")
async def verify_backup_code_endpoint(
    body: VerifyBackupCodeBody,
    db: Session = Depends(get_db),
):
    valid = verify_backup_code(body.user_id, body.code, db)
    if not valid:
        raise HTTPException(status_code=401, detail="Invalid or expired backup code")
    session = create_session(body.user_id)
    return {"session_token": session.token}
[ApiController]
[Route("api")]
public class BackupCodeController : ControllerBase
{
    private readonly BackupCodeService _backupCodeService;
    private readonly SessionService _sessionService;

    public BackupCodeController(BackupCodeService backupCodeService,
                                SessionService sessionService)
    {
        _backupCodeService = backupCodeService;
        _sessionService = sessionService;
    }

    [HttpPost("mfa/backup-codes/generate")]
    [Authorize]
    public async Task<IActionResult> GenerateBackupCodes()
    {
        var userId = User.FindFirstValue(ClaimTypes.NameIdentifier)!;
        var codes = await _backupCodeService.GenerateBackupCodesAsync(userId);
        return Ok(new
        {
            codes,
            message = "Save these codes in a secure location. Each code can only be used once."
        });
    }

    [HttpPost("auth/verify-backup-code")]
    public async Task<IActionResult> VerifyBackupCode([FromBody] VerifyBackupCodeRequest request)
    {
        var valid = await _backupCodeService.VerifyBackupCodeAsync(request.UserId, request.Code);
        if (!valid)
            return Unauthorized(new { error = "Invalid or expired backup code" });
        var session = await _sessionService.CreateSessionAsync(request.UserId);
        return Ok(new { sessionToken = session.Token });
    }
}

SMS-Based MFA (Legacy but Necessary)

แม้ว่า SMS จะเสี่ยงต่อ SIM swapping และ interception แต่ก็ยังคงรองรับอย่างกว้างขวาง ใช้เป็น fallback เท่านั้น ไม่ใช่วิธี primary MFA

การใช้งานกับ Rate Limiting

import twilio from 'twilio';
import { RateLimiter } from 'bottleneck';

const twilioClient = twilio(process.env.TWILIO_ACCOUNT_SID, process.env.TWILIO_AUTH_TOKEN);

// Rate limiter: max 3 SMS per phone number per hour
const smsLimiter = new RateLimiter({
  maxConcurrent: 1,
  minTime: 1000,
  reservoir: 3,
  reservoirRefreshAmount: 3,
  reservoirRefreshInterval: 60 * 60 * 1000
});

export async function sendSMSCode(userId: string, phoneNumber: string) {
  try {
    // Rate limiting check
    await smsLimiter.schedule(async () => {
      const code = crypto.randomInt(100000, 999999).toString();

      // Store code with expiration (5 minutes)
      await prisma.smsChallenge.create({
        data: {
          userId,
          code,
          phoneNumber: maskPhoneNumber(phoneNumber),
          expiresAt: new Date(Date.now() + 5 * 60 * 1000),
          attempts: 0
        }
      });

      // Send SMS
      await twilioClient.messages.create({
        body: `Your authentication code is: ${code}. Valid for 5 minutes.`,
        from: process.env.TWILIO_PHONE_NUMBER,
        to: phoneNumber
      });

      return { success: true };
    });
  } catch (error) {
    if (error instanceof Error && error.message.includes('Rate limit exceeded')) {
      throw new Error('Too many SMS requests. Please try again later.');
    }
    throw error;
  }
}

export async function verifySMSCode(userId: string, code: string): Promise<boolean> {
  const challenge = await prisma.smsChallenge.findFirst({
    where: {
      userId,
      expiresAt: { gt: new Date() },
      attempts: { lt: 5 } // Max 5 attempts
    },
    orderBy: { createdAt: 'desc' }
  });

  if (!challenge) {
    return false;
  }

  // Increment attempts
  await prisma.smsChallenge.update({
    where: { id: challenge.id },
    data: { attempts: challenge.attempts + 1 }
  });

  // Verify code (timing-safe comparison to prevent timing attacks)
  const isValid = crypto.timingSafeEqual(
    Buffer.from(challenge.code),
    Buffer.from(code)
  );

  if (isValid) {
    // Mark as used
    await prisma.smsChallenge.update({
      where: { id: challenge.id },
      data: { verifiedAt: new Date() }
    });
  }

  return isValid;
}

function maskPhoneNumber(phone: string): string {
  return phone.replace(/\d(?=\d{4})/g, '*');
}
// Spring Boot + Twilio SDK + Bucket4j rate limiting
import com.twilio.Twilio;
import com.twilio.rest.api.v2010.account.Message;
import com.twilio.type.PhoneNumber;
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import java.time.Duration;

@Service
public class SmsService {

    private final SmsChallengeRepository smsChallengeRepo;
    private final Map<String, Bucket> rateLimitBuckets = new ConcurrentHashMap<>();

    public SmsService(SmsChallengeRepository smsChallengeRepo) {
        this.smsChallengeRepo = smsChallengeRepo;
        Twilio.init(System.getenv("TWILIO_ACCOUNT_SID"), System.getenv("TWILIO_AUTH_TOKEN"));
    }

    // Rate limiter: max 3 SMS per userId per hour
    private Bucket getBucketForUser(String userId) {
        return rateLimitBuckets.computeIfAbsent(userId, k ->
            Bucket.builder()
                .addLimit(Bandwidth.simple(3, Duration.ofHours(1)))
                .build()
        );
    }

    @Transactional
    public void sendSmsCode(String userId, String phoneNumber) {
        Bucket bucket = getBucketForUser(userId);
        if (!bucket.tryConsume(1)) {
            throw new RuntimeException("Too many SMS requests. Please try again later.");
        }

        // 6-digit code
        String code = String.format("%06d", new SecureRandom().nextInt(900000) + 100000);

        SmsChallenge challenge = new SmsChallenge();
        challenge.setUserId(userId);
        challenge.setCode(code);
        challenge.setPhoneNumber(maskPhoneNumber(phoneNumber));
        challenge.setExpiresAt(Instant.now().plusSeconds(5 * 60));
        challenge.setAttempts(0);
        smsChallengeRepo.save(challenge);

        Message.creator(
            new PhoneNumber(phoneNumber),
            new PhoneNumber(System.getenv("TWILIO_PHONE_NUMBER")),
            "Your authentication code is: " + code + ". Valid for 5 minutes."
        ).create();
    }

    @Transactional
    public boolean verifySmsCode(String userId, String code) {
        SmsChallenge challenge = smsChallengeRepo
            .findFirstByUserIdAndExpiresAtAfterAndAttemptsLessThan(
                userId, Instant.now(), 5)
            .orElse(null);

        if (challenge == null) return false;

        challenge.setAttempts(challenge.getAttempts() + 1);
        smsChallengeRepo.save(challenge);

        // Timing-safe comparison
        boolean isValid = MessageDigest.isEqual(
            challenge.getCode().getBytes(StandardCharsets.UTF_8),
            code.getBytes(StandardCharsets.UTF_8)
        );

        if (isValid) {
            challenge.setVerifiedAt(Instant.now());
            smsChallengeRepo.save(challenge);
        }
        return isValid;
    }

    private String maskPhoneNumber(String phone) {
        return phone.replaceAll("\\d(?=\\d{4})", "*");
    }
}
# FastAPI + Twilio + slowapi rate limiting
from twilio.rest import Client as TwilioClient
from datetime import datetime, timedelta, timezone
import secrets
import hmac
from sqlalchemy.orm import Session
from models import SmsChallenge

twilio_client = TwilioClient(
    os.environ["TWILIO_ACCOUNT_SID"],
    os.environ["TWILIO_AUTH_TOKEN"],
)

# Simple in-memory rate limiter (use Redis in production)
_sms_rate_limits: dict[str, list[datetime]] = {}

def _check_rate_limit(user_id: str, max_count: int = 3, window_hours: int = 1) -> bool:
    now = datetime.now(timezone.utc)
    cutoff = now - timedelta(hours=window_hours)
    times = [t for t in _sms_rate_limits.get(user_id, []) if t > cutoff]
    if len(times) >= max_count:
        return False
    times.append(now)
    _sms_rate_limits[user_id] = times
    return True


def send_sms_code(user_id: str, phone_number: str, db: Session) -> None:
    if not _check_rate_limit(user_id):
        raise ValueError("Too many SMS requests. Please try again later.")

    code = str(secrets.randbelow(900000) + 100000)  # 6-digit code

    challenge = SmsChallenge(
        user_id=user_id,
        code=code,
        phone_number=mask_phone_number(phone_number),
        expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
        attempts=0,
    )
    db.add(challenge)
    db.commit()

    twilio_client.messages.create(
        body=f"Your authentication code is: {code}. Valid for 5 minutes.",
        from_=os.environ["TWILIO_PHONE_NUMBER"],
        to=phone_number,
    )


def verify_sms_code(user_id: str, code: str, db: Session) -> bool:
    challenge = (
        db.query(SmsChallenge)
        .filter(
            SmsChallenge.user_id == user_id,
            SmsChallenge.expires_at > datetime.now(timezone.utc),
            SmsChallenge.attempts < 5,
        )
        .order_by(SmsChallenge.created_at.desc())
        .first()
    )
    if not challenge:
        return False

    challenge.attempts += 1

    # Timing-safe comparison
    is_valid = hmac.compare_digest(challenge.code.encode(), code.encode())

    if is_valid:
        challenge.verified_at = datetime.now(timezone.utc)
    db.commit()
    return is_valid


def mask_phone_number(phone: str) -> str:
    import re
    return re.sub(r'\d(?=\d{4})', '*', phone)
// ASP.NET Core + Twilio + in-process rate limiter
using Twilio;
using Twilio.Rest.Api.V2010.Account;
using Twilio.Types;
using System.Security.Cryptography;
using System.Collections.Concurrent;

public class SmsService
{
    private readonly AppDbContext _db;
    private readonly ConcurrentDictionary<string, List<DateTime>> _rateLimits = new();

    public SmsService(AppDbContext db)
    {
        _db = db;
        TwilioClient.Init(
            Environment.GetEnvironmentVariable("TWILIO_ACCOUNT_SID"),
            Environment.GetEnvironmentVariable("TWILIO_AUTH_TOKEN")
        );
    }

    private bool CheckRateLimit(string userId, int maxCount = 3, int windowHours = 1)
    {
        var now = DateTime.UtcNow;
        var cutoff = now.AddHours(-windowHours);
        var times = _rateLimits.GetOrAdd(userId, _ => new List<DateTime>());
        lock (times)
        {
            times.RemoveAll(t => t < cutoff);
            if (times.Count >= maxCount) return false;
            times.Add(now);
        }
        return true;
    }

    public async Task SendSmsCodeAsync(string userId, string phoneNumber)
    {
        if (!CheckRateLimit(userId))
            throw new InvalidOperationException("Too many SMS requests. Please try again later.");

        // 6-digit code
        var code = RandomNumberGenerator.GetInt32(100000, 999999).ToString();

        _db.SmsChallenges.Add(new SmsChallenge
        {
            UserId = userId,
            Code = code,
            PhoneNumber = MaskPhoneNumber(phoneNumber),
            ExpiresAt = DateTime.UtcNow.AddMinutes(5),
            Attempts = 0
        });
        await _db.SaveChangesAsync();

        await MessageResource.CreateAsync(
            body: $"Your authentication code is: {code}. Valid for 5 minutes.",
            from: new PhoneNumber(Environment.GetEnvironmentVariable("TWILIO_PHONE_NUMBER")),
            to: new PhoneNumber(phoneNumber)
        );
    }

    public async Task<bool> VerifySmsCodeAsync(string userId, string code)
    {
        var challenge = await _db.SmsChallenges
            .Where(c => c.UserId == userId
                     && c.ExpiresAt > DateTime.UtcNow
                     && c.Attempts < 5)
            .OrderByDescending(c => c.CreatedAt)
            .FirstOrDefaultAsync();

        if (challenge == null) return false;

        challenge.Attempts++;

        // Timing-safe comparison
        bool isValid = CryptographicOperations.FixedTimeEquals(
            Encoding.UTF8.GetBytes(challenge.Code),
            Encoding.UTF8.GetBytes(code)
        );

        if (isValid) challenge.VerifiedAt = DateTime.UtcNow;
        await _db.SaveChangesAsync();
        return isValid;
    }

    private static string MaskPhoneNumber(string phone) =>
        System.Text.RegularExpressions.Regex.Replace(phone, @"\d(?=\d{4})", "*");
}

SMS MFA Endpoint

app.post('/api/mfa/sms/send', async (req: Request, res: Response) => {
  try {
    const userId = req.user?.id;
    if (!userId) {
      return res.status(401).json({ error: 'Unauthorized' });
    }

    const user = await prisma.user.findUnique({ where: { id: userId } });
    if (!user?.phoneNumber) {
      return res.status(400).json({ error: 'Phone number not configured' });
    }

    await sendSMSCode(userId, user.phoneNumber);

    res.json({
      success: true,
      message: `Code sent to ${maskPhoneNumber(user.phoneNumber)}`
    });
  } catch (error) {
    res.status(429).json({ error: 'Too many requests' });
  }
});

app.post('/api/auth/verify-sms', async (req: Request, res: Response) => {
  try {
    const { userId, code } = req.body;

    const valid = await verifySMSCode(userId, code);
    if (!valid) {
      return res.status(401).json({ error: 'Invalid or expired code' });
    }

    const session = await createSession(userId);
    res.json({ sessionToken: session.token });
  } catch (error) {
    res.status(500).json({ error: 'SMS verification failed' });
  }
});
@RestController
@RequestMapping("/api")
public class SmsController {

    private final SmsService smsService;
    private final UserRepository userRepo;
    private final SessionService sessionService;

    public SmsController(SmsService smsService, UserRepository userRepo,
                         SessionService sessionService) {
        this.smsService = smsService;
        this.userRepo = userRepo;
        this.sessionService = sessionService;
    }

    @PostMapping("/mfa/sms/send")
    @Authorize
    public ResponseEntity<?> sendSms(@AuthenticationPrincipal UserPrincipal principal) {
        try {
            User user = userRepo.findById(principal.getId())
                .orElseThrow(() -> new RuntimeException("User not found"));
            if (user.getPhoneNumber() == null) {
                return ResponseEntity.status(400).body(Map.of("error", "Phone number not configured"));
            }
            smsService.sendSmsCode(principal.getId(), user.getPhoneNumber());
            return ResponseEntity.ok(Map.of(
                "success", true,
                "message", "Code sent to " + smsService.maskPhoneNumber(user.getPhoneNumber())
            ));
        } catch (RuntimeException e) {
            return ResponseEntity.status(429).body(Map.of("error", "Too many requests"));
        }
    }

    @PostMapping("/auth/verify-sms")
    public ResponseEntity<?> verifySms(@RequestBody Map<String, String> body) {
        try {
            boolean valid = smsService.verifySmsCode(body.get("userId"), body.get("code"));
            if (!valid) {
                return ResponseEntity.status(401).body(Map.of("error", "Invalid or expired code"));
            }
            var session = sessionService.createSession(body.get("userId"));
            return ResponseEntity.ok(Map.of("sessionToken", session.getToken()));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "SMS verification failed"));
        }
    }
}
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from pydantic import BaseModel
from database import get_db
from auth import get_current_user

router = APIRouter()

class VerifySmsBody(BaseModel):
    user_id: str
    code: str

@router.post("/mfa/sms/send")
async def send_sms_endpoint(
    current_user=Depends(get_current_user),
    db: Session = Depends(get_db),
):
    user = db.query(User).filter(User.id == current_user.id).first()
    if not user or not user.phone_number:
        raise HTTPException(status_code=400, detail="Phone number not configured")
    try:
        send_sms_code(current_user.id, user.phone_number, db)
        return {
            "success": True,
            "message": f"Code sent to {mask_phone_number(user.phone_number)}",
        }
    except ValueError:
        raise HTTPException(status_code=429, detail="Too many requests")

@router.post("/auth/verify-sms")
async def verify_sms_endpoint(body: VerifySmsBody, db: Session = Depends(get_db)):
    valid = verify_sms_code(body.user_id, body.code, db)
    if not valid:
        raise HTTPException(status_code=401, detail="Invalid or expired code")
    session = create_session(body.user_id)
    return {"session_token": session.token}
[ApiController]
[Route("api")]
public class SmsController : ControllerBase
{
    private readonly SmsService _smsService;
    private readonly AppDbContext _db;
    private readonly SessionService _sessionService;

    public SmsController(SmsService smsService, AppDbContext db, SessionService sessionService)
    {
        _smsService = smsService;
        _db = db;
        _sessionService = sessionService;
    }

    [HttpPost("mfa/sms/send")]
    [Authorize]
    public async Task<IActionResult> SendSms()
    {
        var userId = User.FindFirstValue(ClaimTypes.NameIdentifier)!;
        var user = await _db.Users.FindAsync(userId);
        if (user?.PhoneNumber == null)
            return BadRequest(new { error = "Phone number not configured" });
        try
        {
            await _smsService.SendSmsCodeAsync(userId, user.PhoneNumber);
            return Ok(new { success = true, message = $"Code sent to {SmsService.MaskPhoneNumber(user.PhoneNumber)}" });
        }
        catch (InvalidOperationException)
        {
            return StatusCode(429, new { error = "Too many requests" });
        }
    }

    [HttpPost("auth/verify-sms")]
    public async Task<IActionResult> VerifySms([FromBody] VerifySmsRequest request)
    {
        var valid = await _smsService.VerifySmsCodeAsync(request.UserId, request.Code);
        if (!valid)
            return Unauthorized(new { error = "Invalid or expired code" });
        var session = await _sessionService.CreateSessionAsync(request.UserId);
        return Ok(new { sessionToken = session.Token });
    }
}

เหตุใด SMS จึงอ่อนแอ: SIM swapping (ผู้โจมตีโน้มน้าวให้ผู้ให้บริการหลักเปลี่ยนหมายเลขของคุณ), network interception และ delay attacks ทำให้ SMS เสี่ยง หน่วยงานกำกับดูแลเช่น NIST ปัจจุบันแนะนำให้ปฏิเสธ SMS-only MFA

Hardware Security Keys (FIDO2/WebAuthn)

Hardware keys แสดงรูปแบบ MFA ที่แข็งแกร่งที่สุด พวกเขาใช้ public-key cryptography และปลอดภัยจากการหลอกลวง

การใช้งาน WebAuthn Registration

import { generateRegistrationOptions, verifyRegistrationResponse } from '@simplewebauthn/server';
import { isoBase64URL } from '@simplewebauthn/server/helpers/iso';

// Step 1: Begin WebAuthn registration
export async function beginWebAuthnRegistration(userId: string, userName: string) {
  const options = generateRegistrationOptions({
    rpID: process.env.WEBAUTHN_RP_ID || 'yourdomain.com',
    rpName: 'Your Application',
    userID: userId,
    userName,
    attestationType: 'direct',
    supportedAlgorithmIDs: [-7, -257] // ES256, RS256
  });

  // Store challenge temporarily
  await prisma.webAuthnChallenge.create({
    data: {
      userId,
      challenge: isoBase64URL.toBuffer(options.challenge).toString('base64'),
      expiresAt: new Date(Date.now() + 10 * 60 * 1000) // 10 minute window
    }
  });

  return options;
}

// Step 2: Verify WebAuthn registration
export async function completeWebAuthnRegistration(
  userId: string,
  credential: RegistrationResponseJSON
) {
  const challenge = await prisma.webAuthnChallenge.findFirst({
    where: {
      userId,
      expiresAt: { gt: new Date() }
    },
    orderBy: { createdAt: 'desc' }
  });

  if (!challenge) {
    throw new Error('Challenge expired or not found');
  }

  try {
    const verification = await verifyRegistrationResponse({
      response: credential,
      expectedChallenge: challenge.challenge,
      expectedOrigin: process.env.WEBAUTHN_ORIGIN || 'https://yourdomain.com',
      expectedRPID: process.env.WEBAUTHN_RP_ID || 'yourdomain.com'
    });

    if (!verification.verified) {
      throw new Error('WebAuthn registration verification failed');
    }

    // Store credential
    const credentialIdBuffer = isoBase64URL.toBuffer(credential.id);

    await prisma.webAuthnCredential.create({
      data: {
        userId,
        credentialId: credentialIdBuffer.toString('base64'),
        publicKey: Buffer.from(verification.registrationInfo!.credentialPublicKey).toString('base64'),
        counter: verification.registrationInfo!.counter,
        transports: credential.response.transports || [],
        aaguid: verification.registrationInfo!.aaguid || '',
        deviceName: 'Security Key'
      }
    });

    // Mark challenge as used
    await prisma.webAuthnChallenge.update({
      where: { id: challenge.id },
      data: { usedAt: new Date() }
    });

    return { success: true };
  } catch (error) {
    throw error;
  }
}
// Spring Boot + webauthn4j
import com.webauthn4j.WebAuthnManager;
import com.webauthn4j.data.*;
import com.webauthn4j.data.client.Origin;
import com.webauthn4j.data.client.challenge.DefaultChallenge;
import org.springframework.stereotype.Service;

@Service
public class WebAuthnService {

    private final WebAuthnChallengeRepository challengeRepo;
    private final WebAuthnCredentialRepository credentialRepo;
    private final WebAuthnManager webAuthnManager = WebAuthnManager.createNonStrictWebAuthnManager();

    public WebAuthnService(WebAuthnChallengeRepository challengeRepo,
                           WebAuthnCredentialRepository credentialRepo) {
        this.challengeRepo = challengeRepo;
        this.credentialRepo = credentialRepo;
    }

    // Step 1: Begin WebAuthn registration
    public PublicKeyCredentialCreationOptions beginRegistration(String userId, String userName) {
        byte[] challengeBytes = new byte[32];
        new SecureRandom().nextBytes(challengeBytes);
        String challengeB64 = Base64.getEncoder().encodeToString(challengeBytes);

        // Store challenge
        WebAuthnChallenge challenge = new WebAuthnChallenge();
        challenge.setUserId(userId);
        challenge.setChallenge(challengeB64);
        challenge.setExpiresAt(Instant.now().plusSeconds(10 * 60));
        challengeRepo.save(challenge);

        // Build creation options
        PublicKeyCredentialRpEntity rp = new PublicKeyCredentialRpEntity(
            System.getenv("WEBAUTHN_RP_ID"), "Your Application");
        PublicKeyCredentialUserEntity user = new PublicKeyCredentialUserEntity(
            userId.getBytes(StandardCharsets.UTF_8), userName, userName);

        return new PublicKeyCredentialCreationOptions(
            rp, user, new DefaultChallenge(challengeBytes),
            List.of(
                new PublicKeyCredentialParameters(PublicKeyCredentialType.PUBLIC_KEY, COSEAlgorithmIdentifier.ES256),
                new PublicKeyCredentialParameters(PublicKeyCredentialType.PUBLIC_KEY, COSEAlgorithmIdentifier.RS256)
            )
        );
    }

    // Step 2: Verify WebAuthn registration
    @Transactional
    public void completeRegistration(String userId, RegistrationRequest registrationRequest) {
        WebAuthnChallenge storedChallenge = challengeRepo
            .findFirstByUserIdAndExpiresAtAfterOrderByCreatedAtDesc(userId, Instant.now())
            .orElseThrow(() -> new RuntimeException("Challenge expired or not found"));

        byte[] challengeBytes = Base64.getDecoder().decode(storedChallenge.getChallenge());
        Origin origin = new Origin(System.getenv("WEBAUTHN_ORIGIN"));
        String rpId = System.getenv("WEBAUTHN_RP_ID");

        RegistrationData data = webAuthnManager.parse(registrationRequest);
        RegistrationParameters params = new RegistrationParameters(
            new ServerProperty(origin, rpId, new DefaultChallenge(challengeBytes), null),
            null, false, true
        );
        webAuthnManager.validate(data, params);

        // Store credential
        WebAuthnCredential credential = new WebAuthnCredential();
        credential.setUserId(userId);
        credential.setCredentialId(Base64.getEncoder().encodeToString(
            data.getAttestationObject().getAuthenticatorData().getAttestedCredentialData().getCredentialId()));
        credential.setPublicKey(Base64.getEncoder().encodeToString(
            data.getAttestationObject().getAuthenticatorData().getAttestedCredentialData()
                .getCOSEKey().getBytes()));
        credential.setCounter(data.getAttestationObject().getAuthenticatorData().getSignCount());
        credential.setDeviceName("Security Key");
        credentialRepo.save(credential);

        storedChallenge.setUsedAt(Instant.now());
        challengeRepo.save(storedChallenge);
    }
}
# FastAPI + py_webauthn
from webauthn import generate_registration_options, verify_registration_response
from webauthn.helpers.structs import (
    RegistrationCredential, PublicKeyCredentialDescriptor,
)
from webauthn.helpers.cose import COSEAlgorithmIdentifier
import os, base64, secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from models import WebAuthnChallenge, WebAuthnCredential

RP_ID = os.environ.get("WEBAUTHN_RP_ID", "yourdomain.com")
RP_NAME = "Your Application"
ORIGIN = os.environ.get("WEBAUTHN_ORIGIN", "https://yourdomain.com")


def begin_webauthn_registration(user_id: str, user_name: str, db: Session):
    options = generate_registration_options(
        rp_id=RP_ID,
        rp_name=RP_NAME,
        user_id=user_id,
        user_name=user_name,
        supported_pub_key_algs=[
            COSEAlgorithmIdentifier.ECDSA_SHA_256,
            COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
        ],
    )
    challenge_b64 = base64.b64encode(options.challenge).decode()
    db.add(WebAuthnChallenge(
        user_id=user_id,
        challenge=challenge_b64,
        expires_at=datetime.now(timezone.utc) + timedelta(minutes=10),
    ))
    db.commit()
    return options


def complete_webauthn_registration(
    user_id: str, credential: RegistrationCredential, db: Session
):
    stored = (
        db.query(WebAuthnChallenge)
        .filter(
            WebAuthnChallenge.user_id == user_id,
            WebAuthnChallenge.expires_at > datetime.now(timezone.utc),
        )
        .order_by(WebAuthnChallenge.created_at.desc())
        .first()
    )
    if not stored:
        raise ValueError("Challenge expired or not found")

    verification = verify_registration_response(
        credential=credential,
        expected_challenge=base64.b64decode(stored.challenge),
        expected_rp_id=RP_ID,
        expected_origin=ORIGIN,
    )
    if not verification.verified:
        raise ValueError("WebAuthn registration verification failed")

    db.add(WebAuthnCredential(
        user_id=user_id,
        credential_id=base64.b64encode(verification.credential_id).decode(),
        public_key=base64.b64encode(verification.credential_public_key).decode(),
        counter=verification.sign_count,
        device_name="Security Key",
    ))
    stored.used_at = datetime.now(timezone.utc)
    db.commit()
    return {"success": True}
// ASP.NET Core + Fido2NetLib
using Fido2NetLib;
using Fido2NetLib.Objects;

public class WebAuthnService
{
    private readonly AppDbContext _db;
    private readonly IFido2 _fido2;

    public WebAuthnService(AppDbContext db, IFido2 fido2)
    {
        _db = db;
        _fido2 = fido2;
    }

    // Step 1: Begin WebAuthn registration
    public async Task<CredentialCreateOptions> BeginRegistrationAsync(string userId, string userName)
    {
        var user = new Fido2User
        {
            Id = Encoding.UTF8.GetBytes(userId),
            Name = userName,
            DisplayName = userName
        };

        var options = _fido2.RequestNewCredential(
            user,
            new List<PublicKeyCredentialDescriptor>(),
            AuthenticatorSelection.Default,
            AttestationConveyancePreference.Direct
        );

        _db.WebAuthnChallenges.Add(new WebAuthnChallenge
        {
            UserId = userId,
            Challenge = Convert.ToBase64String(options.Challenge),
            ExpiresAt = DateTime.UtcNow.AddMinutes(10)
        });
        await _db.SaveChangesAsync();

        return options;
    }

    // Step 2: Verify WebAuthn registration
    public async Task CompleteRegistrationAsync(string userId, AuthenticatorAttestationRawResponse attestationResponse)
    {
        var stored = await _db.WebAuthnChallenges
            .Where(c => c.UserId == userId && c.ExpiresAt > DateTime.UtcNow)
            .OrderByDescending(c => c.CreatedAt)
            .FirstOrDefaultAsync()
            ?? throw new InvalidOperationException("Challenge expired or not found");

        var options = CredentialCreateOptions.FromJson(
            $"{{\"challenge\":\"{stored.Challenge}\"}}");

        IsCredentialIdUniqueToUserAsyncDelegate callback = async (args, ct) =>
            !await _db.WebAuthnCredentials.AnyAsync(c => c.CredentialId == Convert.ToBase64String(args.CredentialId));

        var result = await _fido2.MakeNewCredentialAsync(attestationResponse, options, callback);

        _db.WebAuthnCredentials.Add(new WebAuthnCredential
        {
            UserId = userId,
            CredentialId = Convert.ToBase64String(result.Result!.CredentialId),
            PublicKey = Convert.ToBase64String(result.Result.PublicKey),
            Counter = result.Result.Counter,
            DeviceName = "Security Key"
        });
        stored.UsedAt = DateTime.UtcNow;
        await _db.SaveChangesAsync();
    }
}

WebAuthn Authentication

import { generateAuthenticationOptions, verifyAuthenticationResponse } from '@simplewebauthn/server';

// Step 1: Begin WebAuthn authentication
export async function beginWebAuthnAuthentication(userId: string) {
  const credentials = await prisma.webAuthnCredential.findMany({
    where: { userId }
  });

  if (credentials.length === 0) {
    throw new Error('No security keys registered');
  }

  const options = generateAuthenticationOptions({
    rpID: process.env.WEBAUTHN_RP_ID || 'yourdomain.com',
    allowCredentials: credentials.map(cred => ({
      id: cred.credentialId,
      type: 'public-key',
      transports: cred.transports
    }))
  });

  // Store challenge
  await prisma.webAuthnChallenge.create({
    data: {
      userId,
      challenge: isoBase64URL.toBuffer(options.challenge).toString('base64'),
      expiresAt: new Date(Date.now() + 10 * 60 * 1000)
    }
  });

  return options;
}

// Step 2: Verify WebAuthn authentication
export async function completeWebAuthnAuthentication(
  userId: string,
  credential: AuthenticationResponseJSON
) {
  const challenge = await prisma.webAuthnChallenge.findFirst({
    where: {
      userId,
      usedAt: null,
      expiresAt: { gt: new Date() }
    },
    orderBy: { createdAt: 'desc' }
  });

  if (!challenge) {
    throw new Error('Challenge expired');
  }

  const webAuthnCredential = await prisma.webAuthnCredential.findFirst({
    where: {
      userId,
      credentialId: credential.id
    }
  });

  if (!webAuthnCredential) {
    throw new Error('Credential not found');
  }

  try {
    const verification = await verifyAuthenticationResponse({
      response: credential,
      expectedChallenge: challenge.challenge,
      expectedOrigin: process.env.WEBAUTHN_ORIGIN || 'https://yourdomain.com',
      expectedRPID: process.env.WEBAUTHN_RP_ID || 'yourdomain.com',
      credential: {
        id: webAuthnCredential.credentialId,
        publicKey: Buffer.from(webAuthnCredential.publicKey, 'base64'),
        counter: webAuthnCredential.counter,
        transports: webAuthnCredential.transports
      }
    });

    if (!verification.verified) {
      throw new Error('Authentication failed');
    }

    // Update counter to prevent cloned key attacks
    await prisma.webAuthnCredential.update({
      where: { id: webAuthnCredential.id },
      data: { counter: verification.authenticationInfo.newCounter }
    });

    // Mark challenge as used
    await prisma.webAuthnChallenge.update({
      where: { id: challenge.id },
      data: { usedAt: new Date() }
    });

    return { success: true };
  } catch (error) {
    throw error;
  }
}
// Continued in WebAuthnService
@Service
public class WebAuthnAuthService {

    private final WebAuthnChallengeRepository challengeRepo;
    private final WebAuthnCredentialRepository credentialRepo;
    private final WebAuthnManager webAuthnManager = WebAuthnManager.createNonStrictWebAuthnManager();

    public WebAuthnAuthService(WebAuthnChallengeRepository challengeRepo,
                               WebAuthnCredentialRepository credentialRepo) {
        this.challengeRepo = challengeRepo;
        this.credentialRepo = credentialRepo;
    }

    // Step 1: Begin WebAuthn authentication
    public PublicKeyCredentialRequestOptions beginAuthentication(String userId) {
        List<WebAuthnCredential> credentials = credentialRepo.findByUserId(userId);
        if (credentials.isEmpty()) throw new RuntimeException("No security keys registered");

        byte[] challengeBytes = new byte[32];
        new SecureRandom().nextBytes(challengeBytes);
        String challengeB64 = Base64.getEncoder().encodeToString(challengeBytes);

        WebAuthnChallenge challenge = new WebAuthnChallenge();
        challenge.setUserId(userId);
        challenge.setChallenge(challengeB64);
        challenge.setExpiresAt(Instant.now().plusSeconds(10 * 60));
        challengeRepo.save(challenge);

        List<PublicKeyCredentialDescriptor> allowCredentials = credentials.stream()
            .map(c -> new PublicKeyCredentialDescriptor(
                PublicKeyCredentialType.PUBLIC_KEY,
                Base64.getDecoder().decode(c.getCredentialId()),
                null))
            .toList();

        return new PublicKeyCredentialRequestOptions(
            new DefaultChallenge(challengeBytes),
            60000L,
            System.getenv("WEBAUTHN_RP_ID"),
            allowCredentials,
            UserVerificationRequirement.PREFERRED,
            null
        );
    }

    // Step 2: Verify WebAuthn authentication
    @Transactional
    public void completeAuthentication(String userId, AuthenticationRequest authRequest) {
        WebAuthnChallenge stored = challengeRepo
            .findFirstByUserIdAndUsedAtNullAndExpiresAtAfterOrderByCreatedAtDesc(userId, Instant.now())
            .orElseThrow(() -> new RuntimeException("Challenge expired"));

        String credentialIdB64 = Base64.getEncoder().encodeToString(authRequest.getCredentialId());
        WebAuthnCredential webAuthnCredential = credentialRepo
            .findByUserIdAndCredentialId(userId, credentialIdB64)
            .orElseThrow(() -> new RuntimeException("Credential not found"));

        byte[] challengeBytes = Base64.getDecoder().decode(stored.getChallenge());
        Origin origin = new Origin(System.getenv("WEBAUTHN_ORIGIN"));
        String rpId = System.getenv("WEBAUTHN_RP_ID");

        AuthenticatorData authenticatorData = new AuthenticatorData(
            authRequest.getAuthenticatorData());
        ServerProperty serverProperty = new ServerProperty(
            origin, rpId, new DefaultChallenge(challengeBytes), null);

        AuthenticationParameters params = new AuthenticationParameters(
            serverProperty,
            new AttestedCredentialData(
                null,
                authRequest.getCredentialId(),
                COSEKeyUtil.parse(Base64.getDecoder().decode(webAuthnCredential.getPublicKey()))
            ),
            webAuthnCredential.getCounter(),
            true, true
        );

        webAuthnManager.validate(new AuthenticationData(
            authRequest.getCredentialId(),
            authRequest.getUserHandle(),
            authenticatorData,
            authRequest.getClientDataJSON(),
            authRequest.getSignature()
        ), params);

        // Update counter to prevent cloned key attacks
        webAuthnCredential.setCounter(authenticatorData.getSignCount());
        credentialRepo.save(webAuthnCredential);

        stored.setUsedAt(Instant.now());
        challengeRepo.save(stored);
    }
}
def begin_webauthn_authentication(user_id: str, db: Session):
    from webauthn import generate_authentication_options
    from webauthn.helpers.structs import PublicKeyCredentialDescriptor

    credentials = db.query(WebAuthnCredential).filter(
        WebAuthnCredential.user_id == user_id
    ).all()
    if not credentials:
        raise ValueError("No security keys registered")

    allow_credentials = [
        PublicKeyCredentialDescriptor(id=base64.b64decode(c.credential_id))
        for c in credentials
    ]
    options = generate_authentication_options(
        rp_id=RP_ID,
        allow_credentials=allow_credentials,
    )
    challenge_b64 = base64.b64encode(options.challenge).decode()
    db.add(WebAuthnChallenge(
        user_id=user_id,
        challenge=challenge_b64,
        expires_at=datetime.now(timezone.utc) + timedelta(minutes=10),
    ))
    db.commit()
    return options


def complete_webauthn_authentication(
    user_id: str, credential, db: Session
):
    from webauthn import verify_authentication_response

    stored = (
        db.query(WebAuthnChallenge)
        .filter(
            WebAuthnChallenge.user_id == user_id,
            WebAuthnChallenge.used_at == None,
            WebAuthnChallenge.expires_at > datetime.now(timezone.utc),
        )
        .order_by(WebAuthnChallenge.created_at.desc())
        .first()
    )
    if not stored:
        raise ValueError("Challenge expired")

    credential_id_b64 = base64.b64encode(credential.raw_id).decode()
    webauthn_cred = db.query(WebAuthnCredential).filter(
        WebAuthnCredential.user_id == user_id,
        WebAuthnCredential.credential_id == credential_id_b64,
    ).first()
    if not webauthn_cred:
        raise ValueError("Credential not found")

    verification = verify_authentication_response(
        credential=credential,
        expected_challenge=base64.b64decode(stored.challenge),
        expected_rp_id=RP_ID,
        expected_origin=ORIGIN,
        credential_public_key=base64.b64decode(webauthn_cred.public_key),
        credential_current_sign_count=webauthn_cred.counter,
    )
    if not verification.verified:
        raise ValueError("Authentication failed")

    # Update counter to prevent cloned key attacks
    webauthn_cred.counter = verification.new_sign_count
    stored.used_at = datetime.now(timezone.utc)
    db.commit()
    return {"success": True}
// Continued in WebAuthnService (Fido2NetLib)
public async Task<AssertionOptions> BeginAuthenticationAsync(string userId)
{
    var credentials = await _db.WebAuthnCredentials
        .Where(c => c.UserId == userId)
        .ToListAsync();

    if (!credentials.Any())
        throw new InvalidOperationException("No security keys registered");

    var allowedCredentials = credentials
        .Select(c => new PublicKeyCredentialDescriptor(Convert.FromBase64String(c.CredentialId)))
        .ToList();

    var options = _fido2.GetAssertionOptions(allowedCredentials, UserVerificationRequirement.Preferred);

    _db.WebAuthnChallenges.Add(new WebAuthnChallenge
    {
        UserId = userId,
        Challenge = Convert.ToBase64String(options.Challenge),
        ExpiresAt = DateTime.UtcNow.AddMinutes(10)
    });
    await _db.SaveChangesAsync();
    return options;
}

public async Task CompleteAuthenticationAsync(string userId, AuthenticatorAssertionRawResponse assertionResponse)
{
    var stored = await _db.WebAuthnChallenges
        .Where(c => c.UserId == userId && c.UsedAt == null && c.ExpiresAt > DateTime.UtcNow)
        .OrderByDescending(c => c.CreatedAt)
        .FirstOrDefaultAsync()
        ?? throw new InvalidOperationException("Challenge expired");

    var credentialIdB64 = Convert.ToBase64String(assertionResponse.Id);
    var webAuthnCred = await _db.WebAuthnCredentials
        .Where(c => c.UserId == userId && c.CredentialId == credentialIdB64)
        .FirstOrDefaultAsync()
        ?? throw new InvalidOperationException("Credential not found");

    var options = AssertionOptions.FromJson($"{{\"challenge\":\"{stored.Challenge}\"}}");

    IsUserHandleOwnerOfCredentialIdAsync callback = async (args, ct) =>
        await _db.WebAuthnCredentials.AnyAsync(
            c => c.CredentialId == Convert.ToBase64String(args.CredentialId)
              && c.UserId == userId);

    var result = await _fido2.MakeAssertionAsync(
        assertionResponse, options,
        Convert.FromBase64String(webAuthnCred.PublicKey),
        webAuthnCred.Counter, callback);

    // Update counter to prevent cloned key attacks
    webAuthnCred.Counter = result.Counter;
    stored.UsedAt = DateTime.UtcNow;
    await _db.SaveChangesAsync();
}

Express.js endpoints สำหรับ WebAuthn

app.post('/api/mfa/webauthn/register/begin', async (req: Request, res: Response) => {
  try {
    const userId = req.user?.id;
    if (!userId) {
      return res.status(401).json({ error: 'Unauthorized' });
    }

    const options = await beginWebAuthnRegistration(userId, req.user.email);
    res.json(options);
  } catch (error) {
    res.status(500).json({ error: 'Failed to initiate registration' });
  }
});

app.post('/api/mfa/webauthn/register/complete', async (req: Request, res: Response) => {
  try {
    const userId = req.user?.id;
    const { credential } = req.body;

    if (!userId || !credential) {
      return res.status(400).json({ error: 'Missing required fields' });
    }

    await completeWebAuthnRegistration(userId, credential);
    res.json({ success: true, message: 'Security key registered successfully' });
  } catch (error) {
    res.status(400).json({ error: 'Registration failed' });
  }
});

app.post('/api/auth/webauthn/begin', async (req: Request, res: Response) => {
  try {
    const { userId } = req.body;
    const options = await beginWebAuthnAuthentication(userId);
    res.json(options);
  } catch (error) {
    res.status(400).json({ error: 'Authentication initiation failed' });
  }
});

app.post('/api/auth/webauthn/complete', async (req: Request, res: Response) => {
  try {
    const { userId, credential } = req.body;
    await completeWebAuthnAuthentication(userId, credential);

    const session = await createSession(userId);
    res.json({ sessionToken: session.token });
  } catch (error) {
    res.status(401).json({ error: 'Authentication failed' });
  }
});
@RestController
@RequestMapping("/api")
public class WebAuthnController {

    private final WebAuthnService webAuthnService;
    private final WebAuthnAuthService webAuthnAuthService;
    private final SessionService sessionService;

    public WebAuthnController(WebAuthnService webAuthnService,
                              WebAuthnAuthService webAuthnAuthService,
                              SessionService sessionService) {
        this.webAuthnService = webAuthnService;
        this.webAuthnAuthService = webAuthnAuthService;
        this.sessionService = sessionService;
    }

    @PostMapping("/mfa/webauthn/register/begin")
    @Authorize
    public ResponseEntity<?> beginRegistration(@AuthenticationPrincipal UserPrincipal principal) {
        try {
            var options = webAuthnService.beginRegistration(principal.getId(), principal.getEmail());
            return ResponseEntity.ok(options);
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "Failed to initiate registration"));
        }
    }

    @PostMapping("/mfa/webauthn/register/complete")
    @Authorize
    public ResponseEntity<?> completeRegistration(
            @AuthenticationPrincipal UserPrincipal principal,
            @RequestBody RegistrationRequest registrationRequest) {
        try {
            webAuthnService.completeRegistration(principal.getId(), registrationRequest);
            return ResponseEntity.ok(Map.of("success", true, "message", "Security key registered successfully"));
        } catch (Exception e) {
            return ResponseEntity.status(400).body(Map.of("error", "Registration failed"));
        }
    }

    @PostMapping("/auth/webauthn/begin")
    public ResponseEntity<?> beginAuthentication(@RequestBody Map<String, String> body) {
        try {
            var options = webAuthnAuthService.beginAuthentication(body.get("userId"));
            return ResponseEntity.ok(options);
        } catch (Exception e) {
            return ResponseEntity.status(400).body(Map.of("error", "Authentication initiation failed"));
        }
    }

    @PostMapping("/auth/webauthn/complete")
    public ResponseEntity<?> completeAuthentication(@RequestBody Map<String, Object> body) {
        try {
            String userId = (String) body.get("userId");
            webAuthnAuthService.completeAuthentication(userId, (AuthenticationRequest) body.get("credential"));
            var session = sessionService.createSession(userId);
            return ResponseEntity.ok(Map.of("sessionToken", session.getToken()));
        } catch (Exception e) {
            return ResponseEntity.status(401).body(Map.of("error", "Authentication failed"));
        }
    }
}
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from database import get_db
from auth import get_current_user

router = APIRouter()

@router.post("/mfa/webauthn/register/begin")
async def webauthn_register_begin(
    current_user=Depends(get_current_user),
    db: Session = Depends(get_db),
):
    try:
        options = begin_webauthn_registration(current_user.id, current_user.email, db)
        return options
    except Exception:
        raise HTTPException(status_code=500, detail="Failed to initiate registration")

@router.post("/mfa/webauthn/register/complete")
async def webauthn_register_complete(
    credential,
    current_user=Depends(get_current_user),
    db: Session = Depends(get_db),
):
    try:
        complete_webauthn_registration(current_user.id, credential, db)
        return {"success": True, "message": "Security key registered successfully"}
    except Exception:
        raise HTTPException(status_code=400, detail="Registration failed")

@router.post("/auth/webauthn/begin")
async def webauthn_auth_begin(body: dict, db: Session = Depends(get_db)):
    try:
        options = begin_webauthn_authentication(body["user_id"], db)
        return options
    except Exception:
        raise HTTPException(status_code=400, detail="Authentication initiation failed")

@router.post("/auth/webauthn/complete")
async def webauthn_auth_complete(body: dict, db: Session = Depends(get_db)):
    try:
        complete_webauthn_authentication(body["user_id"], body["credential"], db)
        session = create_session(body["user_id"])
        return {"session_token": session.token}
    except Exception:
        raise HTTPException(status_code=401, detail="Authentication failed")
[ApiController]
[Route("api")]
public class WebAuthnController : ControllerBase
{
    private readonly WebAuthnService _webAuthnService;
    private readonly SessionService _sessionService;

    public WebAuthnController(WebAuthnService webAuthnService, SessionService sessionService)
    {
        _webAuthnService = webAuthnService;
        _sessionService = sessionService;
    }

    [HttpPost("mfa/webauthn/register/begin")]
    [Authorize]
    public async Task<IActionResult> BeginRegistration()
    {
        var userId = User.FindFirstValue(ClaimTypes.NameIdentifier)!;
        var email = User.FindFirstValue(ClaimTypes.Email)!;
        var options = await _webAuthnService.BeginRegistrationAsync(userId, email);
        return Ok(options);
    }

    [HttpPost("mfa/webauthn/register/complete")]
    [Authorize]
    public async Task<IActionResult> CompleteRegistration(
        [FromBody] AuthenticatorAttestationRawResponse attestationResponse)
    {
        var userId = User.FindFirstValue(ClaimTypes.NameIdentifier)!;
        try
        {
            await _webAuthnService.CompleteRegistrationAsync(userId, attestationResponse);
            return Ok(new { success = true, message = "Security key registered successfully" });
        }
        catch
        {
            return BadRequest(new { error = "Registration failed" });
        }
    }

    [HttpPost("auth/webauthn/begin")]
    public async Task<IActionResult> BeginAuthentication([FromBody] BeginAuthRequest request)
    {
        try
        {
            var options = await _webAuthnService.BeginAuthenticationAsync(request.UserId);
            return Ok(options);
        }
        catch
        {
            return BadRequest(new { error = "Authentication initiation failed" });
        }
    }

    [HttpPost("auth/webauthn/complete")]
    public async Task<IActionResult> CompleteAuthentication(
        [FromBody] WebAuthnCompleteRequest request)
    {
        try
        {
            await _webAuthnService.CompleteAuthenticationAsync(request.UserId, request.Credential);
            var session = await _sessionService.CreateSessionAsync(request.UserId);
            return Ok(new { sessionToken = session.Token });
        }
        catch
        {
            return Unauthorized(new { error = "Authentication failed" });
        }
    }
}

Adaptive/Risk-Based MFA

ไม่ใช่ทุก login ที่ต้องใช้ระดับการตรวจสอบเดียวกัน Adaptive MFA เพิ่มเติมแรงฝั่งเฉพาะเมื่อจำเป็น

ระบบการให้คะแนนความเสี่ยง

interface RiskFactors {
  newDeviceId: boolean;
  unusualLocation: boolean;
  unusualTime: boolean;
  travelImpossible: boolean;
  ipReputation: number;
  failedAttempts: number;
}

async function calculateRiskScore(userId: string, context: {
  ipAddress: string;
  userAgent: string;
  timestamp: Date;
}): Promise<{ score: number; factors: RiskFactors }> {
  const user = await prisma.user.findUnique({ where: { id: userId } });
  if (!user) throw new Error('User not found');

  let score = 0;
  const factors: RiskFactors = {
    newDeviceId: false,
    unusualLocation: false,
    unusualTime: false,
    travelImpossible: false,
    ipReputation: 0,
    failedAttempts: 0
  };

  // 1. Check device recognition
  const existingDevice = await prisma.device.findFirst({
    where: {
      userId,
      userAgent: context.userAgent
    }
  });

  if (!existingDevice) {
    score += 25;
    factors.newDeviceId = true;
  } else {
    // Update last seen
    await prisma.device.update({
      where: { id: existingDevice.id },
      data: { lastSeenAt: context.timestamp }
    });
  }

  // 2. Check location anomaly (using IP geolocation)
  const geoLocation = await getIPGeolocation(context.ipAddress);
  const lastLogin = await prisma.loginEvent.findFirst({
    where: { userId },
    orderBy: { timestamp: 'desc' },
    take: 1
  });

  if (lastLogin) {
    const distance = calculateDistance(
      { lat: lastLogin.latitude, lon: lastLogin.longitude },
      { lat: geoLocation.latitude, lon: geoLocation.longitude }
    );
    const timeDiff = (context.timestamp.getTime() - lastLogin.timestamp.getTime()) / 1000 / 3600; // hours

    // Impossible travel: distance > 900km/hour is suspicious
    const maxPossibleDistance = timeDiff * 900; // km/h
    if (distance > maxPossibleDistance) {
      score += 40;
      factors.travelImpossible = true;
    } else if (distance > 100 && timeDiff < 6) {
      score += 20;
      factors.unusualLocation = true;
    }
  }

  // 3. Check time-based anomaly
  const hour = context.timestamp.getHours();
  const lastLogins = await prisma.loginEvent.findMany({
    where: { userId },
    orderBy: { timestamp: 'desc' },
    take: 10
  });

  const typicalHours = lastLogins
    .map(l => new Date(l.timestamp).getHours())
    .filter(h => h >= 6 && h <= 23);

  if (typicalHours.length > 5 && !typicalHours.includes(hour)) {
    score += 15;
    factors.unusualTime = true;
  }

  // 4. Check IP reputation (integration with threat intelligence)
  const ipRep = await checkIPReputation(context.ipAddress);
  if (ipRep.riskScore > 50) {
    score += Math.min(30, ipRep.riskScore / 2);
    factors.ipReputation = ipRep.riskScore;
  }

  // 5. Check recent failed attempts
  const recentFailures = await prisma.failedLogin.count({
    where: {
      userId,
      timestamp: {
        gt: new Date(Date.now() - 15 * 60 * 1000) // Last 15 minutes
      }
    }
  });

  if (recentFailures > 0) {
    score += Math.min(30, recentFailures * 10);
    factors.failedAttempts = recentFailures;
  }

  return { score: Math.min(100, score), factors };
}

async function shouldRequireMFA(userId: string, riskScore: number): Promise<boolean> {
  const user = await prisma.user.findUnique({
    where: { id: userId },
    include: { mfaSettings: true }
  });

  if (!user?.mfaSettings) {
    return false;
  }

  // User's MFA threshold (e.g., "require MFA when risk > 30")
  return riskScore > (user.mfaSettings.riskThreshold || 30);
}
@Service
public class RiskScoringService {

    private final UserRepository userRepo;
    private final DeviceRepository deviceRepo;
    private final LoginEventRepository loginEventRepo;
    private final FailedLoginRepository failedLoginRepo;
    private final IpGeolocationService geoService;
    private final IpReputationService ipRepService;

    public RiskScoringService(UserRepository userRepo, DeviceRepository deviceRepo,
                              LoginEventRepository loginEventRepo,
                              FailedLoginRepository failedLoginRepo,
                              IpGeolocationService geoService,
                              IpReputationService ipRepService) {
        this.userRepo = userRepo;
        this.deviceRepo = deviceRepo;
        this.loginEventRepo = loginEventRepo;
        this.failedLoginRepo = failedLoginRepo;
        this.geoService = geoService;
        this.ipRepService = ipRepService;
    }

    @Transactional
    public RiskResult calculateRiskScore(String userId, String ipAddress, String userAgent, Instant timestamp) {
        User user = userRepo.findById(userId)
            .orElseThrow(() -> new RuntimeException("User not found"));

        int score = 0;
        RiskFactors factors = new RiskFactors();

        // 1. Check device recognition
        Optional<Device> existingDevice = deviceRepo.findByUserIdAndUserAgent(userId, userAgent);
        if (existingDevice.isEmpty()) {
            score += 25;
            factors.setNewDeviceId(true);
        } else {
            existingDevice.get().setLastSeenAt(timestamp);
            deviceRepo.save(existingDevice.get());
        }

        // 2. Check location anomaly
        GeoLocation geoLocation = geoService.lookup(ipAddress);
        Optional<LoginEvent> lastLogin = loginEventRepo
            .findFirstByUserIdOrderByTimestampDesc(userId);

        if (lastLogin.isPresent()) {
            double distance = calculateDistance(
                lastLogin.get().getLatitude(), lastLogin.get().getLongitude(),
                geoLocation.getLatitude(), geoLocation.getLongitude());
            double timeDiffHours = Duration.between(lastLogin.get().getTimestamp(), timestamp)
                .toMinutes() / 60.0;
            double maxPossibleDistance = timeDiffHours * 900;

            if (distance > maxPossibleDistance) {
                score += 40;
                factors.setTravelImpossible(true);
            } else if (distance > 100 && timeDiffHours < 6) {
                score += 20;
                factors.setUnusualLocation(true);
            }
        }

        // 3. Check time-based anomaly
        int hour = LocalTime.ofInstant(timestamp, ZoneOffset.UTC).getHour();
        List<Integer> typicalHours = loginEventRepo.findTop10ByUserIdOrderByTimestampDesc(userId)
            .stream()
            .map(e -> LocalTime.ofInstant(e.getTimestamp(), ZoneOffset.UTC).getHour())
            .filter(h -> h >= 6 && h <= 23)
            .toList();
        if (typicalHours.size() > 5 && !typicalHours.contains(hour)) {
            score += 15;
            factors.setUnusualTime(true);
        }

        // 4. Check IP reputation
        IpReputation ipRep = ipRepService.check(ipAddress);
        if (ipRep.getRiskScore() > 50) {
            score += Math.min(30, ipRep.getRiskScore() / 2);
            factors.setIpReputation(ipRep.getRiskScore());
        }

        // 5. Check recent failed attempts
        long recentFailures = failedLoginRepo.countByUserIdAndTimestampAfter(
            userId, Instant.now().minusSeconds(15 * 60));
        if (recentFailures > 0) {
            score += Math.min(30, (int)(recentFailures * 10));
            factors.setFailedAttempts((int) recentFailures);
        }

        return new RiskResult(Math.min(100, score), factors);
    }

    public boolean shouldRequireMFA(String userId, int riskScore) {
        return userRepo.findById(userId)
            .map(u -> u.getMfaSettings() != null
                && riskScore > (u.getMfaSettings().getRiskThreshold() != null
                    ? u.getMfaSettings().getRiskThreshold() : 30))
            .orElse(false);
    }
}
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from models import User, Device, LoginEvent, FailedLogin

@dataclass
class RiskFactors:
    new_device_id: bool = False
    unusual_location: bool = False
    unusual_time: bool = False
    travel_impossible: bool = False
    ip_reputation: int = 0
    failed_attempts: int = 0

@dataclass
class RiskResult:
    score: int
    factors: RiskFactors


def calculate_risk_score(
    user_id: str, ip_address: str, user_agent: str, timestamp: datetime, db: Session
) -> RiskResult:
    user = db.query(User).filter(User.id == user_id).first()
    if not user:
        raise ValueError("User not found")

    score = 0
    factors = RiskFactors()

    # 1. Check device recognition
    existing_device = db.query(Device).filter(
        Device.user_id == user_id, Device.user_agent == user_agent
    ).first()
    if not existing_device:
        score += 25
        factors.new_device_id = True
    else:
        existing_device.last_seen_at = timestamp
        db.commit()

    # 2. Check location anomaly
    geo = get_ip_geolocation(ip_address)
    last_login = (
        db.query(LoginEvent)
        .filter(LoginEvent.user_id == user_id)
        .order_by(LoginEvent.timestamp.desc())
        .first()
    )
    if last_login:
        distance = calculate_distance(
            (last_login.latitude, last_login.longitude),
            (geo["latitude"], geo["longitude"]),
        )
        time_diff_hours = (timestamp - last_login.timestamp).total_seconds() / 3600
        max_possible = time_diff_hours * 900
        if distance > max_possible:
            score += 40
            factors.travel_impossible = True
        elif distance > 100 and time_diff_hours < 6:
            score += 20
            factors.unusual_location = True

    # 3. Check time-based anomaly
    hour = timestamp.hour
    last_logins = (
        db.query(LoginEvent)
        .filter(LoginEvent.user_id == user_id)
        .order_by(LoginEvent.timestamp.desc())
        .limit(10)
        .all()
    )
    typical_hours = [l.timestamp.hour for l in last_logins if 6 <= l.timestamp.hour <= 23]
    if len(typical_hours) > 5 and hour not in typical_hours:
        score += 15
        factors.unusual_time = True

    # 4. Check IP reputation
    ip_rep = check_ip_reputation(ip_address)
    if ip_rep["risk_score"] > 50:
        score += min(30, ip_rep["risk_score"] // 2)
        factors.ip_reputation = ip_rep["risk_score"]

    # 5. Check recent failed attempts
    cutoff = datetime.now(timezone.utc) - timedelta(minutes=15)
    recent_failures = db.query(FailedLogin).filter(
        FailedLogin.user_id == user_id, FailedLogin.timestamp > cutoff
    ).count()
    if recent_failures > 0:
        score += min(30, recent_failures * 10)
        factors.failed_attempts = recent_failures

    return RiskResult(score=min(100, score), factors=factors)


def should_require_mfa(user_id: str, risk_score: int, db: Session) -> bool:
    user = db.query(User).filter(User.id == user_id).first()
    if not user or not user.mfa_settings:
        return False
    threshold = user.mfa_settings.risk_threshold or 30
    return risk_score > threshold
public class RiskScoringService
{
    private readonly AppDbContext _db;
    private readonly IIpGeolocationService _geoService;
    private readonly IIpReputationService _ipRepService;

    public RiskScoringService(AppDbContext db, IIpGeolocationService geoService,
                              IIpReputationService ipRepService)
    {
        _db = db;
        _geoService = geoService;
        _ipRepService = ipRepService;
    }

    public async Task<RiskResult> CalculateRiskScoreAsync(
        string userId, string ipAddress, string userAgent, DateTime timestamp)
    {
        var user = await _db.Users.FindAsync(userId)
            ?? throw new InvalidOperationException("User not found");

        int score = 0;
        var factors = new RiskFactors();

        // 1. Check device recognition
        var existingDevice = await _db.Devices
            .FirstOrDefaultAsync(d => d.UserId == userId && d.UserAgent == userAgent);
        if (existingDevice == null)
        {
            score += 25;
            factors.NewDeviceId = true;
        }
        else
        {
            existingDevice.LastSeenAt = timestamp;
            await _db.SaveChangesAsync();
        }

        // 2. Check location anomaly
        var geo = await _geoService.LookupAsync(ipAddress);
        var lastLogin = await _db.LoginEvents
            .Where(e => e.UserId == userId)
            .OrderByDescending(e => e.Timestamp)
            .FirstOrDefaultAsync();

        if (lastLogin != null)
        {
            double distance = CalculateDistance(
                lastLogin.Latitude, lastLogin.Longitude,
                geo.Latitude, geo.Longitude);
            double timeDiffHours = (timestamp - lastLogin.Timestamp).TotalHours;
            double maxPossible = timeDiffHours * 900;

            if (distance > maxPossible)
            {
                score += 40;
                factors.TravelImpossible = true;
            }
            else if (distance > 100 && timeDiffHours < 6)
            {
                score += 20;
                factors.UnusualLocation = true;
            }
        }

        // 3. Check time-based anomaly
        int hour = timestamp.Hour;
        var recentHours = await _db.LoginEvents
            .Where(e => e.UserId == userId)
            .OrderByDescending(e => e.Timestamp)
            .Take(10)
            .Select(e => e.Timestamp.Hour)
            .Where(h => h >= 6 && h <= 23)
            .ToListAsync();
        if (recentHours.Count > 5 && !recentHours.Contains(hour))
        {
            score += 15;
            factors.UnusualTime = true;
        }

        // 4. Check IP reputation
        var ipRep = await _ipRepService.CheckAsync(ipAddress);
        if (ipRep.RiskScore > 50)
        {
            score += Math.Min(30, ipRep.RiskScore / 2);
            factors.IpReputation = ipRep.RiskScore;
        }

        // 5. Check recent failed attempts
        var cutoff = DateTime.UtcNow.AddMinutes(-15);
        int recentFailures = await _db.FailedLogins
            .CountAsync(f => f.UserId == userId && f.Timestamp > cutoff);
        if (recentFailures > 0)
        {
            score += Math.Min(30, recentFailures * 10);
            factors.FailedAttempts = recentFailures;
        }

        return new RiskResult(Math.Min(100, score), factors);
    }

    public async Task<bool> ShouldRequireMfaAsync(string userId, int riskScore)
    {
        var user = await _db.Users
            .Include(u => u.MfaSettings)
            .FirstOrDefaultAsync(u => u.Id == userId);

        if (user?.MfaSettings == null) return false;
        return riskScore > (user.MfaSettings.RiskThreshold ?? 30);
    }
}

Adaptive MFA ในขั้นตอนการ Login

app.post('/api/auth/login', async (req: Request, res: Response) => {
  try {
    const { email, password, deviceFingerprint } = req.body;
    const ipAddress = req.ip || '';
    const userAgent = req.get('user-agent') || '';
    const timestamp = new Date();

    // Verify credentials
    const user = await authenticateUser(email, password);
    if (!user) {
      await prisma.failedLogin.create({
        data: { userId: '', ipAddress, timestamp }
      });
      return res.status(401).json({ error: 'Invalid credentials' });
    }

    // Calculate risk
    const { score: riskScore, factors } = await calculateRiskScore(user.id, {
      ipAddress,
      userAgent,
      timestamp
    });

    // Log login event
    const geoLocation = await getIPGeolocation(ipAddress);
    await prisma.loginEvent.create({
      data: {
        userId: user.id,
        ipAddress,
        latitude: geoLocation.latitude,
        longitude: geoLocation.longitude,
        riskScore,
        timestamp,
        factors: JSON.stringify(factors)
      }
    });

    // Determine if MFA is required
    const requireMFA = await shouldRequireMFA(user.id, riskScore);

    if (requireMFA) {
      // Generate temporary auth token for MFA challenge
      const tempToken = jwt.sign(
        { userId: user.id, type: 'mfa_challenge' },
        process.env.JWT_SECRET!,
        { expiresIn: '10m' }
      );

      return res.json({
        requiresMFA: true,
        mfaToken: tempToken,
        availableMethods: user.mfaMethods, // ['totp', 'webauthn', 'sms']
        riskFactors: factors
      });
    }

    // No MFA required, issue session
    const session = await createSession(user.id);
    res.json({ sessionToken: session.token });
  } catch (error) {
    res.status(500).json({ error: 'Login failed' });
  }
});
@RestController
@RequestMapping("/api/auth")
public class AdaptiveLoginController {

    private final AuthService authService;
    private final RiskScoringService riskScoringService;
    private final IpGeolocationService geoService;
    private final LoginEventRepository loginEventRepo;
    private final FailedLoginRepository failedLoginRepo;
    private final SessionService sessionService;
    private final JwtService jwtService;

    // constructor omitted for brevity

    @PostMapping("/login")
    public ResponseEntity<?> login(
            @RequestBody LoginRequest body,
            HttpServletRequest request) {
        try {
            String ipAddress = request.getRemoteAddr();
            String userAgent = request.getHeader("User-Agent");
            Instant timestamp = Instant.now();

            User user = authService.authenticateUser(body.getEmail(), body.getPassword());
            if (user == null) {
                failedLoginRepo.save(new FailedLogin("", ipAddress, timestamp));
                return ResponseEntity.status(401).body(Map.of("error", "Invalid credentials"));
            }

            RiskResult risk = riskScoringService.calculateRiskScore(
                user.getId(), ipAddress, userAgent, timestamp);

            GeoLocation geo = geoService.lookup(ipAddress);
            loginEventRepo.save(new LoginEvent(user.getId(), ipAddress,
                geo.getLatitude(), geo.getLongitude(), risk.getScore(), timestamp));

            boolean requireMfa = riskScoringService.shouldRequireMFA(user.getId(), risk.getScore());

            if (requireMfa) {
                String tempToken = jwtService.sign(
                    Map.of("userId", user.getId(), "type", "mfa_challenge"),
                    Duration.ofMinutes(10));
                return ResponseEntity.ok(Map.of(
                    "requiresMFA", true,
                    "mfaToken", tempToken,
                    "availableMethods", user.getMfaMethods(),
                    "riskFactors", risk.getFactors()
                ));
            }

            var session = sessionService.createSession(user.getId());
            return ResponseEntity.ok(Map.of("sessionToken", session.getToken()));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "Login failed"));
        }
    }
}
from fastapi import APIRouter, Request, Depends, HTTPException
from sqlalchemy.orm import Session
from pydantic import BaseModel
from database import get_db

router = APIRouter()

class LoginBody(BaseModel):
    email: str
    password: str

@router.post("/auth/login")
async def adaptive_login(body: LoginBody, request: Request, db: Session = Depends(get_db)):
    try:
        ip_address = request.client.host or ""
        user_agent = request.headers.get("user-agent", "")
        timestamp = datetime.now(timezone.utc)

        user = authenticate_user(body.email, body.password, db)
        if not user:
            db.add(FailedLogin(user_id="", ip_address=ip_address, timestamp=timestamp))
            db.commit()
            raise HTTPException(status_code=401, detail="Invalid credentials")

        risk = calculate_risk_score(user.id, ip_address, user_agent, timestamp, db)

        geo = get_ip_geolocation(ip_address)
        db.add(LoginEvent(
            user_id=user.id,
            ip_address=ip_address,
            latitude=geo["latitude"],
            longitude=geo["longitude"],
            risk_score=risk.score,
            timestamp=timestamp,
            factors=str(risk.factors),
        ))
        db.commit()

        require_mfa = should_require_mfa(user.id, risk.score, db)
        if require_mfa:
            temp_token = create_jwt(
                {"user_id": user.id, "type": "mfa_challenge"}, expires_minutes=10
            )
            return {
                "requires_mfa": True,
                "mfa_token": temp_token,
                "available_methods": user.mfa_methods,
                "risk_factors": risk.factors,
            }

        session = create_session(user.id)
        return {"session_token": session.token}
    except HTTPException:
        raise
    except Exception:
        raise HTTPException(status_code=500, detail="Login failed")
[ApiController]
[Route("api/auth")]
public class AdaptiveLoginController : ControllerBase
{
    private readonly IAuthService _authService;
    private readonly RiskScoringService _riskService;
    private readonly IIpGeolocationService _geoService;
    private readonly AppDbContext _db;
    private readonly SessionService _sessionService;
    private readonly IJwtService _jwtService;

    // constructor omitted for brevity

    [HttpPost("login")]
    public async Task<IActionResult> Login([FromBody] LoginRequest body)
    {
        try
        {
            var ipAddress = HttpContext.Connection.RemoteIpAddress?.ToString() ?? "";
            var userAgent = Request.Headers.UserAgent.ToString();
            var timestamp = DateTime.UtcNow;

            var user = await _authService.AuthenticateAsync(body.Email, body.Password);
            if (user == null)
            {
                _db.FailedLogins.Add(new FailedLogin { UserId = "", IpAddress = ipAddress, Timestamp = timestamp });
                await _db.SaveChangesAsync();
                return Unauthorized(new { error = "Invalid credentials" });
            }

            var risk = await _riskService.CalculateRiskScoreAsync(
                user.Id, ipAddress, userAgent, timestamp);

            var geo = await _geoService.LookupAsync(ipAddress);
            _db.LoginEvents.Add(new LoginEvent
            {
                UserId = user.Id, IpAddress = ipAddress,
                Latitude = geo.Latitude, Longitude = geo.Longitude,
                RiskScore = risk.Score, Timestamp = timestamp
            });
            await _db.SaveChangesAsync();

            bool requireMfa = await _riskService.ShouldRequireMfaAsync(user.Id, risk.Score);
            if (requireMfa)
            {
                var tempToken = _jwtService.Sign(
                    new { userId = user.Id, type = "mfa_challenge" },
                    TimeSpan.FromMinutes(10));
                return Ok(new
                {
                    requiresMFA = true,
                    mfaToken = tempToken,
                    availableMethods = user.MfaMethods,
                    riskFactors = risk.Factors
                });
            }

            var session = await _sessionService.CreateSessionAsync(user.Id);
            return Ok(new { sessionToken = session.Token });
        }
        catch
        {
            return StatusCode(500, new { error = "Login failed" });
        }
    }
}

การกู้คืนบัญชีเมื่ออุปกรณ์ MFA หลายไป

การสูญเสียการเข้าถึง MFA device คือสถานการณ์ nightmarish มีกระบวนการกู้คืนพร้อมหรือไม่

Recovery Code System

export async function initiateAccountRecovery(email: string) {
  const user = await prisma.user.findUnique({ where: { email } });
  if (!user) {
    // Don't reveal if user exists
    return { success: true };
  }

  const recoveryCode = crypto.randomBytes(32).toString('hex');
  const codeHash = crypto.createHash('sha256').update(recoveryCode).digest('hex');

  await prisma.recoveryToken.create({
    data: {
      userId: user.id,
      tokenHash: codeHash,
      expiresAt: new Date(Date.now() + 24 * 60 * 60 * 1000), // 24 hours
      usedAt: null
    }
  });

  // Send email with recovery link
  await sendRecoveryEmail(user.email, recoveryCode);

  return { success: true };
}

export async function verifyRecoveryCode(recoveryCode: string): Promise<string | null> {
  const codeHash = crypto.createHash('sha256').update(recoveryCode).digest('hex');

  const token = await prisma.recoveryToken.findFirst({
    where: {
      tokenHash: codeHash,
      expiresAt: { gt: new Date() },
      usedAt: null
    }
  });

  if (!token) {
    return null;
  }

  // Mark as used
  await prisma.recoveryToken.update({
    where: { id: token.id },
    data: { usedAt: new Date() }
  });

  return token.userId;
}

export async function resetMFAForUser(userId: string) {
  // Disable all MFA methods
  await prisma.user.update({
    where: { id: userId },
    data: {
      totpEnabled: false,
      totpSecret: null
    }
  });

  await prisma.webAuthnCredential.deleteMany({ where: { userId } });
  await prisma.backupCode.deleteMany({ where: { userId } });

  // Create new backup codes
  const newCodes = await generateBackupCodes(userId);

  // Log recovery event
  await prisma.auditLog.create({
    data: {
      userId,
      action: 'MFA_RESET_RECOVERY',
      severity: 'HIGH',
      timestamp: new Date()
    }
  });

  return newCodes;
}
@Service
public class AccountRecoveryService {

    private final UserRepository userRepo;
    private final RecoveryTokenRepository recoveryTokenRepo;
    private final WebAuthnCredentialRepository webAuthnRepo;
    private final BackupCodeRepository backupCodeRepo;
    private final AuditLogRepository auditLogRepo;
    private final BackupCodeService backupCodeService;
    private final EmailService emailService;

    // constructor omitted for brevity

    @Transactional
    public void initiateAccountRecovery(String email) {
        User user = userRepo.findByEmail(email).orElse(null);
        if (user == null) return; // Don't reveal if user exists

        byte[] recoveryBytes = new byte[32];
        new SecureRandom().nextBytes(recoveryBytes);
        String recoveryCode = HexFormat.of().formatHex(recoveryBytes);
        String codeHash = sha256Hex(recoveryCode);

        RecoveryToken token = new RecoveryToken();
        token.setUserId(user.getId());
        token.setTokenHash(codeHash);
        token.setExpiresAt(Instant.now().plusSeconds(24 * 3600));
        recoveryTokenRepo.save(token);

        emailService.sendRecoveryEmail(user.getEmail(), recoveryCode);
    }

    @Transactional
    public Optional<String> verifyRecoveryCode(String recoveryCode) {
        String codeHash = sha256Hex(recoveryCode);

        RecoveryToken token = recoveryTokenRepo
            .findFirstByTokenHashAndExpiresAtAfterAndUsedAtNull(codeHash, Instant.now())
            .orElse(null);
        if (token == null) return Optional.empty();

        token.setUsedAt(Instant.now());
        recoveryTokenRepo.save(token);
        return Optional.of(token.getUserId());
    }

    @Transactional
    public List<String> resetMFAForUser(String userId) {
        // Disable all MFA methods
        User user = userRepo.findById(userId)
            .orElseThrow(() -> new RuntimeException("User not found"));
        user.setTotpEnabled(false);
        user.setTotpSecret(null);
        userRepo.save(user);

        webAuthnRepo.deleteByUserId(userId);
        backupCodeRepo.deleteByUserId(userId);

        List<String> newCodes = backupCodeService.generateBackupCodes(userId, 10);

        AuditLog log = new AuditLog();
        log.setUserId(userId);
        log.setAction("MFA_RESET_RECOVERY");
        log.setSeverity("HIGH");
        log.setTimestamp(Instant.now());
        auditLogRepo.save(log);

        return newCodes;
    }

    private String sha256Hex(String input) {
        try {
            MessageDigest md = MessageDigest.getInstance("SHA-256");
            return HexFormat.of().formatHex(md.digest(input.getBytes(StandardCharsets.UTF_8)));
        } catch (NoSuchAlgorithmException e) { throw new RuntimeException(e); }
    }
}
import secrets, hashlib
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from models import User, RecoveryToken, WebAuthnCredential, BackupCode, AuditLog

def initiate_account_recovery(email: str, db: Session) -> dict:
    user = db.query(User).filter(User.email == email).first()
    if not user:
        return {"success": True}  # Don't reveal if user exists

    recovery_code = secrets.token_hex(32)
    code_hash = hashlib.sha256(recovery_code.encode()).hexdigest()

    db.add(RecoveryToken(
        user_id=user.id,
        token_hash=code_hash,
        expires_at=datetime.now(timezone.utc) + timedelta(hours=24),
        used_at=None,
    ))
    db.commit()

    send_recovery_email(user.email, recovery_code)
    return {"success": True}


def verify_recovery_code(recovery_code: str, db: Session) -> str | None:
    code_hash = hashlib.sha256(recovery_code.encode()).hexdigest()

    token = (
        db.query(RecoveryToken)
        .filter(
            RecoveryToken.token_hash == code_hash,
            RecoveryToken.expires_at > datetime.now(timezone.utc),
            RecoveryToken.used_at == None,
        )
        .first()
    )
    if not token:
        return None

    token.used_at = datetime.now(timezone.utc)
    db.commit()
    return token.user_id


def reset_mfa_for_user(user_id: str, db: Session) -> list[str]:
    # Disable all MFA methods
    user = db.query(User).filter(User.id == user_id).first()
    user.totp_enabled = False
    user.totp_secret = None

    db.query(WebAuthnCredential).filter(WebAuthnCredential.user_id == user_id).delete()
    db.query(BackupCode).filter(BackupCode.user_id == user_id).delete()

    new_codes = generate_backup_codes(user_id, db)

    db.add(AuditLog(
        user_id=user_id,
        action="MFA_RESET_RECOVERY",
        severity="HIGH",
        timestamp=datetime.now(timezone.utc),
    ))
    db.commit()
    return new_codes
public class AccountRecoveryService
{
    private readonly AppDbContext _db;
    private readonly BackupCodeService _backupCodeService;
    private readonly IEmailService _emailService;

    public AccountRecoveryService(AppDbContext db, BackupCodeService backupCodeService,
                                  IEmailService emailService)
    {
        _db = db;
        _backupCodeService = backupCodeService;
        _emailService = emailService;
    }

    public async Task InitiateAccountRecoveryAsync(string email)
    {
        var user = await _db.Users.FirstOrDefaultAsync(u => u.Email == email);
        if (user == null) return; // Don't reveal if user exists

        var recoveryBytes = RandomNumberGenerator.GetBytes(32);
        var recoveryCode = Convert.ToHexString(recoveryBytes).ToLower();
        var codeHash = Sha256Hex(recoveryCode);

        _db.RecoveryTokens.Add(new RecoveryToken
        {
            UserId = user.Id,
            TokenHash = codeHash,
            ExpiresAt = DateTime.UtcNow.AddHours(24),
            UsedAt = null
        });
        await _db.SaveChangesAsync();
        await _emailService.SendRecoveryEmailAsync(user.Email, recoveryCode);
    }

    public async Task<string?> VerifyRecoveryCodeAsync(string recoveryCode)
    {
        var codeHash = Sha256Hex(recoveryCode);

        var token = await _db.RecoveryTokens
            .Where(t => t.TokenHash == codeHash
                     && t.ExpiresAt > DateTime.UtcNow
                     && t.UsedAt == null)
            .FirstOrDefaultAsync();

        if (token == null) return null;

        token.UsedAt = DateTime.UtcNow;
        await _db.SaveChangesAsync();
        return token.UserId;
    }

    public async Task<List<string>> ResetMfaForUserAsync(string userId)
    {
        var user = await _db.Users.FindAsync(userId)
            ?? throw new InvalidOperationException("User not found");

        user.TotpEnabled = false;
        user.TotpSecret = null;

        _db.WebAuthnCredentials.RemoveRange(
            _db.WebAuthnCredentials.Where(c => c.UserId == userId));
        _db.BackupCodes.RemoveRange(
            _db.BackupCodes.Where(c => c.UserId == userId));

        var newCodes = await _backupCodeService.GenerateBackupCodesAsync(userId);

        _db.AuditLogs.Add(new AuditLog
        {
            UserId = userId,
            Action = "MFA_RESET_RECOVERY",
            Severity = "HIGH",
            Timestamp = DateTime.UtcNow
        });
        await _db.SaveChangesAsync();
        return newCodes;
    }

    private static string Sha256Hex(string input)
    {
        var hash = SHA256.HashData(Encoding.UTF8.GetBytes(input));
        return Convert.ToHexString(hash).ToLower();
    }
}

Recovery Endpoint

app.post('/api/auth/recovery/initiate', async (req: Request, res: Response) => {
  try {
    const { email } = req.body;
    await initiateAccountRecovery(email);

    res.json({
      success: true,
      message: 'Recovery instructions sent to email'
    });
  } catch (error) {
    res.status(500).json({ error: 'Recovery initiation failed' });
  }
});

app.post('/api/auth/recovery/verify', async (req: Request, res: Response) => {
  try {
    const { code } = req.body;
    const userId = await verifyRecoveryCode(code);

    if (!userId) {
      return res.status(400).json({ error: 'Invalid or expired recovery code' });
    }

    // Reset MFA and get new backup codes
    const newCodes = await resetMFAForUser(userId);

    res.json({
      success: true,
      message: 'MFA has been reset',
      backupCodes: newCodes,
      warning: 'Save these codes immediately. They are only shown once.'
    });
  } catch (error) {
    res.status(500).json({ error: 'Recovery verification failed' });
  }
});
@RestController
@RequestMapping("/api/auth/recovery")
public class RecoveryController {

    private final AccountRecoveryService recoveryService;

    public RecoveryController(AccountRecoveryService recoveryService) {
        this.recoveryService = recoveryService;
    }

    @PostMapping("/initiate")
    public ResponseEntity<?> initiateRecovery(@RequestBody Map<String, String> body) {
        try {
            recoveryService.initiateAccountRecovery(body.get("email"));
            return ResponseEntity.ok(Map.of(
                "success", true,
                "message", "Recovery instructions sent to email"
            ));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "Recovery initiation failed"));
        }
    }

    @PostMapping("/verify")
    public ResponseEntity<?> verifyRecovery(@RequestBody Map<String, String> body) {
        try {
            Optional<String> userId = recoveryService.verifyRecoveryCode(body.get("code"));
            if (userId.isEmpty()) {
                return ResponseEntity.status(400).body(Map.of("error", "Invalid or expired recovery code"));
            }
            List<String> newCodes = recoveryService.resetMFAForUser(userId.get());
            return ResponseEntity.ok(Map.of(
                "success", true,
                "message", "MFA has been reset",
                "backupCodes", newCodes,
                "warning", "Save these codes immediately. They are only shown once."
            ));
        } catch (Exception e) {
            return ResponseEntity.status(500).body(Map.of("error", "Recovery verification failed"));
        }
    }
}
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from pydantic import BaseModel
from database import get_db

router = APIRouter(prefix="/api/auth/recovery")

class InitiateRecoveryBody(BaseModel):
    email: str

class VerifyRecoveryBody(BaseModel):
    code: str

@router.post("/initiate")
async def initiate_recovery(body: InitiateRecoveryBody, db: Session = Depends(get_db)):
    initiate_account_recovery(body.email, db)
    return {"success": True, "message": "Recovery instructions sent to email"}

@router.post("/verify")
async def verify_recovery(body: VerifyRecoveryBody, db: Session = Depends(get_db)):
    user_id = verify_recovery_code(body.code, db)
    if not user_id:
        raise HTTPException(status_code=400, detail="Invalid or expired recovery code")

    new_codes = reset_mfa_for_user(user_id, db)
    return {
        "success": True,
        "message": "MFA has been reset",
        "backup_codes": new_codes,
        "warning": "Save these codes immediately. They are only shown once.",
    }
[ApiController]
[Route("api/auth/recovery")]
public class RecoveryController : ControllerBase
{
    private readonly AccountRecoveryService _recoveryService;

    public RecoveryController(AccountRecoveryService recoveryService) =>
        _recoveryService = recoveryService;

    [HttpPost("initiate")]
    public async Task<IActionResult> InitiateRecovery([FromBody] InitiateRecoveryRequest request)
    {
        await _recoveryService.InitiateAccountRecoveryAsync(request.Email);
        return Ok(new { success = true, message = "Recovery instructions sent to email" });
    }

    [HttpPost("verify")]
    public async Task<IActionResult> VerifyRecovery([FromBody] VerifyRecoveryRequest request)
    {
        var userId = await _recoveryService.VerifyRecoveryCodeAsync(request.Code);
        if (userId == null)
            return BadRequest(new { error = "Invalid or expired recovery code" });

        var newCodes = await _recoveryService.ResetMfaForUserAsync(userId);
        return Ok(new
        {
            success = true,
            message = "MFA has been reset",
            backupCodes = newCodes,
            warning = "Save these codes immediately. They are only shown once."
        });
    }
}

การป้องกัน MFA Attacks

MFA Fatigue (Push Bombing)

ผู้โจมตีส่ง push notifications หลายสิบรายหวังว่าคุณจะยอมรับหนึ่งในความหงุดหงิด

// Rate limit push notifications
const pushNotificationLimiter = new RateLimiter({
  maxConcurrent: 1,
  minTime: 5000, // 5 seconds between pushes
  reservoir: 5,
  reservoirRefreshAmount: 5,
  reservoirRefreshInterval: 60 * 60 * 1000 // 5 per hour
});

export async function sendMFAPushNotification(userId: string, context: {
  ipAddress: string;
  deviceName: string;
}) {
  try {
    await pushNotificationLimiter.schedule(async () => {
      const notification = {
        userId,
        title: 'Login Request',
        body: `Login attempt from ${context.deviceName}`,
        action: `https://yourdomain.com/approve-login/${generateToken()}`,
        expiresAt: new Date(Date.now() + 5 * 60 * 1000),
        approved: false
      };

      await sendPushNotification(userId, notification);

      // Log for fatigue detection
      await prisma.mfaPushLog.create({
        data: {
          userId,
          sentAt: new Date(),
          ipAddress: context.ipAddress
        }
      });
    });
  } catch (error) {
    if (error instanceof Error && error.message.includes('Rate limit exceeded')) {
      // Alert: potential MFA fatigue attack
      await alertSecurityTeam({
        severity: 'HIGH',
        message: `MFA fatigue detected for user ${userId}`
      });
    }
  }
}

// Detect fatigue patterns
export async function detectMFAFatigue(userId: string): Promise<boolean> {
  const recentPushes = await prisma.mfaPushLog.count({
    where: {
      userId,
      sentAt: {
        gt: new Date(Date.now() - 15 * 60 * 1000) // Last 15 minutes
      }
    }
  });

  // Alert if more than 3 push attempts in 15 minutes
  if (recentPushes > 3) {
    await alertSecurityTeam({
      severity: 'HIGH',
      userId,
      message: `Potential MFA fatigue attack detected (${recentPushes} pushes)`
    });

    // Require additional verification
    return true;
  }

  return false;
}
@Service
public class MfaFatigueService {

    private final MfaPushLogRepository pushLogRepo;
    private final PushNotificationService pushService;
    private final SecurityAlertService alertService;
    private final Map<String, Bucket> rateLimitBuckets = new ConcurrentHashMap<>();

    public MfaFatigueService(MfaPushLogRepository pushLogRepo,
                             PushNotificationService pushService,
                             SecurityAlertService alertService) {
        this.pushLogRepo = pushLogRepo;
        this.pushService = pushService;
        this.alertService = alertService;
    }

    // Rate limiter: max 5 pushes per userId per hour, min 5s between pushes
    private Bucket getBucketForUser(String userId) {
        return rateLimitBuckets.computeIfAbsent(userId, k ->
            Bucket.builder()
                .addLimit(Bandwidth.simple(5, Duration.ofHours(1)))
                .build()
        );
    }

    @Transactional
    public void sendMfaPushNotification(String userId, String ipAddress, String deviceName) {
        Bucket bucket = getBucketForUser(userId);
        if (!bucket.tryConsume(1)) {
            alertService.alert("HIGH", "MFA fatigue detected for user " + userId);
            return;
        }

        MfaPushNotification notification = new MfaPushNotification(
            userId,
            "Login Request",
            "Login attempt from " + deviceName,
            "https://yourdomain.com/approve-login/" + generateToken(),
            Instant.now().plusSeconds(5 * 60),
            false
        );
        pushService.send(userId, notification);

        MfaPushLog log = new MfaPushLog();
        log.setUserId(userId);
        log.setSentAt(Instant.now());
        log.setIpAddress(ipAddress);
        pushLogRepo.save(log);
    }

    public boolean detectMfaFatigue(String userId) {
        long recentPushes = pushLogRepo.countByUserIdAndSentAtAfter(
            userId, Instant.now().minusSeconds(15 * 60));

        if (recentPushes > 3) {
            alertService.alert("HIGH",
                String.format("Potential MFA fatigue attack detected (%d pushes)", recentPushes));
            return true;
        }
        return false;
    }
}
from datetime import datetime, timedelta, timezone
from collections import defaultdict
import threading
from sqlalchemy.orm import Session
from models import MfaPushLog

_rate_limit_lock = threading.Lock()
_push_timestamps: dict[str, list[datetime]] = defaultdict(list)

def _check_push_rate_limit(user_id: str, max_count: int = 5, window_hours: int = 1) -> bool:
    now = datetime.now(timezone.utc)
    cutoff = now - timedelta(hours=window_hours)
    with _rate_limit_lock:
        times = [t for t in _push_timestamps[user_id] if t > cutoff]
        if len(times) >= max_count:
            return False
        times.append(now)
        _push_timestamps[user_id] = times
    return True


def send_mfa_push_notification(
    user_id: str, ip_address: str, device_name: str, db: Session
) -> None:
    if not _check_push_rate_limit(user_id):
        alert_security_team(severity="HIGH", message=f"MFA fatigue detected for user {user_id}")
        return

    notification = {
        "user_id": user_id,
        "title": "Login Request",
        "body": f"Login attempt from {device_name}",
        "action": f"https://yourdomain.com/approve-login/{generate_token()}",
        "expires_at": datetime.now(timezone.utc) + timedelta(minutes=5),
        "approved": False,
    }
    send_push_notification(user_id, notification)

    db.add(MfaPushLog(
        user_id=user_id,
        sent_at=datetime.now(timezone.utc),
        ip_address=ip_address,
    ))
    db.commit()


def detect_mfa_fatigue(user_id: str, db: Session) -> bool:
    cutoff = datetime.now(timezone.utc) - timedelta(minutes=15)
    recent_pushes = db.query(MfaPushLog).filter(
        MfaPushLog.user_id == user_id,
        MfaPushLog.sent_at > cutoff,
    ).count()

    if recent_pushes > 3:
        alert_security_team(
            severity="HIGH",
            user_id=user_id,
            message=f"Potential MFA fatigue attack detected ({recent_pushes} pushes)",
        )
        return True
    return False
public class MfaFatigueService
{
    private readonly AppDbContext _db;
    private readonly IPushNotificationService _pushService;
    private readonly ISecurityAlertService _alertService;
    private readonly ConcurrentDictionary<string, List<DateTime>> _rateLimits = new();

    public MfaFatigueService(AppDbContext db, IPushNotificationService pushService,
                             ISecurityAlertService alertService)
    {
        _db = db;
        _pushService = pushService;
        _alertService = alertService;
    }

    private bool CheckPushRateLimit(string userId, int maxCount = 5, int windowHours = 1)
    {
        var now = DateTime.UtcNow;
        var cutoff = now.AddHours(-windowHours);
        var times = _rateLimits.GetOrAdd(userId, _ => new List<DateTime>());
        lock (times)
        {
            times.RemoveAll(t => t < cutoff);
            if (times.Count >= maxCount) return false;
            times.Add(now);
        }
        return true;
    }

    public async Task SendMfaPushNotificationAsync(string userId, string ipAddress, string deviceName)
    {
        if (!CheckPushRateLimit(userId))
        {
            await _alertService.AlertAsync("HIGH", $"MFA fatigue detected for user {userId}");
            return;
        }

        var notification = new MfaPushNotification
        {
            UserId = userId,
            Title = "Login Request",
            Body = $"Login attempt from {deviceName}",
            Action = $"https://yourdomain.com/approve-login/{GenerateToken()}",
            ExpiresAt = DateTime.UtcNow.AddMinutes(5),
            Approved = false
        };
        await _pushService.SendAsync(userId, notification);

        _db.MfaPushLogs.Add(new MfaPushLog
        {
            UserId = userId,
            SentAt = DateTime.UtcNow,
            IpAddress = ipAddress
        });
        await _db.SaveChangesAsync();
    }

    public async Task<bool> DetectMfaFatigueAsync(string userId)
    {
        var cutoff = DateTime.UtcNow.AddMinutes(-15);
        int recentPushes = await _db.MfaPushLogs
            .CountAsync(l => l.UserId == userId && l.SentAt > cutoff);

        if (recentPushes > 3)
        {
            await _alertService.AlertAsync("HIGH",
                $"Potential MFA fatigue attack detected ({recentPushes} pushes)");
            return true;
        }
        return false;
    }
}

การป้องกัน Phishing ของ MFA Codes

export async function validateMFACodeContext(
  userId: string,
  code: string,
  context: { ipAddress: string; userAgent: string }
) {
  // Check if code was requested from same IP/device
  const latestChallenge = await prisma.mfaChallenge.findFirst({
    where: { userId },
    orderBy: { createdAt: 'desc' }
  });

  if (!latestChallenge) {
    throw new Error('No active MFA challenge');
  }

  // Check if code is being submitted from different IP
  if (latestChallenge.initiatorIP !== context.ipAddress) {
    await prisma.auditLog.create({
      data: {
        userId,
        action: 'MFA_CODE_FROM_DIFFERENT_IP',
        severity: 'HIGH',
        details: JSON.stringify({
          initiatedFrom: latestChallenge.initiatorIP,
          submittedFrom: context.ipAddress
        }),
        timestamp: new Date()
      }
    });

    // Reject and alert
    throw new Error('MFA code submitted from different location');
  }

  return true;
}
@Service
public class MfaPhishingProtectionService {

    private final MfaChallengeRepository challengeRepo;
    private final AuditLogRepository auditLogRepo;

    public MfaPhishingProtectionService(MfaChallengeRepository challengeRepo,
                                        AuditLogRepository auditLogRepo) {
        this.challengeRepo = challengeRepo;
        this.auditLogRepo = auditLogRepo;
    }

    public void validateMfaCodeContext(String userId, String ipAddress) {
        MfaChallenge latestChallenge = challengeRepo
            .findFirstByUserIdOrderByCreatedAtDesc(userId)
            .orElseThrow(() -> new RuntimeException("No active MFA challenge"));

        if (!latestChallenge.getInitiatorIP().equals(ipAddress)) {
            AuditLog log = new AuditLog();
            log.setUserId(userId);
            log.setAction("MFA_CODE_FROM_DIFFERENT_IP");
            log.setSeverity("HIGH");
            log.setDetails(String.format(
                "{\"initiatedFrom\":\"%s\",\"submittedFrom\":\"%s\"}",
                latestChallenge.getInitiatorIP(), ipAddress));
            log.setTimestamp(Instant.now());
            auditLogRepo.save(log);

            throw new RuntimeException("MFA code submitted from different location");
        }
    }
}
from sqlalchemy.orm import Session
from models import MfaChallenge, AuditLog
import json
from datetime import datetime, timezone

def validate_mfa_code_context(
    user_id: str, ip_address: str, db: Session
) -> bool:
    latest_challenge = (
        db.query(MfaChallenge)
        .filter(MfaChallenge.user_id == user_id)
        .order_by(MfaChallenge.created_at.desc())
        .first()
    )
    if not latest_challenge:
        raise ValueError("No active MFA challenge")

    if latest_challenge.initiator_ip != ip_address:
        db.add(AuditLog(
            user_id=user_id,
            action="MFA_CODE_FROM_DIFFERENT_IP",
            severity="HIGH",
            details=json.dumps({
                "initiated_from": latest_challenge.initiator_ip,
                "submitted_from": ip_address,
            }),
            timestamp=datetime.now(timezone.utc),
        ))
        db.commit()
        raise ValueError("MFA code submitted from different location")

    return True
public class MfaPhishingProtectionService
{
    private readonly AppDbContext _db;

    public MfaPhishingProtectionService(AppDbContext db) => _db = db;

    public async Task ValidateMfaCodeContextAsync(string userId, string ipAddress)
    {
        var latestChallenge = await _db.MfaChallenges
            .Where(c => c.UserId == userId)
            .OrderByDescending(c => c.CreatedAt)
            .FirstOrDefaultAsync()
            ?? throw new InvalidOperationException("No active MFA challenge");

        if (latestChallenge.InitiatorIP != ipAddress)
        {
            _db.AuditLogs.Add(new AuditLog
            {
                UserId = userId,
                Action = "MFA_CODE_FROM_DIFFERENT_IP",
                Severity = "HIGH",
                Details = System.Text.Json.JsonSerializer.Serialize(new
                {
                    initiatedFrom = latestChallenge.InitiatorIP,
                    submittedFrom = ipAddress
                }),
                Timestamp = DateTime.UtcNow
            });
            await _db.SaveChangesAsync();

            throw new InvalidOperationException("MFA code submitted from different location");
        }
    }
}

Production Checklist

  • Secret Management: TOTP secrets จัดเก็บแบบเข้ารหัสที่ rest ไม่เคยอยู่ในบันทึก
  • Timing-Safe Comparisons: ใช้ crypto.timingSafeEqual() สำหรับการตรวจสอบรหัส
  • Rate Limiting: จำนวน max attempts พร้อม exponential backoff เมื่อ failures
  • Audit Logging: ทุกการกระทำ MFA บันทึกด้วย timestamp, IP, device, outcome
  • Challenge Expiration: ความท้าทายชั่วคราวทั้งหมดหมดอายุหลัง 10-15 นาที
  • Secure Channel: ใช้ HTTPS เท่านั้น ไม่มี MFA codes ใน URLs
  • Device Tracking: รู้จัก trusted devices ต้องการ MFA น้อยลง
  • Recovery Paths: backup codes, recovery tokens, support procedures
  • Backup Code Rotation: สร้างรหัส fresh หลังจากการกู้คืนบัญชี
  • Security Key Attestation: ตรวจสอบความถูกต้องของ hardware key (optional แต่แนะนำ)
  • Counter Checks: WebAuthn credential counters ป้องกัน cloned keys
  • Session Binding: MFA-verified sessions ผูกกับ specific device/IP
  • Abuse Detection: แจ้งเตือนเมื่อ multiple failed attempts, impossible travel, unusual patterns
  • Incident Response: ระเบียบปฏิบัติที่ชัดเจนสำหรับ security key compromise, account takeover

ข้อพิจารณาด้านการปฏิบัติตามกฎหมาย

PCI-DSS (Payment Card Industry)

  • ต้องการ MFA สำหรับการเข้าถึง cardholder data ใดๆ
  • ต้องใช้ “strong cryptography” (แยกออก SMS-only)
  • แนะนำ TOTP หรือ hardware keys

HIPAA (Healthcare)

  • ต้องการ MFA สำหรับการเข้าถึง administrative
  • ต้องใช้ “approved” second factors
  • ต้องมี extensive audit logging

SOC2 (Service Organization Control)

  • MFA สำหรับผู้ดูแลระบบทั้งหมด
  • ทำการตรวจสอบและทดสอบ MFA mechanisms ตามปกติ
  • Incident logging สำหรับเหตุการณ์ที่เกี่ยวข้องกับ MFA

GDPR (General Data Protection Regulation)

  • MFA ลดความรับผิดชอบสำหรับ data breaches
  • เอกสาร MFA mechanisms ในเอกสารความเป็นส่วนตัว
  • Backup codes ต้องจัดเก็บ/ทำลายอย่างปลอดภัย

บทสรุป

การใช้งาน robust MFA ต้องการการจัดชั้นที่ตัดสินใจได้ดี: TOTP สำหรับความสะดวก hardware keys สำหรับผู้ใช้ที่รักษาความปลอดภัย SMS เป็น fallback และ adaptive MFA เพื่อเก็บแรงฝั่งต่ำ เสมอจัดเตรียมเส้นทางการกู้คืน—ระบบความปลอดภัยที่สมบูรณ์แบบที่ล็อคผู้ใช้ออกตลอดไปไม่ให้บริการใครเลย เป้าหมายคือการป้องกันความลึก: ทำให้การบุกรุก single factor ไม่เพียงพอ และคุณชนะครึ่งสนาม

Comments powered by Giscus are not yet configured. Set PUBLIC_GISCUS_REPO_ID and PUBLIC_GISCUS_CATEGORY_ID in apps/web/.env to enable.

PV

เขียนโดย พลากร วรมงคล

Software Engineer Specialist ประสบการณ์กว่า 20 ปี เขียนเกี่ยวกับ Architecture, Performance และการสร้างระบบ Production

เพิ่มเติมเกี่ยวกับผม

บทความที่เกี่ยวข้อง