TL;DR SumTree で O(log N) 優先度比例サンプリングを実装した。α 優先度付け、β アニーリング(0.4→1.0、200K サンプル)、IS 重み付き損失。
use_per=Trueフラグ 1 つで SACTrainer に統合。170 行、10 つの単体テスト。
なぜ均一リプレイが不十分なのか
SAC のデフォルトのリプレイバッファは均一にサンプリングする — 各トランジションが選ばれる確率が等しい。100 万件のトランジションがバッファにあれば、全てが 1/1,000,000 の確率でサンプリングされる。
これがなぜ問題なのか、ペアトレーディングの文脈で考えてみよう。ほとんどのトランジションは「スプレッドが通常通りに動いている」だ。時折レジーム転換が起きる — ボラティリティが急騰し相関が崩れる瞬間。これらのトランジションは稀だが、エージェントにとって最も多くを学ぶべきものだ。
均一サンプリングでこうした重要なトランジションが再びサンプリングされるまで平均何ステップかかるか?バッファサイズが N なら期待値は N ステップだ。
SumTree データ構造
PER の核心は SumTree だ。完全二分木で、各リーフノードがトランジションの優先度を保存し、内部ノードはサブツリーの合計を保存する。
42
/ \
29 13
/ \ / \
13 16 3 10
/ \ / \ / \ / \
3 10 12 4 1 2 7 3 ルートは全優先度の合計だ。この構造により O(log N) 時間での優先度比例サンプリングが可能になる。
import numpy as np
class SumTree:
"""
O(log N) 優先度比例サンプリングのための SumTree。
不変条件:
- tree[i] = tree[2*i+1] + tree[2*i+2] (内部ノード)
- tree[capacity-1 + i] = トランジション i の優先度 (リーフノード)
"""
def __init__(self, capacity: int):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1, dtype=np.float64)
self.data = [None] * capacity
self._write_idx = 0
self._size = 0
def push(self, priority: float, data) -> None:
"""新しいトランジションを挿入し優先度を更新。"""
idx = self._write_idx + self.capacity - 1
self.data[self._write_idx] = data
self._update(idx, priority)
self._write_idx = (self._write_idx + 1) % self.capacity
self._size = min(self._size + 1, self.capacity)
def _update(self, idx: int, priority: float) -> None:
"""リーフからルートまでツリーを更新。"""
delta = priority - self.tree[idx]
self.tree[idx] = priority
while idx != 0:
idx = (idx - 1) // 2
self.tree[idx] += delta
def sample(self, value: float) -> tuple[int, float, object]:
"""
累積優先度 `value` に対応するリーフを返す。
value ∈ [0, total_priority)
"""
idx = 0
while idx < self.capacity - 1:
left = 2 * idx + 1
right = left + 1
if value <= self.tree[left]:
idx = left
else:
value -= self.tree[left]
idx = right
data_idx = idx - (self.capacity - 1)
return data_idx, self.tree[idx], self.data[data_idx]
@property
def total_priority(self) -> float:
return self.tree[0]
def __len__(self) -> int:
return self._size PER バッファの実装
class PrioritizedReplayBuffer:
"""
Schaul et al. 2016 に基づく Prioritized Experience Replay。
パラメータ:
capacity: 最大トランジション数
alpha: 優先度付けの強さ (0=均一、1=完全優先度)
beta_start: IS 重みの初期 β (0.4 推奨)
beta_end: IS 重みの最終 β (1.0 = 完全補正)
beta_anneal_steps: β が end に達するまでのステップ数
epsilon: 最小優先度 (数値安定性)
"""
def __init__(
self,
capacity: int = 100_000,
alpha: float = 0.6,
beta_start: float = 0.4,
beta_end: float = 1.0,
beta_anneal_steps: int = 200_000,
epsilon: float = 1e-6,
):
self.tree = SumTree(capacity)
self.alpha = alpha
self.beta_start = beta_start
self.beta_end = beta_end
self.beta_anneal_steps = beta_anneal_steps
self.epsilon = epsilon
self._step = 0
self._max_priority = 1.0
@property
def beta(self) -> float:
"""現在の β を線形にアニーリング。"""
frac = min(self._step / self.beta_anneal_steps, 1.0)
return self.beta_start + frac * (self.beta_end - self.beta_start)
def push(self, state, action, reward, next_state, done) -> None:
"""最大優先度で新しいトランジションを挿入。"""
priority = self._max_priority ** self.alpha
self.tree.push(priority, (state, action, reward, next_state, done))
def sample(self, batch_size: int) -> tuple:
"""IS 重みとともに優先度比例のバッチをサンプリング。"""
indices = np.zeros(batch_size, dtype=np.int32)
priorities = np.zeros(batch_size, dtype=np.float64)
batch = []
segment = self.tree.total_priority / batch_size
for i in range(batch_size):
lo, hi = segment * i, segment * (i + 1)
value = np.random.uniform(lo, hi)
idx, priority, data = self.tree.sample(value)
indices[i] = idx
priorities[i] = priority
batch.append(data)
# Importance-Sampling 重み
n = len(self.tree)
prob = priorities / self.tree.total_priority
weights = (n * prob) ** (-self.beta)
weights /= weights.max() # 正規化
self._step += 1
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states, dtype=np.float32),
np.array(actions, dtype=np.float32),
np.array(rewards, dtype=np.float32).reshape(-1, 1),
np.array(next_states, dtype=np.float32),
np.array(dones, dtype=np.float32).reshape(-1, 1),
indices,
weights.astype(np.float32).reshape(-1, 1),
)
def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray) -> None:
"""TD 誤差に基づいて優先度を更新。"""
for idx, td_error in zip(indices, td_errors):
priority = (abs(td_error) + self.epsilon) ** self.alpha
self.tree._update(idx + self.tree.capacity - 1, priority)
self._max_priority = max(self._max_priority, priority) SACTrainer への統合
use_per=True フラグ 1 つで切り替える:
class SACTrainer:
def __init__(self, ..., use_per: bool = False, ...):
if use_per:
self.replay_buffer = PrioritizedReplayBuffer(capacity=self.buffer_size, ...)
else:
self.replay_buffer = ReplayBuffer(capacity=self.buffer_size)
self.use_per = use_per
def _update(self) -> dict[str, float]:
if self.use_per:
states, actions, rewards, next_states, dones, indices, weights = \
self.replay_buffer.sample(self.batch_size)
weights_tensor = torch.from_numpy(weights).to(self.device)
else:
states, actions, rewards, next_states, dones = \
self.replay_buffer.sample(self.batch_size)
weights_tensor = torch.ones(self.batch_size, 1).to(self.device)
indices = None
# Twin-Critic TD 誤差
# ... (省略)
# IS 重み付き損失
critic1_loss = (weights_tensor * (q_target - q1_pred).pow(2)).mean()
critic2_loss = (weights_tensor * (q_target - q2_pred).pow(2)).mean()
# 優先度の更新
if self.use_per and indices is not None:
td_errors = ((td_errors1 + td_errors2) / 2).detach().cpu().numpy()
self.replay_buffer.update_priorities(indices, td_errors.flatten()) 結果比較
| 均一リプレイ | PER (α=0.6) | |
|---|---|---|
| 1,000 エピソード報酬 | +0.191±0.004 | +0.203±0.003 |
| 収束エピソード | ~600 | ~420 |
| メモリオーバーヘッド | 基準 | +8% (SumTree) |
PER は収束速度を約 30% 改善した。特にレジーム転換を含むデータで効果が顕著だった。
まとめ
SumTree は優先度比例サンプリングを O(N) から O(log N) に削減する優雅なデータ構造だ。170 行で完全な PER 実装が可能で、use_per=True フラグ 1 つで既存の SACTrainer に完全統合できる。重要なトランジションをより頻繁にサンプリングするという直感は、実際に測定可能な性能向上につながる。