Skip to content

Split View: [심층 강화학습] 14. 연속 행동 공간: DDPG와 분포 정책

|

[심층 강화학습] 14. 연속 행동 공간: DDPG와 분포 정책

개요

지금까지 다룬 대부분의 환경은 이산 행동 공간(왼쪽/오른쪽, 각 방향키 등)을 가졌다. 하지만 로봇 제어, 자율주행, 물리 시뮬레이션 같은 현실 문제에서는 연속 행동 공간이 필요하다. 관절 각도, 토크, 가속도 등은 실수 값이기 때문이다.

이 글에서는 연속 행동 공간에 대한 정책 설계, A2C 확장, DDPG(Deep Deterministic Policy Gradient), 그리고 분포 정책 그래디언트를 다룬다.


왜 연속 공간이 필요한가

이산화의 한계

연속 행동을 이산화하는 방법(예: 토크를 -1, -0.5, 0, 0.5, 1로 분할)은 간단하지만 문제가 있다:

  • 차원의 저주: N개의 관절 각각을 K단계로 이산화하면 행동 공간이 K^N으로 폭발한다
  • 정밀도 부족: 이산화 간격 사이의 세밀한 제어가 불가능하다
  • 비효율성: 대부분의 이산 행동 조합은 무의미하다

예를 들어 6관절 로봇을 10단계로 이산화하면 10^6 = 백만 개의 행동이 생긴다.

연속 행동 공간의 표현

연속 정책은 행동의 확률 분포를 출력한다:

  • 가우시안 정책: 평균과 분산을 출력하여 정규분포에서 행동을 샘플링
  • 베타 정책: 유한 범위의 행동에 적합한 베타 분포 사용
  • 결정적 정책: 하나의 행동 값을 직접 출력 (DDPG)

행동 공간 설계

import torch
import torch.nn as nn
import numpy as np

class ContinuousActionSpace:
    """연속 행동 공간 정의"""

    def __init__(self, low, high):
        self.low = np.array(low, dtype=np.float32)
        self.high = np.array(high, dtype=np.float32)
        self.shape = self.low.shape

    def clip(self, action):
        return np.clip(action, self.low, self.high)

    def sample(self):
        return np.random.uniform(self.low, self.high)

A2C를 연속 행동 공간에 적용

기존 A2C의 카테고리컬 분포를 가우시안 분포로 교체한다:

class ContinuousActorCritic(nn.Module):
    """연속 행동 공간용 Actor-Critic"""

    def __init__(self, obs_size, act_size):
        super().__init__()

        self.shared = nn.Sequential(
            nn.Linear(obs_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

        # Actor: 평균과 로그 표준편차 출력
        self.mu = nn.Sequential(
            nn.Linear(256, act_size),
            nn.Tanh(),  # 행동 범위를 -1, 1로 제한
        )
        # 로그 표준편차는 상태 독립 파라미터로 학습
        self.log_std = nn.Parameter(torch.zeros(act_size))

        # Critic
        self.value = nn.Linear(256, 1)

    def forward(self, obs):
        features = self.shared(obs)
        mu = self.mu(features)
        std = self.log_std.exp()
        value = self.value(features)
        return mu, std, value

    def get_action(self, obs):
        mu, std, value = self.forward(obs)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action, log_prob, value

    def evaluate(self, obs, action):
        mu, std, value = self.forward(obs)
        dist = torch.distributions.Normal(mu, std)
        log_prob = dist.log_prob(action).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        return log_prob, entropy, value

연속 A2C 학습

def train_continuous_a2c(env, model, num_steps=2048, num_updates=1000,
                         gamma=0.99, lr=3e-4, entropy_coef=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    obs = env.reset()
    for update in range(num_updates):
        # 경험 수집
        observations = []
        actions = []
        rewards = []
        dones = []
        log_probs = []
        values = []

        for _ in range(num_steps):
            obs_t = torch.FloatTensor(obs).unsqueeze(0)
            action, log_prob, value = model.get_action(obs_t)

            action_np = action.detach().numpy().flatten()
            action_clipped = np.clip(action_np, env.action_space.low,
                                     env.action_space.high)

            next_obs, reward, done, _ = env.step(action_clipped)

            observations.append(obs)
            actions.append(action.detach())
            rewards.append(reward)
            dones.append(done)
            log_probs.append(log_prob)
            values.append(value.squeeze())

            obs = next_obs if not done else env.reset()

        # GAE(Generalized Advantage Estimation) 계산
        advantages, returns = compute_gae(rewards, values, dones, gamma)

        # 업데이트
        obs_batch = torch.FloatTensor(np.array(observations))
        act_batch = torch.cat(actions)
        adv_batch = torch.FloatTensor(advantages)
        ret_batch = torch.FloatTensor(returns)

        new_log_probs, entropy, new_values = model.evaluate(
            obs_batch, act_batch
        )

        policy_loss = -(new_log_probs * adv_batch).mean()
        value_loss = (ret_batch - new_values.squeeze()).pow(2).mean()
        entropy_loss = -entropy.mean()

        loss = policy_loss + 0.5 * value_loss + entropy_coef * entropy_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

def compute_gae(rewards, values, dones, gamma, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0
    values_list = [v.detach().item() for v in values]
    values_list.append(0)  # 부트스트랩

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values_list[t]
            gae = delta
        else:
            delta = rewards[t] + gamma * values_list[t+1] - values_list[t]
            gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    returns = [a + v for a, v in zip(advantages, values_list[:-1])]
    return advantages, returns

결정적 정책 그래디언트 (DDPG)

DDPG의 핵심 아이디어

DDPG(Deep Deterministic Policy Gradient) 는 DQN의 아이디어를 연속 행동 공간에 적용한 알고리즘이다:

  • 결정적 정책: 확률적으로 샘플링하지 않고, 상태에 대해 하나의 행동을 직접 출력
  • 경험 리플레이: DQN처럼 과거 경험을 버퍼에 저장하고 재사용
  • 타겟 네트워크: 학습 안정화를 위한 소프트 업데이트
  • 탐색 노이즈: 결정적 정책에 노이즈를 추가하여 탐색

네트워크 구조

class DDPGActor(nn.Module):
    """결정적 정책 네트워크"""

    def __init__(self, obs_size, act_size, act_limit):
        super().__init__()
        self.act_limit = act_limit
        self.net = nn.Sequential(
            nn.Linear(obs_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, act_size),
            nn.Tanh(),
        )

    def forward(self, obs):
        return self.act_limit * self.net(obs)

class DDPGCritic(nn.Module):
    """행동 가치 함수 Q(s, a)"""

    def __init__(self, obs_size, act_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size + act_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1),
        )

    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        return self.net(x)

탐색 노이즈: Ornstein-Uhlenbeck 프로세스

DDPG는 시간적 상관관계가 있는 OU 노이즈를 사용하여 연속적이고 부드러운 탐색을 수행한다:

class OrnsteinUhlenbeckNoise:
    """시간 상관 탐색 노이즈"""

    def __init__(self, size, mu=0.0, theta=0.15, sigma=0.2):
        self.size = size
        self.mu = mu
        self.theta = theta
        self.sigma = sigma
        self.state = np.ones(size) * mu

    def reset(self):
        self.state = np.ones(self.size) * self.mu

    def sample(self):
        dx = (self.theta * (self.mu - self.state)
              + self.sigma * np.random.randn(self.size))
        self.state += dx
        return self.state.copy()

DDPG 전체 구현

import copy
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=1000000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions),
                np.array(rewards, dtype=np.float32),
                np.array(next_states),
                np.array(dones, dtype=np.float32))

    def __len__(self):
        return len(self.buffer)

class DDPGAgent:
    def __init__(self, obs_size, act_size, act_limit,
                 gamma=0.99, tau=0.005, lr_actor=1e-4, lr_critic=1e-3):
        self.gamma = gamma
        self.tau = tau

        # 메인 네트워크
        self.actor = DDPGActor(obs_size, act_size, act_limit)
        self.critic = DDPGCritic(obs_size, act_size)

        # 타겟 네트워크 (소프트 업데이트 대상)
        self.target_actor = copy.deepcopy(self.actor)
        self.target_critic = copy.deepcopy(self.critic)

        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=lr_actor)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=lr_critic)

        self.replay_buffer = ReplayBuffer()
        self.noise = OrnsteinUhlenbeckNoise(act_size)

    def select_action(self, state, add_noise=True):
        state_t = torch.FloatTensor(state).unsqueeze(0)
        action = self.actor(state_t).detach().numpy().flatten()
        if add_noise:
            action += self.noise.sample()
        return np.clip(action, -self.actor.act_limit, self.actor.act_limit)

    def update(self, batch_size=256):
        if len(self.replay_buffer) < batch_size:
            return

        states, actions, rewards, next_states, dones = \
            self.replay_buffer.sample(batch_size)

        states_t = torch.FloatTensor(states)
        actions_t = torch.FloatTensor(actions)
        rewards_t = torch.FloatTensor(rewards).unsqueeze(1)
        next_states_t = torch.FloatTensor(next_states)
        dones_t = torch.FloatTensor(dones).unsqueeze(1)

        # === Critic 업데이트 ===
        with torch.no_grad():
            next_actions = self.target_actor(next_states_t)
            target_q = self.target_critic(next_states_t, next_actions)
            target_value = rewards_t + self.gamma * (1 - dones_t) * target_q

        current_q = self.critic(states_t, actions_t)
        critic_loss = nn.MSELoss()(current_q, target_value)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # === Actor 업데이트 ===
        # 정책의 목표: Q(s, mu(s))를 최대화
        predicted_actions = self.actor(states_t)
        actor_loss = -self.critic(states_t, predicted_actions).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # === 타겟 네트워크 소프트 업데이트 ===
        self._soft_update(self.actor, self.target_actor)
        self._soft_update(self.critic, self.target_critic)

    def _soft_update(self, source, target):
        for src_param, tgt_param in zip(source.parameters(),
                                         target.parameters()):
            tgt_param.data.copy_(
                self.tau * src_param.data + (1 - self.tau) * tgt_param.data
            )

분포 정책 그래디언트

D4PG: Distributed Distributional DDPG

D4PG는 DDPG에 분포 강화학습(Distributional RL)을 결합한 알고리즘이다. Q 값의 기대값 대신 전체 분포를 학습한다.

class DistributionalCritic(nn.Module):
    """분포 Q 함수: 리턴의 분포를 예측"""

    def __init__(self, obs_size, act_size,
                 num_atoms=51, v_min=-10.0, v_max=10.0):
        super().__init__()
        self.num_atoms = num_atoms
        self.v_min = v_min
        self.v_max = v_max

        self.support = torch.linspace(v_min, v_max, num_atoms)
        self.delta_z = (v_max - v_min) / (num_atoms - 1)

        self.net = nn.Sequential(
            nn.Linear(obs_size + act_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, num_atoms),
        )

    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        logits = self.net(x)
        probs = torch.softmax(logits, dim=-1)
        return probs

    def get_q_value(self, obs, action):
        """기대 Q값 계산"""
        probs = self.forward(obs, action)
        return (probs * self.support.unsqueeze(0)).sum(dim=-1, keepdim=True)

분포 TD 학습

def distributional_critic_loss(critic, target_critic, target_actor,
                                states, actions, rewards, next_states,
                                dones, gamma=0.99):
    """분포 크리틱의 손실 계산 (카테고리컬 프로젝션)"""
    num_atoms = critic.num_atoms
    v_min = critic.v_min
    v_max = critic.v_max
    delta_z = critic.delta_z
    support = critic.support

    with torch.no_grad():
        next_actions = target_actor(next_states)
        target_probs = target_critic(next_states, next_actions)

        # 벨만 업데이트된 서포트
        tz = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * \
             support.unsqueeze(0)
        tz = tz.clamp(v_min, v_max)

        # 카테고리컬 프로젝션
        b = (tz - v_min) / delta_z
        l = b.floor().long()
        u = b.ceil().long()

        projected = torch.zeros_like(target_probs)
        for i in range(num_atoms):
            projected.scatter_add_(1, l[:, i:i+1],
                                   target_probs[:, i:i+1] * (u[:, i:i+1].float() - b[:, i:i+1]))
            projected.scatter_add_(1, u[:, i:i+1],
                                   target_probs[:, i:i+1] * (b[:, i:i+1] - l[:, i:i+1].float()))

    # 크로스 엔트로피 손실
    current_logits = critic.net(torch.cat([states, actions], dim=-1))
    loss = -(projected * torch.log_softmax(current_logits, dim=-1)).sum(dim=-1).mean()

    return loss

실험 결과 비교

MuJoCo의 HalfCheetah-v4 환경에서의 성능 비교:

알고리즘100만 스텝 평균 보상학습 안정성샘플 효율성
A2C (연속)약 3000중간낮음
DDPG약 8000낮음 (민감)높음
D4PG약 9500높음높음

DDPG는 하이퍼파라미터에 민감하지만 잘 튜닝하면 높은 성능을 보인다. D4PG는 분포 학습 덕분에 더 안정적이다.


핵심 요약

  • 연속 행동 공간은 가우시안 정책(A2C) 또는 결정적 정책(DDPG)으로 처리한다
  • DDPG는 DQN의 핵심 기법(리플레이 버퍼, 타겟 네트워크)을 연속 공간에 적용한다
  • OU 노이즈로 시간적 상관이 있는 부드러운 탐색을 수행한다
  • 분포 정책 그래디언트(D4PG)는 리턴의 분포를 학습하여 안정성을 높인다

다음 글에서는 정책 업데이트의 안정성을 보장하는 Trust Region 방법(TRPO, PPO, ACKTR) 을 다룬다.

[Deep RL] 14. Continuous Action Spaces: DDPG and Distributional Policies

Overview

Most environments we have covered so far had discrete action spaces (left/right, directional keys, etc.). However, real-world problems like robot control, autonomous driving, and physics simulations require continuous action spaces. Joint angles, torques, and accelerations are real-valued quantities.

This post covers policy design for continuous action spaces, A2C extension, DDPG (Deep Deterministic Policy Gradient), and distributional policy gradients.


Why Continuous Spaces Are Needed

Limitations of Discretization

Discretizing continuous actions (e.g., splitting torque into -1, -0.5, 0, 0.5, 1) is simple but problematic:

  • Curse of dimensionality: Discretizing each of N joints into K levels causes the action space to explode to K^N
  • Lack of precision: Fine-grained control between discretization intervals is impossible
  • Inefficiency: Most discrete action combinations are meaningless

For example, discretizing a 6-joint robot into 10 levels creates 10^6 = one million actions.

Representing Continuous Action Spaces

Continuous policies output probability distributions over actions:

  • Gaussian policy: Outputs mean and variance, samples actions from a normal distribution
  • Beta policy: Uses beta distribution suitable for bounded action ranges
  • Deterministic policy: Directly outputs a single action value (DDPG)

Action Space Design

import torch
import torch.nn as nn
import numpy as np

class ContinuousActionSpace:
    """Continuous action space definition"""

    def __init__(self, low, high):
        self.low = np.array(low, dtype=np.float32)
        self.high = np.array(high, dtype=np.float32)
        self.shape = self.low.shape

    def clip(self, action):
        return np.clip(action, self.low, self.high)

    def sample(self):
        return np.random.uniform(self.low, self.high)

Applying A2C to Continuous Action Spaces

We replace the categorical distribution in the existing A2C with a Gaussian distribution:

class ContinuousActorCritic(nn.Module):
    """Actor-Critic for continuous action spaces"""

    def __init__(self, obs_size, act_size):
        super().__init__()

        self.shared = nn.Sequential(
            nn.Linear(obs_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )

        # Actor: outputs mean and log standard deviation
        self.mu = nn.Sequential(
            nn.Linear(256, act_size),
            nn.Tanh(),  # Constrains action range to -1, 1
        )
        # Log standard deviation is a state-independent learnable parameter
        self.log_std = nn.Parameter(torch.zeros(act_size))

        # Critic
        self.value = nn.Linear(256, 1)

    def forward(self, obs):
        features = self.shared(obs)
        mu = self.mu(features)
        std = self.log_std.exp()
        value = self.value(features)
        return mu, std, value

    def get_action(self, obs):
        mu, std, value = self.forward(obs)
        dist = torch.distributions.Normal(mu, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action, log_prob, value

    def evaluate(self, obs, action):
        mu, std, value = self.forward(obs)
        dist = torch.distributions.Normal(mu, std)
        log_prob = dist.log_prob(action).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        return log_prob, entropy, value

Continuous A2C Training

def train_continuous_a2c(env, model, num_steps=2048, num_updates=1000,
                         gamma=0.99, lr=3e-4, entropy_coef=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    obs = env.reset()
    for update in range(num_updates):
        # Collect experiences
        observations = []
        actions = []
        rewards = []
        dones = []
        log_probs = []
        values = []

        for _ in range(num_steps):
            obs_t = torch.FloatTensor(obs).unsqueeze(0)
            action, log_prob, value = model.get_action(obs_t)

            action_np = action.detach().numpy().flatten()
            action_clipped = np.clip(action_np, env.action_space.low,
                                     env.action_space.high)

            next_obs, reward, done, _ = env.step(action_clipped)

            observations.append(obs)
            actions.append(action.detach())
            rewards.append(reward)
            dones.append(done)
            log_probs.append(log_prob)
            values.append(value.squeeze())

            obs = next_obs if not done else env.reset()

        # GAE (Generalized Advantage Estimation) computation
        advantages, returns = compute_gae(rewards, values, dones, gamma)

        # Update
        obs_batch = torch.FloatTensor(np.array(observations))
        act_batch = torch.cat(actions)
        adv_batch = torch.FloatTensor(advantages)
        ret_batch = torch.FloatTensor(returns)

        new_log_probs, entropy, new_values = model.evaluate(
            obs_batch, act_batch
        )

        policy_loss = -(new_log_probs * adv_batch).mean()
        value_loss = (ret_batch - new_values.squeeze()).pow(2).mean()
        entropy_loss = -entropy.mean()

        loss = policy_loss + 0.5 * value_loss + entropy_coef * entropy_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

def compute_gae(rewards, values, dones, gamma, lam=0.95):
    """Generalized Advantage Estimation"""
    advantages = []
    gae = 0
    values_list = [v.detach().item() for v in values]
    values_list.append(0)  # Bootstrap

    for t in reversed(range(len(rewards))):
        if dones[t]:
            delta = rewards[t] - values_list[t]
            gae = delta
        else:
            delta = rewards[t] + gamma * values_list[t+1] - values_list[t]
            gae = delta + gamma * lam * gae
        advantages.insert(0, gae)

    returns = [a + v for a, v in zip(advantages, values_list[:-1])]
    return advantages, returns

Deterministic Policy Gradient (DDPG)

Core Idea of DDPG

DDPG (Deep Deterministic Policy Gradient) applies DQN ideas to continuous action spaces:

  • Deterministic policy: Instead of probabilistic sampling, directly outputs a single action for a given state
  • Experience replay: Stores past experiences in a buffer and reuses them, like DQN
  • Target network: Soft updates for learning stability
  • Exploration noise: Adds noise to the deterministic policy for exploration

Network Architecture

class DDPGActor(nn.Module):
    """Deterministic policy network"""

    def __init__(self, obs_size, act_size, act_limit):
        super().__init__()
        self.act_limit = act_limit
        self.net = nn.Sequential(
            nn.Linear(obs_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, act_size),
            nn.Tanh(),
        )

    def forward(self, obs):
        return self.act_limit * self.net(obs)

class DDPGCritic(nn.Module):
    """Action-value function Q(s, a)"""

    def __init__(self, obs_size, act_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size + act_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 1),
        )

    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        return self.net(x)

Exploration Noise: Ornstein-Uhlenbeck Process

DDPG uses temporally correlated OU noise for smooth, continuous exploration:

class OrnsteinUhlenbeckNoise:
    """Temporally correlated exploration noise"""

    def __init__(self, size, mu=0.0, theta=0.15, sigma=0.2):
        self.size = size
        self.mu = mu
        self.theta = theta
        self.sigma = sigma
        self.state = np.ones(size) * mu

    def reset(self):
        self.state = np.ones(self.size) * self.mu

    def sample(self):
        dx = (self.theta * (self.mu - self.state)
              + self.sigma * np.random.randn(self.size))
        self.state += dx
        return self.state.copy()

Full DDPG Implementation

import copy
from collections import deque
import random

class ReplayBuffer:
    def __init__(self, capacity=1000000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (np.array(states), np.array(actions),
                np.array(rewards, dtype=np.float32),
                np.array(next_states),
                np.array(dones, dtype=np.float32))

    def __len__(self):
        return len(self.buffer)

class DDPGAgent:
    def __init__(self, obs_size, act_size, act_limit,
                 gamma=0.99, tau=0.005, lr_actor=1e-4, lr_critic=1e-3):
        self.gamma = gamma
        self.tau = tau

        # Main networks
        self.actor = DDPGActor(obs_size, act_size, act_limit)
        self.critic = DDPGCritic(obs_size, act_size)

        # Target networks (soft update targets)
        self.target_actor = copy.deepcopy(self.actor)
        self.target_critic = copy.deepcopy(self.critic)

        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=lr_actor)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=lr_critic)

        self.replay_buffer = ReplayBuffer()
        self.noise = OrnsteinUhlenbeckNoise(act_size)

    def select_action(self, state, add_noise=True):
        state_t = torch.FloatTensor(state).unsqueeze(0)
        action = self.actor(state_t).detach().numpy().flatten()
        if add_noise:
            action += self.noise.sample()
        return np.clip(action, -self.actor.act_limit, self.actor.act_limit)

    def update(self, batch_size=256):
        if len(self.replay_buffer) < batch_size:
            return

        states, actions, rewards, next_states, dones = \
            self.replay_buffer.sample(batch_size)

        states_t = torch.FloatTensor(states)
        actions_t = torch.FloatTensor(actions)
        rewards_t = torch.FloatTensor(rewards).unsqueeze(1)
        next_states_t = torch.FloatTensor(next_states)
        dones_t = torch.FloatTensor(dones).unsqueeze(1)

        # === Critic update ===
        with torch.no_grad():
            next_actions = self.target_actor(next_states_t)
            target_q = self.target_critic(next_states_t, next_actions)
            target_value = rewards_t + self.gamma * (1 - dones_t) * target_q

        current_q = self.critic(states_t, actions_t)
        critic_loss = nn.MSELoss()(current_q, target_value)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # === Actor update ===
        # Policy objective: maximize Q(s, mu(s))
        predicted_actions = self.actor(states_t)
        actor_loss = -self.critic(states_t, predicted_actions).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # === Target network soft update ===
        self._soft_update(self.actor, self.target_actor)
        self._soft_update(self.critic, self.target_critic)

    def _soft_update(self, source, target):
        for src_param, tgt_param in zip(source.parameters(),
                                         target.parameters()):
            tgt_param.data.copy_(
                self.tau * src_param.data + (1 - self.tau) * tgt_param.data
            )

Distributional Policy Gradient

D4PG: Distributed Distributional DDPG

D4PG combines distributional reinforcement learning with DDPG. Instead of learning the expected Q-value, it learns the entire distribution of returns.

class DistributionalCritic(nn.Module):
    """Distributional Q function: predicts the distribution of returns"""

    def __init__(self, obs_size, act_size,
                 num_atoms=51, v_min=-10.0, v_max=10.0):
        super().__init__()
        self.num_atoms = num_atoms
        self.v_min = v_min
        self.v_max = v_max

        self.support = torch.linspace(v_min, v_max, num_atoms)
        self.delta_z = (v_max - v_min) / (num_atoms - 1)

        self.net = nn.Sequential(
            nn.Linear(obs_size + act_size, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, num_atoms),
        )

    def forward(self, obs, action):
        x = torch.cat([obs, action], dim=-1)
        logits = self.net(x)
        probs = torch.softmax(logits, dim=-1)
        return probs

    def get_q_value(self, obs, action):
        """Compute expected Q-value"""
        probs = self.forward(obs, action)
        return (probs * self.support.unsqueeze(0)).sum(dim=-1, keepdim=True)

Distributional TD Learning

def distributional_critic_loss(critic, target_critic, target_actor,
                                states, actions, rewards, next_states,
                                dones, gamma=0.99):
    """Distributional critic loss computation (categorical projection)"""
    num_atoms = critic.num_atoms
    v_min = critic.v_min
    v_max = critic.v_max
    delta_z = critic.delta_z
    support = critic.support

    with torch.no_grad():
        next_actions = target_actor(next_states)
        target_probs = target_critic(next_states, next_actions)

        # Bellman-updated support
        tz = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * \
             support.unsqueeze(0)
        tz = tz.clamp(v_min, v_max)

        # Categorical projection
        b = (tz - v_min) / delta_z
        l = b.floor().long()
        u = b.ceil().long()

        projected = torch.zeros_like(target_probs)
        for i in range(num_atoms):
            projected.scatter_add_(1, l[:, i:i+1],
                                   target_probs[:, i:i+1] * (u[:, i:i+1].float() - b[:, i:i+1]))
            projected.scatter_add_(1, u[:, i:i+1],
                                   target_probs[:, i:i+1] * (b[:, i:i+1] - l[:, i:i+1].float()))

    # Cross-entropy loss
    current_logits = critic.net(torch.cat([states, actions], dim=-1))
    loss = -(projected * torch.log_softmax(current_logits, dim=-1)).sum(dim=-1).mean()

    return loss

Experimental Results Comparison

Performance comparison on MuJoCo's HalfCheetah-v4 environment:

AlgorithmAverage reward at 1M stepsLearning stabilitySample efficiency
A2C (continuous)~3000MediumLow
DDPG~8000Low (sensitive)High
D4PG~9500HighHigh

DDPG is sensitive to hyperparameters but achieves high performance when well-tuned. D4PG is more stable thanks to distributional learning.


Key Takeaways

  • Continuous action spaces are handled with Gaussian policies (A2C) or deterministic policies (DDPG)
  • DDPG applies DQN's core techniques (replay buffer, target network) to continuous spaces
  • OU noise provides temporally correlated, smooth exploration
  • Distributional policy gradient (D4PG) improves stability by learning the distribution of returns

In the next post, we will cover Trust Region methods (TRPO, PPO, ACKTR) that guarantee stability of policy updates.