Policy Gradients

1,565 words 8 min read Written a month ago

The Nuts and Bolts

The ultimate goal in RL is to learn a policy π:SΔ(A)\pi: \mathcal{S} \rightarrow \Delta(\mathcal{A}) that maximises the expected return in an MDP M=(S,A,P,r,ρ0,γ)\mathcal{M} = (\mathcal{S}, \mathcal{A}, P, r, \rho_0, \gamma) with reward r:S×A×SRr: \mathcal{S} \times \mathcal{A} \times \mathcal{S} \rightarrow \R provided by the environment. At time tt the agent samples atπ(st)a_t \sim \pi(\cdot|s_t) and the environment transitions to st+1P(st,at)s_{t+1} \sim P(\cdot|s_t,a_t) where atAa_t\in\mathcal{A} and stSs_t\in\mathcal{S}.

It’s convenient to bundle one full episode into a trajectory as τ=(s0,a0,,aT,sT+1)\tau = (s_0, a_0, \ldots, a_{T}, s_{T+1}) for t[0,T]t \in [0, T], representing T+1T+1 transitions.

In general,

  1. Observe the current state (or observation) sts_t (s0ρ0()s_0\sim \rho_0(\cdot) at episode start)
  2. Select the action ata_t
  3. Transition to the next state st+1P(st,at)s_{t+1} \sim P(\cdot|s_t,a_t)
  4. Retrieve reward from the environment rt=r(st,at,st+1)r_t = r(s_t,a_t,s_{t+1})
  5. Repeat

So far I described the stochastic setting. There also exists a deterministic special case with the policy at=μ(st)a_t = \mu(s_t), and if dynamics are deterministic, st+1=f(st,at)s_{t+1} = f(s_t,a_t) when environments are noise free (rare in practice). In most real world cases, stochasticity is impossible to ignore.

From the oversimplified RL algorithm above 2) is the chief concern of most RL research. Since this post is titled “Policy Gradients” I’m going to focus on action selection. There’s plenty of room for later posts on reward functions (stay tuned).

Markov Decision Processes

Before we start its useful to formally define a Markov Decision Process (MDP). It is a stochastic system where the future depends only on the present and not the past. Mathematically,

P(St+1St)=P(St+1S1,,St).P(S_{t+1}|S_t) = P(S_{t+1}|S_1,\ldots, S_t).

If you’re familiar with greedy algorithms, it follows the same general principles.

Action Selection

  1. Select the action ata_t

The best action, should by definition yield the best reward or more concretely the expected return and it comes from the policy atπ(st)a_t \sim \pi(\cdot|s_t). The effectiveness of this policy can be quantified from the value function,

Vπ(s)=Eτπ[t=0γtrt    s0=s].\begin{equation} V^\pi(s) = \mathbb E_{\tau \sim \pi}\left[\sum_{t=0}^{\infty} \gamma^t r_{t}\;\bigg|\; s_0 = s\right]. \end{equation}

This explains how valuable a state sts_t is when the action is sampled directly from the policy atπ(st)a_t \sim \pi(\cdot|s_t) assuming the start of the episode is at t=0t = 0.

Eq. 1 employs a discounted reward with γ\gamma to determine how myopic (γ0)(\gamma \approx 0) or foresighted (γ1)(\gamma \approx 1) the agent should be. The infinite sum case requires that γ[0,1)\gamma \in [0,1) to converge. If the present reward isn’t more valuable than a future reward then discounting can be removed in favour of a finite time horizon,

Vπ(s)=Eτπ[R(τ)s0=s],R(τ)={k=0γkrk,k=0Trk.V^\pi(s) = \mathbb E_{\tau\sim\pi }\left[R(\tau)\mid s_0 = s\right] ,\quad R(\tau) =\begin{cases} \sum_{k=0}^\infty \gamma^k r_{k}, & \\ \sum_{k=0}^{T} r_{k}. \end{cases}

This is assuming the episode starts t=0t = 0, otherwise we could express it more generally for any tt,

R(τ)={k=0γkrt+k,k=0Ttrt+k.R(\tau) =\begin{cases} \sum_{k=0}^\infty \gamma^k r_{t+k}, \\ \sum_{k=0}^{T - t} r_{t+k}. \end{cases}

The big E\mathbb{E} represents the expected value Eτπ[R(τ)]\mathbb{E}_{\tau\sim\pi }[R(\tau)] of a certain policy π\pi; at its core RL is basically an optimisation problem around this equation. It’s useful for determining how effective the policy is performing. We can use it to formally solve for the optimal policy,

π=arg maxπEτπ[R(τ)].\pi^* = \argmax_\pi \mathbb{E}_{\tau\sim\pi }[R(\tau)].

The value of a state Vπ(s)V^\pi(s) is useful but not the complete picture. We might be interested about the effect of an arbitrary action aa coupled with a given state. We can use the Q-function for this,

Qπ(s,a)=Eτπ[R(τ)s0=s,a0=a].Q^\pi(s, a) = \mathbb E_{\tau\sim\pi}[R(\tau)\mid s_0 = s, a_0 = a].

Unless stated otherwise, Q(s,a)Q(s,a) denotes Qπ(s,a)Q^{\pi}(s,a) for the current policy; QQ^{*} denotes the optimal action-value.

If the action space A\mathcal{A} and state space S\mathcal{S} are both small and discrete, its feasible to sample a large quantity of Q-values into a Q-table with SA|\mathcal{S}| |\mathcal{A}| entries. This can be used to produce the next best action based on,

a=arg maxaAQπ(s,a).a = \argmax_{a\in \mathcal{A}} Q^\pi(s,a).

In discrete action spaces arg maxaAQπ(s,a)\argmax_{a\in \mathcal{A}} Q^\pi(s,a) can be chosen via a finite O(A)O(|\mathcal{A}|) scan of the Q-table which can be defined in one line of numpy.

Q = np.array((state_count, action_count), dtype=float)
a = np.argmax(Q[s, :])

In continuous spaces arg maxaAQπ(s,a)\argmax_{a \in \mathcal{A}} Q^\pi(s, a) is tractable only if Qπ(s,)Q^\pi(s,\cdot) is concave and differentiable, allowing convex optimisation techniques to find the optimal action.

Most real world problems (especially robotics) unfortunately suffer from the curse of dimensionality with extremely large continuous action spaces of torque control, action controls and gripper commands. State spaces are likewise high dimensional and continuous, spanning joint angles and velocities, torques, motor currents and other sensors readings.

It’s always easier to grok compute limits with real examples. A typical Q-table takes up SA×dsize|\mathcal{S}| |\mathcal{A}| \times \texttt{dsize} bytes of memory where dsize=4\texttt{dsize} = 4 bytes if you’re using single precision floats. An indoor mobile navigator operating in a 100 x 100 m workspace with a 0.1 m discretisation per cell has a total state space of 1000×1000=1061000\times 1000 = 10^6 cells. Even before adding orientation and speed the Q-table already needs 10610^6 rows. Now add 6 actions and you’re total memory footprint is 6×106×4 bytes=24 MB6\times 10^6 \times 4 \texttt{ bytes} = 24 \texttt{ MB}. What happens when we add a zz-axis (discretised the same) for a new 100 x 100 x 100 m workspace for a drone? That’s 2424 GB. As you can see this does not scale well, especially on embedded hardware that often has significant compute limits.

Hence why pure argmaxaQπ(s,a)\arg\max_a Q^\pi(s,a) is often impractical, motivating other methods like actor-critic and deterministic policy gradients; I’ll cover when to use value-based methods (SARSA, Q-learning, DQN) in a separate post; they’re still useful.

However, we’re here to talk about policy gradients.

Policy Gradients

The real world is high dimensional and continuous, so we cannot sample V(s)V(s) densely enough to find a global best action without catastrophe like having a robot arm swing at you, or drone fly through a wall.

Instead we can borrow concepts from supervised learning with gradient descent to optimise a policy parametrised over θRd\theta \in \mathbb{R}^d directly from interaction, updating πθ\pi_\theta with gradients of the expected return estimated from rollouts in the real environment. In short, policy gradients learn by doing.

Gradient updates in supervised learning often minimise a loss θθηθJ(θ)\theta \leftarrow \theta - \eta \nabla_{\theta}J(\theta) via descent whereas in policy gradients we want to maximise an objective,

J(θ)=Eτπθ[R(τ)],J(\theta) = \mathbb E_{ \tau\sim \pi_{\theta}}\left[R(\tau)\right],

therefore we’re interested in computing gradient ascent,

θθ+αθJ(θ).\theta \leftarrow \theta + \alpha \nabla_{\theta}J(\theta).

Beyond gradient-based non-convex optimisation, little carries over from supervised learning; techniques like dropout and very deep networks do not transfer cleanly into RL and can even decrease learning stability (way more where that came from in this talk by John Schulman).

RL gets exciting once we attempt to compute θJ(θ)\nabla_{\theta}J(\theta),

θJ(θ)=θJ(πθ)=θEτπθ[R(τ)]=θτP(τθ)R(τ)dτ=τθP(τθ)R(τ)dτ.\begin{align*} \nabla_\theta J(\theta) = \nabla_\theta J(\pi_\theta) &= \nabla_\theta \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)] \\ &= \nabla_\theta \int_\tau P(\tau|\theta)R(\tau) \,\mathrm{d}\tau \\ &= \int_\tau \nabla_\theta P(\tau|\theta)R(\tau) \,\mathrm{d}\tau. \end{align*}

A big issue here is the intractable trajectory integral τP(τθ)R(τ)dτ\int_\tau P(\tau|\theta)R(\tau)\,\mathrm{d}\tau. In continuous MDPs the trajectory space τT\tau \in \mathcal{T} is often enormous and uncountable, involves unknown dynamics P(st+1st,at)P(s_{t+1} | s_t, a_t) and the integral has no closed form. The log-derivative trick can help substitute θP(τθ)\nabla_\theta P(\tau|\theta) for something tractable,

θlogP(τθ)=1P(τθ)θP(τθ)θP(τθ)=P(τθ)θlogP(τθ).\begin{align*} \nabla_\theta \log P(\tau|\theta) &= \dfrac{1}{P(\tau|\theta)}\nabla_\theta P(\tau|\theta) \\ \Rightarrow \nabla_\theta P(\tau|\theta) &= P(\tau|\theta) \nabla_\theta \log P(\tau|\theta). \end{align*}

The final gradient expression becomes,

θJ(θ)=τθP(τθ)R(τ)dτ=τP(τθ)θlogP(τθ)R(τ)dτ=Eτπθ[θlogP(τθ)R(τ)].\begin{align} \nabla_\theta J(\theta) &= \int_\tau \nabla_\theta P(\tau|\theta)R(\tau) \,\mathrm{d}\tau \\ &= \int_\tau P(\tau|\theta) \nabla_\theta \log P(\tau|\theta)R(\tau) \,\mathrm{d}\tau \\ &= \mathbb{E}_{\tau \sim \pi_\theta} [\nabla_\theta \log P(\tau|\theta)R(\tau)]. \end{align}

We can use the policy gradient theorem to remove the intractable terms from Eq. 4,

θJ(θ)=θEτπθ[t=0Tγtr(st,at,st+1)]=Eτπθ[t=0Tθlogπθ(atst)R(τ)].\nabla_\theta J(\theta) = \nabla_\theta \mathbb E_{ \tau\sim \pi_{\theta}}\left[\sum_{t=0}^T \gamma^tr(s_t,a_t, s_{t+1})\right] = \mathbb{E}_{\tau \sim \pi_\theta} \left[\sum^T_{t=0} \nabla_\theta \log\pi_\theta(a_t|s_t)R(\tau)\right].

Policy Gradient Theorem

We can instead augment the definition of P(τθ)P(\tau | \theta) where ρ0\rho_0 is the initial state distribution with the log-derivative trick to remove terms with unknown dynamics,

P(τθ)=ρ0(s0)t=0TP(st+1st,at)πθ(atst).P(\tau | \theta) = \rho_0(s_0) \prod^{T}_{t=0} P(s_{t+1}|s_t, a_t) \pi_\theta(a_t|s_t).

Calculating the gradient-log-probability θlogP(τθ)\nabla_\theta \log P(\tau | \theta) in Eq. 5,

logP(τθ)=logρ0(s0)+t=0TlogP(st+1st,at)πθ(atst),logixi=ilogxi,=logρ0(s0)+t=0T[logP(st+1st,at)+logπθ(atst)],θlogP(τθ)=θlogρ0(s0)+t=0T[θlogP(st+1st,at)+θlogπθ(atst)].\begin{align*} \log P(\tau | \theta) &= \log \rho_0(s_0) + \sum^T_{t=0} \log P(s_{t+1}|s_t, a_t) \pi_\theta(a_t|s_t),\\ & \because \log \prod_i x_i = \sum_i \log x_i , \\ & = \log \rho_0(s_0) + \sum^T_{t=0} \left[ \log P(s_{t+1}|s_t, a_t) + \log\pi_\theta(a_t|s_t)\right], \\ \Rightarrow \nabla_\theta \log P(\tau | \theta) &= \nabla_\theta \log \rho_0(s_0) + \sum^T_{t=0} \left[ \nabla_\theta \log P(s_{t+1}|s_t, a_t) + \nabla_\theta \log\pi_\theta(a_t|s_t)\right]. \end{align*}

Since both ρ0\rho_0 and PP are independent of θ\theta their gradients are zero,

θlogP(τθ)=θlogρ0(s0)+t=0T[θlogP(st+1st,at)+θlogπθ(atst)]=t=0Tθlogπθ(atst).\begin{align*} \nabla_\theta \log P(\tau | \theta) &= \cancel{\nabla_\theta \log \rho_0(s_0)} + \sum^T_{t=0} \left[ \cancel{\nabla_\theta \log P(s_{t+1}|s_t, a_t)} + \nabla_\theta \log\pi_\theta(a_t|s_t)\right] \\ &= \sum^T_{t=0} \nabla_\theta \log\pi_\theta(a_t|s_t). \end{align*}

Substituting back into Eq. 4 produces the final result,

θJ(θ)=Eτπθ[t=0Tθlogπθ(atst)R(τ)].\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[\sum^T_{t=0} \nabla_\theta \log\pi_\theta(a_t|s_t)R(\tau)\right].

Whilst this serves as the foundation for a general policy gradient, there are many variants which minimise variance through different choices of Ψt\Psi_t,

θJ(θ)=Eτπθ[t=0Tθlogπθ(atst)Ψt].\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[\sum^T_{t=0} \nabla_\theta \log\pi_\theta(a_t|s_t)\Psi_t\right].

Schulman et al summarised the following options for Ψt\Psi_t,

  1. t=0rt\sum_{t=0}^\infty r_{t} total reward for the trajectory
  2. t=trt\sum_{t'=t}^\infty r_{t'} reward following action ata_t
  3. t=trtb(st)\sum_{t'=t}^\infty r_{t'} - b(s_t) baselined version of (2)
  4. Qπ(st,at)=Est+1:,at+1:[l=0rt+l]Q^{\pi}(s_t,a_t) = \mathbb E_{\substack{s_{t+1}:\infin, a_{t+1}:\infin}}\left[ \sum_{l=0}^{\infin}r_{t+l} \right] state-action value function
  5. Aπ(st,at)=Qπ(st,at)Vπ(st)A^\pi(s_t,a_t) = Q^{\pi}(s_t,a_t) - V^{\pi}(s_t) advantage function
  6. rt+Vπ(st+1)Vπ(st)r_t + V^\pi(s_{t+1}) - V^\pi(s_t) TD residual where Vπ(st)=Est+1:,at:[l=0rt+l]V^{\pi}(s_t) = \mathbb E_{\substack{s_{t+1}:\infin, a_{t}:\infin}}\left[ \sum_{l=0}^{\infin}r_{t+l} \right]