import torch
from torch import Tensor
from torch.nn import KLDivLoss
from torch.nn import functional as F
from torchmetrics import Metric, MeanSquaredError


class TrainAbstractMetricsDiscrete(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, log: bool):
        pass

    def reset(self):
        pass

    def log_epoch_metrics(self, current_epoch):
        pass


class TrainAbstractMetrics(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, masked_pred_epsX, masked_pred_epsE, pred_y, true_epsX, true_epsE, true_y, log):
        pass

    def reset(self):
        pass

    def log_epoch_metrics(self, current_epoch):
        pass


class SumExceptBatchMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, values) -> None:
        self.total_value += torch.sum(values)
        self.total_samples += values.shape[0]

    def compute(self):
        return self.total_value / self.total_samples


class SumExceptBatchMSE(MeanSquaredError):
    def update(self, preds: Tensor, target: Tensor) -> None:
        """Update state with predictions and targets.

        Args:
            preds: Predictions from model
            target: Ground truth values
        """
        assert preds.shape == target.shape
        sum_squared_error, n_obs = self._mean_squared_error_update(preds, target)

        self.sum_squared_error += sum_squared_error
        self.total += n_obs

    def _mean_squared_error_update(self, preds: Tensor, target: Tensor):
            """ Updates and returns variables required to compute Mean Squared Error. Checks for same shape of input
            tensors.
                preds: Predicted tensor
                target: Ground truth tensor
            """
            diff = preds - target
            sum_squared_error = torch.sum(diff * diff)
            n_obs = preds.shape[0]
            return sum_squared_error, n_obs


class SumExceptBatchKL(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_value', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, p, q) -> None:
        self.total_value += F.kl_div(q, p, reduction='sum')
        self.total_samples += p.size(0)

    def compute(self):
        return self.total_value / self.total_samples

class CrossEntropyMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_ce', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds, target, target_weight=None, sample_weight=None) -> None:
        """ Update state with predictions and targets.
            preds: Predictions from model   (bs * n, d) or (bs * n * n, d)
            target: Ground truth values     (bs * n, d) or (bs * n * n, d). """
        target = torch.argmax(target, dim=-1)
        if target_weight is not None:
            target_weight = target_weight.type_as(preds)
            # output = F.cross_entropy(preds, target, weight = target_weight, reduction='sum')
            output = F.cross_entropy(preds, target, weight = target_weight, reduction='none')
        else:
            # output = F.cross_entropy(preds, target, reduction='sum')
            output = F.cross_entropy(preds, target, reduction='none')
        if sample_weight is not None:
            # print('sample_weight', sample_weight.shape, sample_weight[:10], sample_weight[100:200], sample_weight[-10:])
            # print('output', output.shape)
            output = output * sample_weight.view(-1)

        output = output.sum()
        self.total_ce += output
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_ce / self.total_samples

class KLDivergenceForPrior(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_kl', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.kl_div = KLDivLoss(reduction='none')

    def update(self, pred_X, pred_E, pred_pos, target_X, target_E, target_pos, sample_weight=None) -> None:
        eps = 1e-8
        pred_X = pred_X + eps
        pred_E = pred_E + eps 
        pred_pos = pred_pos + eps
        target_X = target_X + eps
        target_E = target_E + eps
        target_pos = target_pos + eps
        
        # Normalize along last dimension
        pred_X = F.softmax(pred_X, dim=-1)
        pred_E = F.softmax(pred_E, dim=-1)
        pred_pos = F.softmax(pred_pos, dim=-1)
        
        # Compute KL div
        output_X = self.kl_div(pred_X.log(), target_X)
        output_E = self.kl_div(pred_E.log(), target_E)
        output_pos = self.kl_div(pred_pos.log(), target_pos)
            
        if sample_weight is not None:
            output_X = output_X * sample_weight.view(-1)
            output_E = output_E * sample_weight.view(-1)
            output_pos = output_pos * sample_weight.view(-1)

        # Sum each output
        output_X = output_X.sum()
        output_E = output_E.sum()
        output_pos = output_pos.sum()

        # Combine all outputs
        output = output_X + output_E + output_pos
        self.total_kl += output
        self.total_samples += pred_X.size(0)

    def compute(self):
        return self.total_kl / self.total_samples
    
class KLDivergenceMetricForAtomDist(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_kl', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.kl_div = KLDivLoss(reduction='none')

    def update(self, preds, target, target_weight=None, sample_weight=None) -> None:
        eps = 1e-8
        preds = preds + eps
        target = target + eps        
        output = self.kl_div(preds.log(), target)
        if sample_weight is not None:
            output = output * sample_weight.view(-1)
        output = output.sum()

        self.total_kl += output
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_kl / self.total_samples

class FocalLossMetric(Metric):
    def __init__(self, gamma=1.2, alpha=None):
        """ Implements focal loss as a metric.
        
        Args:
            gamma (float): Focusing parameter (higher values focus more on hard samples).
            alpha (Tensor or None): Class weights for imbalanced datasets. If None, no weighting is applied.
        """
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha

        self.add_state('total_focal_loss', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds, target, target_weight=None, sample_weight=None) -> None:
        """ Update state with predictions and targets.
            preds: Predictions from model   (bs * n, d) or (bs * n * n, d)
            target: Ground truth values     (bs * n, d) or (bs * n * n, d).
        """
        target = torch.argmax(target, dim=-1)  # Convert one-hot to class indices

        log_probs = F.log_softmax(preds, dim=-1)  # Get log probabilities
        probs = torch.exp(log_probs)  # Convert log probabilities to probabilities

        # Gather the predicted probabilities for the correct class
        target_probs = probs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
        target_log_probs = log_probs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)

        # Compute Focal Loss component (1 - p_t)^gamma
        # focal_weight = (1 - target_probs) ** self.gamma
        focal_weight = target_probs ** self.gamma

        # Apply class weighting (if provided)
        if self.alpha is not None:
            alpha_factor = self.alpha.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
            focal_weight *= alpha_factor

        # Compute focal loss
        focal_loss = -focal_weight * target_log_probs

        # Apply sample weight if provided
        if sample_weight is not None:
            focal_loss *= sample_weight.view(-1)

        focal_loss = focal_loss.sum()

        self.total_focal_loss += focal_loss
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_focal_loss / self.total_samples

class ProbabilityMetric(Metric):
    def __init__(self):
        """ This metric is used to track the marginal predicted probability of a class during training. """
        super().__init__()
        self.add_state('prob', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: Tensor) -> None:
        self.prob += preds.sum()
        self.total += preds.numel()

    def compute(self):
        return self.prob / self.total


class NLL(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_nll', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, batch_nll) -> None:
        self.total_nll += torch.sum(batch_nll)
        self.total_samples += batch_nll.numel()

    def compute(self):
        return self.total_nll / self.total_samples

class BCEWithLogitsMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_bce', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds, target, target_weight=None, sample_weight=None) -> None:
        """ Update state with predictions and targets.
            preds: Predictions from model   (bs, d)
            target: Ground truth values     (bs, d).
            Each column in target represents a separate binary classification task. """
        
        if target_weight is not None:
            target_weight = target_weight.type_as(preds)
            output = F.binary_cross_entropy_with_logits(preds, target, weight=target_weight, reduction='none')
        else:
            output = F.binary_cross_entropy_with_logits(preds, target, reduction='none')
        
        # Sum across tasks (last dimension)
        output = output.sum(dim=-1)
        
        if sample_weight is not None:
            output = output * sample_weight.view(-1)

        output = output.sum()
        self.total_bce += output
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_bce / self.total_samples
    
class MSEMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state('total_mse', default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state('total_samples', default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds, target, target_weight=None, sample_weight=None) -> None:
        """ Update state with predictions and targets.
            preds: Predictions from model   (bs, d)
            target: Ground truth values     (bs, d).
            Each column in target represents a separate regression task. """
        
        # Calculate MSE
        output = F.mse_loss(preds, target, reduction='none')
        
        if target_weight is not None:
            target_weight = target_weight.type_as(preds)
            output = output * target_weight
            
        # Sum across tasks (last dimension)
        output = output.sum(dim=-1)
        
        if sample_weight is not None:
            output = output * sample_weight.view(-1)

        output = output.sum()
        self.total_mse += output
        self.total_samples += preds.size(0)

    def compute(self):
        return self.total_mse / self.total_samples
    
