from typing import List, Tuple

import torch
from mlwiz.training.callback.metric import Metric, MulticlassClassification
from torch.nn import CrossEntropyLoss


class ELBO_Classification(Metric):
    @property
    def name(self) -> str:
        return "ELBO_Classification"

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            output_state,
            embeddings,
            log_p_theta,
            kld_alpha,
            qW_probs,
            n_obs,
            variational_Ws,
            forward_delta,
        ) = outputs

        # maximize log p y given x
        # we assume that each minibatch contributes as if it was the entire
        # dataset, hence we multiply by n_obs (size of training set)
        log_p_y = (
            -CrossEntropyLoss(reduction="mean")(output_state, targets) * n_obs
        )
        elbo = log_p_y.unsqueeze(0)

        elbo += log_p_theta
        elbo += kld_alpha

        # renormalize everything by n_obs (just scaling gradients of ELBO)
        elbo = elbo / n_obs
        return elbo, targets

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        elbo = predictions
        # to maximize the elbo we need to minimize -elbo
        return -elbo.mean(0)  # sum over samples


class ELBO_Regression(Metric):
    @property
    def name(self) -> str:
        return "ELBO_Regression"

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            output_state,
            embeddings,
            log_p_theta,
            kld_alpha,
            qW_probs,
            n_obs,
            variational_Ws,
            forward_delta,
        ) = outputs

        two = torch.Tensor([2.0]).to(targets.device)

        # maximize log p y given x
        # we assume that each minibatch contributes as if it was the entire
        # dataset, hence we multiply by n_obs (size of training set)
        log_p_y = (-torch.mean(((output_state - targets) ** 2).sum(1)) / two) * n_obs

        elbo = log_p_y.unsqueeze(0)

        elbo += log_p_theta
        elbo += kld_alpha

        # renormalize everything by n_obs (just scaling gradients of ELBO)
        elbo = elbo / n_obs
        return elbo, targets

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        elbo = predictions
        # to maximize the elbo we need to minimize -elbo
        return -elbo.mean(0)  # sum over samples


class CELoss(Metric):
    @property
    def name(self) -> str:
        return "Cross Entropy"

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            output_state,
            embeddings,
            log_p_theta,
            kld_alpha,
            qW_probs,
            n_obs,
            variational_Ws,
            forward_delta,
        ) = outputs

        # maximize log p y given x
        # we assume that each minibatch contributes as if it was the entire
        # dataset, hence we multiply by n_obs (size of training set)
        log_p_y = CrossEntropyLoss(reduction="none")(output_state, targets)
        return log_p_y, targets

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        log_p_y = predictions
        # to maximize the elbo we need to minimize -elbo
        return log_p_y.mean(0)  # sum over samples


class Prior_theta(Metric):
    @property
    def name(self) -> str:
        return "Prior_theta"

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            output_state,
            embeddings,
            log_p_theta,
            kld_alpha,
            qW_probs,
            n_obs,
            variational_Ws,
            forward_delta,
        ) = outputs

        return log_p_theta, targets

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        # to maximize the elbo we need to minimize -elbo
        return predictions.mean(0)  # sum over samples


class Prior_gamma(Metric):
    @property
    def name(self) -> str:
        return "Prior_alpha"

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            output_state,
            embeddings,
            log_p_theta,
            kld_alpha,
            qW_probs,
            n_obs,
            variational_Ws,
            forward_delta,
        ) = outputs

        return kld_alpha, targets

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        # to maximize the elbo we need to minimize -elbo
        return predictions.mean(0)  # sum over samples


class TotalWidth(Metric):
    @property
    def name(self) -> str:
        return "Total Width"

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            output_state,
            embeddings,
            log_p_theta,
            kld_gamma,
            qW_probs,
            n_obs,
            variational_Ws,
            forward_delta,
        ) = outputs

        sum_w = torch.tensor([w.shape[0] for w in qW_probs]).float().sum()
        return (sum_w.unsqueeze(0), targets)

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        return predictions.mean()


class ForwardTime(Metric):
    @property
    def name(self) -> str:
        return "Forward Time"

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        (
            output_state,
            embeddings,
            log_p_theta,
            kld_gamma,
            qW_probs,
            n_obs,
            variational_Ws,
            forward_delta,
        ) = outputs

        return (torch.tensor([forward_delta]), targets)

    def compute_metric(
        self, targets: torch.Tensor, predictions: torch.Tensor
    ) -> torch.tensor:
        return predictions.mean()


class MachineTranslationMulticlassClassification(MulticlassClassification):

    def get_predictions_and_targets(
        self, targets: torch.Tensor, *outputs: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns output[0] as predictions and dataset targets.
        Squeezes the first dimension of output and targets to get
        single vector.

        Args:
            targets (:class:`torch.Tensor`): ground truth
            outputs (List[:class:`torch.Tensor`]): outputs of the model

        Returns:
            A tuple of tensors (predicted_values, target_values)
        """
        outputs = outputs[0]

        output_reshape = outputs.contiguous().view(-1, outputs.shape[-1])
        targets = targets.contiguous().view(-1)

        return output_reshape, targets


# class BLEU(Metric):
#
#     def get_predictions_and_targets(
#         self, targets: torch.Tensor, *outputs: List[torch.Tensor]
#     ) -> Tuple[torch.Tensor, torch.Tensor]:
#         """
#         Returns output[0] as predictions and dataset targets.
#         Squeezes the first dimension of output and targets to get
#         single vector.
#
#         Args:
#             targets (:class:`torch.Tensor`): ground truth
#             outputs (List[:class:`torch.Tensor`]): outputs of the model
#
#         Returns:
#             A tuple of tensors (predicted_values, target_values)
#         """
#         outputs = outputs[0]
#
#         output_reshape = outputs.contiguous().view(-1, outputs.shape[-1])
#
#         targets = targets.contiguous().view(-1)
#
#         return output_reshape, targets
