import torch
from torch.nn import Module
from .bregman_pytorch import sinkhorn


class OT_Loss(Module):
    def __init__(self, num_of_iter_in_ot=100, reg=10.0, method="sinkhorn"):
        super(OT_Loss, self).__init__()
        self.num_of_iter_in_ot = num_of_iter_in_ot
        self.reg = reg
        self.method = method

    def forward(
        self, t_scores, s_scores, pts, cost_type="all", clamp_ot=False, aux_cost=None
    ):
        """
        Calculating OT loss between teacher and student's distribution.
        Cost map is defined as: cost = dist(p_t, p_s) + dist(score_t, score_s).
        All dist are l2 distance.
        Args:
            t_scores: Tensor with shape (N, )
            s_scores: Tensor with shape (N, )

        Returns:

        """
        assert cost_type in ["all", "dist", "score"]
        with torch.no_grad():
            t_scores_prob = torch.softmax(t_scores, dim=0)
            s_scores_prob = torch.softmax(s_scores, dim=0)
            score_cost = (
                t_scores.detach().unsqueeze(1) - s_scores.detach().unsqueeze(0)
            ) ** 2
            score_cost = score_cost / score_cost.max()
            if cost_type in ["all", "dist"]:
                coord_x = pts[:, 0]
                coord_y = pts[:, 1]
                dist_x = (coord_x.reshape(1, -1) - coord_x.reshape(-1, 1)) ** 2
                dist_y = (coord_y.reshape(1, -1) - coord_y.reshape(-1, 1)) ** 2
                dist_cost = (dist_x + dist_y).to(t_scores_prob.device)
                dist_cost = dist_cost / dist_cost.max()
                if cost_type == "all":
                    cost_map = dist_cost + score_cost
                else:
                    cost_map = dist_cost
            else:
                cost_map = score_cost
            if not isinstance(aux_cost, type(None)):
                cost_map = cost_map + aux_cost
            # cost_map = (dist_cost + score_cost) / 2
            source_prob = s_scores_prob.detach().view(-1)
            target_prob = t_scores_prob.detach().view(-1)
            if t_scores.shape[0] < 2000:  # 2500
                _, log = sinkhorn(
                    target_prob,
                    source_prob,
                    cost_map,
                    self.reg,
                    maxIter=self.num_of_iter_in_ot,
                    log=True,
                    method=self.method,
                )
                beta = log["beta"]  # size is the same as source_prob: [#cood * #cood]
            else:
                _, log = sinkhorn(
                    target_prob.cpu(),
                    source_prob.cpu(),
                    cost_map.cpu(),
                    self.reg,
                    maxIter=self.num_of_iter_in_ot,
                    log=True,
                    method=self.method,
                )
                beta = log["beta"].to(
                    target_prob.device
                )  # size is the same as source_prob: [#cood * #cood]
        # compute the gradient of OT loss to predicted density (unnormed_density).
        # im_grad = beta / source_count - < beta, source_density> / (source_count)^2
        source_density = s_scores.detach().view(-1)
        source_count = source_density.sum()
        im_grad_1 = (
            (source_count) / (source_count * source_count + 1e-8) * beta
        )  # size of [#cood * #cood]
        im_grad_2 = (source_density * beta).sum() / (
            source_count * source_count + 1e-8
        )  # size of 1
        im_grad = im_grad_1 - im_grad_2
        im_grad = im_grad.detach()
        # Define loss = <im_grad, predicted density>. The gradient of loss w.r.t prediced density is im_grad.
        if clamp_ot:
            return torch.clamp_min(torch.sum(s_scores * im_grad), 0)
        return torch.sum(s_scores * im_grad)
