Split View: [심층 강화학습] 14. 연속 행동 공간: DDPG와 분포 정책
[심층 강화학습] 14. 연속 행동 공간: DDPG와 분포 정책
개요
지금까지 다룬 대부분의 환경은 이산 행동 공간(왼쪽/오른쪽, 각 방향키 등)을 가졌다. 하지만 로봇 제어, 자율주행, 물리 시뮬레이션 같은 현실 문제에서는 연속 행동 공간이 필요하다. 관절 각도, 토크, 가속도 등은 실수 값이기 때문이다.
이 글에서는 연속 행동 공간에 대한 정책 설계, A2C 확장, DDPG(Deep Deterministic Policy Gradient), 그리고 분포 정책 그래디언트를 다룬다.
왜 연속 공간이 필요한가
이산화의 한계
연속 행동을 이산화하는 방법(예: 토크를 -1, -0.5, 0, 0.5, 1로 분할)은 간단하지만 문제가 있다:
- 차원의 저주: N개의 관절 각각을 K단계로 이산화하면 행동 공간이 K^N으로 폭발한다
- 정밀도 부족: 이산화 간격 사이의 세밀한 제어가 불가능하다
- 비효율성: 대부분의 이산 행동 조합은 무의미하다
예를 들어 6관절 로봇을 10단계로 이산화하면 10^6 = 백만 개의 행동이 생긴다.
연속 행동 공간의 표현
연속 정책은 행동의 확률 분포를 출력한다:
- 가우시안 정책: 평균과 분산을 출력하여 정규분포에서 행동을 샘플링
- 베타 정책: 유한 범위의 행동에 적합한 베타 분포 사용
- 결정적 정책: 하나의 행동 값을 직접 출력 (DDPG)
행동 공간 설계
import torch
import torch.nn as nn
import numpy as np
class ContinuousActionSpace:
"""연속 행동 공간 정의"""
def __init__(self, low, high):
self.low = np.array(low, dtype=np.float32)
self.high = np.array(high, dtype=np.float32)
self.shape = self.low.shape
def clip(self, action):
return np.clip(action, self.low, self.high)
def sample(self):
return np.random.uniform(self.low, self.high)
A2C를 연속 행동 공간에 적용
기존 A2C의 카테고리컬 분포를 가우시안 분포로 교체한다:
class ContinuousActorCritic(nn.Module):
"""연속 행동 공간용 Actor-Critic"""
def __init__(self, obs_size, act_size):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
)
# Actor: 평균과 로그 표준편차 출력
self.mu = nn.Sequential(
nn.Linear(256, act_size),
nn.Tanh(), # 행동 범위를 -1, 1로 제한
)
# 로그 표준편차는 상태 독립 파라미터로 학습
self.log_std = nn.Parameter(torch.zeros(act_size))
# Critic
self.value = nn.Linear(256, 1)
def forward(self, obs):
features = self.shared(obs)
mu = self.mu(features)
std = self.log_std.exp()
value = self.value(features)
return mu, std, value
def get_action(self, obs):
mu, std, value = self.forward(obs)
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1)
return action, log_prob, value
def evaluate(self, obs, action):
mu, std, value = self.forward(obs)
dist = torch.distributions.Normal(mu, std)
log_prob = dist.log_prob(action).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1)
return log_prob, entropy, value
연속 A2C 학습
def train_continuous_a2c(env, model, num_steps=2048, num_updates=1000,
gamma=0.99, lr=3e-4, entropy_coef=0.01):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
obs = env.reset()
for update in range(num_updates):
# 경험 수집
observations = []
actions = []
rewards = []
dones = []
log_probs = []
values = []
for _ in range(num_steps):
obs_t = torch.FloatTensor(obs).unsqueeze(0)
action, log_prob, value = model.get_action(obs_t)
action_np = action.detach().numpy().flatten()
action_clipped = np.clip(action_np, env.action_space.low,
env.action_space.high)
next_obs, reward, done, _ = env.step(action_clipped)
observations.append(obs)
actions.append(action.detach())
rewards.append(reward)
dones.append(done)
log_probs.append(log_prob)
values.append(value.squeeze())
obs = next_obs if not done else env.reset()
# GAE(Generalized Advantage Estimation) 계산
advantages, returns = compute_gae(rewards, values, dones, gamma)
# 업데이트
obs_batch = torch.FloatTensor(np.array(observations))
act_batch = torch.cat(actions)
adv_batch = torch.FloatTensor(advantages)
ret_batch = torch.FloatTensor(returns)
new_log_probs, entropy, new_values = model.evaluate(
obs_batch, act_batch
)
policy_loss = -(new_log_probs * adv_batch).mean()
value_loss = (ret_batch - new_values.squeeze()).pow(2).mean()
entropy_loss = -entropy.mean()
loss = policy_loss + 0.5 * value_loss + entropy_coef * entropy_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
def compute_gae(rewards, values, dones, gamma, lam=0.95):
"""Generalized Advantage Estimation"""
advantages = []
gae = 0
values_list = [v.detach().item() for v in values]
values_list.append(0) # 부트스트랩
for t in reversed(range(len(rewards))):
if dones[t]:
delta = rewards[t] - values_list[t]
gae = delta
else:
delta = rewards[t] + gamma * values_list[t+1] - values_list[t]
gae = delta + gamma * lam * gae
advantages.insert(0, gae)
returns = [a + v for a, v in zip(advantages, values_list[:-1])]
return advantages, returns
결정적 정책 그래디언트 (DDPG)
DDPG의 핵심 아이디어
DDPG(Deep Deterministic Policy Gradient) 는 DQN의 아이디어를 연속 행동 공간에 적용한 알고리즘이다:
- 결정적 정책: 확률적으로 샘플링하지 않고, 상태에 대해 하나의 행동을 직접 출력
- 경험 리플레이: DQN처럼 과거 경험을 버퍼에 저장하고 재사용
- 타겟 네트워크: 학습 안정화를 위한 소프트 업데이트
- 탐색 노이즈: 결정적 정책에 노이즈를 추가하여 탐색
네트워크 구조
class DDPGActor(nn.Module):
"""결정적 정책 네트워크"""
def __init__(self, obs_size, act_size, act_limit):
super().__init__()
self.act_limit = act_limit
self.net = nn.Sequential(
nn.Linear(obs_size, 400),
nn.ReLU(),
nn.Linear(400, 300),
nn.ReLU(),
nn.Linear(300, act_size),
nn.Tanh(),
)
def forward(self, obs):
return self.act_limit * self.net(obs)
class DDPGCritic(nn.Module):
"""행동 가치 함수 Q(s, a)"""
def __init__(self, obs_size, act_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size + act_size, 400),
nn.ReLU(),
nn.Linear(400, 300),
nn.ReLU(),
nn.Linear(300, 1),
)
def forward(self, obs, action):
x = torch.cat([obs, action], dim=-1)
return self.net(x)
탐색 노이즈: Ornstein-Uhlenbeck 프로세스
DDPG는 시간적 상관관계가 있는 OU 노이즈를 사용하여 연속적이고 부드러운 탐색을 수행한다:
class OrnsteinUhlenbeckNoise:
"""시간 상관 탐색 노이즈"""
def __init__(self, size, mu=0.0, theta=0.15, sigma=0.2):
self.size = size
self.mu = mu
self.theta = theta
self.sigma = sigma
self.state = np.ones(size) * mu
def reset(self):
self.state = np.ones(self.size) * self.mu
def sample(self):
dx = (self.theta * (self.mu - self.state)
+ self.sigma * np.random.randn(self.size))
self.state += dx
return self.state.copy()
DDPG 전체 구현
import copy
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=1000000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states), np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(next_states),
np.array(dones, dtype=np.float32))
def __len__(self):
return len(self.buffer)
class DDPGAgent:
def __init__(self, obs_size, act_size, act_limit,
gamma=0.99, tau=0.005, lr_actor=1e-4, lr_critic=1e-3):
self.gamma = gamma
self.tau = tau
# 메인 네트워크
self.actor = DDPGActor(obs_size, act_size, act_limit)
self.critic = DDPGCritic(obs_size, act_size)
# 타겟 네트워크 (소프트 업데이트 대상)
self.target_actor = copy.deepcopy(self.actor)
self.target_critic = copy.deepcopy(self.critic)
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=lr_actor)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=lr_critic)
self.replay_buffer = ReplayBuffer()
self.noise = OrnsteinUhlenbeckNoise(act_size)
def select_action(self, state, add_noise=True):
state_t = torch.FloatTensor(state).unsqueeze(0)
action = self.actor(state_t).detach().numpy().flatten()
if add_noise:
action += self.noise.sample()
return np.clip(action, -self.actor.act_limit, self.actor.act_limit)
def update(self, batch_size=256):
if len(self.replay_buffer) < batch_size:
return
states, actions, rewards, next_states, dones = \
self.replay_buffer.sample(batch_size)
states_t = torch.FloatTensor(states)
actions_t = torch.FloatTensor(actions)
rewards_t = torch.FloatTensor(rewards).unsqueeze(1)
next_states_t = torch.FloatTensor(next_states)
dones_t = torch.FloatTensor(dones).unsqueeze(1)
# === Critic 업데이트 ===
with torch.no_grad():
next_actions = self.target_actor(next_states_t)
target_q = self.target_critic(next_states_t, next_actions)
target_value = rewards_t + self.gamma * (1 - dones_t) * target_q
current_q = self.critic(states_t, actions_t)
critic_loss = nn.MSELoss()(current_q, target_value)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# === Actor 업데이트 ===
# 정책의 목표: Q(s, mu(s))를 최대화
predicted_actions = self.actor(states_t)
actor_loss = -self.critic(states_t, predicted_actions).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# === 타겟 네트워크 소프트 업데이트 ===
self._soft_update(self.actor, self.target_actor)
self._soft_update(self.critic, self.target_critic)
def _soft_update(self, source, target):
for src_param, tgt_param in zip(source.parameters(),
target.parameters()):
tgt_param.data.copy_(
self.tau * src_param.data + (1 - self.tau) * tgt_param.data
)
분포 정책 그래디언트
D4PG: Distributed Distributional DDPG
D4PG는 DDPG에 분포 강화학습(Distributional RL)을 결합한 알고리즘이다. Q 값의 기대값 대신 전체 분포를 학습한다.
class DistributionalCritic(nn.Module):
"""분포 Q 함수: 리턴의 분포를 예측"""
def __init__(self, obs_size, act_size,
num_atoms=51, v_min=-10.0, v_max=10.0):
super().__init__()
self.num_atoms = num_atoms
self.v_min = v_min
self.v_max = v_max
self.support = torch.linspace(v_min, v_max, num_atoms)
self.delta_z = (v_max - v_min) / (num_atoms - 1)
self.net = nn.Sequential(
nn.Linear(obs_size + act_size, 400),
nn.ReLU(),
nn.Linear(400, 300),
nn.ReLU(),
nn.Linear(300, num_atoms),
)
def forward(self, obs, action):
x = torch.cat([obs, action], dim=-1)
logits = self.net(x)
probs = torch.softmax(logits, dim=-1)
return probs
def get_q_value(self, obs, action):
"""기대 Q값 계산"""
probs = self.forward(obs, action)
return (probs * self.support.unsqueeze(0)).sum(dim=-1, keepdim=True)
분포 TD 학습
def distributional_critic_loss(critic, target_critic, target_actor,
states, actions, rewards, next_states,
dones, gamma=0.99):
"""분포 크리틱의 손실 계산 (카테고리컬 프로젝션)"""
num_atoms = critic.num_atoms
v_min = critic.v_min
v_max = critic.v_max
delta_z = critic.delta_z
support = critic.support
with torch.no_grad():
next_actions = target_actor(next_states)
target_probs = target_critic(next_states, next_actions)
# 벨만 업데이트된 서포트
tz = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * \
support.unsqueeze(0)
tz = tz.clamp(v_min, v_max)
# 카테고리컬 프로젝션
b = (tz - v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
projected = torch.zeros_like(target_probs)
for i in range(num_atoms):
projected.scatter_add_(1, l[:, i:i+1],
target_probs[:, i:i+1] * (u[:, i:i+1].float() - b[:, i:i+1]))
projected.scatter_add_(1, u[:, i:i+1],
target_probs[:, i:i+1] * (b[:, i:i+1] - l[:, i:i+1].float()))
# 크로스 엔트로피 손실
current_logits = critic.net(torch.cat([states, actions], dim=-1))
loss = -(projected * torch.log_softmax(current_logits, dim=-1)).sum(dim=-1).mean()
return loss
실험 결과 비교
MuJoCo의 HalfCheetah-v4 환경에서의 성능 비교:
| 알고리즘 | 100만 스텝 평균 보상 | 학습 안정성 | 샘플 효율성 |
|---|---|---|---|
| A2C (연속) | 약 3000 | 중간 | 낮음 |
| DDPG | 약 8000 | 낮음 (민감) | 높음 |
| D4PG | 약 9500 | 높음 | 높음 |
DDPG는 하이퍼파라미터에 민감하지만 잘 튜닝하면 높은 성능을 보인다. D4PG는 분포 학습 덕분에 더 안정적이다.
핵심 요약
- 연속 행동 공간은 가우시안 정책(A2C) 또는 결정적 정책(DDPG)으로 처리한다
- DDPG는 DQN의 핵심 기법(리플레이 버퍼, 타겟 네트워크)을 연속 공간에 적용한다
- OU 노이즈로 시간적 상관이 있는 부드러운 탐색을 수행한다
- 분포 정책 그래디언트(D4PG)는 리턴의 분포를 학습하여 안정성을 높인다
다음 글에서는 정책 업데이트의 안정성을 보장하는 Trust Region 방법(TRPO, PPO, ACKTR) 을 다룬다.
[Deep RL] 14. Continuous Action Spaces: DDPG and Distributional Policies
Overview
Most environments we have covered so far had discrete action spaces (left/right, directional keys, etc.). However, real-world problems like robot control, autonomous driving, and physics simulations require continuous action spaces. Joint angles, torques, and accelerations are real-valued quantities.
This post covers policy design for continuous action spaces, A2C extension, DDPG (Deep Deterministic Policy Gradient), and distributional policy gradients.
Why Continuous Spaces Are Needed
Limitations of Discretization
Discretizing continuous actions (e.g., splitting torque into -1, -0.5, 0, 0.5, 1) is simple but problematic:
- Curse of dimensionality: Discretizing each of N joints into K levels causes the action space to explode to K^N
- Lack of precision: Fine-grained control between discretization intervals is impossible
- Inefficiency: Most discrete action combinations are meaningless
For example, discretizing a 6-joint robot into 10 levels creates 10^6 = one million actions.
Representing Continuous Action Spaces
Continuous policies output probability distributions over actions:
- Gaussian policy: Outputs mean and variance, samples actions from a normal distribution
- Beta policy: Uses beta distribution suitable for bounded action ranges
- Deterministic policy: Directly outputs a single action value (DDPG)
Action Space Design
import torch
import torch.nn as nn
import numpy as np
class ContinuousActionSpace:
"""Continuous action space definition"""
def __init__(self, low, high):
self.low = np.array(low, dtype=np.float32)
self.high = np.array(high, dtype=np.float32)
self.shape = self.low.shape
def clip(self, action):
return np.clip(action, self.low, self.high)
def sample(self):
return np.random.uniform(self.low, self.high)
Applying A2C to Continuous Action Spaces
We replace the categorical distribution in the existing A2C with a Gaussian distribution:
class ContinuousActorCritic(nn.Module):
"""Actor-Critic for continuous action spaces"""
def __init__(self, obs_size, act_size):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_size, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
)
# Actor: outputs mean and log standard deviation
self.mu = nn.Sequential(
nn.Linear(256, act_size),
nn.Tanh(), # Constrains action range to -1, 1
)
# Log standard deviation is a state-independent learnable parameter
self.log_std = nn.Parameter(torch.zeros(act_size))
# Critic
self.value = nn.Linear(256, 1)
def forward(self, obs):
features = self.shared(obs)
mu = self.mu(features)
std = self.log_std.exp()
value = self.value(features)
return mu, std, value
def get_action(self, obs):
mu, std, value = self.forward(obs)
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1)
return action, log_prob, value
def evaluate(self, obs, action):
mu, std, value = self.forward(obs)
dist = torch.distributions.Normal(mu, std)
log_prob = dist.log_prob(action).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1)
return log_prob, entropy, value
Continuous A2C Training
def train_continuous_a2c(env, model, num_steps=2048, num_updates=1000,
gamma=0.99, lr=3e-4, entropy_coef=0.01):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
obs = env.reset()
for update in range(num_updates):
# Collect experiences
observations = []
actions = []
rewards = []
dones = []
log_probs = []
values = []
for _ in range(num_steps):
obs_t = torch.FloatTensor(obs).unsqueeze(0)
action, log_prob, value = model.get_action(obs_t)
action_np = action.detach().numpy().flatten()
action_clipped = np.clip(action_np, env.action_space.low,
env.action_space.high)
next_obs, reward, done, _ = env.step(action_clipped)
observations.append(obs)
actions.append(action.detach())
rewards.append(reward)
dones.append(done)
log_probs.append(log_prob)
values.append(value.squeeze())
obs = next_obs if not done else env.reset()
# GAE (Generalized Advantage Estimation) computation
advantages, returns = compute_gae(rewards, values, dones, gamma)
# Update
obs_batch = torch.FloatTensor(np.array(observations))
act_batch = torch.cat(actions)
adv_batch = torch.FloatTensor(advantages)
ret_batch = torch.FloatTensor(returns)
new_log_probs, entropy, new_values = model.evaluate(
obs_batch, act_batch
)
policy_loss = -(new_log_probs * adv_batch).mean()
value_loss = (ret_batch - new_values.squeeze()).pow(2).mean()
entropy_loss = -entropy.mean()
loss = policy_loss + 0.5 * value_loss + entropy_coef * entropy_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
def compute_gae(rewards, values, dones, gamma, lam=0.95):
"""Generalized Advantage Estimation"""
advantages = []
gae = 0
values_list = [v.detach().item() for v in values]
values_list.append(0) # Bootstrap
for t in reversed(range(len(rewards))):
if dones[t]:
delta = rewards[t] - values_list[t]
gae = delta
else:
delta = rewards[t] + gamma * values_list[t+1] - values_list[t]
gae = delta + gamma * lam * gae
advantages.insert(0, gae)
returns = [a + v for a, v in zip(advantages, values_list[:-1])]
return advantages, returns
Deterministic Policy Gradient (DDPG)
Core Idea of DDPG
DDPG (Deep Deterministic Policy Gradient) applies DQN ideas to continuous action spaces:
- Deterministic policy: Instead of probabilistic sampling, directly outputs a single action for a given state
- Experience replay: Stores past experiences in a buffer and reuses them, like DQN
- Target network: Soft updates for learning stability
- Exploration noise: Adds noise to the deterministic policy for exploration
Network Architecture
class DDPGActor(nn.Module):
"""Deterministic policy network"""
def __init__(self, obs_size, act_size, act_limit):
super().__init__()
self.act_limit = act_limit
self.net = nn.Sequential(
nn.Linear(obs_size, 400),
nn.ReLU(),
nn.Linear(400, 300),
nn.ReLU(),
nn.Linear(300, act_size),
nn.Tanh(),
)
def forward(self, obs):
return self.act_limit * self.net(obs)
class DDPGCritic(nn.Module):
"""Action-value function Q(s, a)"""
def __init__(self, obs_size, act_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size + act_size, 400),
nn.ReLU(),
nn.Linear(400, 300),
nn.ReLU(),
nn.Linear(300, 1),
)
def forward(self, obs, action):
x = torch.cat([obs, action], dim=-1)
return self.net(x)
Exploration Noise: Ornstein-Uhlenbeck Process
DDPG uses temporally correlated OU noise for smooth, continuous exploration:
class OrnsteinUhlenbeckNoise:
"""Temporally correlated exploration noise"""
def __init__(self, size, mu=0.0, theta=0.15, sigma=0.2):
self.size = size
self.mu = mu
self.theta = theta
self.sigma = sigma
self.state = np.ones(size) * mu
def reset(self):
self.state = np.ones(self.size) * self.mu
def sample(self):
dx = (self.theta * (self.mu - self.state)
+ self.sigma * np.random.randn(self.size))
self.state += dx
return self.state.copy()
Full DDPG Implementation
import copy
from collections import deque
import random
class ReplayBuffer:
def __init__(self, capacity=1000000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (np.array(states), np.array(actions),
np.array(rewards, dtype=np.float32),
np.array(next_states),
np.array(dones, dtype=np.float32))
def __len__(self):
return len(self.buffer)
class DDPGAgent:
def __init__(self, obs_size, act_size, act_limit,
gamma=0.99, tau=0.005, lr_actor=1e-4, lr_critic=1e-3):
self.gamma = gamma
self.tau = tau
# Main networks
self.actor = DDPGActor(obs_size, act_size, act_limit)
self.critic = DDPGCritic(obs_size, act_size)
# Target networks (soft update targets)
self.target_actor = copy.deepcopy(self.actor)
self.target_critic = copy.deepcopy(self.critic)
self.actor_optimizer = torch.optim.Adam(
self.actor.parameters(), lr=lr_actor)
self.critic_optimizer = torch.optim.Adam(
self.critic.parameters(), lr=lr_critic)
self.replay_buffer = ReplayBuffer()
self.noise = OrnsteinUhlenbeckNoise(act_size)
def select_action(self, state, add_noise=True):
state_t = torch.FloatTensor(state).unsqueeze(0)
action = self.actor(state_t).detach().numpy().flatten()
if add_noise:
action += self.noise.sample()
return np.clip(action, -self.actor.act_limit, self.actor.act_limit)
def update(self, batch_size=256):
if len(self.replay_buffer) < batch_size:
return
states, actions, rewards, next_states, dones = \
self.replay_buffer.sample(batch_size)
states_t = torch.FloatTensor(states)
actions_t = torch.FloatTensor(actions)
rewards_t = torch.FloatTensor(rewards).unsqueeze(1)
next_states_t = torch.FloatTensor(next_states)
dones_t = torch.FloatTensor(dones).unsqueeze(1)
# === Critic update ===
with torch.no_grad():
next_actions = self.target_actor(next_states_t)
target_q = self.target_critic(next_states_t, next_actions)
target_value = rewards_t + self.gamma * (1 - dones_t) * target_q
current_q = self.critic(states_t, actions_t)
critic_loss = nn.MSELoss()(current_q, target_value)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# === Actor update ===
# Policy objective: maximize Q(s, mu(s))
predicted_actions = self.actor(states_t)
actor_loss = -self.critic(states_t, predicted_actions).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# === Target network soft update ===
self._soft_update(self.actor, self.target_actor)
self._soft_update(self.critic, self.target_critic)
def _soft_update(self, source, target):
for src_param, tgt_param in zip(source.parameters(),
target.parameters()):
tgt_param.data.copy_(
self.tau * src_param.data + (1 - self.tau) * tgt_param.data
)
Distributional Policy Gradient
D4PG: Distributed Distributional DDPG
D4PG combines distributional reinforcement learning with DDPG. Instead of learning the expected Q-value, it learns the entire distribution of returns.
class DistributionalCritic(nn.Module):
"""Distributional Q function: predicts the distribution of returns"""
def __init__(self, obs_size, act_size,
num_atoms=51, v_min=-10.0, v_max=10.0):
super().__init__()
self.num_atoms = num_atoms
self.v_min = v_min
self.v_max = v_max
self.support = torch.linspace(v_min, v_max, num_atoms)
self.delta_z = (v_max - v_min) / (num_atoms - 1)
self.net = nn.Sequential(
nn.Linear(obs_size + act_size, 400),
nn.ReLU(),
nn.Linear(400, 300),
nn.ReLU(),
nn.Linear(300, num_atoms),
)
def forward(self, obs, action):
x = torch.cat([obs, action], dim=-1)
logits = self.net(x)
probs = torch.softmax(logits, dim=-1)
return probs
def get_q_value(self, obs, action):
"""Compute expected Q-value"""
probs = self.forward(obs, action)
return (probs * self.support.unsqueeze(0)).sum(dim=-1, keepdim=True)
Distributional TD Learning
def distributional_critic_loss(critic, target_critic, target_actor,
states, actions, rewards, next_states,
dones, gamma=0.99):
"""Distributional critic loss computation (categorical projection)"""
num_atoms = critic.num_atoms
v_min = critic.v_min
v_max = critic.v_max
delta_z = critic.delta_z
support = critic.support
with torch.no_grad():
next_actions = target_actor(next_states)
target_probs = target_critic(next_states, next_actions)
# Bellman-updated support
tz = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * \
support.unsqueeze(0)
tz = tz.clamp(v_min, v_max)
# Categorical projection
b = (tz - v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
projected = torch.zeros_like(target_probs)
for i in range(num_atoms):
projected.scatter_add_(1, l[:, i:i+1],
target_probs[:, i:i+1] * (u[:, i:i+1].float() - b[:, i:i+1]))
projected.scatter_add_(1, u[:, i:i+1],
target_probs[:, i:i+1] * (b[:, i:i+1] - l[:, i:i+1].float()))
# Cross-entropy loss
current_logits = critic.net(torch.cat([states, actions], dim=-1))
loss = -(projected * torch.log_softmax(current_logits, dim=-1)).sum(dim=-1).mean()
return loss
Experimental Results Comparison
Performance comparison on MuJoCo's HalfCheetah-v4 environment:
| Algorithm | Average reward at 1M steps | Learning stability | Sample efficiency |
|---|---|---|---|
| A2C (continuous) | ~3000 | Medium | Low |
| DDPG | ~8000 | Low (sensitive) | High |
| D4PG | ~9500 | High | High |
DDPG is sensitive to hyperparameters but achieves high performance when well-tuned. D4PG is more stable thanks to distributional learning.
Key Takeaways
- Continuous action spaces are handled with Gaussian policies (A2C) or deterministic policies (DDPG)
- DDPG applies DQN's core techniques (replay buffer, target network) to continuous spaces
- OU noise provides temporally correlated, smooth exploration
- Distributional policy gradient (D4PG) improves stability by learning the distribution of returns
In the next post, we will cover Trust Region methods (TRPO, PPO, ACKTR) that guarantee stability of policy updates.