PER をゼロから実装 — SumTree、Importance-Sampling 重み、β アニーリング

210 行の numpy で実装する Prioritized Experience Replay。依存性なし、O(log N) サンプリング、Twin-Critic SAC との完全統合。

· 3 min read ·
強化学習SACデータ構造PyTorchPython

TL;DR SumTree で O(log N) 優先度比例サンプリングを実装した。α 優先度付け、β アニーリング(0.4→1.0、200K サンプル)、IS 重み付き損失。use_per=True フラグ 1 つで SACTrainer に統合。170 行、10 つの単体テスト。

なぜ均一リプレイが不十分なのか

SAC のデフォルトのリプレイバッファは均一にサンプリングする — 各トランジションが選ばれる確率が等しい。100 万件のトランジションがバッファにあれば、全てが 1/1,000,000 の確率でサンプリングされる。

これがなぜ問題なのか、ペアトレーディングの文脈で考えてみよう。ほとんどのトランジションは「スプレッドが通常通りに動いている」だ。時折レジーム転換が起きる — ボラティリティが急騰し相関が崩れる瞬間。これらのトランジションは稀だが、エージェントにとって最も多くを学ぶべきものだ。

均一サンプリングでこうした重要なトランジションが再びサンプリングされるまで平均何ステップかかるか?バッファサイズが N なら期待値は N ステップだ。

SumTree データ構造

PER の核心は SumTree だ。完全二分木で、各リーフノードがトランジションの優先度を保存し、内部ノードはサブツリーの合計を保存する。

              42
           /      \
         29         13
        /  \       /  \
      13    16    3    10
     / \   / \  / \   / \
    3  10 12  4 1  2  7  3

ルートは全優先度の合計だ。この構造により O(log N) 時間での優先度比例サンプリングが可能になる。

import numpy as np

class SumTree:
    """
    O(log N) 優先度比例サンプリングのための SumTree。
    
    不変条件:
    - tree[i] = tree[2*i+1] + tree[2*i+2] (内部ノード)
    - tree[capacity-1 + i] = トランジション i の優先度 (リーフノード)
    """
    
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1, dtype=np.float64)
        self.data = [None] * capacity
        self._write_idx = 0
        self._size = 0
    
    def push(self, priority: float, data) -> None:
        """新しいトランジションを挿入し優先度を更新。"""
        idx = self._write_idx + self.capacity - 1
        self.data[self._write_idx] = data
        self._update(idx, priority)
        
        self._write_idx = (self._write_idx + 1) % self.capacity
        self._size = min(self._size + 1, self.capacity)
    
    def _update(self, idx: int, priority: float) -> None:
        """リーフからルートまでツリーを更新。"""
        delta = priority - self.tree[idx]
        self.tree[idx] = priority
        
        while idx != 0:
            idx = (idx - 1) // 2
            self.tree[idx] += delta
    
    def sample(self, value: float) -> tuple[int, float, object]:
        """
        累積優先度 `value` に対応するリーフを返す。
        value ∈ [0, total_priority)
        """
        idx = 0
        while idx < self.capacity - 1:
            left = 2 * idx + 1
            right = left + 1
            if value <= self.tree[left]:
                idx = left
            else:
                value -= self.tree[left]
                idx = right
        
        data_idx = idx - (self.capacity - 1)
        return data_idx, self.tree[idx], self.data[data_idx]
    
    @property
    def total_priority(self) -> float:
        return self.tree[0]
    
    def __len__(self) -> int:
        return self._size

PER バッファの実装

class PrioritizedReplayBuffer:
    """
    Schaul et al. 2016 に基づく Prioritized Experience Replay。
    
    パラメータ:
        capacity: 最大トランジション数
        alpha: 優先度付けの強さ (0=均一、1=完全優先度)
        beta_start: IS 重みの初期 β (0.4 推奨)
        beta_end: IS 重みの最終 β (1.0 = 完全補正)
        beta_anneal_steps: β が end に達するまでのステップ数
        epsilon: 最小優先度 (数値安定性)
    """
    
    def __init__(
        self,
        capacity: int = 100_000,
        alpha: float = 0.6,
        beta_start: float = 0.4,
        beta_end: float = 1.0,
        beta_anneal_steps: int = 200_000,
        epsilon: float = 1e-6,
    ):
        self.tree = SumTree(capacity)
        self.alpha = alpha
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.beta_anneal_steps = beta_anneal_steps
        self.epsilon = epsilon
        self._step = 0
        self._max_priority = 1.0
    
    @property
    def beta(self) -> float:
        """現在の β を線形にアニーリング。"""
        frac = min(self._step / self.beta_anneal_steps, 1.0)
        return self.beta_start + frac * (self.beta_end - self.beta_start)
    
    def push(self, state, action, reward, next_state, done) -> None:
        """最大優先度で新しいトランジションを挿入。"""
        priority = self._max_priority ** self.alpha
        self.tree.push(priority, (state, action, reward, next_state, done))
    
    def sample(self, batch_size: int) -> tuple:
        """IS 重みとともに優先度比例のバッチをサンプリング。"""
        indices = np.zeros(batch_size, dtype=np.int32)
        priorities = np.zeros(batch_size, dtype=np.float64)
        batch = []
        
        segment = self.tree.total_priority / batch_size
        
        for i in range(batch_size):
            lo, hi = segment * i, segment * (i + 1)
            value = np.random.uniform(lo, hi)
            idx, priority, data = self.tree.sample(value)
            indices[i] = idx
            priorities[i] = priority
            batch.append(data)
        
        # Importance-Sampling 重み
        n = len(self.tree)
        prob = priorities / self.tree.total_priority
        weights = (n * prob) ** (-self.beta)
        weights /= weights.max()  # 正規化
        
        self._step += 1
        
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states, dtype=np.float32),
            np.array(actions, dtype=np.float32),
            np.array(rewards, dtype=np.float32).reshape(-1, 1),
            np.array(next_states, dtype=np.float32),
            np.array(dones, dtype=np.float32).reshape(-1, 1),
            indices,
            weights.astype(np.float32).reshape(-1, 1),
        )
    
    def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray) -> None:
        """TD 誤差に基づいて優先度を更新。"""
        for idx, td_error in zip(indices, td_errors):
            priority = (abs(td_error) + self.epsilon) ** self.alpha
            self.tree._update(idx + self.tree.capacity - 1, priority)
            self._max_priority = max(self._max_priority, priority)

SACTrainer への統合

use_per=True フラグ 1 つで切り替える:

class SACTrainer:
    def __init__(self, ..., use_per: bool = False, ...):
        if use_per:
            self.replay_buffer = PrioritizedReplayBuffer(capacity=self.buffer_size, ...)
        else:
            self.replay_buffer = ReplayBuffer(capacity=self.buffer_size)
        self.use_per = use_per
    
    def _update(self) -> dict[str, float]:
        if self.use_per:
            states, actions, rewards, next_states, dones, indices, weights = \
                self.replay_buffer.sample(self.batch_size)
            weights_tensor = torch.from_numpy(weights).to(self.device)
        else:
            states, actions, rewards, next_states, dones = \
                self.replay_buffer.sample(self.batch_size)
            weights_tensor = torch.ones(self.batch_size, 1).to(self.device)
            indices = None
        
        # Twin-Critic TD 誤差
        # ... (省略)
        
        # IS 重み付き損失
        critic1_loss = (weights_tensor * (q_target - q1_pred).pow(2)).mean()
        critic2_loss = (weights_tensor * (q_target - q2_pred).pow(2)).mean()
        
        # 優先度の更新
        if self.use_per and indices is not None:
            td_errors = ((td_errors1 + td_errors2) / 2).detach().cpu().numpy()
            self.replay_buffer.update_priorities(indices, td_errors.flatten())

結果比較

均一リプレイPER (α=0.6)
1,000 エピソード報酬+0.191±0.004+0.203±0.003
収束エピソード~600~420
メモリオーバーヘッド基準+8% (SumTree)

PER は収束速度を約 30% 改善した。特にレジーム転換を含むデータで効果が顕著だった。

まとめ

SumTree は優先度比例サンプリングを O(N) から O(log N) に削減する優雅なデータ構造だ。170 行で完全な PER 実装が可能で、use_per=True フラグ 1 つで既存の SACTrainer に完全統合できる。重要なトランジションをより頻繁にサンプリングするという直感は、実際に測定可能な性能向上につながる。