import torch 
import torch.nn as nn 
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from sw2 import Wasserstein_Distance, Sliced_Wasserstein_Distance, Projected_Wasserstein_Distance, Energy_based_Sliced_Wasserstein, Max_Sliced_Wasserstein_Distance, Min_SWGG, Expected_Sliced_Transport
from utils import generate_uniform_unit_sphere_projections, optimal_alpha
import os
import time
from data.shapenet_dataset import ShapeNet15kPointClouds
from sklearn.metrics import mean_squared_error, r2_score
import random
import json
import time
import argparse
from sklearn.metrics import r2_score



class Wasserstein_kNN:
    def __init__(self, X_train, y_train, metric, device='cpu'):
        self.X_train = X_train.to(device)
        self.y_train = y_train.to(device)
        self.metric  = metric
        self.device  = device
        self._cached_test = None
        self._dist_matrix = None
        self._rank_matrix = None

    def _prepare_test(self, X_test):
        X_test = X_test.to(self.device)
        if self._cached_test is not None and torch.equal(X_test, self._cached_test):
            return
        N_test, N_train = X_test.shape[0], self.X_train.shape[0]
        D = torch.empty(N_test, N_train, device=self.device)
        for i in range(N_test):
            distances = []
            xi = X_test[i]
            for j in range(N_train):
                dij = self.metric(xi, self.X_train[j])
                distances.append(dij.item() if isinstance(dij, torch.Tensor) else dij)
            D[i] = torch.tensor(distances, device=self.device)
        rank = torch.argsort(D, dim=1)
        self._cached_test = X_test
        self._dist_matrix = D
        self._rank_matrix = rank

    def predict(self, X_test, k=1):
        self._prepare_test(X_test)
        idx_topk = self._rank_matrix[:, :k]
        knn_labels = self.y_train[idx_topk]
        preds = torch.mode(knn_labels, dim=1).values
        return preds

    def accuracy(self, X_test, y_test, k=1):
        y_pred = self.predict(X_test, k)
        return (y_pred == y_test.to(self.device)).float().mean().item()

    def distances(self):
        return None if self._dist_matrix is None else self._dist_matrix.clone()
    def rankings(self):
        return None if self._rank_matrix is None else self._rank_matrix.clone()


def compute_and_save_distance_matrix(metric_fn, X_test, X_train, save_path, metric_name):
    N_test, N_train = X_test.shape[0], X_train.shape[0]
    D = torch.empty(N_test, N_train)
    start = time.time()
    for i in range(N_test):
        for j in range(N_train):
            D[i, j] = metric_fn(X_test[i], X_train[j])
        if (i + 1) % 10 == 0 or i == N_test - 1:
            print(f"{metric_name}: Processed {i+1}/{N_test} test samples")
    end = time.time()
    torch.save(D, save_path)
    print(f"Saved {metric_name} distance matrix to {save_path}")
    print(f"==> [{metric_name}] Time elapsed: {end-start:.2f} seconds")
    with open(os.path.join(os.path.dirname(save_path), "result.txt"), "a") as f:
        f.write(f"{metric_name} distance matrix processing time: {end-start:.2f} seconds\n")


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--parent_dir", type=str, default="data/point_cloud",
                        help="Directory of normalized point clouds")
    parser.add_argument("--saved_path", type=str, default="saved_knn",
                        help="Where to save results")
    parser.add_argument("--num_train", type=int, default=50, help="Number of train samples per class")
    parser.add_argument("--num_test", type=int, default=100, help="Number of test samples per class")
    args = parser.parse_args()

    num_pairs = 10
    npoints = 2048
    num_train = args.num_train
    num_test = args.num_test
    dataroot = "data/ShapeNetCore.v2.PC15k"
    parent_dir = args.parent_dir
    saved_path = args.saved_path
    os.makedirs(saved_path, exist_ok=True)


    print(f"parent_dir: {parent_dir}\nsaved_path: {saved_path}\nnum_train: {num_train}\nnum_test: {num_test}")
    with open(f"{saved_path}/result.txt", "a") as f:
        f.write(f"parent_dir: {parent_dir}\nsaved_path: {saved_path}\nnum_train: {num_train}\nnum_test: {num_test} \n")


    categories = ['table', 'chair', 'airplane', 'car', 'sofa', 'rifle', 'lamp', 'vessel', 'bench', 'speaker', 
    'cabinet', 'monitor', 'bus', 'bathtub', 'guitar', 'faucet', 'clock', 'pot', 'cellphone', 'jar', 'bottle', 
    'telephone', 'laptop', 'bookshelf', 'knife', 'train', 'motorcycle', 'can', 'file', 'pistol', 'bed', 'piano', 
    'stove', 'mug', 'bowl', 'washer', 'printer', 'helmet', 'microwave', 'skateboard', 'tower', 'camera', 'basket', 
    'tin_can', 'pillow', 'dishwasher', 'mailbox', 'rocket', 'bag', 
    'earphone', 'birdhouse', 'microphone', 'remote_control', 'keyboard', 'cap'][:10]

    torch.manual_seed(1)
    np.random.seed(1)

    DEVICE = "cuda"
    DTYPE = torch.float32

    X_train_path = f"{saved_path}/X_train.pt"
    y_train_path = f"{saved_path}/y_train.pt"
    X_test_path  = f"{saved_path}/X_test.pt"
    y_test_path  = f"{saved_path}/y_test.pt"

    all_exist = all([os.path.exists(p) for p in [X_train_path, y_train_path, X_test_path, y_test_path]])

    if all_exist:
        print("==> Loading cached tensors")
        X_train = torch.load(X_train_path, map_location=DEVICE)
        y_train = torch.load(y_train_path, map_location=DEVICE)
        X_test  = torch.load(X_test_path,  map_location=DEVICE)
        y_test  = torch.load(y_test_path,  map_location=DEVICE)
    else:
        print("==> Creating new tensors and saving")
        all_point_clouds, all_labels = [], []
        all_point_clouds_test, all_labels_test = [], []
        for idx, thing in enumerate(categories):
            pc = torch.load(f"{parent_dir}/train/{thing}.pt")
            pc_test = torch.load(f"{parent_dir}/val/{thing}.pt")
            perm = torch.randperm(pc.shape[0])[:num_train]
            perm_test = torch.randperm(pc_test.shape[0])[:num_test]
            pc = pc[perm]
            pc_test = pc_test[perm_test]

            all_point_clouds.append(pc)
            all_labels.append(torch.full((pc.shape[0],), idx, dtype=DTYPE))
            all_point_clouds_test.append(pc_test)
            all_labels_test.append(torch.full((pc_test.shape[0],), idx, dtype=DTYPE))

        X_train = torch.cat(all_point_clouds, dim=0).to(DEVICE)
        y_train = torch.cat(all_labels, dim=0).to(DEVICE)
        X_test = torch.cat(all_point_clouds_test, dim=0).to(DEVICE)
        y_test = torch.cat(all_labels_test, dim=0).to(DEVICE)

        torch.save(X_train, X_train_path)
        torch.save(y_train, y_train_path)
        torch.save(X_test,  X_test_path)
        torch.save(y_test,  y_test_path)

    projection_matrix = generate_uniform_unit_sphere_projections(
        dim=3, requires_grad=False, num_projections=100, dtype=DTYPE, device=DEVICE
    )

    chosen_idx = torch.randperm(X_train.shape[0], device=DEVICE)[:2*num_pairs]
    X = chosen_idx[:num_pairs]
    Y = chosen_idx[num_pairs:]

    list_ws, list_sw, list_pwd, list_ebsw, list_est, list_minswgg, list_maxsw = [], [], [], [], [], [], []
    for i in range(num_pairs):
        ws = Wasserstein_Distance(X_train[X[i]], X_train[Y[i]], device=DEVICE)
        sliced_ws = Sliced_Wasserstein_Distance(X_train[X[i]], X_train[Y[i]], projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
        pwd = Projected_Wasserstein_Distance(X_train[X[i]], X_train[Y[i]], projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
        ebsw = Energy_based_Sliced_Wasserstein(X_train[X[i]], X_train[Y[i]], projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
        est = Expected_Sliced_Transport(X_train[X[i]], X_train[Y[i]], projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
        min_swgg = Min_SWGG(X_train[X[i]], X_train[Y[i]], lr=5e-2, num_iter=5, s=20, std=0.5, device=DEVICE, dtype=DTYPE)[0]
        max_sw = Max_Sliced_Wasserstein_Distance(X_train[X[i]], X_train[Y[i]], require_optimize=True, lr=1e-1, num_iter=5, device=DEVICE, dtype=DTYPE)[0]
        list_ws.append(ws.item())
        list_sw.append(sliced_ws.item())
        list_pwd.append(pwd.item())
        list_ebsw.append(ebsw.item())
        list_est.append(est.item())
        list_minswgg.append(min_swgg.item())
        list_maxsw.append(max_sw.item())
    list_ws = np.array(list_ws)
    list_sw = np.array(list_sw)
    list_pwd = np.array(list_pwd)
    list_ebsw = np.array(list_ebsw)
    list_est = np.array(list_est)
    list_minswgg = np.array(list_minswgg)
    list_maxsw = np.array(list_maxsw)

    opt_alpha_sw_pwd = optimal_alpha(list_sw, list_pwd, list_ws)
    opt_alpha_ebsw_est = optimal_alpha(list_ebsw, list_est, list_ws)
    opt_alpha_maxsw_minswgg = optimal_alpha(list_maxsw, list_minswgg, list_ws)

    r2_sw_pwd = r2_score(list_ws, opt_alpha_sw_pwd * list_sw + (1 - opt_alpha_sw_pwd) * list_pwd)
    r2_ebsw_est = r2_score(list_ws, opt_alpha_ebsw_est * list_ebsw + (1 - opt_alpha_ebsw_est) * list_est)
    r2_maxsw_minswgg = r2_score(list_ws, opt_alpha_maxsw_minswgg * list_maxsw + (1 - opt_alpha_maxsw_minswgg) * list_minswgg)

    r2_sw = r2_score(list_ws, list_sw)
    r2_pwd = r2_score(list_ws, list_pwd)
    r2_ebsw = r2_score(list_ws, list_ebsw)
    r2_est = r2_score(list_ws, list_est)
    r2_maxsw = r2_score(list_ws, list_maxsw)
    r2_minswgg = r2_score(list_ws, list_minswgg)

    alphas = {
        "SWD_PWD": float(opt_alpha_sw_pwd),
        "EBSW_EST": float(opt_alpha_ebsw_est),
        "MaxSW_MinSWGG": float(opt_alpha_maxsw_minswgg)
    }
    with open(f"{saved_path}/optimal_alpha.json", "w") as f:
        json.dump(alphas, f, indent=2)
    with open(f"{saved_path}/result.txt", "a") as f:

        f.write(f"Optimal alpha for SWD+PWD: {opt_alpha_sw_pwd}, R²: {r2_sw_pwd:.4f}\n")
        f.write(f"Optimal alpha for EBSW+EST: {opt_alpha_ebsw_est}, R²: {r2_ebsw_est:.4f}\n")
        f.write(f"Optimal alpha for MaxSW+MinSWGG: {opt_alpha_maxsw_minswgg}, R²: {r2_maxsw_minswgg:.4f}\n")
        f.write(f"SW, R²: {r2_sw:.4f}\n")
        f.write(f"PWD, R²: {r2_pwd:.4f}\n")
        f.write(f"EBSW, R²: {r2_ebsw:.4f}\n")
        f.write(f"EST, R²: {r2_est:.4f}\n")
        f.write(f"MaxSW, R²: {r2_maxsw:.4f}\n")
        f.write(f"MinSWGG, R²: {r2_minswgg:.4f}\n")

    print(f"Optimal alpha for SWD+PWD: {opt_alpha_sw_pwd}, R²: {r2_sw_pwd:.4f}")
    print(f"Optimal alpha for EBSW+EST: {opt_alpha_ebsw_est}, R²: {r2_ebsw_est:.4f}")
    print(f"Optimal alpha for MaxSW+MinSWGG: {opt_alpha_maxsw_minswgg}, R²: {r2_maxsw_minswgg:.4f}")
    print(f"SW, R²: {r2_sw:.4f}")
    print(f"PWD, R²: {r2_pwd:.4f}")
    print(f"EBSW, R²: {r2_ebsw:.4f}")
    print(f"EST, R²: {r2_est:.4f}")
    print(f"MaxSW, R²: {r2_maxsw:.4f}")
    print(f"MinSWGG, R²: {r2_minswgg:.4f}")

    wd_fn = lambda x, y: Wasserstein_Distance(x, y, numItermax=100000, device=DEVICE)
    sw_fn = lambda x, y: Sliced_Wasserstein_Distance(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
    pwd_fn = lambda x, y: Projected_Wasserstein_Distance(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
    ebsw_fn = lambda x, y: Energy_based_Sliced_Wasserstein(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
    est_fn = lambda x, y: Expected_Sliced_Transport(x, y, projection_matrix=projection_matrix, device=DEVICE, dtype=DTYPE)
    minswgg_fn = lambda x, y: Min_SWGG(x, y, lr=5e-2, num_iter=10, s=10, std=0.5, device=DEVICE, dtype=DTYPE)[0]
    maxsw_fn = lambda x, y: Max_Sliced_Wasserstein_Distance(x, y, require_optimize=True, lr=1e-1, num_iter=10, device=DEVICE, dtype=DTYPE)[0]

    print("Saving all metric distance matrices for later alpha blending...")
    compute_and_save_distance_matrix(sw_fn, X_test, X_train, f"{saved_path}/SWD_dist.pt", "SWD")
    compute_and_save_distance_matrix(pwd_fn, X_test, X_train, f"{saved_path}/PWD_dist.pt", "PWD")
    compute_and_save_distance_matrix(ebsw_fn, X_test, X_train, f"{saved_path}/EBSW_dist.pt", "EBSW")
    compute_and_save_distance_matrix(est_fn, X_test, X_train, f"{saved_path}/EST_dist.pt", "EST")
    compute_and_save_distance_matrix(minswgg_fn, X_test, X_train, f"{saved_path}/MinSWGG_dist.pt", "MinSWGG")
    compute_and_save_distance_matrix(maxsw_fn, X_test, X_train, f"{saved_path}/MaxSW_dist.pt", "MaxSW")
    compute_and_save_distance_matrix(wd_fn, X_test, X_train, f"{saved_path}/WD_dist.pt", "WD")
