import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from sklearn.metrics import roc_auc_score, silhouette_samples
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
import pdb
import pickle
import numpy as np
import torch
import os

import scipy.io as sio
import pywavefront
import pandas as pd
from src.utils import calculate_geodesic
from scipy.spatial import Delaunay
import argparse
import pathlib

from src.ewot import EWOT
from src.lwot import LWOT

parser = argparse.ArgumentParser()
parser.add_argument(
	"--outpath", default="experiment_data/shape_correspondence/", 
	help="outpath of experiment results, both data and figures"
)
parser.add_argument(
	"--datapath", default="data/shape_matching/", 
	help="datapath of shape data"
)
parser.add_argument(
	'--lwot', default=False, action=argparse.BooleanOptionalAction, 
	help="run lwot if True or run ewot if False"
)
parser.add_argument(
	'--unbalanced', default=False, action=argparse.BooleanOptionalAction, 
	help="run unbalanced sinkhorn iterations if True or balanced if False"
)
parser.add_argument(
	"--wavelet_kernel", default="heat", 
	help="wavelet kernel to use"
)
parser.add_argument(
	"--n_samples", default=1000, 
	help="number of points to sample from each shape"
)
parser.add_argument(
	"--epsilon", default=0.1, 
	help="entropic regularization parameter"
)
parser.add_argument(
	"--agg_op", default="sum", 
	help="aggregation operation for wavelet coefficients scales"
)
parser.add_argument(
	"--n_scales", default=20, 
	help="number of wavelet scales to use"
)
parser.add_argument('--plot', default=True, action=argparse.BooleanOptionalAction, help="plot results if True")
args = parser.parse_args()

# CREATE PATHS IF NONE EXISTS
pathlib.Path(args.outpath).mkdir(parents=True, exist_ok=True) 

# CONSTANTS
TEST_SET = ["test-set1", "test-set2", "test-set3", "test-set4"]
N_SCALES = args.n_scales
EPSILON = args.epsilon
AGG_OP = args.agg_op
WAVELET_KERNEL = args.wavelet_kernel
DIST = "geodesic"
RHO1 = 1.0
RHO2 = 1.0

def calculate_geo_error(X1, X2, X1_labels, X2_labels, X1_gt, X2_gt, k=5, to_X2=True):
    n_samples = len(X1)
    X1_gt = X1_gt.squeeze()
    X2_gt = X2_gt.squeeze()

    intersection = list(set(X2_labels.squeeze().tolist()).intersection(set(X1_labels.squeeze().tolist())))

    X1_indices = np.where(np.isin(X1_labels, intersection))[0]
    X2_indices = np.where(np.isin(X2_labels, intersection))[0]

    true_correspondences = len(X1_indices)

    X1 = torch.from_numpy(np.concatenate([X1, X1_gt[X1_indices]]))
    X2 = torch.from_numpy(np.concatenate([X2, X2_gt[X2_indices]]))

    if args.lwot:
        wot = LWOT(X1, X2, n_scales=N_SCALES, w_op=WAVELET_KERNEL, dist=DIST)
        wot.solve(epsilon=EPSILON, agg_op=AGG_OP, balanced=(not args.unbalanced), rho=RHO1, rho2=RHO2)
    else:
        wot = EWOT(X1, X2, n_scales=N_SCALES, w_op=WAVELET_KERNEL, dist=DIST)
        wot.solve(epsilon=EPSILON, agg_op=AGG_OP, balanced=(not args.unbalanced), rho=RHO1, rho2=RHO2)

    aligned_point_X2 = wot.project()
    X2 = torch.cat([X2, aligned_point_X2[-true_correspondences:]])

    # uc = UnionCom.UnionCom(project_mode="barycentric")
    # aligned_point_X2 = uc.fit_transform(dataset=[X1.numpy(), X2.numpy()])
    # X2 = torch.cat([X2, torch.from_numpy(aligned_point_X2[0][-true_correspondences:])])

    # pa = Pamona.Pamona(output_dim=3)
    # aligned_point_X2, T = pa.run_Pamona([X1.numpy(), X2.numpy()])
    # X2 = torch.cat([X2, torch.from_numpy(aligned_point_X2[0][-true_correspondences:])])


    # hi = mmd_algo(calculate_geodesic(X1), calculate_geodesic(X2))
    # pdb.set_trace()


    dist_matrix = calculate_geodesic(X2)

    triangulation = Delaunay(X2[:-true_correspondences])

    # Access surface triangles
    surface_triangles = X2[:-true_correspondences][triangulation.simplices]

    # Calculate surface area
    triangle_areas = np.linalg.norm(np.cross(surface_triangles[:, 1] - surface_triangles[:, 0],surface_triangles[:, 2] - surface_triangles[:, 0])) / 2
    surface_area = np.sum(triangle_areas)

    err = dist_matrix[n_samples:-true_correspondences,n_samples+true_correspondences:].diagonal() / np.sqrt(surface_area)
    
    return err

for test_file in TEST_SET:
    with open(os.path.join(args.datapath, "shape_data/test-sets", test_file + ".txt"), "r") as f:
        lines = f.readlines()
        animal_pairs = [line.strip().split(",") for line in lines]
        
    for animal_pair in animal_pairs:
        animal1 = animal_pair[0]
        animal2 = animal_pair[1]
        np.random.seed(0)
        # animal1 
        mat_contents = sio.loadmat(os.path.join(args.datapath, f"shape_gts/{animal1}.mat"))
        X1_centroid_coords = mat_contents['centroids']
        X1_centroid_indices = mat_contents['points']
        obj_data = pywavefront.Wavefront(os.path.join(args.datapath, f'shape_data/models/{animal1}.obj'), collect_faces=True)
        vertices = obj_data.vertices
        choice_indices = np.random.choice(np.arange(0, len(vertices)), args.n_samples)
        X1 = np.array(vertices)[choice_indices]

        # animal2
        mat_contents = sio.loadmat(os.path.join(args.datapath, f"shape_gts/{animal2}.mat"))
        X2_centroid_coords = mat_contents['centroids']
        X2_centroid_indices = mat_contents['points']
        obj_data = pywavefront.Wavefront(os.path.join(args.datapath, f'shape_data/models/{animal2}.obj'), collect_faces=True)
        vertices = obj_data.vertices
        choice_indices = np.random.choice(np.arange(0, len(vertices)), args.n_samples)
        X2 = np.array(vertices)[choice_indices]

        errors = calculate_geo_error(X1, X2, X1_centroid_indices, X2_centroid_indices, X1_centroid_coords, X2_centroid_coords)
        print(f"All errors: {errors}")
        print(f"Mean error: {errors.mean()}")
        print(f"Animal pair: {animal_pair}")
        
        pathlib.Path(os.path.join(args.outpath, args.wavelet_kernel, test_file)).mkdir(parents=True, exist_ok=True) 
        with open(os.path.join(args.outpath, args.wavelet_kernel, test_file, f"{animal1}_{animal2}.txt"), "wb") as f:
            pickle.dump(errors, f)

if args.plot:
    # dump_path = "experiment_data/shape_correspondence/"
    test_set = ["test-set1", "test-set2", "test-set3", "test-set4"]
    # competitors = ["wot", "gw", "pgw", "ugw"]
    # competitors = ["gw"]

    # ours = ["wot-b", "wot-u"]
    # ours_meyer = ["wot-b-meyer", "wot-u-meyer"]
    # ours_simple = ["wot-b-simple", "wot-u-simple"]

    # competitor_naming = {"gw": "Gromov-Wasserstein (GW)", "ugw": "Unbalanced GW"}
    # # union_naming = {"wot-b": "WOT (Heat Kernel)", "wot-p": "Partial WOT (Heat Kernel)", "wot-u": "Unbalanced WOT (Heat Kernel)"}
    # our_naming = {"wot-b": "E-WOT (Heat Kernel)", "wot-u": "Unbalanced E-WOT (Heat Kernel)"}
    # our_meyer_naming = {"wot-b-meyer": "E-WOT (Meyer)", "wot-u-meyer": "Unbalanced E-WOT (Meyer)"}
    # our_simple_naming = {"wot-b-simple": "E-WOT (Simple Tight)", "wot-u-simple": "Unbalanced E-WOT (Simple Tight)"}
    # out_learned_naming = {"wot-b-learned": "L-WOT (Heat Kernel)", "wot-u-learned": "Unbalanced L-WOT (Heat Kernel)"}


    for test_file in test_set:
        curr_competitor_errors = []
        curr_our_errors = []
        curr_our_meyer_errors = []
        curr_our_simple_errors = []

        curr_union_errors = []
        curr_learned_errors = []
        curr_pamona_errors = []

        with open(os.path.join(args.datapath, "shape_data/test-sets", test_file + ".txt"), "r") as f:
            lines = f.readlines()
            animal_pairs = [line.strip().split(",") for line in lines]

            for animal_pair in animal_pairs:
                animal1, animal2 = animal_pair
                with open(os.path.join(args.outpath, WAVELET_KERNEL, test_file, f"{animal1}_{animal2}.txt"), "rb") as f:
                    curr_our_errors += pickle.load(f)

        print(f"Test File: {test_file}")
        print(f"Wavelet Kernel: {WAVELET_KERNEL}")
        # print(f"Competitor Mean: {np.mean(curr_competitor_errors)}")
        # print(f"UnionCom Mean: {np.mean(curr_union_errors)}") 
        # print(f"Learned Mean: {np.mean(curr_learned_errors)}") 
        # print(f"Pamona Mean: {np.mean(curr_pamona_errors)}")
        print(f"Our Mean: {np.mean(curr_our_errors)}")
        # print(f"Our Meyer Mean: {np.mean(curr_our_meyer_errors)}")
        # print(f"Our Simple Tight Mean: {np.mean(curr_our_simple_errors)}")

        # curr_competitor_errors = np.array(curr_competitor_errors)
        curr_our_errors = np.array(curr_our_errors)
        # curr_our_meyer_errors = np.array(curr_our_meyer_errors)
        # curr_our_simple_errors = np.array(curr_our_simple_errors)
        # curr_learned_errors = np.array(curr_learned_errors)
        # curr_union_errors = np.array(curr_union_errors)

        print(f"Test File: {test_file}")
        print(f"Wavelet Kernel: {WAVELET_KERNEL}")
        # print(f"Competitor < 0.25: {len(curr_competitor_errors[curr_competitor_errors < 0.25]) / len(curr_competitor_errors)}")
        # print(f"UnionCom < 0.25: {len(curr_union_errors[curr_union_errors < 0.25]) / len(curr_union_errors)}")
        # print(f"Learned < 0.25: {len(curr_learned_errors[curr_learned_errors < 0.25]) / len(curr_learned_errors)}")
        print(f"Our < 0.25: {len(curr_our_errors[curr_our_errors < 0.25]) / len(curr_our_errors)}")
        # print(f"Our < 0.25 Meyer Mean: {len(curr_our_meyer_errors[curr_our_meyer_errors < 0.25]) / len(curr_our_meyer_errors)}")
        # print(f"Our < 0.25 Tight Mean: {len(curr_our_simple_errors[curr_our_simple_errors < 0.25]) / len(curr_our_simple_errors)}")

        # competitor_values, competitor_base = np.histogram(curr_competitor_errors, bins=200)
        # competitor_cumulative = (np.cumsum(competitor_values) / np.cumsum(competitor_values).max()) * 100

        # union_values, union_base = np.histogram(curr_union_errors, bins=200)
        # union_cumulative = (np.cumsum(union_values) / np.cumsum(union_values).max()) * 100

        # learned_values, learned_base = np.histogram(curr_learned_errors, bins=200)
        # learned_cumulative = (np.cumsum(learned_values) / np.cumsum(learned_values).max()) * 100
    
        # pamona_values, pamona_base = np.histogram(curr_pamona_errors, bins=200)
        # pamona_cumulative = (np.cumsum(pamona_values) / np.cumsum(pamona_values).max()) * 100

        our_values, our_base = np.histogram(curr_our_errors, bins=200)
        our_cumulative = (np.cumsum(our_values) / np.cumsum(our_values).max()) * 100

        # our_meyer_values, our_meyer_base = np.histogram(curr_our_meyer_errors, bins=200)
        # our_meyer_cumulative = (np.cumsum(our_meyer_values) / np.cumsum(our_meyer_values).max()) * 100

        # our_simple_values, our_simple_base = np.histogram(curr_our_simple_errors, bins=200)
        # our_simple_cumulative = (np.cumsum(our_simple_values) / np.cumsum(our_simple_values).max()) * 100

        plt.title(f"Cumulative Relative Geodesic Error (SHREC20)")
        # plt.plot(learned_base[:-1], learned_cumulative, c='fuchsia', label=out_learned_naming[curr_learned])
        plt.plot(our_base[:-1], our_cumulative, c='cyan', label=f"WOT_{WAVELET_KERNEL}")
        # plt.plot(our_meyer_base[:-1], our_meyer_cumulative, c='purple', label=our_meyer_naming[curr_our_meyer])
        # plt.plot(our_simple_base[:-1], our_simple_cumulative, c='blue', label=our_simple_naming[curr_our_simple])
        # plt.plot(competitor_base[:-1], competitor_cumulative, c='orange', label=competitor_naming[curr_competitor])
        # plt.plot(union_base[:-1], union_cumulative, c='olive', label="UnionCom")
        # plt.plot(pamona_base[:-1], pamona_cumulative, c='darkred', label="Pamona")
        plt.xlabel("Relative Geodesic Error")
        plt.ylabel("% Matches")

        plt.xlim(0, 1)
        plt.ylim(0, 100)
        plt.xticks(np.arange(0, 1.1, 0.1))
        plt.yticks(np.arange(0, 101, 20))
        plt.legend()
        plt.grid(True)
        plt.show()
        plt.savefig(os.path.join(args.outpath, WAVELET_KERNEL, test_file, "cum_plot.png"))
        plt.close()