import torch
import numpy as np
from datasets import *
from torch_geometric.loader import DataLoader
from scipy.spatial.distance import cdist
#from geomloss import SamplesLoss
#import ot
# from pytorch3d.ops import knn_points
from scipy.spatial import cKDTree
import utils
from scipy.spatial.distance import directed_hausdorff

# 1, further accelerate the computation of mmd?
# 2, pytorch3d.ops.knn_points is not working? would this accelerate the computation of mmd?
# 3, mmd sometimes get negative even samples are ro, why? may be difference between U-test and V-test

# faster version todo
# def stupid_mmd(x, y):
#     # x is list of point clouds
#     # y is list of point clouds
#     # pc size variable, len(x) and len(y) can be diff


# For QM9, use mask_x and mask_y to mask the invalid padding points
# For MD17, no mask is used
def compute_mmd(x, y,mask_x=None, mask_y=None, kernel_func=None):
    # x, y shape: [batch_size, n_points, 3] only position
    # mask_x, mask_y shape: [batch_size, n_points]
    n_x, n_y = len(x), len(y)

    # If input is list, turn it into tensor
    # x = torch.stack(x, dim=0)
    # y = torch.stack(y, dim=0)

    xx_indices = torch.triu_indices(n_x, n_x, offset=1).to(torch.long)
    if mask_x is not None:
        xx_distances = kernel_func(x[xx_indices[0]], x[xx_indices[1]], mask_x[xx_indices[0]], mask_x[xx_indices[1]])  # [n_x*(n_x-1)/2]
        xx_diag = kernel_func(x, x, mask_x, mask_x)  # [n_x]
    else:
        xx_distances = kernel_func(x[xx_indices[0]], x[xx_indices[1]])  # [n_x*(n_x-1)/2]
        xx_diag = kernel_func(x, x)  # [n_x]
    

    xx_mean = (xx_distances.sum() * 2 + xx_diag.sum()) / (n_x * n_x)

    yy_indices = torch.triu_indices(n_y, n_y, offset=1).to(torch.long)
    if mask_y is not None:
        yy_distances = kernel_func(y[yy_indices[0]], y[yy_indices[1]], mask_y[yy_indices[0]], mask_y[yy_indices[1]])  
        yy_diag = kernel_func(y, y, mask_y, mask_y)  
    else:
        yy_distances = kernel_func(y[yy_indices[0]], y[yy_indices[1]])  
        yy_diag = kernel_func(y, y)  
    
    yy_mean = (yy_distances.sum() * 2 + yy_diag.sum()) / (n_y * n_y)

    if mask_x is not None and mask_y is not None:
        xy_distances = kernel_func(x[:, None], y[None, :], mask_x[:, None], mask_y[None, :])  # [n_x, n_y]
        xy_mean = xy_distances.mean()
    else:
        xy_distances = kernel_func(x[:, None], y[None, :])  # [n_x, n_y]
        xy_mean = xy_distances.mean()

    mmd = xx_mean + yy_mean - 2 * xy_mean
    return mmd

def chamfer_kernel(x, y, mask_x=None, mask_y=None, sigma=1.0):
    """
    Args:
        x: shape [batch_size, n_points, 3]
        y: shape [batch_size, n_points, 3]
    """

    dist1 = torch.min(torch.cdist(x, y), dim=-1)[0]  # [batch, n_points]
    dist2 = torch.min(torch.cdist(y, x), dim=-1)[0]  # [batch, n_points]

    if mask_x is not None and mask_y is not None:
        masked_min_dist1 = dist1 * mask_x
        masked_min_dist2 = dist2 * mask_y
        chamfer_dist = (masked_min_dist1.sum(dim=-1) / mask_x.sum(dim=-1) + masked_min_dist2.sum(dim=-1) / mask_y.sum(dim=-1)) / 2.0  # [batch]
    else:
        chamfer_dist = (dist1.mean(dim=-1) + dist2.mean(dim=-1)) / 2.0  # [batch]

    return torch.exp(-chamfer_dist / (2 * sigma * sigma))

def hausdorff_kernel(x, y, mask_x=None, mask_y=None, sigma=1.0):
    """
    Args:
        x: shape [batch_size, n_points, 3]
        y: shape [batch_size, n_points, 3]
    """
    dist1 = torch.min(torch.cdist(x, y), dim=-1)[0]  # [batch, n_points]
    dist2 = torch.min(torch.cdist(y, x), dim=-1)[0]  # [batch, n_points]
    if mask_x is not None and mask_y is not None:
        masked_dist1 = dist1 * mask_x
        masked_dist2 = dist2 * mask_y
        hausdorff_dist = torch.max(torch.max(masked_dist1, dim=-1)[0], torch.max(masked_dist2, dim=-1)[0])  # [batch]
    else:
        hausdorff_dist = torch.max(torch.max(dist1, dim=-1)[0], torch.max(dist2, dim=-1)[0])  # [batch]

    return torch.exp(-hausdorff_dist / (2 * sigma * sigma))

def stupid_kernel(x,y, mask_x=None, mask_y=None, sigma=10000):
    # weak baseline perhaps / exhibiting importance of choosing kernel carefully

    if mask_x is not None and mask_y is not None:
        # print(x.shape, y.shape, mask_x.shape, mask_y.shape)
        mean_x = torch.sum(x, dim=-2) / torch.sum(mask_x, dim=-1).unsqueeze(-1)
        mean_y = torch.sum(y, dim=-2) / torch.sum(mask_y, dim=-1).unsqueeze(-1)
        if len(x.shape) == 3:
            cov_x = torch.einsum('bnd,bme->bde', x, x) / torch.sum(mask_x, dim=-1).unsqueeze(-1).unsqueeze(-1)
            cov_y = torch.einsum('bnd,bme->bde', y, y) / torch.sum(mask_y, dim=-1).unsqueeze(-1).unsqueeze(-1)
        else:
            cov_x = torch.einsum('ijnd,ijme->ijde', x, x) / torch.sum(mask_x, dim=-1).unsqueeze(-1).unsqueeze(-1)
            cov_y = torch.einsum('ijnd,ijme->ijde', y, y) / torch.sum(mask_y, dim=-1).unsqueeze(-1).unsqueeze(-1)
    else:
        mean_x, mean_y = torch.mean(x, dim=-2), torch.mean(y, dim=-2) # B x 3
        if len(x.shape) == 3:
            cov_x = torch.einsum('bnd,bme->bde', x, x) / x.shape[1] # B x 3 x 3
            cov_y = torch.einsum('bnd,bme->bde', y, y) / y.shape[1] # B x 3 x 3
        else:
            cov_x = torch.einsum('ijnd,ijme->ijde', x, x) / x.shape[2] 
            cov_y = torch.einsum('ijnd,ijme->ijde', y, y) / y.shape[2]
    # print(mean_x.shape, cov_x.shape)
    # print(mean_y.shape, cov_y.shape)

    if len(x.shape) == 3:
        try:
            embedding_x = torch.cat((mean_x, cov_x.view(x.shape[0], -1)), dim=-1) # B x 12
            embedding_y = torch.cat((mean_y, cov_y.view(y.shape[0], -1)), dim=-1) # B x 12
        except Exception as e:
            print(e)
            breakpoint()
    else:
        embedding_x = torch.cat((mean_x, cov_x.view(x.shape[0],x.shape[1], -1)), dim=-1) # B x 1 x 12
        embedding_y = torch.cat((mean_y, cov_y.view(y.shape[0],y.shape[1], -1)), dim=-1) # 1 x B x 12

    # return torch.sum(embeddingA * embeddingB)
    dist = torch.norm(embedding_x - embedding_y, dim=-1)
    #print(dist.shape)
    kernel_val = torch.exp(-1 * (dist**2) / sigma) # B
    return kernel_val

# def stupid_kernel_buggy(A, B, mask_x=None, mask_y=None, sigma=1000):
#     # weak baseline perhaps / exhibiting importance of choosing kernel carefully

#     # A is batch x N x 3 point cloud
#     # B is batch x N x 3 point cloud

#     if len(A.shape) > 3:
#         A = A.squeeze()
#         B = B.squeeze() # hacky fix

#     batchA = A.shape[0]
#     batchB = B.shape[0]

#     num_A = torch.ones(batchA) * A.shape[1]
#     num_B = torch.ones(batchB) * B.shape[1]

#     if mask_x is not None:
#         # full_mask_x = mask_x.unsqueeze(-1).repeat(1,1,3)
#         # full_mask_y = mask_y.unsqueeze(-1).repeat(1,1,3)
#         # A[full_mask_x] = 0 # this is wrong !!
#         # B[full_mask_y] = 0
#         try:
#             num_A = torch.sum(mask_x.float(), -1)
#             num_B = torch.sum(mask_y.float(), -1)
#         except Exception as e:
#             print(e)
#             breakpoint()
#     meanA, meanB = torch.mean(A, 1), torch.mean(B, 1) # batch x 3 for each

#     covA = torch.einsum('bnd,bme->bde', A, A) / num_A.view(-1, 1, 1) # batch x 3 x 3
#     covB = torch.einsum('bnd,bme->bde', B, B) / num_B.view(-1, 1, 1) # batch x 3 x 3

#     embeddingA = torch.cat((meanA, covA.view(batchA, -1)), 1)
#     embeddingB = torch.cat((meanB, covB.view(batchB, -1)), 1)

#     # return torch.sum(embeddingA * embeddingB)
#     dist = torch.sqrt(torch.sum(torch.abs(embeddingA - embeddingB)**2, -1))
#     kernel_val = torch.exp(-1 * (dist**2) / sigma)
#     return kernel_val

# def stupid_kernel_unpar(A,B):
#     # weak baseline perhaps / exhibiting importance of choosing kernel carefully

#     if mask_x is not None and mask_y is not None:
#         # print(x.shape, y.shape, mask_x.shape, mask_y.shape)
#         mean_x = torch.sum(x, dim=-2) / torch.sum(mask_x, dim=-1).unsqueeze(-1)
#         mean_y = torch.sum(y, dim=-2) / torch.sum(mask_y, dim=-1).unsqueeze(-1)
#         if len(x.shape) == 3:
#             cov_x = torch.einsum('bnd,bme->bde', x, x) / torch.sum(mask_x, dim=-1).unsqueeze(-1).unsqueeze(-1)
#             cov_y = torch.einsum('bnd,bme->bde', y, y) / torch.sum(mask_y, dim=-1).unsqueeze(-1).unsqueeze(-1)
#         else:
#             cov_x = torch.einsum('ijnd,ijme->ijde', x, x) / torch.sum(mask_x, dim=-1).unsqueeze(-1).unsqueeze(-1)
#             cov_y = torch.einsum('ijnd,ijme->ijde', y, y) / torch.sum(mask_y, dim=-1).unsqueeze(-1).unsqueeze(-1)
#     else:
#         mean_x, mean_y = torch.mean(x, dim=-2), torch.mean(y, dim=-2) # B x 3
#         if len(x.shape) == 3:
#             cov_x = torch.einsum('bnd,bme->bde', x, x) / x.shape[1] # B x 3 x 3
#             cov_y = torch.einsum('bnd,bme->bde', y, y) / y.shape[1] # B x 3 x 3
#         else:
#             cov_x = torch.einsum('ijnd,ijme->ijde', x, x) / x.shape[2] 
#             cov_y = torch.einsum('ijnd,ijme->ijde', y, y) / y.shape[2]
#     # print(mean_x.shape, cov_x.shape)
#     # print(mean_y.shape, cov_y.shape)

#     if len(x.shape) == 3:
#         embedding_x = torch.cat((mean_x, cov_x.view(x.shape[0], -1)), dim=-1) # B x 12
#         embedding_y = torch.cat((mean_y, cov_y.view(y.shape[0], -1)), dim=-1) # B x 12
#     else:
#         embedding_x = torch.cat((mean_x, cov_x.view(x.shape[0],x.shape[1], -1)), dim=-1) # B x 1 x 12
#         embedding_y = torch.cat((mean_y, cov_y.view(y.shape[0],y.shape[1], -1)), dim=-1) # 1 x B x 12

#     # return torch.sum(embeddingA * embeddingB)
#     dist = torch.norm(embedding_x - embedding_y, dim=-1)
#     #print(dist.shape)
#     kernel_val = torch.exp(-1 * (dist**2) / 1000) # B
#     return kernel_val


# def stupid_kernel(A,B):
#     # weak baseline perhaps / exhibiting importance of choosing kernel carefully

#     # A is N x 3 point cloud
#     # B is M x 3 point cloud

#     meanA, meanB = torch.mean(A, 0), torch.mean(B, 0)
#     covA = torch.einsum('nd,me->de', A, A) / len(A)
#     covB = torch.einsum('nd,me->de', B, B) / len(B)

#     embeddingA = torch.cat((meanA, covA.view(-1)))
#     embeddingB = torch.cat((meanB, covB.view(-1)))

#     # return torch.sum(embeddingA * embeddingB)
#     dist = torch.norm(embeddingA - embeddingB)
#     kernel_val = torch.exp(-1 * (dist**2) / 1000)
#     return kernel_val

#     # arbitrary distance between dumb descriptors!
#     mn = torch.norm(meanA - meanB)
#     cv = torch.norm(covA-covB, p='fro')

#     # print(f'mn {mn} cv {cv}')
#     dist = torch.exp((mn**2 + cv**2) / 10000)

#     if dist < 0:
#         breakpoint()

#     return dist


# def hausdorff_distance(A, B):
#     """
#     Compute the bidirectional Hausdorff distance between two point clouds A and B.
    
#     Parameters:
#         A (ndarray): First point cloud, shape (N, D).
#         B (ndarray): Second point cloud, shape (M, D).

#     Returns:
#         float: The Hausdorff distance.
#     """
#     d_AB = directed_hausdorff(A, B)[0]  # A → B
#     d_BA = directed_hausdorff(B, A)[0]  # B → A
#     if d_AB <0 or d_BA < 0:
#         breakpoint()
#     return max(d_AB, d_BA)

# def chamfer_kernel(x, y, sigma=1.0):
#     """
#     Args:
#         x: shape [n_points, 3] 
#         y: shape [n_points, 3] 
#     """
#     def chamfer_distance(p1, p2):
#         p1_np = p1.detach().cpu().numpy()
#         p2_np = p2.detach().cpu().numpy()
        
#         tree1 = cKDTree(p1_np)
#         tree2 = cKDTree(p2_np)
        
#         dist1, _ = tree1.query(p2_np)
#         dist2, _ = tree2.query(p1_np)
        
#         mean_dist = (np.mean(dist1) + np.mean(dist2)) / 2.0
        
#         return torch.tensor(mean_dist, device=p1.device)
    
#     dist = chamfer_distance(x, y)
#     return torch.exp(-dist / (2 * sigma * sigma))

# def hausdorff_kernel(x, y, sigma=1.0):
#     """
#     Args:
#         x: shape [n_points, 3] 
#         y: shape [n_points, 3] 
#     """
#     def hausdorff_distance(p1, p2):
#         p1_np = p1.detach().cpu().numpy()
#         p2_np = p2.detach().cpu().numpy()
        
#         tree1 = cKDTree(p1_np)
#         tree2 = cKDTree(p2_np)
        
#         dist1, _ = tree1.query(p2_np)
#         dist2, _ = tree2.query(p1_np)
        
#         h1 = np.max(dist1)
#         h2 = np.max(dist2)
#         max_dist = max(h1, h2)
        
#         return torch.tensor(max_dist, device=p1.device)
    
#     dist = hausdorff_distance(x, y)
#     return torch.exp(-dist / (2 * sigma * sigma))

# #would take a long time to compute
# def emd_kernel(x, y, sigma=1.0):
#     def earth_mover_distance(p1, p2):
#         loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)
#         return loss(p1, p2)
    
#     dist = earth_mover_distance(x, y)
#     return torch.exp(-dist / (2 * sigma * sigma))

# #would take a long time to compute
# def gromov_wasserstein_kernel(x, y, sigma=1.0):

#     def gromov_wasserstein_distance(p1, p2):
#         C1 = torch.cdist(p1, p1)
#         C2 = torch.cdist(p2, p2)
        
    
#         p = torch.ones(len(p1)) / len(p1)
#         q = torch.ones(len(p2)) / len(p2)
        
#         gw_dist = ot.gromov.gromov_wasserstein2(
#             C1.numpy(), C2.numpy(), 
#             p.numpy(), q.numpy(),
#             'square_loss', verbose=False
#         )
#         return torch.tensor(gw_dist)
    
    # dist = gromov_wasserstein_distance(x, y)
    # return torch.exp(-dist / (2 * sigma * sigma))
