## Intro
The RAT implementation is provided in train_detach.py from line 69 to line 167. 


### Details
Computing log pi: 
* Discrete case:
``` python
_logp_full = F.log_softmax(_outputs, dim=-1)
_logp_full_old = F.log_softmax(_outputs_old, dim=-1)
_llr = torch.gather(_logp_full - _logp_full_old, dim=-1, index=_act.unsqueeze(-1)).squeeze(1)
_ratio = torch.exp(_llr)
_p_log_p = torch.exp(_logp_full) * _logp_full
_entropy = - _p_log_p.sum(-1).mean()
_logp = torch.gather(_logp_full, dim=-1, index=_act.unsqueeze(-1)).squeeze(1)
_real_kl = (torch.exp(_logp_full_old) * (_logp_full_old - _logp_full)).sum(dim=-1).mean()

def compute_logp(params, buffers, batch_obs, batch_act):
    batch_obs, batch_act = batch_obs.unsqueeze(0), batch_act.unsqueeze(0)
    batch_outs = functional_call(actor_critic.pi_net, (params, buffers), (batch_obs,) )
    batch_logp_full = F.log_softmax(batch_outs, dim=-1)
    batch_logp = torch.gather(batch_logp_full, dim=-1, index=batch_act.unsqueeze(-1)).squeeze(1)
    return batch_logp.squeeze(0)
```

* Continuous case: 
``` python
_dist = torch.distributions.Normal(_mu, torch.exp(_logstd))
_logp = _dist.log_prob(_act).sum(dim=-1) 

_mu_old, _logstd_old = _outputs_old.chunk(2, dim=-1)
_dist_old = torch.distributions.Normal(_mu_old, torch.exp(_logstd_old))
_logp_old = _dist_old.log_prob(_act).sum(dim=-1)

_llr = _logp - _logp_old
_ratio = torch.exp(_llr)
_entropy = _dist.entropy().sum(dim=-1).mean()
_real_kl = (_logstd - _logstd_old + 0.5 * ( torch.exp(_logstd_old).pow(2) + (_mu_old - _mu).pow(2) ) / torch.exp(_logstd).pow(2) - 0.5).sum(dim=-1).mean()

def compute_logp(params, buffers, batch_obs, batch_act):
    batch_obs, batch_act = batch_obs.unsqueeze(0), batch_act.unsqueeze(0)
    batch_outs = functional_call(actor_critic.pi_net, (params, buffers), (batch_obs,) )
    batch_mu, batch_logstd = batch_outs.chunk(2, dim=-1)

    var = torch.exp(batch_logstd)**2
    batch_logp = (
        -((batch_act - batch_mu) ** 2) / (2 * var)
        - batch_logstd
        - math.log(math.sqrt(2 * math.pi))
    )

    return batch_logp.sum(dim=-1).squeeze(0)
```

Compute H and HHT matrices: 
``` python
ft_compute_sample_grad = vmap(grad(compute_logp), in_dims=(None, None, 0, 0))
ft_per_sample_grads = ft_compute_sample_grad(dict_params, dict_buffers, _obs, _act) # num_samples x param_shape

with torch.no_grad():
    num_sa = _obs.shape[0]
    H = torch.cat([v.contiguous().view(num_sa, -1) for v in ft_per_sample_grads.values()], dim=-1)  # num_samples x num_params
    HHT = H @ H.t() @ torch.diag(_ratio) / num_sa # num_samples x num_samples
```

Apply randomized advantage transformation:
``` python
gk_list = [ v['momentum_buffer'].contiguous().flatten() for v in pi_optimizer.state.values() if v['momentum_buffer'] is not None ]
if algo_config.is_karzmarz and len(gk_list) > 0:
    g_k = torch.cat(gk_list, dim=0)
    _adv = _adv - torch.mv(H, g_k)

_png_adv = torch.linalg.solve( HHT + algo_config.cg_damping * torch.eye(num_sa, device=device), _adv)
```  