import torch
from torch import Tensor
import torch.nn as nn
from torchmetrics import Metric, MeanSquaredError, MetricCollection
import time
import wandb
from src.metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchMSE, SumExceptBatchKL, CrossEntropyMetric, \
    ProbabilityMetric, NLL, MSEMetric, dtime_loss_BFN


class NodeMSE(MeanSquaredError):
    def __init__(self, *args):
        super().__init__(*args)


class EdgeMSE(MeanSquaredError):
    def __init__(self, *args):
        super().__init__(*args)
        


class TrainLoss(nn.Module):
    def __init__(self):
        super(TrainLoss, self).__init__()
        self.train_node_mse = NodeMSE()
        self.train_edge_mse = EdgeMSE()
        self.train_y_mse = MeanSquaredError()

    def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log: bool):
        mse_X = self.train_node_mse(masked_pred_epsX, true_epsX) if true_epsX.numel() > 0 else 0.0
        mse_E = self.train_edge_mse(masked_pred_epsE, true_epsE) if true_epsE.numel() > 0 else 0.0
        mse_y = self.train_y_mse(pred_y, true_y) if true_y.numel() > 0 else 0.0
        mse = mse_X + mse_E + mse_y

        if log:
            to_log = {'train_loss/batch_mse': mse.detach(),
                      'train_loss/node_MSE': self.train_node_mse.compute(),
                      'train_loss/edge_MSE': self.train_edge_mse.compute(),
                      'train_loss/y_mse': self.train_y_mse.compute()}
            if wandb.run:
                wandb.log(to_log, commit=True)

        return mse

    def reset(self):
        for metric in (self.train_node_mse, self.train_edge_mse, self.train_y_mse):
            metric.reset()

    def log_epoch_metrics(self):
        epoch_node_mse = self.train_node_mse.compute() if self.train_node_mse.total > 0 else -1
        epoch_edge_mse = self.train_edge_mse.compute() if self.train_edge_mse.total > 0 else -1
        epoch_y_mse = self.train_y_mse.compute() if self.train_y_mse.total > 0 else -1

        to_log = {"train_epoch/epoch_X_mse": epoch_node_mse,
                  "train_epoch/epoch_E_mse": epoch_edge_mse,
                  "train_epoch/epoch_y_mse": epoch_y_mse}
        if wandb.run:
            wandb.log(to_log)
        return to_log



class TrainLossDiscrete(nn.Module):
    """ Train with Cross entropy"""
    def __init__(self, lambda_train):
        super().__init__()
        self.node_loss = CrossEntropyMetric()
        self.edge_loss = CrossEntropyMetric()
        self.y_loss = CrossEntropyMetric()
        self.lambda_train = lambda_train

    def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, log: bool):
        """ Compute train metrics
        masked_pred_X : tensor -- (bs, n, dx)
        masked_pred_E : tensor -- (bs, n, n, de)
        pred_y : tensor -- (bs, )
        true_X : tensor -- (bs, n, dx)
        true_E : tensor -- (bs, n, n, de)
        true_y : tensor -- (bs, )
        log : boolean. """
        true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx)
        true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de)
        masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1)))  # (bs * n, dx)
        masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1)))   # (bs * n * n, de)

        # Remove masked rows
        mask_X = (true_X != 0.).any(dim=-1)
        mask_E = (true_E != 0.).any(dim=-1)

        flat_true_X = true_X[mask_X, :]
        flat_pred_X = masked_pred_X[mask_X, :]

        flat_true_E = true_E[mask_E, :]
        flat_pred_E = masked_pred_E[mask_E, :]

        loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0
        loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0
        loss_y = self.y_loss(pred_y, true_y) if true_y.numel() > 0 else 0.0

        if log:
            to_log = {"train_loss/batch_CE": (loss_X + loss_E + loss_y).detach(),
                      "train_loss/X_CE": self.node_loss.compute() if true_X.numel() > 0 else -1,
                      "train_loss/E_CE": self.edge_loss.compute() if true_E.numel() > 0 else -1,
                      "train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1}
            if wandb.run:
                wandb.log(to_log, commit=True)
        return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + self.lambda_train[2] * loss_y

    def reset(self):
        for metric in [self.node_loss, self.edge_loss, self.y_loss]:
            metric.reset()

    def log_epoch_metrics(self):
        epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
        epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
        epoch_y_loss = self.train_y_loss.compute() if self.y_loss.total_samples > 0 else -1

        to_log = {"train_epoch/x_CE": epoch_node_loss,
                  "train_epoch/E_CE": epoch_edge_loss,
                  "train_epoch/y_CE": epoch_y_loss}
        if wandb.run:
            wandb.log(to_log, commit=False)

        return to_log




class TrainLossBFN(nn.Module):
    """ Train with Cross entropy"""
    def __init__(self, lambda_train, sample_steps: int, continuous_t: bool=True):
        super().__init__()
        self.node_loss = MSEMetric() if continuous_t else dtime_loss_BFN(sample_steps)
        self.edge_loss = MSEMetric() if continuous_t else dtime_loss_BFN(sample_steps)
        self.continuous_t = continuous_t
        self.y_loss = CrossEntropyMetric() #TODO check whether this should be the mse metric
        self.lambda_train = lambda_train


    # def discrete_time_loss(e_hat, e_x, alpha, K):
    #     assert e_x.size() == e_hat.size()
    #     K = e_hat.size(-1)
    #     mean_ = alpha*(K*e_x - 1) #(D, K)
    #     std_ =  torch.sqrt(alpha * K) #(D, 1) 
    #     eps_ = torch.rand_like(mean_) #(D, K)
    #     y_ = mean_ + std_ * eps_ #(D, K)
    #     matrix_ek = torch.eye(K,K).unsqueeze(0).to(e_x.device)
    #     matrix_ek.repeat(alpha.size(0),1,1) #(D,K,K)
    #     mean_matrix = alpha.unsqueeze(-1) * (K*matrix_ek - 1)
    #     std_matrix = torch.sqrt(alpha * K).unsqueeze(1) #(D, 1)
    #     likelihood = (torch.exp(-(y_.unsqueeze(-1).repeat(1,1,K) - mean_matrix) ** 2 / (2 * std_matrix ** 2)) / (std_matrix * np.sqrt(2 * np.pi))).sum(-1) #(D, K)
        
    #     L_N = -torch.log((likelihood * e_hat).sum(dim=-1)) #[D]
    #     return L_N
    
    def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y, weight_x, weight_e, log: bool):
        """ Compute train metrics
        masked_pred_X : tensor -- (bs, n, dx)
        masked_pred_E : tensor -- (bs, n, n, de)
        pred_y : tensor -- (bs, )
        true_X : tensor -- (bs, n, dx)
        true_E : tensor -- (bs, n, n, de)
        weight_x : tensor -- (bs, n, 1)
        weight_e : tensor -- (bs, n, n, 1)
        true_y : tensor -- (bs, )
        log : boolean. """
        true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx)
        true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de)

        masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1)))  # (bs * n, dx)
        masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1)))   # (bs * n * n, de)

        weight_x = torch.reshape(weight_x, (-1, weight_x.size(-1)))  # (bs * n, 1)
        weight_e = torch.reshape(weight_e, (-1, weight_e.size(-1)))  # (bs * n * n, 1)

        # Remove masked rows
        mask_X = (true_X != 0.).any(dim=-1)
        mask_E = (true_E != 0.).any(dim=-1)

        flat_true_X = true_X[mask_X, :]
        flat_pred_X = masked_pred_X[mask_X, :]
        flat_weight_x = weight_x[mask_X, :]
        # import pdb
        # pdb.set_trace()


        flat_true_E = true_E[mask_E, :]
        flat_pred_E = masked_pred_E[mask_E, :]
        flat_weight_e = weight_e[mask_E, :]

        loss_X = self.node_loss(flat_pred_X, flat_true_X, flat_weight_x) if true_X.numel() > 0 else 0.0
        loss_E = self.edge_loss(flat_pred_E, flat_true_E, flat_weight_e) if true_E.numel() > 0 else 0.0
        loss_y = self.y_loss(pred_y, true_y) if true_y.numel() > 0 else 0.0
        
        # if not self.continuous_t:
        #     loss_X *= N     # N is not multiplied in self.discrete_time_loss()
        #     loss_X = torch.sum(loss_X) / bs     # normalize by bs
        #     loss_E *= N
        #     loss_E = torch.sum(loss_E) / bs

        # if log:
        to_log = {"train_loss/batch_MSE": (loss_X + loss_E + loss_y).detach(),
                      "train_loss/X_MSE": self.node_loss.compute() if true_X.numel() > 0 else -1,
                      "train_loss/E_MSE": self.edge_loss.compute() if true_E.numel() > 0 else -1,
                      "train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1}
            # if wandb.run:
            #     wandb.log(to_log, commit=True)
        return self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + self.lambda_train[2] * loss_y, to_log

    def reset(self):
        for metric in [self.node_loss, self.edge_loss, self.y_loss]:
            metric.reset()

    def log_epoch_metrics(self):
        epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
        epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
        epoch_y_loss = self.train_y_loss.compute() if self.y_loss.total_samples > 0 else -1

        to_log = {"train_epoch/X_MSE": epoch_node_loss,
                  "train_epoch/E_MSE": epoch_edge_loss,
                  "train_epoch/y_CE": epoch_y_loss}
        # if wandb.run:
        #     wandb.log(to_log, commit=False)

        return to_log



