TL;DR リクエストごとに
os.stat()1 回で.ptファイルの mtime を確認する。変わっていればthreading.Lockで保護しながらリロード。270 行の FastAPI ルーター、7 つの E2E テスト。複雑なインフラなしで動作する。
問題: 訓練とサービングの非同期性
SAC 訓練ループは 50,000 ステップごとにチェックポイントを保存する:
cache/models/rl/stat_pair/actor_latest.pt
cache/models/rl/stat_pair/actor_20260412_143022.pt
cache/models/rl/stat_pair/actor_20260412_160511.pt
... 本番の FastAPI サーバーはこの actor_latest.pt を読み込んでリアルタイムの売買シグナルを生成する。
問題はサーバーを再起動せずに新しいチェックポイントをどう反映するかだ。3 つの選択肢を検討した:
- inotify/watchdog — ファイルシステムイベントの監視。Linux 専用でコンテナ環境では権限の問題がある。
- Redis pub/sub — 訓練完了時にメッセージを発行。追加インフラが必要だ。
- リクエストごとの mtime 確認 —
os.stat()1 回。シンプルでプラットフォーム非依存、追加依存性なし。
3 番目を選んだ。
コア設計: ModelRegistry
import os
import threading
import time
from pathlib import Path
import torch
class ModelRegistry:
"""
スレッドセーフな `.pt` チェックポイントホットリローダー。
リクエストごとに os.stat() 1 回で変更を検出する。
"""
def __init__(self, model_path: str | Path, device: str = "cpu"):
self.model_path = Path(model_path)
self.device = device
self._model = None
self._last_mtime: float = 0.0
self._lock = threading.Lock()
self._load_count = 0
def get_model(self, force_reload: bool = False) -> torch.nn.Module:
"""現在のモデルを返す。必要であればリロード。"""
current_mtime = self._get_mtime()
if not force_reload and current_mtime == self._last_mtime and self._model is not None:
return self._model # 高速パス: stat 確認後すぐに返す
with self._lock:
# Double-check locking: ロック取得後に再確認
current_mtime = self._get_mtime()
if not force_reload and current_mtime == self._last_mtime and self._model is not None:
return self._model
self._reload(current_mtime)
return self._model
def _get_mtime(self) -> float:
try:
return os.stat(self.model_path).st_mtime
except FileNotFoundError:
return 0.0
def _reload(self, mtime: float) -> None:
new_model = torch.jit.load(self.model_path, map_location=self.device)
new_model.eval()
# アトミックな置き換え
self._model = new_model
self._last_mtime = mtime
self._load_count += 1
print(f"[ModelRegistry] リロード #{self._load_count}: {self.model_path} (mtime={mtime:.3f})")
@property
def load_count(self) -> int:
return self._load_count
@property
def is_loaded(self) -> bool:
return self._model is not None FastAPI ルーター
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
router = APIRouter(prefix="/rl", tags=["rl-agent"])
_registry: ModelRegistry | None = None
def get_registry() -> ModelRegistry:
if _registry is None:
raise HTTPException(503, "モデルレジストリが初期化されていません")
return _registry
class PredictRequest(BaseModel):
state: list[float]
class PredictResponse(BaseModel):
action: list[float]
model_version: float # mtime
load_count: int
@router.post("/predict", response_model=PredictResponse)
async def predict(
req: PredictRequest,
force_reload: bool = False,
registry: ModelRegistry = Depends(get_registry),
):
model = registry.get_model(force_reload=force_reload)
state_tensor = torch.tensor(req.state, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action, _ = model(state_tensor, deterministic=True)
# Shape 検証
expected_dim = 1
if action.shape[-1] != expected_dim:
raise HTTPException(
500,
f"アクター出力の shape 不一致: 期待 {expected_dim}、実際 {action.shape[-1]}"
)
return PredictResponse(
action=action.squeeze(0).tolist(),
model_version=registry._last_mtime,
load_count=registry.load_count,
)
@router.get("/status")
async def status(registry: ModelRegistry = Depends(get_registry)):
return {
"loaded": registry.is_loaded,
"model_path": str(registry.model_path),
"load_count": registry.load_count,
"last_mtime": registry._last_mtime,
} 7 つの End-to-End テスト
import pytest
import time
import torch
class TestModelRegistry:
def test_cold_start_raises_before_file_exists(self, model_dir):
"""ファイルがないとき is_loaded=False"""
model_dir.mkdir(parents=True)
registry = ModelRegistry(model_dir / "actor_latest.pt")
assert not registry.is_loaded
def test_loads_on_first_get(self, model_dir, dummy_actor):
"""最初の get_model() 呼び出し時にロード"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
model = registry.get_model()
assert registry.is_loaded
assert registry.load_count == 1
def test_auto_swap_on_mtime_change(self, model_dir, dummy_actor):
"""mtime 変更時に自動リロード"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
registry.get_model()
assert registry.load_count == 1
time.sleep(0.01) # mtime の差を保証
torch.jit.save(dummy_actor, path) # 新しいチェックポイント
registry.get_model()
assert registry.load_count == 2
def test_no_reload_when_mtime_unchanged(self, model_dir, dummy_actor):
"""mtime が同じならリロードしない"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
registry.get_model()
registry.get_model()
registry.get_model()
assert registry.load_count == 1 # 1 回だけロード
def test_forced_reload(self, model_dir, dummy_actor):
"""force_reload=True は mtime に関わらずリロード"""
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
registry.get_model()
registry.get_model(force_reload=True)
assert registry.load_count == 2
def test_concurrent_reload_safety(self, model_dir, dummy_actor):
"""並行リロードリクエストでレースコンディションなし"""
import threading
model_dir.mkdir(parents=True)
path = model_dir / "actor_latest.pt"
torch.jit.save(dummy_actor, path)
registry = ModelRegistry(path)
errors = []
def reload_worker():
try:
for _ in range(100):
registry.get_model(force_reload=True)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=reload_worker) for _ in range(10)]
for t in threads: t.start()
for t in threads: t.join()
assert len(errors) == 0 パフォーマンスプロファイル
os.stat() 呼び出しのコストが気になるかもしれない。測定してみた:
- ローカル SSD: ~2μs
- NFS (EFS): ~8μs
- 推論レイテンシ: ~400μs (scripted アクター基準)
os.stat() は全レイテンシの 0.5~2% 程度だ。無視して問題ない。
まとめ
ファイル mtime 確認という最もシンプルなアプローチが、複雑なファイルウォッチャーやメッセージキューよりも実用的だ。プラットフォーム非依存で、追加依存性がなく、テストしやすい。本番環境で 3 ヶ月使用しており、予期しないリロードは一度もなかった。