import torch
from torch import Tensor, nn

import smlm
from smlm.losses.cov_weighting import CoVWeighting


def optimal_transport_loss(
    x: Tensor,
    p: Tensor,
    logsigma: Tensor,
    x_gt: Tensor,
    mask_gt: Tensor,
    mask_dist: float,
    reg: float,
):
    bs, n_preds, _ = x.size()
    device, dtype = x.device, x.dtype

    atom_w = 1 / n_preds
    b = torch.full((bs, n_preds), atom_w, device=device, dtype=dtype)
    a = torch.where(mask_gt, atom_w, 0.0)
    bg = 1.0 - a.sum(1, keepdim=True)
    a = torch.cat([a, bg], dim=1)

    # p loss func is a logit BCE
    p = p[:, None]
    lp1 = torch.nn.functional.softplus(-p, beta=1.0)
    lp0 = p + lp1

    # distance loss func is a gaussian with learnt uncertainties
    delta = x[:, None, :] - x_gt[:, :, None]
    delta = delta / torch.exp(logsigma)
    delta = torch.square(delta)
    delta = 0.5 * delta + logsigma

    loss = torch.sum(delta, dim=-1) + lp1
    reg = reg * loss.median()
    inf = 10 * loss.max()
    cost = torch.where(mask_dist, loss, inf)
    loss = torch.cat([loss, lp0], dim=1)
    cost = torch.cat([cost, lp0], dim=1)

    K = sinkhorn(a, b, cost, reg=reg)
    loss = (loss * K).view(bs, -1).sum(-1)
    loss = loss.mean()
    return loss


def focal_loss(x: Tensor, y: Tensor, gamma: float, alpha: float = 1.0):
    bce = binary_cross_entropy(x, y)
    pt = x * y + (1.0 - x) * (1.0 - y)
    loss = alpha * (1.0 - pt) ** gamma * bce
    return loss


def binary_cross_entropy(x: Tensor, y: Tensor) -> Tensor:
    return torch.nn.functional.binary_cross_entropy(x, y, reduction="none")


def sinkhorn(a: torch.Tensor, b: torch.Tensor, K: torch.Tensor, reg: float):
    assert a.ndim == 2 and b.ndim == 2 and K.ndim == 3

    loga = torch.log(a)
    logb = torch.log(b + 1e-12)
    u = torch.zeros_like(a)
    v = torch.zeros_like(b)
    K = K / (-reg)

    for _ in range(20):
        u = loga - torch.logsumexp(K + v[:, None, :], dim=-1)
        v = logb - torch.logsumexp(K + u[:, :, None], dim=-2)

    logP = u[:, :, None] + K + v[:, None, :]
    return torch.exp(logP)


class OptimalTransportLoss(nn.Module):
    def __init__(self, reg: float):
        super().__init__()
        self.reg = reg
        self.weighter = CoVWeighting(2, eps=1e-12, t_lim=float("inf"))

    def forward(
        self, x: Tensor, p: Tensor, rsigma: Tensor, x_gt: Tensor, mask_gt: Tensor
    ):
        losses = optimal_transport_loss(
            x=x, p=p, rsigma=rsigma, x_gt=x_gt, mask_gt=mask_gt, reg=self.reg
        )
        weights = self.weighter(losses)
        loss = torch.dot(weights, losses)
        return loss


class OptimalTransportWithUncertaintiesLoss(nn.Module):
    def __init__(self, reg: float, pixel_size: Tensor, photon_cst: float):
        super().__init__()
        self.reg = reg
        self.logsigma = torch.nn.Parameter(torch.ones((4,)))
        z_chara_scale = 2 * pixel_size.max()
        normalize = torch.cat([pixel_size, Tensor([z_chara_scale, photon_cst])], dim=0)
        self.register_buffer("rnormalize", normalize.reciprocal())
        # TODO: hardcoded
        self.threshold = 2 * (1.5**2)
        grid = smlm.utils.coordinates.map_coordinates_cell_center(
            h=32,  # hardcoded
            w=32,
            cell_width=2,
            cell_height=2,
            device=torch.get_default_device(),
        )
        grid = grid.unsqueeze(0)
        grid = smlm.utils.map2list.map2list(grid)
        self.register_buffer("grid", grid)

    def forward(self, x: Tensor, p: Tensor, x_gt: Tensor, mask_gt: Tensor):
        x = x * self.rnormalize
        x_gt = x_gt * self.rnormalize

        delta = self.grid[:, None, :] - x_gt[:, :, None, :2]
        delta = torch.square(delta)
        delta = torch.sum(delta, dim=-1)
        mask = delta < self.threshold

        return optimal_transport_loss(
            x=x,
            p=p,
            logsigma=self.logsigma,
            x_gt=x_gt,
            mask_gt=mask_gt,
            mask_dist=mask,
            reg=self.reg,
        )
