import math

import torch
from torch.distributions.categorical import Categorical
import torch.nn.functional as F

from env_loop import make_env_loop
from utils import coroutine


@coroutine
def make_actor_critic_loss_computer(env, model, gamma, lambda_, entropy_weight, clip_reward):
    
    env_loop = make_env_loop(env, model)

    num_env_steps = yield

    while True:
        
        model.train()

        _, act, rew, end, _, logits_act, val, val_bootstrap, mask = env_loop.send(num_env_steps)

        d = Categorical(logits=logits_act)
        entropy = d.entropy()[mask].mean()
                
        with torch.no_grad():
            lambda_returns = compute_lambda_returns(rew, end, val, val_bootstrap, gamma, lambda_, clip_reward)
      
        loss_actions = (-d.log_prob(act)[mask] * (lambda_returns[mask] - val[mask]).detach()).mean()
        loss_values = F.mse_loss(val[mask], target=lambda_returns[mask])            
        loss_entropy = -entropy_weight * entropy
    
        loss = loss_actions + loss_entropy + loss_values
        
        metrics = {
            'policy_entropy': entropy.detach() / math.log(2),
            'loss_actions': loss_actions.detach(),
            'loss_entropy': loss_entropy.detach(),
            'loss_values': loss_values.detach(),
        }

        num_env_steps = yield loss, metrics


def compute_lambda_returns(rew, end, val, val_bootstrap, gamma, lambda_, clip_reward):
    assert rew.ndim == 2 and rew.size() == val.size() == end.size()
    assert val_bootstrap.ndim == 1 and val_bootstrap.size(0) == rew.size(0)

    if clip_reward: 
        rew = rew.sign()

    lambda_returns = rew + end.logical_not() * gamma * (1 - lambda_) * torch.cat((val[:, 1:], val_bootstrap.unsqueeze(1)), dim=1)
    last = val_bootstrap
    for t in reversed(range(rew.size(1))):
        lambda_returns[:, t] += end[:, t].logical_not() * gamma * lambda_ * last
        last = lambda_returns[:, t]

    return lambda_returns
