import torch


def ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, optim_value_iternum, states, actions,
             returns, advantages, fixed_log_probs, clip_epsilon, l2_reg, opponent=False, max_grad=40):

    """update critic"""
    #if not opponent:
        #for _ in range(optim_value_iternum):
            #values_pred = value_net(states)
            #value_loss = (values_pred - returns).pow(2).mean()
            # weight decay
            #for param in value_net.parameters():
            #    value_loss += param.pow(2).sum() * l2_reg
            #optimizer_value.zero_grad()
            #value_loss.backward()
            #optimizer_value.step()

    """update policy"""
    log_probs = policy_net.get_log_prob(states, actions)
    ratio = torch.exp(log_probs - fixed_log_probs)
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages

    if opponent:
        policy_surr =  -torch.min(-surr1, -surr2).mean() #(log_probs*advantages).mean() #-torch.min(-surr1, -surr2).mean()
    else:
        policy_surr = -torch.min(surr1, surr2).mean()
    optimizer_policy.zero_grad()
    policy_surr.backward()
    torch.nn.utils.clip_grad_norm_(policy_net.parameters(), max_grad)
    optimizer_policy.step()
