import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

from .base import kl_div, get_region_size, TwoItemsLoss

_EPS = 1e-10


class DiceWithLogitsLoss(nn.Module):
    """
    Segmentation dice loss in http://arxiv.org/abs/1912.11619 combined with sigmoid function for multi-label setting
    The loss can be described as:
        l = 1 - (2 * intersect) / (|X^2| + |Y^2|)
    Args:
        Input : Tensor of shape (minibatch, C, h, w)
        Target : Tensor of shape (minibatch, C, h, w)
    """
    def __init__(self, log_loss=False) -> None:
        super(DiceWithLogitsLoss, self).__init__()
        self.act = nn.Sigmoid()
        self.log_loss = log_loss
        # self.act = nn.Softmax(dim=1)

    def forward(self, input : torch.Tensor, target : torch.Tensor) -> torch.Tensor:
        prob = self.act(input).type(torch.float32)
        target = target.type(torch.float32)

        intersection = torch.einsum("bcwh,bcwh->bc", prob, target)
        union = (torch.einsum("bcwh->bc", prob) + torch.einsum("bcwh->bc", target))
        dice_score = (2 * intersection + _EPS) / (union + _EPS)
        if self.log_loss:
            loss = - torch.log(dice_score).mean()
        else:
            loss = (1 - dice_score).mean()

        return loss


class BCEWithDiceWithLogitsLoss(TwoItemsLoss):
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1.0,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 log_loss : bool = False) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.log_loss = log_loss
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceWithLogitsLoss(self.log_loss)

    def forward(self, inputs : torch.Tensor, targets : torch.Tensor) -> torch.Tensor:
        loss_bce = self.bce(inputs, targets)
        loss_dice = self.dice(inputs, targets)
        loss = loss_bce + self.alpha * loss_dice
        return loss


class PartialBCEWithRegionKL(TwoItemsLoss):
    """
    The loss can be described as :
        l = - sum y_i log(p_i) + alpha * KL(gt_region, prob_region)
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1.0,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 1.0) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.thres = thres
        self.temp = temp
        self.act = nn.Sigmoid()

    def forward(self, inputs : torch.Tensor, targets : List[torch.Tensor]):
        masks, gt_region = targets
        # KL divergence term
        probs = self.act(self.temp * inputs)  # temperature for kl term
        prob_fg = get_region_size(probs)
        prob_bg = get_region_size(1 - probs)
        loss_kl = kl_div(gt_region, prob_fg) + kl_div(1 - gt_region, prob_bg)
        # foreground (first) term of cross entropy
        probs = self.act(inputs)  # no temperature in CE term
        N = probs.size(0)
        probs = probs.view(N, -1)
        masks = masks.view(N, -1)
        loss_bce = - (masks * torch.log(probs + _EPS)).mean()

        loss = loss_bce + self.alpha * loss_kl

        return loss, loss_bce, loss_kl


class PartialBCEWithRegionKLv2(TwoItemsLoss):
    """
    The loss can be described as :
        l = - sum y_i log(p_i) + alpha * KL(gt_region, prob_region)
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1.0,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 1.0) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.thres = thres
        self.temp = temp
        self.act = nn.Sigmoid()

    def forward(self, inputs : torch.Tensor, targets : List[torch.Tensor]):
        masks, gt_region = targets
        # KL divergence term
        probs = self.act(self.temp * inputs)  # temperature for kl term
        prob_fg = get_region_size(probs)
        # prob_bg = get_region_size(1 - probs)
        loss_kl = kl_div(gt_region, prob_fg)
        # foreground (first) term of cross entropy
        probs = self.act(inputs)  # no temperature in CE term
        N = probs.size(0)
        probs = probs.view(N, -1)
        masks = masks.view(N, -1)
        loss_bce = - (masks * torch.log(probs + _EPS)).mean()

        loss = loss_bce + self.alpha * loss_kl

        return loss, loss_bce, loss_kl


class BCEWithRegionKL(nn.Module):
    """
    Binary cross entropy loss with region size priors measured by KL divergence.
    The loss can be described as:
        l = BCE(X, Y) + alpha * KL(gt_region, prob_region)
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 20.) -> None:
        super().__init__()
        self.alpha = alpha
        self.max_alpha = max_alpha
        self.factor = factor
        self.step_size = step_size
        self.thres = thres
        self.temp = temp
        self.act = nn.Sigmoid()
        self.bce = nn.BCEWithLogitsLoss()

    def kl_div(self, p : torch.Tensor, q : torch.Tensor) -> torch.Tensor:
        x = p * torch.log(p / q)
        return x.mean()

    def forward(self, input : torch.Tensor, targets : List[torch.Tensor]):
        target, region_size = targets
        # cross entropy loss
        loss_bce = self.bce(input, target)
        # size prior
        prob = self.act(self.temp * input)
        # thres = F.threshold(prob, threshold=self.thres, value=0)
        # prob_region_size = get_region_size(thres)
        # loss_kl = kl_div(region_size, prob_region_size)

        # prob_fg_size = get_region_size(F.threshold(prob, threshold=self.thres, value=0))
        prob_fg_size = get_region_size(prob)

        # prob_bg_size = get_region_size(F.threshold(1 - prob, threshold=self.thres, value=0))
        loss_kl = self.kl_div(region_size, prob_fg_size) + self.kl_div(1 - region_size, 1 - prob_fg_size)
        # loss_kl = kl_div(region_size, prob_fg_size)

        loss = loss_bce + self.alpha * loss_kl

        return loss, loss_bce, loss_kl

    def adjust_alpha(self, epoch : int) -> None:
        if self.step_size == 0:
            return
        if (epoch + 1) % self.step_size == 0:
            self.alpha = min(self.alpha * self.factor, self.max_alpha)

        # if (epoch + 1) % 20 == 0:
        #     self.temp = max(self.temp / 2, 0.05)


class BCEWithRegionLog(TwoItemsLoss):
    """
    The loss can be described as:
        l = BCE(X, Y) + alpha * log(prob_region)
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 1.) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.thres = thres
        self.temp = temp
        self.act = nn.Sigmoid()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, inputs : torch.Tensor, targets : List[torch.Tensor]):
        masks, gt_region = targets
        loss_bce = self.bce(inputs, masks)
        # sizer regularizer
        probs = self.act(self.temp * inputs)
        prob_fg_size = get_region_size(probs)
        loss_log = torch.log(prob_fg_size + gt_region + _EPS).mean()

        loss = loss_bce + self.alpha * loss_log

        return loss, loss_bce, loss_log


class BCEWithRegionMAE(TwoItemsLoss):
    """
    Binary cross entropy loss with region size priors measured by MAE(L1).
    The loss can be described as:
        l = BCE(X, Y) + alpha * MAE(gt_region, prob_region)
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 1.) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.thres = thres
        self.temp = temp
        self.act = nn.Sigmoid()
        self.bce = nn.BCEWithLogitsLoss()
        self.mae = nn.L1Loss()

    def forward(self, inputs : torch.Tensor, targets : List[torch.Tensor]):
        masks, gt_region = targets
        loss_bce = self.bce(inputs, masks)
        # sizer regularizer
        probs = self.act(self.temp * inputs)
        prob_fg_size = get_region_size(probs)
        # prob_fg_size = get_region_size(F.threshold(probs, threshold=self.thres, value=0))
        # loss_mae = gt_region * (gt_region - prob_fg_size).abs()
        # loss_mae = loss_mae.mean()
        loss_mae = self.mae(gt_region, prob_fg_size)

        loss = loss_bce + self.alpha * loss_mae

        return loss, loss_bce, loss_mae


class BCEWithRegionMSE(TwoItemsLoss):
    """
    Binary cross entropy loss with region size priors measured by KL divergence.
    The loss can be described as:
        l = BCE(X, Y) + alpha * MSE(gt_region, prob_region)
    """
    def __init__(self, alpha : float = 0.1,
                 factor : float = 1,
                 step_size : int = 0,
                 max_alpha : float = 100,
                 thres : float = 0.5,
                 temp : float = 1.) -> None:
        super().__init__(alpha, factor, step_size, max_alpha)
        self.thres = thres
        self.temp = temp
        self.act = nn.Sigmoid()
        self.bce = nn.BCEWithLogitsLoss()
        self.mse = nn.MSELoss()

    def forward(self, inputs : torch.Tensor, targets : List[torch.Tensor]):
        masks, gt_region = targets
        # cross entropy loss
        loss_bce = self.bce(inputs, masks)
        # size prior
        probs = self.act(self.temp * inputs)
        # prob_fg_size = get_region_size(F.threshold(probs, threshold=self.thres, value=0))
        prob_fg_size = get_region_size(probs)
        loss_mse = self.mse(gt_region, prob_fg_size)

        loss = loss_bce + self.alpha * loss_mse

        return loss, loss_bce, loss_mse


class DiceWithRegionKL(nn.Module):
    def __init__(self, alpha : float = 0.1) -> None:
        super(DiceWithRegionKL, self).__init__()
        self.alpha = alpha
        self.act = nn.Sigmoid()
        self.dice = DiceWithLogitsLoss()

    def forward(self, input : torch.Tensor, target : torch.Tensor, region_size : torch.Tensor):
        loss_dice = self.dice(input, target)
        # size prior
        prob = self.act(self.input)
        prob_region_size = get_region_size(prob)
        loss_kl = kl_div(region_size, prob_region_size)

        loss = loss_dice + self.alpha * loss_kl

        return loss, loss_dice, loss_kl

    def adjust_alpha(self, epoch : int) -> None:
        if (epoch + 1) % 5 == 0:
            self.alpha = min(self.alpha * 5, 20)
