import torch
import numpy as np
from datasets import *
from torch_geometric.loader import DataLoader
from scipy.spatial.distance import cdist
from geomloss import SamplesLoss
import ot
# from pytorch3d.ops import knn_points
from scipy.spatial import cKDTree
import utils
import matplotlib.pyplot as plt
import numpy as np

# 1, further accelerate the computation of mmd?
# 2, pytorch3d.ops.knn_points is not working? would this accelerate the computation of mmd?
# 3, mmd sometimes get negative even samples are ro, why? may be difference between U-test and V-test



def compute_mmd(x, y, kernel_func):
    """
    Args:
        x: shape [n_samples_1, n_points, 3]
        y: shape [n_samples_2, n_points, 3] 
        kernel_func: kernel function
    """
    n_x = len(x) 
    n_y = len(y) 
    
    
    xx_sum_tri = 0.0
    for i in range(n_x):
        for j in range(i+1, n_x):  
            xx_sum_tri += kernel_func(x[i], x[j])
    

    xx_sum_diag = 0.0
    for i in range(n_x):
        xx_sum_diag += kernel_func(x[i], x[i])
    
    xx_mean = (xx_sum_tri * 2 + xx_sum_diag) / (n_x * n_x)


    yy_sum_tri = 0.0
    for i in range(n_y):
        for j in range(i+1, n_y):
            yy_sum_tri += kernel_func(y[i], y[j])
    
    yy_sum_diag = 0.0
    for i in range(n_y):
        yy_sum_diag += kernel_func(y[i], y[i])
    
    yy_mean = (yy_sum_tri * 2 + yy_sum_diag) / (n_y * n_y)
    
   
    xy_sum = 0.0
    for i in range(n_x):
        for j in range(n_y):
            # print(x[i].shape, y[j].shape)
            xy_sum += kernel_func(x[i], y[j])
    xy_mean = xy_sum / (n_x * n_y)

    mmd = max(xx_mean + yy_mean - 2 * xy_mean, 0)

    
    return mmd

def chamfer_kernel(x, y, sigma=1.0):
    """
    Args:
        x: shape [n_points, 3] 
        y: shape [n_points, 3] 
    """
    def chamfer_distance(p1, p2):
        p1_np = p1.detach().cpu().numpy()
        p2_np = p2.detach().cpu().numpy()
        
        tree1 = cKDTree(p1_np)
        tree2 = cKDTree(p2_np)
        
        dist1, _ = tree1.query(p2_np)
        dist2, _ = tree2.query(p1_np)
        
        mean_dist = (np.mean(dist1) + np.mean(dist2)) / 2.0
        
        return torch.tensor(mean_dist, device=p1.device)
    
    dist = chamfer_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))

def hausdorff_kernel(x, y, sigma=1.0):
    """
    Args:
        x: shape [n_points, 3] 
        y: shape [n_points, 3] 
    """
    def hausdorff_distance(p1, p2):
        p1_np = p1.detach().cpu().numpy()
        p2_np = p2.detach().cpu().numpy()
        
        tree1 = cKDTree(p1_np)
        tree2 = cKDTree(p2_np)
        
        dist1, _ = tree1.query(p2_np)
        dist2, _ = tree2.query(p1_np)
        
        h1 = np.max(dist1)
        h2 = np.max(dist2)
        max_dist = max(h1, h2)
        
        return torch.tensor(max_dist, device=p1.device)
    
    dist = hausdorff_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))

#would take a long time to compute
def emd_kernel(x, y, sigma=1.0):
    def earth_mover_distance(p1, p2):
        loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)
        return loss(p1, p2)
    
    dist = earth_mover_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))

#would take a long time to compute
def gromov_wasserstein_kernel(x, y, sigma=1.0):

    def gromov_wasserstein_distance(p1, p2):
        C1 = torch.cdist(p1, p1)
        C2 = torch.cdist(p2, p2)
        
    
        p = torch.ones(len(p1)) / len(p1)
        q = torch.ones(len(p2)) / len(p2)
        
        gw_dist = ot.gromov.gromov_wasserstein2(
            C1.numpy(), C2.numpy(), 
            p.numpy(), q.numpy(),
            'square_loss', verbose=False
        )
        return torch.tensor(gw_dist)
    
    dist = gromov_wasserstein_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))





results = {}
kernels = {
    'chamfer': chamfer_kernel,
    'hausdorff': hausdorff_kernel,
    # 'emd': emd_kernel
    # 'gromov_wasserstein': gromov_wasserstein_kernel
}


Sample_batches = 1000
sample_sizes = 100
transformed_class_balance_list = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
qm9dir = 'data/qm9'
# MD17dir = 'data/md17'
base_dataset = QM9(root=qm9dir)
# base_dataset = MD17(root=MD17dir)


chamfer_average_mmd_list = []
chamfer_var_mmd_list = []
hausdorff_average_mmd_list = []
hausdorff_var_mmd_list = []


for transformed_class_balance in transformed_class_balance_list:
    fixed_quaternions = utils.generate_many_random_quaternions(len(base_dataset))
    rotation_operator = (lambda data, idx: utils.memoized_rotation(data, idx, fixed_quaternions))
    mmd_results = {kernel_name: [] for kernel_name in kernels.keys()}
    transform_list_0 = [(rotation_operator, transformed_class_balance)]
    dataset_0 = MetaAugmentedDataset(base_dataset, transform_list_0, label_operator=None)

    transform_list_1 = [(rotation_operator, 1.0)]
    dataset_1 = MetaAugmentedDataset(base_dataset, transform_list_1, label_operator=None)

    loader_0 = DataLoader(dataset_0, batch_size=1, shuffle=False)
    loader_1 = DataLoader(dataset_1, batch_size=1, shuffle=False)

    iter_0 = iter(loader_0)
    iter_1 = iter(loader_1)

    for i in range(Sample_batches):
        point_clouds_1 = []
        point_clouds_2 = []
        
        # if i > 0:
        #     for _ in range(i * sample_sizes):
        #         next(iter_0)
        #         next(iter_1)
        
        for j in range(sample_sizes):
            try:
                batch_0 = next(iter_0)
                batch_1 = next(iter_1)
                
                batch_0_data = batch_0[0]  
                batch_1_data = batch_1[0]  
                
                point_clouds_1.append(batch_0_data.pos)  
                point_clouds_2.append(batch_1_data.pos)  
            except StopIteration:
                break
        # print(len(point_clouds_1), len(point_clouds_2))
        # print(point_clouds_1[0].shape, point_clouds_2[0].shape)
        # print(point_clouds_1[-1].shape, point_clouds_2[-1].shape)

        if len(point_clouds_1) > 0 and len(point_clouds_2) > 0:
            for kernel_name, kernel_func in kernels.items():
                mmd = compute_mmd(point_clouds_1, point_clouds_2, kernel_func)
                # print(f'Batch {i+1}, {kernel_name}: {mmd}')
                mmd_results[kernel_name].append(mmd)

    for kernel_name, mmd_list in mmd_results.items():
        average_mmd = sum(mmd_list) / len(mmd_list)
        var_mmd = torch.var(torch.tensor(mmd_list))
        if kernel_name == 'chamfer':
            chamfer_average_mmd_list.append(average_mmd)
            chamfer_var_mmd_list.append(var_mmd)
        elif kernel_name == 'hausdorff':
            hausdorff_average_mmd_list.append(average_mmd)
            hausdorff_var_mmd_list.append(var_mmd)       
        print(f'Transformed class balance: {transformed_class_balance}, Average {kernel_name} MMD over {Sample_batches} batches: {average_mmd}, Variance {kernel_name} MMD: {var_mmd}')
        

name = 'QM9'
plt.figure(figsize=(10, 6))
chamfer_std_mmd_list = np.sqrt(chamfer_var_mmd_list)
hausdorff_std_mmd_list = np.sqrt(hausdorff_var_mmd_list)
plt.fill_between(transformed_class_balance_list, np.array(chamfer_average_mmd_list) - chamfer_std_mmd_list, np.array(chamfer_average_mmd_list) + chamfer_std_mmd_list, color='blue', alpha=0.2)
plt.fill_between(transformed_class_balance_list, np.array(hausdorff_average_mmd_list) - hausdorff_std_mmd_list, np.array(hausdorff_average_mmd_list) + hausdorff_std_mmd_list, color='orange', alpha=0.2)
plt.plot(transformed_class_balance_list, chamfer_average_mmd_list, label='Chamfer kernel', color='blue', marker='o')
plt.plot(transformed_class_balance_list, hausdorff_average_mmd_list, label='Hausdorff kernel', color='orange', marker='o')

plt.title('MMD vs Transformed class balance')
plt.xlabel('Transformed class balance')
plt.ylabel('MMD')
plt.legend()
plt.savefig(f'{name}_mmd_vs_transformed_class_balance.png')
plt.show()







    # # Example usage
    # for pc1, pc2 in zip(point_clouds_1, point_clouds_2):
    #     if detect_rotation_relation(pc1, pc2):
    #         print("Rotation relation detected between the point cloud samples.")
    #     else:
    #         print("No rotation relation detected between the point cloud samples.")

