import torch
from torch import Tensor
from torch.nn import functional as F
from torchmetrics import Metric, MeanSquaredError
import numpy as np

LOG2PI = np.log(2 * np.pi)
class TrainAbstractMetricsDiscrete(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
        pass

    def reset(self):
        pass

    def log_epoch_metrics(self):
        return None, None


class TrainAbstractMetrics(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log):
        pass

    def reset(self):
        pass

    def log_epoch_metrics(self):
        return None, None


class SumExceptBatchMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, values) -> None:
        self.total_value += torch.sum(values)
        self.total_samples += values.shape[0]

    def compute(self):
        return self.total_value / self.total_samples


class SumExceptBatchMSE(MeanSquaredError):
    def update(self, preds: Tensor, target: Tensor) -> None:
        """Update state with predictions and targets.

        Args:
            preds: Predictions from model
            target: Ground truth values
        """
        assert preds.shape == target.shape
        sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)

        self.sum_squared_error += sum_squared_error
        self.total += n_obs

    def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
            """ Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
            tensors.
                preds: Predicted tensor
                target: Ground truth tensor
            """
            diff = preds - target
            sum_squared_error = torch.sum(diff * diff)
            n_obs = preds.shape[0]
            return sum_squared_error, n_obs


class SumExceptBatchKL(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, p, q) -> None:
        self.total_value += F.kl_div(q, p, reduction='sum')
        self.total_samples += p.size(0)

    def compute(self):
        return self.total_value / self.total_samples


class CrossEntropyMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:
        """ Update state with predictions and targets.
            preds: Predictions from model   (bs * n, d) or (bs * n * n, d)
            target: Ground truth values     (bs * n, d) or (bs * n * n, d). """
        target = torch.argmax(target, dim=-1)
        output = F.cross_entropy(preds, target, reduction='sum')
        self.total_ce += output
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_ce / self.total_samples

class MSEMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_mse', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor, weight:Tensor) -> None:
        """ Update state with predictions and targets.
            preds: Predictions from model   (bs * n, d) or (bs * n * n, d)
            target: Ground truth values     (bs * n, d) or (bs * n * n, d). 
            weight: a time related weight   (bs * n, d) or (bs * n * n, d).
        """
        
        # target = torch.argmax(target, dim=-1)
        output = (preds - target)**2 *weight
        # output = F.mse_loss(preds, target, reduction='sum')
        output = torch.sum(output)

        self.total_mse += output
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_mse / self.total_samples

class dtime_loss_BFN(Metric):
    def __init__(self, N: int):
        super().__init__()
        self.add_state('total_dtime_loss', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.N = N

    # """
    # Old Version
    # """
    # def update(self, e_hat: Tensor, e_x: Tensor, alpha:Tensor) -> None:
    #     assert e_x.size() == e_hat.size()
    #     K = e_hat.size(-1)
    #     mean_ = alpha*(K*e_x - 1) #(D, K)
    #     std_ =  torch.sqrt(alpha * K) #(D, 1) 
    #     eps_ = torch.rand_like(mean_) #(D, K)
    #     y_ = mean_ + std_ * eps_ #(D, K)
    #     matrix_ek = torch.eye(K, K).to(e_x.device) #(K, K)
    #     mean_matrix = alpha.unsqueeze(-1) * (K * matrix_ek - 1) #(D, K, K)
    #     std_matrix = torch.sqrt(alpha * K).unsqueeze(-1) #(D, 1)
    #     LOG2PI = np.log(2 * np.pi)
    #     _log_gaussians = (  # [D, K]
    #         -0.5 * LOG2PI - torch.log(std_matrix)
    #         - (y_.unsqueeze(1).repeat(1,K,1) - mean_matrix)**2
    #         / (2 * std_matrix**2)
    #     ).sum(-1)
    #     _inner_log_likelihood = torch.log(torch.sum(e_hat * torch.exp(_log_gaussians), dim=-1))  # (D,)
    #     # log_likelihood = torch.logsumexp(_inner_log_likelihood, dim=-1)  # [D]
        
    #     L_N = self.N * (-_inner_log_likelihood)
    #     # L_N = self.N * -log_likelihood.sum(dim=-1) #[D]

    #     # import pdb
    #     # pdb.set_trace()

    #     self.total_dtime_loss += torch.sum(L_N)
    #     self.total_samples += e_hat.size(0)
    
    """
    Stable Version
    """
    def update(self, e_hat: Tensor, e_x: Tensor, alpha:Tensor) -> None:
        # i in {1,n}
        # Algorithm 7 in BFN
        assert e_x.size() == e_hat.size()
        # print(alpha.shape)
        K = e_hat.size(-1)
        mean_ = alpha * (K * e_x - 1)  # [D, K]
        std_ = torch.sqrt(alpha * K)  # [D,1] TODO check shape
        eps = torch.randn_like(mean_)  # [D,K,]
        y_ = mean_ + std_ * eps  # [D, K]
        # modify this line:
        matrix_ek = torch.eye(K, K).to(e_x.device)  # [K, K]
        mean_matrix = alpha.unsqueeze(-1) * (K * matrix_ek - 1) #(D, K, K)
        std_matrix = torch.sqrt(alpha * K).unsqueeze(-1)  #(D, 1, 1)
        _log_likelihood = (  # (D, K)
            (-0.5 * LOG2PI - torch.log(std_matrix))
            - (y_.unsqueeze(1) - mean_matrix) ** 2
            / (2 * std_matrix**2)
        ).sum(-1)
        _inner_log_likelihood = torch.log(e_hat) + _log_likelihood  #(D, K)
        log_likelihood = torch.logsumexp(_inner_log_likelihood, dim=-1)  #(D,)
        L_N = self.N * (-log_likelihood)  #(D,)
        
        self.total_dtime_loss += torch.sum(L_N)
        self.total_samples += e_hat.size(0)
    
    # """
    # Reverse KL
    # """
    # def update(self, e_hat: Tensor, e_x: Tensor, alpha:Tensor) -> None:
    #     """
    #     e_hat: (D, k)
    #     e_x: (D, K)
    #     alpha: (D,1)
    #     """
    #     import pdb
    #     temp = 1.0
    #     pad = 1e-8
    #     assert e_x.size() == e_hat.size()

    #     # weights = torch.tensor([0.8,0.2,1e-3,1e-3,1e-3]).to(e_hat.device)
    #     # weights = weights/weights.sum()
    #     # e_hat = torch.ones_like(e_hat) * weights.unsqueeze(0) 
        
    #     K = e_hat.size(-1)
    #     matrix_ek = torch.eye(K, K).to(e_x.device)  # (K, K)
    #     mean_matrix = alpha.unsqueeze(-1) * (K * matrix_ek - 1) # (D, K, K)
    #     log_e_hat = torch.log(e_hat)
    #     soft_sample_from_e_hat = F.gumbel_softmax(log_e_hat, tau=temp, hard=False, dim=-1)
    #     soft_sample_from_e_hat = soft_sample_from_e_hat + pad
    #     soft_sample_from_e_hat = soft_sample_from_e_hat / soft_sample_from_e_hat.sum(dim=-1, keepdim=True)
    #     receiver_mean = torch.sum(soft_sample_from_e_hat.unsqueeze(-1) * mean_matrix, dim=1) # (D, K)
    #     eps = torch.rand_like(receiver_mean)    # by the convolution of independent gaussian
    #     std = torch.sqrt(alpha*K) # (D,1)
    #     receiver_y = receiver_mean + std*eps # (D, K)
    #     # pdb.set_trace()
    #     cross_entropy_ = - (  # (D, )
    #         - 0.5 * LOG2PI - torch.log(std)
    #         - (receiver_y - e_x) ** 2
    #         / (2 * std**2)
    #     ).sum(-1)

    #     # K = e_hat.size(-1)
    #     # matrix_ek = torch.eye(K, K).to(e_x.device)  # (K, K)
    #     # mean_matrix = alpha.unsqueeze(-1) * (K * matrix_ek - 1) # (D, K, K)
    #     # # log_e_hat = torch.log(e_hat)
    #     # # soft_sample_from_e_hat = F.gumbel_softmax(log_e_hat, tau=temp, hard=False, dim=-1)
    #     # # soft_sample_from_e_hat = soft_sample_from_e_hat + pad
    #     # # soft_sample_from_e_hat = soft_sample_from_e_hat / soft_sample_from_e_hat.sum(dim=-1, keepdim=True)
    #     # receiver_mean = torch.sum(e_hat.unsqueeze(-1) * mean_matrix, dim=1) # (D, K)
    #     # eps = torch.rand_like(receiver_mean)    # by the convolution of independent gaussian
    #     # std = torch.sqrt(alpha*K) # (D,1)
    #     # receiver_y = receiver_mean + std*eps # (D, K)
    #     # # pdb.set_trace()
    #     # cross_entropy_ = - (  # (D, )
    #     #     - 0.5 * LOG2PI - torch.log(std)
    #     #     - (receiver_y - e_x) ** 2
    #     #     / (2 * std**2)
    #     # ).sum(-1)
        
        
    #     L_N = self.N * (cross_entropy_)  #(D,)
        
    #     self.total_dtime_loss += torch.sum(L_N)
    #     self.total_samples += e_hat.size(0)

    def compute(self):
        return self.total_dtime_loss / self.total_samples


class ProbabilityMetric(Metric):
    def __init__(self):
        """ This metric is used to track the marginal predicted probability of a class during training. """
        super().__init__()
        self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: Tensor) -> None:
        self.prob += preds.sum()
        self.total += preds.numel()

    def compute(self):
        return self.prob / self.total


class NLL(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, batch_nll) -> None:
        self.total_nll += torch.sum(batch_nll)
        self.total_samples += batch_nll.numel()

    def compute(self):
        return self.total_nll / self.total_samples