import torch
import torch.nn as nn
import numpy as np
import xarray as xr
from xrspatial import proximity


class AsymmetricLoss(nn.Module):
    def __init__(
        self,
        gamma_neg=4,
        gamma_pos=2,
        clip=0.05,
        eps=1e-8,
        disable_torch_grad_focal_loss=True,
    ):
        super(AsymmetricLoss, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

    def forward(self, x, y, mask):
        """
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        mask: mask for the loss (exclude ocean pixels from the loss)
        """

        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
        loss = los_pos + los_neg

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w

        return -(loss * mask).sum()


class Ralloss(nn.Module):
    def __init__(
        self,
        gamma_neg=4,
        gamma_pos=2,
        clip=0.05,
        eps=1e-8,
        lamb=1.5,
        epsilon_neg=0.0,
        epsilon_pos=1.0,
        epsilon_pos_pow=-2.5,
        disable_torch_grad_focal_loss=False,
    ):
        super(Ralloss, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

        # parameters of Taylor expansion polynomials
        self.epsilon_pos = epsilon_pos
        self.epsilon_neg = epsilon_neg
        self.epsilon_pos_pow = epsilon_pos_pow
        self.margin = 1.0
        self.lamb = lamb

    def forward(self, x, y, mask):
        """
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        mask: mask for the loss (exclude ocean pixels from the loss)
        """
        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic Taylor expansion polynomials
        los_pos = y * (
            torch.log(xs_pos.clamp(min=self.eps))
            + self.epsilon_pos * (1 - xs_pos.clamp(min=self.eps))
            + self.epsilon_pos_pow * 0.5 * torch.pow(1 - xs_pos.clamp(min=self.eps), 2)
        )
        los_neg = (
            (1 - y)
            * (
                torch.log(xs_neg.clamp(min=self.eps))
                + self.epsilon_neg * (xs_neg.clamp(min=self.eps))
            )
            * (self.lamb - x_sigmoid)
            * x_sigmoid**2
            * (self.lamb - xs_neg)
        )
        loss = los_pos + los_neg

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)  # pt = p if t > 0 else 1-p
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w

        return -(loss * mask).sum()


def an_full(batch, model, params, loc_to_feats, neg_type="hard"):
    inds = torch.arange(params["batch_size"])

    loc_feat, _, class_id = batch
    loc_feat = loc_feat.to(params["device"])
    class_id = class_id.to(params["device"])

    # assert model.inc_bias == False
    batch_size = loc_feat.shape[0]

    # create random background samples and extract features
    rand_loc = utils.rand_samples(batch_size, params["device"], rand_type="spherical")
    if params["input_enc"] == "spherical_harmonics":
        rand_feat = rand_loc
    else:
        rand_feat = loc_to_feats(rand_loc, normalize=False)

    # get location embeddings
    loc_cat = torch.cat((loc_feat, rand_feat), 0)  # stack vertically
    if params["input_enc"] == "spherical_harmonics":
        loc_emb_cat = model(loc_cat)
        loc_emb = loc_emb_cat[:batch_size, :]
        loc_emb_rand = loc_emb_cat[batch_size:, :]
        loc_pred = torch.sigmoid(loc_emb)
        loc_pred_rand = torch.sigmoid(loc_emb_rand)
    else:
        loc_emb_cat = model(loc_cat, return_feats=True)
        loc_emb = loc_emb_cat[:batch_size, :]
        loc_emb_rand = loc_emb_cat[batch_size:, :]

        loc_pred = torch.sigmoid(model.class_emb(loc_emb))
        loc_pred_rand = torch.sigmoid(model.class_emb(loc_emb_rand))

    # data loss
    if neg_type == "hard":
        loss_pos = neg_log(1.0 - loc_pred)  # assume negative
        loss_bg = neg_log(1.0 - loc_pred_rand)  # assume negative
    elif neg_type == "entropy":
        loss_pos = -1 * bernoulli_entropy(1.0 - loc_pred)  # entropy
        loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred_rand)  # entropy
    else:
        raise NotImplementedError
    loss_pos[inds[:batch_size], class_id] = params["pos_weight"] * neg_log(
        loc_pred[inds[:batch_size], class_id]
    )

    # total loss
    loss = loss_pos.mean() + loss_bg.mean()

    return loss
