FastAPI で `.pt` チェックポイントをホットリロード — ファイル `mtime` だけで十分

再起動なしに新しい訓練結果を反映する RL エージェントのサービングレイヤー。ファイルウォッチャーも pub/sub も Kubernetes も不要。

· 3 min read ·
FastAPIPyTorchMLOps強化学習Python

TL;DR リクエストごとに os.stat() 1 回で .pt ファイルの mtime を確認する。変わっていれば threading.Lock で保護しながらリロード。270 行の FastAPI ルーター、7 つの E2E テスト。複雑なインフラなしで動作する。

問題: 訓練とサービングの非同期性

SAC 訓練ループは 50,000 ステップごとにチェックポイントを保存する:

cache/models/rl/stat_pair/actor_latest.pt
cache/models/rl/stat_pair/actor_20260412_143022.pt
cache/models/rl/stat_pair/actor_20260412_160511.pt
...

本番の FastAPI サーバーはこの actor_latest.pt を読み込んでリアルタイムの売買シグナルを生成する。

問題はサーバーを再起動せずに新しいチェックポイントをどう反映するかだ。3 つの選択肢を検討した:

  1. inotify/watchdog — ファイルシステムイベントの監視。Linux 専用でコンテナ環境では権限の問題がある。
  2. Redis pub/sub — 訓練完了時にメッセージを発行。追加インフラが必要だ。
  3. リクエストごとの mtime 確認os.stat() 1 回。シンプルでプラットフォーム非依存、追加依存性なし。

3 番目を選んだ。

コア設計: ModelRegistry

import os
import threading
import time
from pathlib import Path
import torch

class ModelRegistry:
    """
    スレッドセーフな `.pt` チェックポイントホットリローダー。
    リクエストごとに os.stat() 1 回で変更を検出する。
    """
    
    def __init__(self, model_path: str | Path, device: str = "cpu"):
        self.model_path = Path(model_path)
        self.device = device
        self._model = None
        self._last_mtime: float = 0.0
        self._lock = threading.Lock()
        self._load_count = 0
    
    def get_model(self, force_reload: bool = False) -> torch.nn.Module:
        """現在のモデルを返す。必要であればリロード。"""
        current_mtime = self._get_mtime()
        
        if not force_reload and current_mtime == self._last_mtime and self._model is not None:
            return self._model  # 高速パス: stat 確認後すぐに返す
        
        with self._lock:
            # Double-check locking: ロック取得後に再確認
            current_mtime = self._get_mtime()
            if not force_reload and current_mtime == self._last_mtime and self._model is not None:
                return self._model
            
            self._reload(current_mtime)
        
        return self._model
    
    def _get_mtime(self) -> float:
        try:
            return os.stat(self.model_path).st_mtime
        except FileNotFoundError:
            return 0.0
    
    def _reload(self, mtime: float) -> None:
        new_model = torch.jit.load(self.model_path, map_location=self.device)
        new_model.eval()
        
        # アトミックな置き換え
        self._model = new_model
        self._last_mtime = mtime
        self._load_count += 1
        
        print(f"[ModelRegistry] リロード #{self._load_count}: {self.model_path} (mtime={mtime:.3f})")
    
    @property
    def load_count(self) -> int:
        return self._load_count
    
    @property 
    def is_loaded(self) -> bool:
        return self._model is not None

FastAPI ルーター

from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel

router = APIRouter(prefix="/rl", tags=["rl-agent"])

_registry: ModelRegistry | None = None

def get_registry() -> ModelRegistry:
    if _registry is None:
        raise HTTPException(503, "モデルレジストリが初期化されていません")
    return _registry

class PredictRequest(BaseModel):
    state: list[float]

class PredictResponse(BaseModel):
    action: list[float]
    model_version: float  # mtime
    load_count: int

@router.post("/predict", response_model=PredictResponse)
async def predict(
    req: PredictRequest,
    force_reload: bool = False,
    registry: ModelRegistry = Depends(get_registry),
):
    model = registry.get_model(force_reload=force_reload)
    
    state_tensor = torch.tensor(req.state, dtype=torch.float32).unsqueeze(0)
    
    with torch.no_grad():
        action, _ = model(state_tensor, deterministic=True)
    
    # Shape 検証
    expected_dim = 1
    if action.shape[-1] != expected_dim:
        raise HTTPException(
            500,
            f"アクター出力の shape 不一致: 期待 {expected_dim}、実際 {action.shape[-1]}"
        )
    
    return PredictResponse(
        action=action.squeeze(0).tolist(),
        model_version=registry._last_mtime,
        load_count=registry.load_count,
    )

@router.get("/status")
async def status(registry: ModelRegistry = Depends(get_registry)):
    return {
        "loaded": registry.is_loaded,
        "model_path": str(registry.model_path),
        "load_count": registry.load_count,
        "last_mtime": registry._last_mtime,
    }

7 つの End-to-End テスト

import pytest
import time
import torch

class TestModelRegistry:
    
    def test_cold_start_raises_before_file_exists(self, model_dir):
        """ファイルがないとき is_loaded=False"""
        model_dir.mkdir(parents=True)
        registry = ModelRegistry(model_dir / "actor_latest.pt")
        assert not registry.is_loaded
    
    def test_loads_on_first_get(self, model_dir, dummy_actor):
        """最初の get_model() 呼び出し時にロード"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        model = registry.get_model()
        
        assert registry.is_loaded
        assert registry.load_count == 1
    
    def test_auto_swap_on_mtime_change(self, model_dir, dummy_actor):
        """mtime 変更時に自動リロード"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        registry.get_model()
        assert registry.load_count == 1
        
        time.sleep(0.01)  # mtime の差を保証
        torch.jit.save(dummy_actor, path)  # 新しいチェックポイント
        
        registry.get_model()
        assert registry.load_count == 2
    
    def test_no_reload_when_mtime_unchanged(self, model_dir, dummy_actor):
        """mtime が同じならリロードしない"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        registry.get_model()
        registry.get_model()
        registry.get_model()
        
        assert registry.load_count == 1  # 1 回だけロード
    
    def test_forced_reload(self, model_dir, dummy_actor):
        """force_reload=True は mtime に関わらずリロード"""
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        registry.get_model()
        registry.get_model(force_reload=True)
        
        assert registry.load_count == 2
    
    def test_concurrent_reload_safety(self, model_dir, dummy_actor):
        """並行リロードリクエストでレースコンディションなし"""
        import threading
        model_dir.mkdir(parents=True)
        path = model_dir / "actor_latest.pt"
        torch.jit.save(dummy_actor, path)
        
        registry = ModelRegistry(path)
        errors = []
        
        def reload_worker():
            try:
                for _ in range(100):
                    registry.get_model(force_reload=True)
            except Exception as e:
                errors.append(e)
        
        threads = [threading.Thread(target=reload_worker) for _ in range(10)]
        for t in threads: t.start()
        for t in threads: t.join()
        
        assert len(errors) == 0

パフォーマンスプロファイル

os.stat() 呼び出しのコストが気になるかもしれない。測定してみた:

  • ローカル SSD: ~2μs
  • NFS (EFS): ~8μs
  • 推論レイテンシ: ~400μs (scripted アクター基準)

os.stat() は全レイテンシの 0.5~2% 程度だ。無視して問題ない。

まとめ

ファイル mtime 確認という最もシンプルなアプローチが、複雑なファイルウォッチャーやメッセージキューよりも実用的だ。プラットフォーム非依存で、追加依存性がなく、テストしやすい。本番環境で 3 ヶ月使用しており、予期しないリロードは一度もなかった。