import torch
import numpy as np
from datasets import *
from torch_geometric.loader import DataLoader
from scipy.spatial.distance import cdist
from geomloss import SamplesLoss
import ot
# from pytorch3d.ops import knn_points
from scipy.spatial import cKDTree
import utils
import matplotlib.pyplot as plt
import numpy as np
import yaml
from pathlib import Path

# 1, further accelerate the computation of mmd?
# 2, pytorch3d.ops.knn_points is not working? would this accelerate the computation of mmd?
# 3, mmd sometimes get negative even samples are ro, why? may be difference between U-test and V-test

def compute_mmd(x, y, kernel_func):
    n_x, n_y = len(x), len(y)
    
    x = torch.stack(x, dim=0)
    y = torch.stack(y, dim=0)
    xx_indices = torch.triu_indices(n_x, n_x, offset=1).to(torch.long)
    xx_distances = kernel_func(x[xx_indices[0]], x[xx_indices[1]])  # [n_x*(n_x-1)/2]
    
    xx_diag = kernel_func(x, x)  # [n_x]
    

    xx_mean = (xx_distances.sum() * 2 + xx_diag.sum()) / (n_x * n_x)

    yy_indices = torch.triu_indices(n_y, n_y, offset=1).to(torch.long)
    yy_distances = kernel_func(y[yy_indices[0]], y[yy_indices[1]])  
    
    yy_diag = kernel_func(y, y)  
    
    yy_mean = (yy_distances.sum() * 2 + yy_diag.sum()) / (n_y * n_y)

    xy_distances = kernel_func(x[:, None], y[None, :])  # [n_x, n_y]
    xy_mean = xy_distances.mean()

    mmd = xx_mean + yy_mean - 2 * xy_mean
    return mmd

def chamfer_kernel(x, y, sigma=1.0):
    """
    Args:
        x: shape [batch_size, n_points, 3]
        y: shape [batch_size, n_points, 3]
    """

    dist1 = torch.min(torch.cdist(x, y), dim=-1)[0]  # [batch, n_points]
    dist2 = torch.min(torch.cdist(y, x), dim=-1)[0]  # [batch, n_points]

    chamfer_dist = (dist1.mean(dim=-1) + dist2.mean(dim=-1)) / 2.0  # [batch]

    return torch.exp(-chamfer_dist / (2 * sigma * sigma))

def hausdorff_kernel(x, y, sigma=1.0):
    """
    Args:
        x: shape [batch_size, n_points, 3]
        y: shape [batch_size, n_points, 3]
    """
    dist1 = torch.min(torch.cdist(x, y), dim=-1)[0]  # [batch, n_points]
    dist2 = torch.min(torch.cdist(y, x), dim=-1)[0]  # [batch, n_points]
    
    hausdorff_dist = torch.max(torch.max(dist1, dim=-1)[0], torch.max(dist2, dim=-1)[0])  # [batch]

    return torch.exp(-hausdorff_dist / (2 * sigma * sigma))


#would take a long time to compute
def emd_kernel(x, y, sigma=1.0):
    def earth_mover_distance(p1, p2):
        loss = SamplesLoss(loss="sinkhorn", p=2, blur=0.01)
        return loss(p1, p2)
    
    dist = earth_mover_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))

#would take a long time to compute
def gromov_wasserstein_kernel(x, y, sigma=1.0):

    def gromov_wasserstein_distance(p1, p2):
        C1 = torch.cdist(p1, p1)
        C2 = torch.cdist(p2, p2)
        
    
        p = torch.ones(len(p1)) / len(p1)
        q = torch.ones(len(p2)) / len(p2)
        
        gw_dist = ot.gromov.gromov_wasserstein2(
            C1.numpy(), C2.numpy(), 
            p.numpy(), q.numpy(),
            'square_loss', verbose=False
        )
        return torch.tensor(gw_dist)
    
    dist = gromov_wasserstein_distance(x, y)
    return torch.exp(-dist / (2 * sigma * sigma))




results = {}
kernels = {
    'chamfer': chamfer_kernel,
    'hausdorff': hausdorff_kernel,
    # 'emd': emd_kernel
    # 'gromov_wasserstein': gromov_wasserstein_kernel
}


# Sample_batches = 1000
# sample_sizes = 100
transformed_class_balance_list = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def load_all_batches(config):
    batches = []
    for i in range(1, 11):
        batch_key = f'batch_indices{i}'
        if batch_key not in config['split_args']:
            continue
            
        indices = np.loadtxt(config['split_args'][batch_key], delimiter=',', dtype=int)

        subset_size = config['split_args']['subset_size']
        if subset_size:
            indices = indices[:subset_size]
            
        batches.append(indices)
    
    return batches


config_path = 'configs/dataset/md17_mmd_test.yaml'
config = load_config(config_path)


MD17dir = f"data/{config['directory_name']}/{config['mol_name']}"
name = config['mol_name']
base_dataset = MD17(root=MD17dir, name=name)
print(name, len(base_dataset))


all_batches = load_all_batches(config)



avg_chamfer_mmd_list = []
avg_hausdorff_mmd_list = []
var_chamfer_mmd_list = []
var_hausdorff_mmd_list = []

for transformed_class_balance in transformed_class_balance_list:
    batch_results = {'chamfer': [],
                     'hausdorff': []}
    for batch_idx, indices in enumerate(all_batches):
        # print(f"processing {batch_idx + 1}th batch")
        current_dataset = torch.utils.data.Subset(base_dataset, indices)
        # print(len(current_dataset))
        fixed_quaternions = utils.generate_many_random_quaternions(len(current_dataset))
        rotation_operator = (lambda data, idx: utils.memoized_rotation(data, idx, fixed_quaternions))
        
        transform_list_0 = [(rotation_operator, transformed_class_balance)]
        dataset_0 = MetaAugmentedDataset(current_dataset, transform_list_0, label_operator=None)

        transform_list_1 = [(rotation_operator, 1.0)]
        dataset_1 = MetaAugmentedDataset(current_dataset, transform_list_1, label_operator=None)

        loader_0 = DataLoader(dataset_0, batch_size=1, shuffle=False)
        loader_1 = DataLoader(dataset_1, batch_size=1, shuffle=False)
        
        point_clouds_1 = []
        point_clouds_2 = []
        
        for data_0, data_1 in zip(loader_0, loader_1):
            point_clouds_1.append(data_0[0].pos)
            point_clouds_2.append(data_1[0].pos)
            
            if len(point_clouds_1) >= len(current_dataset):
                break
        

        for kernel_name, kernel_func in kernels.items():
            mmd = compute_mmd(point_clouds_1, point_clouds_2, kernel_func)
            batch_results[kernel_name].append(mmd)
    avg_chamfer_mmd_list.append(np.mean(batch_results['chamfer']))
    avg_hausdorff_mmd_list.append(np.mean(batch_results['hausdorff']))
    var_chamfer_mmd_list.append(np.var(batch_results['chamfer']))
    var_hausdorff_mmd_list.append(np.var(batch_results['hausdorff']))
    print(f"name {name}, Transformed class balance {transformed_class_balance}:, chamfer mmd: {avg_chamfer_mmd_list[-1]}, hausdorff mmd: {avg_hausdorff_mmd_list[-1]}, chamfer var: {var_chamfer_mmd_list[-1]}, hausdorff var: {var_hausdorff_mmd_list[-1]}")

plt.figure(figsize=(10, 6))
chamfer_std_mmd_list = np.sqrt(var_chamfer_mmd_list)
hausdorff_std_mmd_list = np.sqrt(var_hausdorff_mmd_list)
plt.fill_between(transformed_class_balance_list, np.array(avg_chamfer_mmd_list) - chamfer_std_mmd_list, np.array(avg_chamfer_mmd_list) + chamfer_std_mmd_list, color='blue', alpha=0.2)
plt.fill_between(transformed_class_balance_list, np.array(avg_hausdorff_mmd_list) - hausdorff_std_mmd_list, np.array(avg_hausdorff_mmd_list) + hausdorff_std_mmd_list, color='orange', alpha=0.2)
plt.plot(transformed_class_balance_list, avg_chamfer_mmd_list, label='Chamfer kernel', color='blue', marker='o')
plt.plot(transformed_class_balance_list, avg_hausdorff_mmd_list, label='Hausdorff kernel', color='orange', marker='o')

plt.title(f'{name}: Average MMD vs Transformed class balance (10 splits)')
plt.xlabel('Transformed class balance')
plt.ylabel('MMD')
plt.legend()
plt.savefig(f'{name}_mmd_vs_transformed_class_balance_10splits_1000.png')
plt.show()






    # # Example usage
    # for pc1, pc2 in zip(point_clouds_1, point_clouds_2):
    #     if detect_rotation_relation(pc1, pc2):
    #         print("Rotation relation detected between the point cloud samples.")
    #     else:
    #         print("No rotation relation detected between the point cloud samples.")

