import torch as t
import torch.nn.functional as F

from .base_trainer import BaseTrainer
from ..torch_utils import accuracy, merge, select


class LossFunctions(BaseTrainer):
    def __init__(
        self,
        args,
    ):
        super().__init__(args)

        self.td_loss_weight = args.w_tdl
        self.ood_loss_weight = args.w_oodl
        self.reward_loss_weight = args.w_rwl
        self.transition_loss_weight = args.w_trl
        self.tree_loss_weight = args.w_tsl

    def compute_accuracy(self, q_values, actions):
        return accuracy(
            t.argmax(q_values, -1),
            actions,
        )

    def predict_latent_state_encodings(self, state_encodings, actions):
        latent_state_encodings = [state_encodings]
        for i in range(actions.shape[1] - 1):
            latent_state_encodings.append(
                self.solver.f_transition(
                    latent_state_encodings[i],
                    select(actions, index=i),
                )
            )
        return t.stack(latent_state_encodings, dim=1)

    def compute_td_loss(
        self,
        predicted_values,
        target_values,
        reduction="mean",
    ):
        return F.mse_loss(
            predicted_values,
            target_values,
            reduction=reduction,
        )

    def compute_ood_loss(
        self,
        q_values,
        actions,
        reduction="mean",
    ):
        return F.cross_entropy(
            q_values,
            actions,
            reduction=reduction,
        )

    def compute_reward_loss(
        self,
        predicted_rewards,
        target_rewards,
    ):
        return F.mse_loss(
            predicted_rewards,
            target_rewards,
        )

    def compute_transition_loss(
        self,
        predicted_state_encodings,
        target_state_encodings,
    ):
        return (
            F.mse_loss(
                t.flatten(predicted_state_encodings, start_dim=1),
                t.flatten(target_state_encodings, start_dim=1),
                reduction="none",
            )
            .sum(-1)
            .mean()
        )

    def compute_tree_search_policy_loss(
        self,
        actions,
        values,
    ):
        if hasattr(self.solver, "tree_policy_nll"):
            with t.no_grad():
                q_values = merge(
                    t.stack(self.solver.q_values, dim=1),
                    dims=(0, 1),
                )
                actions = merge(
                    actions.view(-1, 1).repeat(1, self.solver.n_trials),
                    dims=(0, 1),
                )
                values = merge(
                    values.view(-1, 1).repeat(1, self.solver.n_trials),
                    dims=(0, 1),
                )
                td_loss = self.compute_td_loss(
                    select(q_values, actions),
                    values,
                    reduction="none",
                )
                ood_loss = self.compute_ood_loss(
                    q_values,
                    actions,
                    reduction="none",
                )

                loss = self.td_loss_weight * td_loss + self.ood_loss_weight * ood_loss
                loss = loss.view(-1, self.solver.n_trials)

                delta_loss = loss - loss[:, -1].view(-1, 1)
                delta_loss = t.cat(
                    tensors=(-loss[:, -1].view(-1, 1), delta_loss[:, :-1]),
                    dim=-1,
                )

            tree_policy_nll = t.cat(self.solver.tree_policy_nll, dim=-1)
            tree_search_policy_loss = tree_policy_nll * delta_loss

            return tree_search_policy_loss.sum(-1).mean()
        else:
            return t.FloatTensor([0.0]).to(self.device)
