import torch
import torch.nn as nn
import time
import wandb
import numpy as np

from torchmetrics import MeanSquaredError
from didigress.metrics.abstract_metrics import CrossEntropyMetric
from didigress.utils import real_s_to_scaled
from didigress.diffusion import diffusion_utils

class TrainLoss(nn.Module):
    """ Train with Cross entropy"""
    def __init__(self, lambda_train, cfg, name="train"):
        super().__init__()

        self.use_3d         = cfg.features.use_3d
        self.use_charges    = cfg.features.use_charges
        self.use_ins_del    = cfg.features.use_ins_del
        self.name           = name

        self.node_loss      = CrossEntropyMetric()
        self.edge_loss      = CrossEntropyMetric()
        self.y_loss         = CrossEntropyMetric()

        if(self.use_ins_del):
            self.n_classes      = cfg.model.diffusion_steps
            self.s_loss         = CrossEntropyMetric()
            self.s_loss_lambda  = cfg.features.s_loss_lambda
            self.T              = cfg.model.diffusion_steps

            _, dz               = diffusion_utils.compute_linear_zeta(T = cfg.model.diffusion_steps, 
                                            D = cfg.features.zeta_D,
                                            w = cfg.features.zeta_w)
            
            zero_indices        = np.where(dz == 0)[0]
            self.left_interval_end = zero_indices[zero_indices < np.argmax(dz)].max()
            self.right_interval_start = zero_indices[zero_indices > np.argmax(dz)].min()

            print("self.left_interval_end: ", self.left_interval_end, ", ",
                  "self.right_interval_start: ", self.right_interval_start)

            self.delt_loss      = CrossEntropyMetric()
            self.delt_loss_lambda = cfg.features.delt_loss_lambda
        if(self.use_3d):
            self.train_pos_mse  = MeanSquaredError(sync_on_compute=False, dist_sync_on_step=False)
        if(self.use_charges):
            self.charges_loss   = CrossEntropyMetric()

        self.lambda_train = lambda_train

    def forward(self, masked_pred, masked_true, log: bool,
                masked_pred_delt=None, masked_true_delt=None):
        """ Compute train metrics. Warning: the predictions and the true values are masked, but the relevant entriese
            need to be computed before calculating the loss

            masked_pred, masked_true: placeholders
            log : boolean. """

        node_mask = masked_true.node_mask
        bs, n = node_mask.shape


        true_X          = masked_true.X[node_mask]       # q x 4
        masked_pred_X   = masked_pred.X[node_mask]       # q x 4

        diag_mask       = ~torch.eye(n, device=node_mask.device, dtype=torch.bool).unsqueeze(0).repeat(bs, 1, 1)
        edge_mask       = diag_mask & node_mask.unsqueeze(-1) & node_mask.unsqueeze(-2)
        masked_pred_E   = masked_pred.E[edge_mask]       # r x 5
        true_E          = masked_true.E[edge_mask]       # r x 5

        # Check that the masking is correct
        assert (true_X != 0.).any(dim=-1).all()
        assert (true_E != 0.).any(dim=-1).all()

        

        loss_X = self.node_loss(masked_pred_X, true_X) if true_X.numel() > 0 else 0.0
        loss_E = self.edge_loss(masked_pred_E, true_E) if true_E.numel() > 0 else 0.0

        summed_loss = self.lambda_train[0] * loss_X + \
                      self.lambda_train[1] * loss_E
        
        if(self.use_ins_del):
            # Extracts stuff
            s_mask          = node_mask & (masked_true.t_int < self.right_interval_start).expand((-1, node_mask.size(-1)))
            true_s          = masked_true.insert_time[s_mask]
            true_s_onehot   = nn.functional.one_hot(true_s.long(), num_classes=self.n_classes).squeeze(1)
            masked_pred_s   = masked_pred.insert_time[s_mask]
            loss_s          = self.s_loss(masked_pred_s, true_s_onehot)
            summed_loss    += (self.s_loss_lambda*loss_s) if node_mask.any() else 0.0

            delt_output     = masked_pred_delt.y
            delt_target     = masked_true_delt

            valid_delt_mask = (masked_true.t_int > self.left_interval_end) & \
                              (masked_true.t_int < self.right_interval_start)
            valid_delt_mask = valid_delt_mask.squeeze(-1)
            delt_output     = delt_output[valid_delt_mask]
            delt_target     = delt_target[valid_delt_mask]

            loss_delt       = self.delt_loss(delt_output, delt_target)
            summed_loss    += (self.delt_loss_lambda*loss_delt) if masked_true_delt.numel() > 0 else 0.0
        
        #default values
        if(self.use_charges):
            true_charges        = masked_true.charges[node_mask]       # q x 3
            masked_pred_charges = masked_pred.charges[node_mask]        # q x 3

            # Check that the masking is correct
            assert (true_charges != 0.).any(dim=-1).all()

            loss_charges = self.charges_loss(masked_pred_charges, true_charges) if true_charges.numel() > 0 else 0.0
            summed_loss += (self.lambda_train[3] * loss_charges)

        #default values
        if(self.use_3d):
            true_pos        = masked_true.pos[node_mask]       # q x 3
            masked_pred_pos = masked_pred.pos[node_mask]        # q x 3

            loss_pos        = self.train_pos_mse(masked_pred_pos, true_pos) if true_X.numel() > 0 else 0.0
            summed_loss    += (self.lambda_train[4] * loss_pos)
        


        #TODO: fix everywhere the lambda_train order.
        batch_loss = summed_loss
        # if(torch.isnan(batch_loss).any()):
        #     print(f"masked_pred_delt.y ({torch.isnan(masked_pred_delt.y).any()})\n", masked_pred_delt.y)
        #     print(f"masked_true_delt\n ({torch.isnan(masked_true_delt).any()})", masked_true_delt)
        #     print("We have a Nan:",
        #           f"loss_X: {loss_X}\n",
        #           f"loss_E: {loss_E}\n",
        #           f"loss_y: {loss_y}\n",
        #           f"loss_s: {loss_s}\n",
        #           f"loss_delt: {loss_delt}\n",
        #           f"loss_charges: {loss_charges}\n")

        if log:
            to_log = {f"{self.name}_loss/X_CE": self.lambda_train[0] * self.node_loss.compute() if true_X.numel() > 0 else -1,
                    f"{self.name}_loss/E_CE": self.lambda_train[1] * self.edge_loss.compute() if true_E.numel() > 0 else -1.0,
                    f"{self.name}_loss/y_CE": self.lambda_train[2] * self.y_loss.compute() if masked_true.y.numel() > 0 else -1.0,
                    f"{self.name}_loss/batch_loss": batch_loss.detach()} 
            
            if(self.use_ins_del):
                to_log.update({f"{self.name}_loss/s_CE": self.s_loss_lambda * self.s_loss.compute() if true_s.numel() > 0 else -1})
                to_log.update({f"{self.name}_loss/delt_CE": self.delt_loss_lambda * self.delt_loss.compute() if true_X.numel() > 0 else -1})
            if(self.use_charges):
                to_log.update({f"{self.name}_loss/charges_CE": self.lambda_train[3] * self.charges_loss.compute() if true_charges.numel() > 0 else -1})
            if(self.use_3d):
                to_log.update({f"{self.name}_loss/pos_mse": self.lambda_train[4] * self.train_pos_mse.compute() if true_X.numel() > 0 else -1})
        else:
            to_log = None
        
        if log and wandb.run:
            wandb.log(to_log, commit=True)
        return batch_loss, to_log

    def reset(self):
        metrics = [self.node_loss, self.edge_loss, self.y_loss]
        if(self.use_charges):
            metrics.append(self.charges_loss)
        if(self.use_3d):
            metrics.append(self.train_pos_mse)
        if(self.use_ins_del):
            metrics.extend([self.delt_loss, self.s_loss])
            
        for metric in metrics:
            metric.reset()

    def log_epoch_metrics(self):
        epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1.0
        epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1.0
        epoch_y_loss    = self.train_y_loss.compute()if self.y_loss.total_samples > 0 else -1.0

        overall_loss    = epoch_node_loss + epoch_edge_loss

        to_log = {f"{self.name}_epoch/x_CE": epoch_node_loss,
                  f"{self.name}_epoch/E_CE": epoch_edge_loss,
                  f"{self.name}_epoch/y_CE": epoch_y_loss}
        
        if(self.use_charges):
            epoch_charges_loss = self.charges_loss.compute() if self.charges_loss > 0 else -1.0
            to_log.update({f"{self.name}_epoch/charges_CE": epoch_charges_loss})
            overall_loss += epoch_charges_loss
        if(self.use_3d):
            epoch_pos_loss = self.train_pos_mse.compute() if self.train_pos_mse.total > 0 else -1.0
            to_log.update({f"{self.name}_epoch/pos_mse": epoch_pos_loss})
            overall_loss += epoch_pos_loss
        if(self.use_ins_del):
            epoch_ins_del_loss_s = self.s_loss.compute() if self.s_loss.total_samples > 0 else -1.0
            to_log.update({f"{self.name}_epoch/s_CE": epoch_ins_del_loss_s})
            epoch_ins_del_loss_delt = self.delt_loss.compute() if self.delt_loss > 0 else -1.0
            to_log.update({f"{self.name}_epoch/delt_CE": epoch_ins_del_loss_delt})
            overall_loss += (epoch_ins_del_loss_s + epoch_ins_del_loss_delt)

        to_log.update({f"{self.name}_epoch/overall_loss": overall_loss.detach()})

        if wandb.run:
            wandb.log(to_log, commit=False)
        
        return to_log