👾

Understanding why there isn't a log probability in TRPO and PPO's objective

Posted on Sat, Aug 17, 2019

Recently I have been reading quite a lot on off-policy policy gradient, importance sampling, etc. When I was reading about Trust Region Policy Optimization (TRPO), I couldn't help but notice that the TPRO's objective doesn't have the log probability normally present in policy gradient methods such as A2C, as shown below.

maxJTRPO(θ)=Eaπθold[πθ(atst)πθold(atst)At]\max J_{TRPO}(\theta) = \mathbb{E}_{a \sim\pi_{\theta_{old}}} \left[\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} A_t \right]

maxJA2C(θ)=Eaπθ[logπθ(atst)At]\max J_{A2C}(\theta) = \mathbb{E}_{a \sim\pi_{\theta}} \left[ \log \pi_{\theta}(a_t|s_t) A_t \right]

This is really puzzling for me because the log probability trick [1] seems very standard in policy gradient methods. So what's the connection? After a bunch of searches, this answer by mglss [2] at StackExchange seems to have answered my question. In a nutshell, the log probability is still present in the TRPO's objective. Let's get the gradient of TRPO:

So apparently the log probability trick still works! Let's verify with code

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
action = 0
advantage = torch.tensor(1.)

target_logits = torch.tensor([1., 1., 1., 1.,] , requires_grad=True)
behavior_logits = torch.tensor([2.3, 1., 1., 1.,], requires_grad=True)
target_probs = Categorical(target_logits)
behavior_probs = Categorical(behavior_logits)

TRPO_loss = -(target_probs.probs / 
              behavior_probs.probs.detach())[action] * advantage
print(TRPO_loss)
TRPO_loss.backward()
print(target_logits.grad)

# tensor(-0.5761, grad_fn=<MulBackward0>)
# tensor([-0.4321,  0.1440,  0.1440,  0.1440])

target_logits = torch.tensor([1., 1., 1., 1.,] , requires_grad=True)
behavior_logits = torch.tensor([2.3, 1., 1., 1.,], requires_grad=True)
target_probs = Categorical(target_logits)
behavior_probs = Categorical(behavior_logits)
importance_sampling = (target_probs.probs / behavior_probs.probs).detach()
# the .detach() above is **very important**
TRPO_log_loss = -target_probs.log_prob(
    torch.tensor(action)) * importance_sampling[action] * advantage

print(TRPO_log_loss)
TRPO_log_loss.backward()
print(target_logits.grad)

# tensor(0.7986, grad_fn=<MulBackward0>)
# tensor([-0.4321,  0.1440,  0.1440,  0.1440])

As shown by the results, the gradients generated by those two losses are the same. I want to take this chance to stress that the .detach() in the calculation of TRPO_log_loss is very important. If we don't add it, then the gradient calculated would have become the following, which is incorrect.

Eaπθold[θ(πθ(atst)logπθ(atst))πθold(atst)At]\mathbb{E}_{a \sim\pi_{\theta_{old}}} \left[\frac{\nabla_{\theta} \left( \pi_{\theta}(a_t|s_t) \log \pi_{\theta}(a_t|s_t)\right)}{\pi_{\theta_{old}}(a_t|s_t)} A_t \right]

And empirically, the gradient is also incorrect.

target_logits = torch.tensor([1., 1., 1., 1.,] , requires_grad=True)
behavior_logits = torch.tensor([2.3, 1., 1., 1.,], requires_grad=True)
target_probs = Categorical(target_logits)
behavior_probs = Categorical(behavior_logits)
importance_sampling = (target_probs.probs / behavior_probs.probs)
# without the .detach() above
TRPO_log_loss = -target_probs.log_prob(
    torch.tensor(action)) * importance_sampling[action] * advantage

print(TRPO_log_loss)
TRPO_log_loss.backward()
print(target_logits.grad)

# tensor(0.7986, grad_fn=<MulBackward0>)
# tensor([ 0.1669, -0.0556, -0.0556, -0.0556])

Given this insight, I proceed to modify my PPO implementation using this kind of loss and verify if two losses will result in an identical gradient and results. Long story short, using the clipped objective and the log loss will result in different a gradient, but I don't feel compelled to investigate because that's sort of beyond the point of this post ... 😅 Instead, we experiment PPO without the clipped objective, then the gradient and trained policy will be the same as shown below.

The final weight of the first fully connected layer of the policy trained using PPO without the clipped objective in CartPole-v0. Check out ppo_no_clipped.py
Similarly, the weight trained using PPO without clipped objective and with the log loss demonstrated above. Check out ppo_no_clipped_log_loss.py

The only difference these two files them is the following

# in ppo_no_clipped.py
policy_loss = -torch.mean(importance_sampling * 
    torch.Tensor(advantages)[:step+1].to(device))

# in ppo_no_clipped_log_loss.py
policy_loss = -torch.mean(importance_sampling.detach() * 
    newlogproba * torch.Tensor(advantages)[:step+1].to(device))

By this analysis, we see a nice connection between the objective of most policy gradient methods such as A2C and that of TRPO/PPO. Btw you can also run the code yourself at https://colab.research.google.com/drive/14A5GDrFiFzsBo6PF68yh4FW8EKMx5Zlt

Why major TRPO/PPO implementations do not use the log probability?

I suspect the reason is twofold. First, it makes it inconvenient to express the TRPO objective, which has then to be something like

maxJTRPO(θ)=πθ(atst)πθold(atst)logπθ(atst)At\max J_{TRPO}(\theta) = \frac{\pi_{\theta'}(a_t|s_t)}{\pi_{\theta_{old}} (a_t|s_t)} \log \pi_{\theta}(a_t|s_t) A_t

where πθ is a copy of πθ\text{where } \pi_{\theta'} \text{ is a copy of } \pi_{\theta}

Essentially, it has to be spelled out that the importance sampling does not contain any gradient for the target policy. Obviously, this is quite redundant.

Another reason might be purely in regards to implementation. As we demonstrated above, the person implementing the policy loss would need to consciously recognize the importance of detaching the gradient of the importance sampling from the gradient calculation of the policy. It is simply much easier to not having to worry about the detachment and leave it to the autograd package.

Conclusion

This article explores the reason that TRPO or PPO doesn't have the log probability as part of its objective. Both in theory and implementation, I show that the log probability is still there in the objective of TRPO or PPO. Hopefully this insight will be helpful to you :)

References

[1] http://incompleteideas.net/book/RLbook2018.pdf#page=349

[2] https://ai.stackexchange.com/questions/7685/why-is-the-log-probability-replaced-with-the-importance-sampling-in-the-loss-fun

https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html

https://arxiv.org/abs/1611.01224