import torch
import numpy as np
from datasets import *
from torch_geometric.loader import DataLoader
from scipy.spatial.distance import cdist
from geomloss import SamplesLoss
import ot
# from pytorch3d.ops import knn_points
from scipy.spatial import cKDTree
import utils
import matplotlib.pyplot as plt
import numpy as np
import time

# 1, further accelerate the computation of mmd?
# 2, pytorch3d.ops.knn_points is not working? would this accelerate the computation of mmd?
# 3, mmd sometimes get negative even samples are ro, why? may be difference between U-test and V-test
def compute_mmd(x, y, kernel_func):
    n_x, n_y = len(x), len(y)
    
    x = torch.stack(x, dim=0)
    y = torch.stack(y, dim=0)
    xx_indices = torch.triu_indices(n_x, n_x, offset=1).to(torch.long)
    xx_distances = kernel_func(x[xx_indices[0]], x[xx_indices[1]])  # [n_x*(n_x-1)/2]
    
    xx_diag = kernel_func(x, x)  # [n_x]
    

    xx_mean = (xx_distances.sum() * 2 + xx_diag.sum()) / (n_x * n_x)

    yy_indices = torch.triu_indices(n_y, n_y, offset=1).to(torch.long)
    yy_distances = kernel_func(y[yy_indices[0]], y[yy_indices[1]])  
    
    yy_diag = kernel_func(y, y)  
    
    yy_mean = (yy_distances.sum() * 2 + yy_diag.sum()) / (n_y * n_y)

    xy_distances = kernel_func(x[:, None], y[None, :])  # [n_x, n_y]
    xy_mean = xy_distances.mean()

    mmd = xx_mean + yy_mean - 2 * xy_mean
    return mmd

def chamfer_kernel(x, y, sigma=1.0):
    """
    Args:
        x: shape [batch_size, n_points, 3]
        y: shape [batch_size, n_points, 3]
    """

    dist1 = torch.min(torch.cdist(x, y), dim=-1)[0]  # [batch, n_points]
    dist2 = torch.min(torch.cdist(y, x), dim=-1)[0]  # [batch, n_points]

    chamfer_dist = (dist1.mean(dim=-1) + dist2.mean(dim=-1)) / 2.0  # [batch]

    return torch.exp(-chamfer_dist / (2 * sigma * sigma))

def hausdorff_kernel(x, y, sigma=1.0):
    """
    Args:
        x: shape [batch_size, n_points, 3]
        y: shape [batch_size, n_points, 3]
    """
    dist1 = torch.min(torch.cdist(x, y), dim=-1)[0]  # [batch, n_points]
    dist2 = torch.min(torch.cdist(y, x), dim=-1)[0]  # [batch, n_points]
    
    hausdorff_dist = torch.max(torch.max(dist1, dim=-1)[0], torch.max(dist2, dim=-1)[0])  # [batch]

    return torch.exp(-hausdorff_dist / (2 * sigma * sigma))


#would take a long time to compute
def emd_kernel(x, y, sigma=1.0):
    def earth_mover_distance(p1, p2):
        loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)
        return loss(p1, p2)
    
    dist = earth_mover_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))

#would take a long time to compute
def gromov_wasserstein_kernel(x, y, sigma=1.0):

    def gromov_wasserstein_distance(p1, p2):
        C1 = torch.cdist(p1, p1)
        C2 = torch.cdist(p2, p2)
        
    
        p = torch.ones(len(p1)) / len(p1)
        q = torch.ones(len(p2)) / len(p2)
        
        gw_dist = ot.gromov.gromov_wasserstein2(
            C1.numpy(), C2.numpy(), 
            p.numpy(), q.numpy(),
            'square_loss', verbose=False
        )
        return torch.tensor(gw_dist)
    
    dist = gromov_wasserstein_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))




results = {}
kernels = {
    'chamfer': chamfer_kernel,
    'hausdorff': hausdorff_kernel,
    # 'emd': emd_kernel
    # 'gromov_wasserstein': gromov_wasserstein_kernel
}

Sample_batches = 100
sample_sizes = 1000
transformed_class_balance_list = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

MD17dir = 'data/md17/revised aspirin'
name = 'revised aspirin'
base_dataset = MD17(root=MD17dir, name=name)
print(name, len(base_dataset))
chamfer_average_mmd_list = []
chamfer_var_mmd_list = []
hausdorff_average_mmd_list = []
hausdorff_var_mmd_list = []


for transformed_class_balance in transformed_class_balance_list:
    fixed_quaternions = utils.generate_many_random_quaternions(len(base_dataset))
    rotation_operator = (lambda data, idx: utils.memoized_rotation(data, idx, fixed_quaternions))
    mmd_results = {kernel_name: [] for kernel_name in kernels.keys()}
    transform_list_0 = [(rotation_operator, transformed_class_balance)]
    dataset_0 = MetaAugmentedDataset(base_dataset, transform_list_0, label_operator=None)

    transform_list_1 = [(rotation_operator, 1.0)]
    dataset_1 = MetaAugmentedDataset(base_dataset, transform_list_1, label_operator=None)

    loader_0 = DataLoader(dataset_0, batch_size=1, shuffle=False)
    loader_1 = DataLoader(dataset_1, batch_size=1, shuffle=False)
    iter_0 = iter(loader_0)
    iter_1 = iter(loader_1)

    for i in range(Sample_batches):
        point_clouds_1 = []
        point_clouds_2 = []
        
        # if i > 0:
        #     for _ in range(i * sample_sizes):
        #         next(iter_0)
        #         next(iter_1)
        for j in range(sample_sizes):
            try:
                batch_0 = next(iter_0)
                batch_1 = next(iter_1)
                
                batch_0_data = batch_0[0]  
                batch_1_data = batch_1[0]  
                
                point_clouds_1.append(batch_0_data.pos)  
                point_clouds_2.append(batch_1_data.pos)  
            except StopIteration:
                break

            # print(len(point_clouds_1), len(point_clouds_2))
            # print(point_clouds_1[0].shape, point_clouds_2[0].shape)
            # print(point_clouds_1[-1].shape, point_clouds_2[-1].shape)

        if len(point_clouds_1) > 0 and len(point_clouds_2) > 0:
            for kernel_name, kernel_func in kernels.items():
                # start_time = time.time()
                mmd = compute_mmd(point_clouds_1, point_clouds_2, kernel_func)
                # end_time = time.time()
                # print(f'Batch {i+1}, {kernel_name}: mmd: {mmd},Sample size: {sample_sizes}, Time: {end_time - start_time} seconds')
                # mmd_results[kernel_name].append(mmd)
                mmd_results[kernel_name].append(mmd)

    for kernel_name, mmd_list in mmd_results.items():
        average_mmd = sum(mmd_list) / len(mmd_list)
        var_mmd = torch.var(torch.tensor(mmd_list))
        print(f'{name},Transformed class balance: {transformed_class_balance}, Average {kernel_name} MMD over {Sample_batches} batches: {average_mmd}, Variance {kernel_name} MMD: {var_mmd}')
        if kernel_name == 'chamfer':
            chamfer_average_mmd_list.append(average_mmd)
            chamfer_var_mmd_list.append(var_mmd)
        elif kernel_name == 'hausdorff':
            hausdorff_average_mmd_list.append(average_mmd)
            hausdorff_var_mmd_list.append(var_mmd)



plt.figure(figsize=(10, 6))
chamfer_std_mmd_list = np.sqrt(chamfer_var_mmd_list)
hausdorff_std_mmd_list = np.sqrt(hausdorff_var_mmd_list)
plt.fill_between(transformed_class_balance_list, np.array(chamfer_average_mmd_list) - chamfer_std_mmd_list, np.array(chamfer_average_mmd_list) + chamfer_std_mmd_list, color='blue', alpha=0.2)
plt.fill_between(transformed_class_balance_list, np.array(hausdorff_average_mmd_list) - hausdorff_std_mmd_list, np.array(hausdorff_average_mmd_list) + hausdorff_std_mmd_list, color='orange', alpha=0.2)
plt.plot(transformed_class_balance_list, chamfer_average_mmd_list, label='Chamfer kernel', color='blue', marker='o')
plt.plot(transformed_class_balance_list, hausdorff_average_mmd_list, label='Hausdorff kernel', color='orange', marker='o')

plt.title('MMD vs Transformed class balance')
plt.xlabel('Transformed class balance')
plt.ylabel('MMD')
plt.legend()
plt.savefig(f'{name}_mmd_vs_transformed_class_balance_100_1000.png')
plt.show()




 
