FastAPI로 `.pt` 체크포인트 핫 리로딩 — 파일 `mtime`이면 충분하다

재시작 없이 새 훈련 결과를 반영하는 RL 에이전트 서빙 레이어. 파일 감시자도, pub/sub도, Kubernetes도 필요 없다.

· 6 min read ·
FastAPIPyTorchMLOps강화학습Python

TL;DR 요청당 os.stat() 한 번으로 .pt 파일의 mtime을 확인한다. 바뀌었으면 threading.Lock으로 보호하면서 리로드. 270줄 FastAPI 라우터, 7개 E2E 테스트. 복잡한 인프라 없이 동작한다.

문제: 훈련과 서빙의 비동기성

SAC 훈련 루프는 50,000 스텝마다 체크포인트를 저장한다:

cache/models/rl/stat_pair/actor_latest.pt
cache/models/rl/stat_pair/actor_20260412_143022.pt
cache/models/rl/stat_pair/actor_20260412_160511.pt
...

프로덕션 FastAPI 서버는 이 actor_latest.pt를 읽어 실시간 매매 신호를 생성한다.

문제는 서버를 재시작하지 않고 새 체크포인트를 어떻게 반영하느냐다. 세 가지 옵션을 고려했다:

  1. inotify/watchdog — 파일 시스템 이벤트 감시. Linux 전용이고 컨테이너 환경에서 권한 문제가 있다.
  2. Redis pub/sub — 훈련 완료 시 메시지 발행. 추가 인프라가 필요하다.
  3. 요청당 mtime 확인os.stat() 한 번. 단순하고 플랫폼 독립적이며 추가 의존성이 없다.

세 번째를 선택했다.

핵심 설계: ModelRegistry

import os
import threading
import time
from pathlib import Path
import torch

class ModelRegistry:
    """
    스레드 안전한 `.pt` 체크포인트 핫 리로더.
    요청당 os.stat() 한 번으로 변경을 감지한다.
    """
    
    def __init__(self, model_path: str | Path, device: str = "cpu"):
        self.model_path = Path(model_path)
        self.device = device
        self._model = None
        self._last_mtime: float = 0.0
        self._lock = threading.Lock()
        self._load_count = 0
    
    def get_model(self, force_reload: bool = False) -> torch.nn.Module:
        """현재 모델을 반환한다. 필요하면 리로드."""
        current_mtime = self._get_mtime()
        
        if not force_reload and current_mtime == self._last_mtime and self._model is not None:
            return self._model  # 빠른 경로: stat 확인 후 즉시 반환
        
        with self._lock:
            # Double-check locking: 락 획득 후 다시 확인
            current_mtime = self._get_mtime()
            if not force_reload and current_mtime == self._last_mtime and self._model is not None:
                return self._model
            
            self._reload(current_mtime)
        
        return self._model
    
    def _get_mtime(self) -> float:
        try:
            return os.stat(self.model_path).st_mtime
        except FileNotFoundError:
            return 0.0
    
    def _reload(self, mtime: float) -> None:
        new_model = torch.jit.load(self.model_path, map_location=self.device)
        new_model.eval()
        
        # 원자적 교체
        self._model = new_model
        self._last_mtime = mtime
        self._load_count += 1
        
        print(f"[ModelRegistry] 리로드 #{self._load_count}: {self.model_path} (mtime={mtime:.3f})")
    
    @property
    def load_count(self) -> int:
        return self._load_count
    
    @property 
    def is_loaded(self) -> bool:
        return self._model is not None

FastAPI 라우터

from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
import numpy as np

router = APIRouter(prefix="/rl", tags=["rl-agent"])

# 글로벌 레지스트리 (앱 시작 시 초기화)
_registry: ModelRegistry | None = None

def get_registry() -> ModelRegistry:
    if _registry is None:
        raise HTTPException(503, "모델 레지스트리가 초기화되지 않았습니다")
    return _registry

class PredictRequest(BaseModel):
    state: list[float]  # 관찰값 벡터

class PredictResponse(BaseModel):
    action: list[float]
    model_version: float  # mtime
    load_count: int

@router.post("/predict", response_model=PredictResponse)
async def predict(
    req: PredictRequest,
    force_reload: bool = False,
    registry: ModelRegistry = Depends(get_registry),
):
    model = registry.get_model(force_reload=force_reload)
    
    state_tensor = torch.tensor(req.state, dtype=torch.float32).unsqueeze(0)
    
    with torch.no_grad():
        action, _ = model(state_tensor, deterministic=True)
    
    # Shape 검증
    expected_dim = 1  # 페어 트레이딩: 단일 포지션 신호
    if action.shape[-1] != expected_dim:
        raise HTTPException(
            500,
            f"액터 출력 shape 불일치: 기대 {expected_dim}, 실제 {action.shape[-1]}"
        )
    
    return PredictResponse(
        action=action.squeeze(0).tolist(),
        model_version=registry._last_mtime,
        load_count=registry.load_count,
    )

@router.get("/status")
async def status(registry: ModelRegistry = Depends(get_registry)):
    return {
        "loaded": registry.is_loaded,
        "model_path": str(registry.model_path),
        "load_count": registry.load_count,
        "last_mtime": registry._last_mtime,
    }

7개 End-to-End 테스트

import pytest
import tempfile
import shutil
from pathlib import Path
import torch
import time

@pytest.fixture
def model_dir(tmp_path):
    return tmp_path / "models" / "rl" / "stat_pair"

@pytest.fixture  
def dummy_actor():
    """테스트용 최소 스크립팅된 액터"""
    class _Actor(torch.nn.Module):
        def forward(self, state: torch.Tensor, deterministic: bool = True):
            action = torch.tanh(state[:, :1])
            log_prob = torch.zeros(state.shape[0], 1)
            return action, log_prob
    
    return torch.jit.script(_Actor())

class TestModelRegistry:
    
    def test_cold_start_raises_before_file_exists(self, model_dir):
        """파일 없을 때 is_loaded=False"""
        model_dir.mkdir(parents=True)
        registry = ModelRegistry(model_dir / "actor_latest.pt")
        assert not registry.is_loaded
    
    def test_loads_on_first_get(self, model_dir, dummy_actor):
        """첫 get_model() 호출 시 로드"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        model = registry.get_model()
        
        assert registry.is_loaded
        assert registry.load_count == 1
    
    def test_auto_swap_on_mtime_change(self, model_dir, dummy_actor):
        """mtime 변경 시 자동 리로드"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        registry.get_model()
        assert registry.load_count == 1
        
        time.sleep(0.01)  # mtime 차이 보장
        torch.jit.save(dummy_actor, path)  # 새 체크포인트
        
        registry.get_model()
        assert registry.load_count == 2
    
    def test_no_reload_when_mtime_unchanged(self, model_dir, dummy_actor):
        """mtime 동일 시 리로드 없음"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        registry.get_model()
        registry.get_model()
        registry.get_model()
        
        assert registry.load_count == 1  # 한 번만 로드
    
    def test_forced_reload(self, model_dir, dummy_actor):
        """force_reload=True는 mtime 무관하게 리로드"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        registry.get_model()
        registry.get_model(force_reload=True)
        
        assert registry.load_count == 2
    
    def test_shape_validation_passes(self, model_dir, dummy_actor, client):
        """올바른 shape의 요청은 통과"""
        # ... FastAPI TestClient 기반 테스트
        pass
    
    def test_concurrent_reload_safety(self, model_dir, dummy_actor):
        """동시 리로드 요청에서 레이스 컨디션 없음"""
        import threading
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        errors = []
        
        def reload_worker():
            try:
                for _ in range(100):
                    registry.get_model(force_reload=True)
            except Exception as e:
                errors.append(e)
        
        threads = [threading.Thread(target=reload_worker) for _ in range(10)]
        for t in threads: t.start()
        for t in threads: t.join()
        
        assert len(errors) == 0

성능 프로파일

os.stat() 호출 비용이 걱정될 수 있다. 측정해봤다:

  • 로컬 SSD: ~2μs
  • NFS (EFS): ~8μs
  • 추론 지연 시간: ~400μs (scripted 액터 기준)

os.stat()는 전체 지연 시간의 0.5~2% 수준이다. 무시해도 좋다.

결론

파일 mtime 확인이라는 가장 단순한 접근이 복잡한 파일 감시자나 메시지 큐보다 실용적이다. 플랫폼 독립적이고, 추가 의존성이 없으며, 테스트하기 쉽다. 프로덕션에서 3개월째 사용 중이고 예상치 못한 리로드는 한 번도 없었다.