import torch
import torch.nn as nn
import os
from omegaconf import OmegaConf
from hydra.utils import instantiate
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))

class ApproxEMDCriterion(nn.Module):
    def __init__(self, ckpt_path):
        super().__init__()

        model_config_path = os.path.abspath(os.path.join(ckpt_path, '../../.hydra/config.yaml'))
        model_config = OmegaConf.load(model_config_path).model
        model_config._target_ = model_config._target_model_
        self.model = instantiate(model_config)

        state_dict = torch.load(ckpt_path)['state_dict'] 
        state_dict = {key[6:]: val for key, val in state_dict.items()}
        self.model.load_state_dict(state_dict)

    def forward(self, output_set, output_mask, target_set, target_mask, reduced=True):
        """
        Compute loss
        :param output
        :param target_set: Tensor([B, N, 2])
        :param target_mask: Tensor([B, N])
        :param args
        :param epoch
        """

        #with torch.no_grad():
        if output_set.dim() == 2:
            output_set = output_set.unsqueeze(0)

        if target_set.dim() == 2:
            target_set = target_set.unsqueeze(0)

        out = self.model(output_set, target_set)  # [B, 1] -> scalar

        if isinstance(out, tuple):
            l2_loss = out[0]
        else:
            l2_loss = out

        if l2_loss.dim() == 2:
            l2_loss = l2_loss.squeeze(1)

        if reduced:
            return l2_loss.mean()
        else:
            return l2_loss


class ChamferCriterion(nn.Module):
    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(output_set, output_mask, target_set, target_mask, reduced=True):
        """
        Compute loss
        :param output
        :param target_set: Tensor([B, N, 2])
        :param target_mask: Tensor([B, N])
        :param args
        :param epoch
        """
        l2_loss = chamfer_loss(output_set, output_mask, target_set, target_mask)  # [B,] -> scalar

        if reduced:
            return l2_loss.mean()
        else:
            return l2_loss


class EMDCriterion(nn.Module):
    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(output_set, output_mask, target_set, target_mask, reduced=True):
        """
        Compute loss
        :param output
        :param target_set: Tensor([B, N, 2])
        :param target_mask: Tensor([B, N])
        :param args
        :param epoch
        """
        accelerate = False
        l1_loss = emd_loss(output_set, output_mask, target_set, target_mask)  # [B,] -> scalar

        if reduced:
            return l1_loss.mean()
        else:
            return l1_loss


class CombinedCriterion(nn.Module):
    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(output_set, output_mask, target_set, target_mask, reduced=True):
        """
        Compute loss
        :param output
        :param target_set: Tensor([B, N, 2])
        :param target_mask: Tensor([B, N])
        :param args
        :param epoch
        """
        cd = chamfer_loss(output_set, output_mask, target_set, target_mask)  # [B,] -> scalar
        emd = emd_loss(output_set, output_mask, target_set, target_mask)  # [B,] -> scalar
        recon_loss = cd + emd

        if reduced:
            return recon_loss.mean()
        else:
            return recon_loss


# function to compute chamfer loss between two point sets
def chamfer_loss(x, x_mask, y, y_mask):
    """
    :param x: Tensor([B, N, 2])
    :param x_mask: Tensor([B, N])
    :param y: Tensor([B, M, 2])
    :param y_mask: Tensor([B, M])
    :return: Tensor([B,])
    """
    if x.dim() == 2:
        x = x.unsqueeze(0)

    if y.dim() == 2:
        y = y.unsqueeze(0)

    if x_mask is None:
        x_mask = torch.zeros_like(x[:, :, 0]).bool()

    if y_mask is None:
        y_mask = torch.zeros_like(y[:, :, 0]).bool()

    
    x_mask = x_mask.unsqueeze(-1).permute(0, 2, 1)  # [B, 1, N]
    y_mask = y_mask.unsqueeze(-1).permute(0, 2, 1)  # [B, 1, M]

    # compute distance matrix
    x = x.permute(0, 2, 1)  # [B, 2, N]
    y = y.permute(0, 2, 1)  # [B, 2, M]
    dist = torch.sum((x.unsqueeze(3) - y.unsqueeze(2)) ** 2, 1)  # [B, N, M]
    # dist = torch.cdist(x, y) # (B, N, M)

    # mask out invalid entries
    dist = dist * (~x_mask) * (~y_mask).permute(0, 2, 1)  # [B, N, M]

    # compute loss
    loss = torch.sum(torch.min(dist, 2)[0] + torch.min(dist, 1)[0], 1)  # [B,]

    return loss

# scipy function to compute optimal assignment
from scipy.optimize import linear_sum_assignment
def emd_loss(pc_source, pc_source_mask=None, pc_target=None, pc_target_mask=None):
    '''
        pc_source : [torch.Tensor] B x M x 2
        pc_target : [torch.Tensor] B x M x 2
    '''
    if pc_source.dim() == 2:
        pc_source = pc_source[None]
        pc_target = pc_target[None]
        
    # Compute pairwise distances
    dists = torch.cdist(pc_source, pc_target) # B x N x M

    costs = []
    
    # Iterate over batch
    for dist in dists:
        # Compute optimal assignment
        assignment = linear_sum_assignment(dist.detach().cpu().numpy())
        # Compute optimal cost
        cost = torch.sum(dist[assignment])
        costs.append(cost)

    costs = torch.stack(costs)
    return costs