import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torchmetrics import Metric
from torchmetrics import MeanSquaredError


class CrossEntropyMetric(Metric):
    def __init__(self, cls_weight = None):
        super().__init__()
        self.cls_weight = cls_weight
        if self.cls_weight is not None:
            self.cls_weight = torch.tensor(self.cls_weight)
        self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: Tensor, target: Tensor) -> None:
        """ Update state with predictions and targets.
            preds: Predictions from model   (bs * n, d) or (bs * n * n, d)
            target: Ground truth values     (bs * n, d) or (bs * n * n, d). """
        if self.cls_weight is not None: self.cls_weight = self.cls_weight.to(preds.device)
        target = torch.argmax(target, dim=-1)
        output = F.cross_entropy(preds, target, reduction='sum', weight = self.cls_weight)
        self.total_ce += output
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_ce / self.total_samples


class TrainLoss(nn.Module):
    """ Train with Cross entropy"""
    def __init__(self, lambda_train, edge_cls_weight = None):
        super().__init__()

        self.node_loss = CrossEntropyMetric()
        self.edge_loss = CrossEntropyMetric(edge_cls_weight)        
        self.y_loss = MeanSquaredError()
        self.lambda_train = lambda_train

    def forward(self, masked_pred_X, masked_pred_E, pred_y, true_X, true_E, true_y):
        true_X = torch.reshape(true_X, (-1, true_X.size(-1)))  # (bs * n, dx)
        masked_pred_X = torch.reshape(masked_pred_X, (-1, masked_pred_X.size(-1)))  # (bs * n, dx)
        # Remove masked rows
        mask_X = (true_X != 0.).any(dim=-1)
        flat_true_X = true_X[mask_X, :]
        flat_pred_X = masked_pred_X[mask_X, :]
        loss_X = self.node_loss(flat_pred_X, flat_true_X) if true_X.numel() > 0 else 0.0

        # DAG, upper triangular
        mask_E = torch.triu(torch.ones(masked_pred_E.shape[:-1], device = masked_pred_E.device), diagonal = 1).reshape(-1).bool()
        true_E = torch.reshape(true_E, (-1, true_E.size(-1)))  # (bs * n * n, de)
        masked_pred_E = torch.reshape(masked_pred_E, (-1, masked_pred_E.size(-1)))   # (bs * n * n, de)
        # Remove masked rows
        mask_E = mask_E * (true_E != 0.).any(dim=-1)
        flat_true_E = true_E[mask_E, :]
        flat_pred_E = masked_pred_E[mask_E, :]
        loss_E = self.edge_loss(flat_pred_E, flat_true_E) if true_E.numel() > 0 else 0.0

        loss_y = self.y_loss(pred_y, true_y) if true_y.numel() > 0 else 0.0

        to_log = {"train_loss/X_CE": self.node_loss.compute() if true_X.numel() > 0 else -1,
                  "train_loss/E_CE": self.edge_loss.compute() if true_E.numel() > 0 else -1,
                  "train_loss/y_MSE": self.y_loss.compute() if true_y.numel() > 0 else -1}

        total_loss = self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + self.lambda_train[2] * loss_y

        return total_loss, to_log

    def reset(self):
        for metric in [self.node_loss, self.edge_loss, self.y_loss]:
            metric.reset()

    def log_epoch_metrics(self):
        epoch_node_loss = self.node_loss.compute() if self.node_loss.total_samples > 0 else -1
        epoch_edge_loss = self.edge_loss.compute() if self.edge_loss.total_samples > 0 else -1
        epoch_y_loss = self.y_loss.compute() if self.y_loss.total > 0 else -1

        to_log = {"train_epoch/X_CE": epoch_node_loss,
                  "train_epoch/E_CE": epoch_edge_loss,
                  "train_epoch/y_MSE": epoch_y_loss,
                  "train_epoch/loss": self.lambda_train[0] * epoch_node_loss + self.lambda_train[1] * epoch_edge_loss + self.lambda_train[2] * epoch_y_loss}

        return to_log



