import torch
from torch.nn import functional as F

from .torchbeast.core import vtrace

def compute_actor_losses(
    behavior_policy_logits,
    target_policy_logits,
    actions,
    discounts,
    rewards,
    values,
    bootstrap_value,
    baseline_cost,
):
    vtrace_returns = vtrace.from_logits(
        behavior_policy_logits=behavior_policy_logits,
        target_policy_logits=target_policy_logits,
        actions=actions,
        discounts=discounts,
        rewards=rewards,
        values=values,
        bootstrap_value=bootstrap_value,
    )

    pg_loss = compute_policy_gradient_loss(
        target_policy_logits,
        actions,
        vtrace_returns.pg_advantages,
    )

    baseline_loss = baseline_cost * compute_baseline_loss(vtrace_returns.vs - values)

    return pg_loss, baseline_loss


def compute_baseline_loss(advantages):
    return 0.5 * torch.sum(torch.mean(advantages ** 2, dim=1))


def compute_entropy_loss(logits):
    policy = F.softmax(logits, dim=-1)
    log_policy = F.log_softmax(logits, dim=-1)
    entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1)
    return -torch.sum(torch.mean(entropy_per_timestep, dim=1))


def compute_policy_gradient_loss(logits, actions, advantages):
    cross_entropy = F.nll_loss(
        F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
        target=torch.flatten(actions, 0, 1),
        reduction="none",
    )
    cross_entropy = cross_entropy.view_as(advantages)
    advantages.requires_grad = False
    policy_gradient_loss_per_timestep = cross_entropy * advantages
    return torch.sum(torch.mean(policy_gradient_loss_per_timestep, dim=1))


def compute_forward_dynamics_loss(pred_next_emb, next_emb):
    forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2)
    return torch.sum(torch.mean(forward_dynamics_loss, dim=1))


def compute_forward_binary_loss(pred_next_binary, next_binary):
    return F.binary_cross_entropy_with_logits(pred_next_binary, next_binary)


def compute_forward_class_loss(pred_next_glyphs, next_glyphs):
    next_glyphs = torch.flatten(next_glyphs, 0, 2).long()
    pred_next_glyphs = pred_next_glyphs.view(next_glyphs.size(0), -1)
    return F.cross_entropy(pred_next_glyphs, next_glyphs)


def compute_inverse_dynamics_loss(pred_actions, true_actions):
    inverse_dynamics_loss = F.cross_entropy(
        torch.flatten(pred_actions, 0, 1),
        torch.flatten(true_actions, 0, 1),
        reduction="none",
    )
    inverse_dynamics_loss = inverse_dynamics_loss.view_as(true_actions)
    return torch.sum(torch.mean(inverse_dynamics_loss, dim=1))
