from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from .one_hot import one_hot
import numpy as np


# based on:
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

def soft_dice_loss_v2(input: torch.Tensor, target: torch.Tensor, eps: float = 1e-8,
                   global_weight: float = 1.) -> torch.Tensor:
    r"""Function that computes Sørensen-Dice Coefficient loss.
    Arthur: add ^2 for soft formulation

    See :class:`~kornia.losses.DiceLoss` for details.
    """
    if not torch.is_tensor(input):
        raise TypeError("Input type is not a torch.Tensor. Got {}"
                        .format(type(input)))

    if not len(input.shape) == 4:
        raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
                         .format(input.shape))

    if not input.shape[-2:] == target.shape[-2:]:
        raise ValueError("input and target shapes must be the same. Got: {} and {}"
                         .format(input.shape, input.shape))

    if not input.device == target.device:
        raise ValueError(
            "input and target must be in the same device. Got: {} and {}" .format(
                input.device, target.device))

    # compute softmax over the classes axis
    input_soft: torch.Tensor = F.softmax(input, dim=1)

    # create the labels one hot tensor
    target_one_hot: torch.Tensor = one_hot(
        target, num_classes=input.shape[1],
        device=input.device, dtype=input.dtype)

    # compute the actual dice score
    #  @20230412
    dims = (2, 3)
    intersection = torch.sum(input_soft * target_one_hot, dims)
    cardinality = torch.sum(torch.pow(input_soft, 2) + torch.pow(target_one_hot, 2), dims)

    # https://gist.github.com/jeremyjordan/9ea3032a32909f71dd2ab35fe3bacc08
    dice_score = (2. * intersection + eps) / (cardinality + eps)
    if global_weight*torch.mean(-dice_score + 1.) < 0:
        print('dice_score: ', global_weight*torch.mean(-dice_score + 1.))
        print('dice_score: ', torch.mean(-dice_score + 1.))
        np.save('dice_score.npy', dice_score.cpu().detach().clone().numpy())
        np.save('intersection.npy', intersection.cpu().detach().clone().numpy())
        np.save('cardinality.npy', cardinality.cpu().detach().clone().numpy())
        np.save('input_soft.npy', input_soft.cpu().detach().clone().numpy())
        np.save('target_one_hot.npy', target_one_hot.cpu().detach().clone().numpy())
        np.save('input.npy', input.cpu().detach().clone().numpy())
        np.save('target.npy', target.cpu().detach().clone().numpy())
        exit(0)

    return global_weight*torch.mean(-dice_score + 1.)


class SoftDiceLossV2(nn.Module):
    r"""Criterion that computes Sørensen-Dice Coefficient loss.
    Arthur: add ^2 for soft formulation

    According to [1], we compute the Sørensen-Dice Coefficient as follows:

    .. math::

        \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}

    where:
       - :math:`X` expects to be the scores of each class.
       - :math:`Y` expects to be the one-hot tensor with the class labels.

    the loss, is finally computed as:

    .. math::

        \text{loss}(x, class) = 1 - \text{Dice}(x, class)

    [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

    Shape:
        - Input: :math:`(N, C, H, W)` where C = number of classes.
        - Target: :math:`(N, H, W)` where each value is
          :math:`0 ≤ targets[i] ≤ C−1`.

    Examples:
        >>> N = 5  # num_classes
        >>> loss = DiceLoss()
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = loss(input, target)
        >>> output.backward()
    """

    def __init__(self, global_weight: float = 1.) -> None:
        super(SoftDiceLossV2, self).__init__()
        self.eps: float = 1e-6
        self.global_weight = global_weight

    def forward(  # type: ignore
            self,
            input: torch.Tensor,
            target: torch.Tensor) -> torch.Tensor:
        return soft_dice_loss_v2(input, target, self.eps, self.global_weight)
