import torch
from torch import nn
import numpy as np

class DiceLoss(nn.Module):
    """
    Dice Loss for semantic segmentation.

    Args:
        n_classes (int): Number of classes in the segmentation task.
    """
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes

    def _one_hot_encoder(self, input_tensor):
        """
        Convert input tensor into one-hot encoded tensor.

        Args:
            input_tensor (torch.Tensor): Input tensor of shape (batch_size, height, width).

        Returns:
            torch.Tensor: One-hot encoded tensor of shape (batch_size, n_classes, height, width).
        """
        tensor_list = [input_tensor == i for i in range(self.n_classes)]
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        """
        Calculate the Dice loss for a single class.

        Args:
            score (torch.Tensor): Predicted score for a class.
            target (torch.Tensor): Target mask for a class.

        Returns:
            torch.Tensor: Dice loss for the class.
        """
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        """
        Calculate the Dice loss for a given prediction and target.

        Args:
            inputs (torch.Tensor): Predicted logits from the model.
            target (torch.Tensor): Target mask.
            weight (list, optional): Class weights for loss computation.
            softmax (bool, optional): Whether to apply softmax to the inputs.

        Returns:
            torch.Tensor: Calculated Dice loss.
        """
        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
        class_wise_dice = []
        loss = 0.0
        for i in range(self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes

class MeanIoU(nn.Module):
    """
    Mean Intersection over Union (IoU) metric for semantic segmentation.

    Args:
        numLabels (int): Number of classes in the segmentation task.
        batchSize (int): Batch size used during training.
    """
    def __init__(self, numLabels=3, batchSize=1):
        super().__init__()
        self.numLabels = numLabels
        self.batchSize = batchSize

    def IoU_coef(self, y_pred, y_true, smooth=0.0001):
        """
        Calculate Intersection over Union (IoU) coefficient.

        Args:
            y_pred (torch.Tensor): Predicted segmentation mask.
            y_true (torch.Tensor): Ground truth segmentation mask.
            smooth (float, optional): Smoothing term.

        Returns:
            float: IoU coefficient.
        """
        y_true_f = y_true.flatten().cpu().detach().numpy()
        y_pred_f = y_pred.flatten()
        intersection = torch.sum(torch.Tensor(y_true_f * y_pred_f))
        total = torch.sum(torch.Tensor(y_true_f + y_pred_f))
        union = total - intersection
        return (intersection + smooth) / (union + smooth)

    def Mean_IoU(self, y_pred, y_true, smooth=0.0001):
        """
        Calculate Mean Intersection over Union (IoU) score.

        Args:
            y_pred (torch.Tensor): Predicted segmentation masks.
            y_true (torch.Tensor): Ground truth segmentation masks.
            smooth (float, optional): Smoothing term.

        Returns:
            float: Mean IoU score.
        """
        IoU_Score = 0
        for index in range(self.numLabels):
            if y_true.shape != (self.batchSize, self.numLabels, y_pred.shape[-1], y_pred.shape[-2]):
                if y_true.shape[0] != self.batchSize:
                    continue
                y_true = np.reshape(np.eye(3, dtype=int)[y_true.cpu()], (self.batchSize, self.numLabels, y_true.cpu().shape[-1],y_true.cpu().shape[-2]))
            IoU_Score += self.IoU_coef(y_true[:, index, :, :], y_pred[:, index, :, :], smooth=1)
        return IoU_Score / self.numLabels

    def forward(self, y_pred, y_true):
        """
        Calculate Mean Intersection over Union (IoU) score.

        Args:
            y_pred (torch.Tensor): Predicted segmentation masks.
            y_true (torch.Tensor): Ground truth segmentation masks.

        Returns:
            float: Mean IoU score.
        """
        return self.Mean_IoU(y_pred, y_true)