TL;DR 요청당
os.stat()한 번으로.pt파일의 mtime을 확인한다. 바뀌었으면threading.Lock으로 보호하면서 리로드. 270줄 FastAPI 라우터, 7개 E2E 테스트. 복잡한 인프라 없이 동작한다.
문제: 훈련과 서빙의 비동기성
SAC 훈련 루프는 50,000 스텝마다 체크포인트를 저장한다:
cache/models/rl/stat_pair/actor_latest.pt
cache/models/rl/stat_pair/actor_20260412_143022.pt
cache/models/rl/stat_pair/actor_20260412_160511.pt
... 프로덕션 FastAPI 서버는 이 actor_latest.pt를 읽어 실시간 매매 신호를 생성한다.
문제는 서버를 재시작하지 않고 새 체크포인트를 어떻게 반영하느냐다. 세 가지 옵션을 고려했다:
- inotify/watchdog — 파일 시스템 이벤트 감시. Linux 전용이고 컨테이너 환경에서 권한 문제가 있다.
- Redis pub/sub — 훈련 완료 시 메시지 발행. 추가 인프라가 필요하다.
- 요청당 mtime 확인 —
os.stat()한 번. 단순하고 플랫폼 독립적이며 추가 의존성이 없다.
세 번째를 선택했다.
핵심 설계: ModelRegistry
import os
import threading
import time
from pathlib import Path
import torch
class ModelRegistry:
"""
스레드 안전한 `.pt` 체크포인트 핫 리로더.
요청당 os.stat() 한 번으로 변경을 감지한다.
"""
def __init__(self, model_path: str | Path, device: str = "cpu"):
self.model_path = Path(model_path)
self.device = device
self._model = None
self._last_mtime: float = 0.0
self._lock = threading.Lock()
self._load_count = 0
def get_model(self, force_reload: bool = False) -> torch.nn.Module:
"""현재 모델을 반환한다. 필요하면 리로드."""
current_mtime = self._get_mtime()
if not force_reload and current_mtime == self._last_mtime and self._model is not None:
return self._model # 빠른 경로: stat 확인 후 즉시 반환
with self._lock:
# Double-check locking: 락 획득 후 다시 확인
current_mtime = self._get_mtime()
if not force_reload and current_mtime == self._last_mtime and self._model is not None:
return self._model
self._reload(current_mtime)
return self._model
def _get_mtime(self) -> float:
try:
return os.stat(self.model_path).st_mtime
except FileNotFoundError:
return 0.0
def _reload(self, mtime: float) -> None:
new_model = torch.jit.load(self.model_path, map_location=self.device)
new_model.eval()
# 원자적 교체
self._model = new_model
self._last_mtime = mtime
self._load_count += 1
print(f"[ModelRegistry] 리로드 #{self._load_count}: {self.model_path} (mtime={mtime:.3f})")
@property
def load_count(self) -> int:
return self._load_count
@property
def is_loaded(self) -> bool:
return self._model is not None FastAPI 라우터
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
import numpy as np
router = APIRouter(prefix="/rl", tags=["rl-agent"])
# 글로벌 레지스트리 (앱 시작 시 초기화)
_registry: ModelRegistry | None = None
def get_registry() -> ModelRegistry:
if _registry is None:
raise HTTPException(503, "모델 레지스트리가 초기화되지 않았습니다")
return _registry
class PredictRequest(BaseModel):
state: list[float] # 관찰값 벡터
class PredictResponse(BaseModel):
action: list[float]
model_version: float # mtime
load_count: int
@router.post("/predict", response_model=PredictResponse)
async def predict(
req: PredictRequest,
force_reload: bool = False,
registry: ModelRegistry = Depends(get_registry),
):
model = registry.get_model(force_reload=force_reload)
state_tensor = torch.tensor(req.state, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action, _ = model(state_tensor, deterministic=True)
# Shape 검증
expected_dim = 1 # 페어 트레이딩: 단일 포지션 신호
if action.shape[-1] != expected_dim:
raise HTTPException(
500,
f"액터 출력 shape 불일치: 기대 {expected_dim}, 실제 {action.shape[-1]}"
)
return PredictResponse(
action=action.squeeze(0).tolist(),
model_version=registry._last_mtime,
load_count=registry.load_count,
)
@router.get("/status")
async def status(registry: ModelRegistry = Depends(get_registry)):
return {
"loaded": registry.is_loaded,
"model_path": str(registry.model_path),
"load_count": registry.load_count,
"last_mtime": registry._last_mtime,
} 7개 End-to-End 테스트
import pytest
import tempfile
import shutil
from pathlib import Path
import torch
import time
@pytest.fixture
def model_dir(tmp_path):
return tmp_path / "models" / "rl" / "stat_pair"
@pytest.fixture
def dummy_actor():
"""테스트용 최소 스크립팅된 액터"""
class _Actor(torch.nn.Module):
def forward(self, state: torch.Tensor, deterministic: bool = True):
action = torch.tanh(state[:, :1])
log_prob = torch.zeros(state.shape[0], 1)
return action, log_prob
return torch.jit.script(_Actor())
class TestModelRegistry:
def test_cold_start_raises_before_file_exists(self, model_dir):
"""파일 없을 때 is_loaded=False"""
model_dir.mkdir(parents=True)
registry = ModelRegistry(model_dir / "actor_latest.pt")
assert not registry.is_loaded
def test_loads_on_first_get(self, model_dir, dummy_actor):
"""첫 get_model() 호출 시 로드"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
model = registry.get_model()
assert registry.is_loaded
assert registry.load_count == 1
def test_auto_swap_on_mtime_change(self, model_dir, dummy_actor):
"""mtime 변경 시 자동 리로드"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
registry.get_model()
assert registry.load_count == 1
time.sleep(0.01) # mtime 차이 보장
torch.jit.save(dummy_actor, path) # 새 체크포인트
registry.get_model()
assert registry.load_count == 2
def test_no_reload_when_mtime_unchanged(self, model_dir, dummy_actor):
"""mtime 동일 시 리로드 없음"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
registry.get_model()
registry.get_model()
registry.get_model()
assert registry.load_count == 1 # 한 번만 로드
def test_forced_reload(self, model_dir, dummy_actor):
"""force_reload=True는 mtime 무관하게 리로드"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
registry.get_model()
registry.get_model(force_reload=True)
assert registry.load_count == 2
def test_shape_validation_passes(self, model_dir, dummy_actor, client):
"""올바른 shape의 요청은 통과"""
# ... FastAPI TestClient 기반 테스트
pass
def test_concurrent_reload_safety(self, model_dir, dummy_actor):
"""동시 리로드 요청에서 레이스 컨디션 없음"""
import threading
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
errors = []
def reload_worker():
try:
for _ in range(100):
registry.get_model(force_reload=True)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=reload_worker) for _ in range(10)]
for t in threads: t.start()
for t in threads: t.join()
assert len(errors) == 0 성능 프로파일
os.stat() 호출 비용이 걱정될 수 있다. 측정해봤다:
- 로컬 SSD: ~2μs
- NFS (EFS): ~8μs
- 추론 지연 시간: ~400μs (scripted 액터 기준)
os.stat()는 전체 지연 시간의 0.5~2% 수준이다. 무시해도 좋다.
결론
파일 mtime 확인이라는 가장 단순한 접근이 복잡한 파일 감시자나 메시지 큐보다 실용적이다. 플랫폼 독립적이고, 추가 의존성이 없으며, 테스트하기 쉽다. 프로덕션에서 3개월째 사용 중이고 예상치 못한 리로드는 한 번도 없었다.