import os
import torch
import numpy as np
from collections import defaultdict
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import MDS
import umap
from sklearn.cluster import DBSCAN
from scipy.stats import mode
import random
from sklearn.neighbors import NearestNeighbors
from matplotlib.colors import BoundaryNorm
from matplotlib.colorbar import ColorbarBase

from sw2 import (
    Wasserstein_Distance,
    Sliced_Wasserstein_Distance,
    Projected_Wasserstein_Distance,
    Energy_based_Sliced_Wasserstein,
    Expected_Sliced_Transport,
    Min_SWGG,
    Max_Sliced_Wasserstein_Distance,
)
from utils import generate_uniform_unit_sphere_projections, optimal_alpha, optimal_alpha_general

# ==== Hyperparameters ====
N_POINTS = 100
THRESHOLD = 0.5
NUM_SAMPLES_PER_LABEL = 50
NUM_CLASSES = 10
PROJ_DIM = 3
NUM_PROJ = 200

DEVICE = torch.device("cuda")
DTYPE = torch.float32

display_dist_name = {
    "SW": "SW",
    "PWD": "PWD",
    "EBSW": "EBSW",
    "EST":"EST",
    "MaxSW": "MaxSW",
    "MinSWGG": "MinSWGG",
    "fastE": "RG-e",
    "fastOp": "RG-o",
    "fastSP": "RG-s",
    "fast4": "RG-se",
    "fast6": "RG-seo"
}


def extract_pointcloud_dataset(categories, train_dir, num_samples_per_class=500, save_path="compare_wormhole_pc"):
    all_samples = []
    labels = []

    for label, cat in enumerate(categories):
        file_path = os.path.join(train_dir, f"{cat}.pt")
        if not os.path.exists(file_path):
            print(f"❌ File not found for category: {cat}")
            continue

        data = torch.load(file_path)
        if len(data) < num_samples_per_class:
            print(f"⚠️ Not enough samples for category: {cat}, only {len(data)} samples")
            continue

        selected = data[torch.randperm(len(data))[:num_samples_per_class]]
        all_samples.append(selected)
        labels.extend([label] * num_samples_per_class)

    final_tensor = torch.cat(all_samples, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    perm = torch.randperm(final_tensor.shape[0])
    final_tensor = final_tensor[perm]
    labels = labels[perm]

    test_save_dir = os.path.join(save_path, f"num_samples_{len(final_tensor)}")
    os.makedirs(test_save_dir, exist_ok=True)
    torch.save(final_tensor, os.path.join(test_save_dir, "X_train.pt"))
    torch.save(labels, os.path.join(test_save_dir, "y_train.pt"))

    print(f"✅ Final tensor shape: {final_tensor.shape}, labels shape: {labels.shape}")



def compute_distance_matrices(data_path, matrix_distance_folder="pairwise_distance_matrices_extend"):
    os.makedirs(os.path.join(data_path, matrix_distance_folder), exist_ok=True)
    X_train = torch.load(os.path.join(data_path, "X_train.pt")).to(DEVICE).to(DTYPE)
    N = X_train.shape[0]

    projection_matrix = generate_uniform_unit_sphere_projections(
        dim=PROJ_DIM, requires_grad=False, num_projections=NUM_PROJ, dtype=DTYPE, device=DEVICE
    )

    distance_functions = {
        "Wasserstein": lambda x, y: Wasserstein_Distance(x, y, numItermax=1000, device=DEVICE).item(),
        "SW": lambda x, y: Sliced_Wasserstein_Distance(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item(),
        "PWD": lambda x, y: Projected_Wasserstein_Distance(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item(),
        "EBSW": lambda x, y: Energy_based_Sliced_Wasserstein(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item(),
        "EST": lambda x, y: Expected_Sliced_Transport(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE).item(),
        "MinSWGG": lambda x, y: Min_SWGG(x, y, lr=5e-2, num_iter=5, s=2, std=0.5, device=DEVICE, dtype=DTYPE)[0],
        "MaxSW": lambda x, y: Max_Sliced_Wasserstein_Distance(x, y, require_optimize=True, lr=1e-1, num_iter=5, device=DEVICE, dtype=DTYPE)[0]
    }

    dist_matrices = {}
    for name, func in distance_functions.items():
        print(f"\nComputing pairwise distance matrix for {name}...")
        dist_matrix = np.zeros((N, N), dtype=np.float32)
        for i in tqdm(range(N)):
            for j in range(i + 1, N):
                dist = func(X_train[i], X_train[j])
                dist_matrix[i, j] = dist
                dist_matrix[j, i] = dist
        dist_matrices[name] = dist_matrix
        np.save(os.path.join(data_path, matrix_distance_folder, f"{name}_matrix.npy"), dist_matrix)
        print(f"Saved: {name}_matrix.npy")

    num_samples_optimal_alpha = 10
    cache_dir = os.path.join(data_path, matrix_distance_folder, "_alpha_cache")
    os.makedirs(cache_dir, exist_ok=True)
    sampled_idx_path = os.path.join(cache_dir, "sampled_indices.npy")
    w_values_path = os.path.join(cache_dir, "wasserstein_values.npy")

    sampled_indices = None
    wasserstein_values = None

    if os.path.exists(sampled_idx_path) and os.path.exists(w_values_path):
        sampled_indices = np.load(sampled_idx_path)
        wasserstein_values = np.load(w_values_path)
        print(f"\n[Cache] Loaded sampled_indices ({len(sampled_indices)}) and wasserstein_values from cache.")
    else:
        print(f"\nSampling {num_samples_optimal_alpha} pairs for Wasserstein estimation...")
        sampled_indices = []
        while len(sampled_indices) < num_samples_optimal_alpha:
            i, j = np.random.randint(0, N, 2)
            if i != j and (i, j) not in sampled_indices and (j, i) not in sampled_indices:
                sampled_indices.append((i, j))
        sampled_indices = np.array(sampled_indices, dtype=np.int32)

        wasserstein_values = []
        for i, j in sampled_indices:
            dist = Wasserstein_Distance(X_train[i], X_train[j], numItermax=10000, device=DEVICE).item()
            wasserstein_values.append(dist)
        wasserstein_values = np.array(wasserstein_values, dtype=np.float32)

        # Lưu cache
        np.save(sampled_idx_path, sampled_indices)
        np.save(w_values_path, wasserstein_values)
        print(f"[Cache] Saved sampled_indices -> {sampled_idx_path}")
        print(f"[Cache] Saved wasserstein_values -> {w_values_path}")

    custom_groups = [
        (["SW", "PWD"], "fastSP"),
        (["EBSW", "EST"], "fastE"),
        (["MaxSW", "MinSWGG"], "fastOp"),
        (["SW", "PWD", "EBSW", "EST"], "fast4"),
        (["SW", "PWD", "EBSW", "EST", "MaxSW", "MinSWGG"], "fast6"),
    ]
    ridge_lambda = 0.0

    for group_names, out_name in custom_groups:
        print(f"\n[Alpha] Estimating coefficients for group {group_names} -> {out_name}")

        X_pairs = np.array(
            [[dist_matrices[m][i, j] for m in group_names] for (i, j) in sampled_indices],
            dtype=float
        )
        y_pairs = np.array(wasserstein_values, dtype=float)

        alphas = optimal_alpha_general(X_pairs, y_pairs, ridge=ridge_lambda)

        alpha_path = os.path.join(data_path, matrix_distance_folder, f"alpha_{'_'.join(group_names)}.npy")
        np.save(alpha_path, alphas.astype(np.float32))
        print(f"Estimated alphas for {group_names}: {alphas} -> saved to {alpha_path}")

        fast = np.zeros((N, N), dtype=np.float32)
        for a_k, mname in zip(alphas, group_names):
            fast += float(a_k) * dist_matrices[mname].astype(np.float32)

        out_path = os.path.join(data_path, matrix_distance_folder, f"{out_name}_matrix.npy")
        np.save(out_path, fast)
        print(f"Saved: {out_name}_matrix.npy")





def visualize_umap_with_class_separation(data_path, dist_name="Wasserstein", embed_dim=2, 
                                         n_neighbors=15, min_dist=0.01, 
                                         class_labels=None):
    saved_path = os.path.join(data_path, "pairwise_distance_matrices")
    matrix_path = os.path.join(saved_path, f"{dist_name}_matrix.npy")
    label_path = os.path.join(data_path, "y_train.pt")
    dist_matrix = np.load(matrix_path)
    y_train = torch.load(label_path).cpu().numpy()

    if class_labels is None:
        class_labels = sorted(np.unique(y_train).tolist())


    reducer = umap.UMAP(
        metric='precomputed',
        n_components=embed_dim,
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        random_state=42
    )
    XX = reducer.fit_transform(dist_matrix)

    cmap = plt.get_cmap('tab10')

    fig, ax = plt.subplots(figsize=(8, 8))
    scatter = ax.scatter(XX[:, 0], XX[:, 1],
                         c=y_train, cmap=cmap, s=50, alpha=0.8)

    # Thêm color legend
    draw_custom_legend_blocks(fig, class_labels=class_labels, cmap=cmap)

    ax.set_title(f"ShapeNetV2: Embeddings of {display_dist_name[dist_name]}", fontsize=20, weight='bold', pad=15)
    ax.axis('equal')
    
    plt.subplots_adjust(
        left=0.05,
        right=0.88,
        top=0.92,
        bottom=0.05
    )

    plt.savefig(os.path.join(saved_path, f"umap_class_separation_{display_dist_name[dist_name].lower()}_{embed_dim}.png"))
    plt.savefig(os.path.join(saved_path, f"umap_class_separation_{display_dist_name[dist_name].lower()}_{embed_dim}.pdf"), dpi=100)
    plt.close()

if __name__ == "__main__":

    sub_categories = ['table', 'airplane', 'car', 'lamp', 'vessel', 'bench', 'cabinet', 'monitor', 'bathtub', 'guitar']
    num_classes = len(sub_categories)
    num_samples_per_class = 500
    extract_pointcloud_dataset(categories=sub_categories[:num_classes], train_dir="preprocessed_dataset/point_cloud", num_samples_per_class=num_samples_per_class, save_path="preprocessed_dataset/point_cloud/saved_embeddings_umap")
    compute_distance_matrices(data_path=f"embeddings_pc3/num_samples_{num_classes * num_samples_per_class}")
    list_dist_name = ["SW", "PWD", "EBSW", "EST", "MaxSW", "MinSWGG", "fastE", "fastOp", "fastSP", "fast4", "fast6", "Wasserstein"]
    for dst in list_dist_name:
        visualize_umap_with_class_separation(
        data_path=f"saved_embeddings_umap/num_samples_{num_classes * num_samples_per_class}", 
        dist_name=dst
        )