import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional

from .base import TwoItemsLoss, get_region_size
from .dice import DiceLoss

_EPS = 1e-10


def _expand_onehot_labels(labels, target_shape, ignore_index):
    """Expand onehot labels to match the size of prediction."""
    bin_labels = labels.new_zeros(target_shape)
    valid_mask = (labels >= 0) & (labels != ignore_index)
    inds = torch.nonzero(valid_mask, as_tuple=True)

    if inds[0].numel() > 0:
        if labels.dim() == 3:
            bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
        else:
            bin_labels[inds[0], labels[valid_mask]] = 1

    return bin_labels, valid_mask


class CEWithRegionL2(TwoItemsLoss):
    """
    Cross entropy loss with region size priors measured by l1.
    The loss can be described as:
        l = BCE(X, Y) + alpha * (gt_region - prob_region)^2
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 1.,
                 ignore_index : int = 255,
                 background_index : int = -1,
                 weight=None) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.thres = thres
        self.temp = temp
        self.ignore_index = ignore_index
        self.background_index = background_index
        self.weight = weight
        self.act = nn.Softmax(dim=1)
        self.ce = nn.CrossEntropyLoss(weight=self.weight, ignore_index=ignore_index)

    def forward(self, inputs : torch.Tensor, labels : torch.Tensor):
        loss_ce = self.ce(inputs, labels)
        # sizer regularizer
        bin_labels, valid_mask = _expand_onehot_labels(labels, inputs.shape, ignore_index=self.ignore_index)
        labels_size = get_region_size(bin_labels, valid_mask)
        probs = self.act(self.temp * inputs)
        probs_size = get_region_size(probs, valid_mask)
        # Ugly implementation : we only consider foreground classes and 0 is the background class
        if self.background_index >= 0 and self.background_index < probs_size.shape[1]:
            probs_size = torch.cat(
                (probs_size[:, :self.background_index], probs_size[:, self.background_index + 1:]),
                dim=1
            )
            labels_size = torch.cat(
                (labels_size[:, :self.background_index], labels_size[:, self.background_index + 1:]),
                dim=1
            )
        loss_l2 = (probs_size - labels_size).square()
        if self.weight is not None:
            loss_l2 = torch.einsum("k,ik->ik", self.weight.to(loss_l2.device), loss_l2).mean()
        else:
            loss_l2 = loss_l2.mean()

        loss = loss_ce + self.alpha * loss_l2

        return loss, loss_ce, loss_l2


class CEWithRegionL1(TwoItemsLoss):
    """
    Cross entropy loss with region size priors measured by l1.
    The loss can be described as:
        l = BCE(X, Y) + alpha * |gt_region - prob_region|
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 1.,
                 ignore_index : int = 255,
                 background_index : int = -1,
                 weight=None) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.thres = thres
        self.temp = temp
        self.ignore_index = ignore_index
        self.background_index = background_index
        self.weight = weight
        self.act = nn.Softmax(dim=1)
        self.ce = nn.CrossEntropyLoss(weight=self.weight, ignore_index=ignore_index)

    def forward(self, inputs : torch.Tensor, labels : torch.Tensor):
        loss_ce = self.ce(inputs, labels)
        # sizer regularizer
        bin_labels, valid_mask = _expand_onehot_labels(labels, inputs.shape, ignore_index=self.ignore_index)
        labels_size = get_region_size(bin_labels, valid_mask)
        probs = self.act(self.temp * inputs)
        probs_size = get_region_size(probs, valid_mask)
        # Ugly implementation : we only consider foreground classes and 0 is the background class
        if self.background_index >= 0 and self.background_index < probs_size.shape[1]:
            probs_size = torch.cat(
                (probs_size[:, :self.background_index], probs_size[:, self.background_index + 1:]),
                dim=1
            )
            labels_size = torch.cat(
                (labels_size[:, :self.background_index], labels_size[:, self.background_index + 1:]),
                dim=1
            )
        loss_l1 = (probs_size - labels_size).abs()
        if self.weight is not None:
            loss_l1 = torch.einsum("k,ik->ik", self.weight.to(loss_l1.device), loss_l1).mean()
        else:
            loss_l1 = loss_l1.mean()

        loss = loss_ce + self.alpha * loss_l1

        return loss, loss_ce, loss_l1


class CEWithRegionDice(TwoItemsLoss):
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 ignore_index : int = 255) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.ignore_index = ignore_index
        self.ce = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
        self.dice = DiceLoss(mode="multiclass", log_loss=True, ignore_index=self.ignore_index)

    def forward(self, inputs : torch.Tensor, labels : torch.Tensor):
        loss_ce = self.ce(inputs, labels)
        loss_dice = self.dice(inputs, labels)
        loss = loss_ce + self.alpha * loss_dice

        return loss, loss_ce, loss_dice


class CEWithRegionKL(TwoItemsLoss):
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 temp : float = 1.,
                 ignore_index : int = 255) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.temp = temp
        self.ignore_index = ignore_index
        self.act = nn.Softmax(dim=1)
        self.ce = nn.CrossEntropyLoss(ignore_index=ignore_index)

    def kl_div(self, p : torch.Tensor, q : torch.Tensor) -> torch.Tensor:
        x = p * torch.log(p / q)
        return x.mean()

    def forward(self, inputs : torch.Tensor, labels : torch.Tensor):
        loss_ce = self.ce(inputs, labels)
        # sizer regularizer
        bin_labels, valid_mask = _expand_onehot_labels(labels, inputs.shape, ignore_index=self.ignore_index)
        labels_size = get_region_size(bin_labels, valid_mask)
        probs = self.act(self.temp * inputs)
        probs_size = get_region_size(probs, valid_mask)
        loss_kl = self.kl_div(labels_size, probs_size)

        loss = loss_ce + self.alpha * loss_kl

        return loss, loss_ce, loss_kl
