from src.wavelets import calculate_wavelet_coeffs
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from sklearn.metrics import roc_auc_score, silhouette_samples
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
import pdb
import pickle
import numpy as np
import torch
import os
import argparse
from src.vanilla_wot import VanillaWOT
import ot
import pathlib

parser = argparse.ArgumentParser()
parser.add_argument(
	"--outpath", default="experiment_data/toy_set/", 
	help="outpath of experiment results, both data and figures"
)
parser.add_argument(
	"--datapath", default="data", 
	help="datapath of bifurcation data"
)
parser.add_argument(
	'--dropout', default=False, action=argparse.BooleanOptionalAction, 
	help="run dropout experiment if True"
)
parser.add_argument(
	'--noise', default=False, action=argparse.BooleanOptionalAction, 
	help="run noising experiment if True"
)
parser.add_argument('--plot', default=True, action=argparse.BooleanOptionalAction, help="plot results if True")
parser.add_argument(
	'--num_trials', default=10, help="number of trials to run for confidence intervals"
)
parser.add_argument(
	'--max_noise', default=0.15, help="max noise to add, value is in fraction of mean distance between points"
)
parser.add_argument('--max_dropout', default=0.9, help="max fraction to dropout from the dataset"
)
parser.add_argument('--num_intervals', default=20, help="number of intervals to run for noise and dropout experiments between min and max values"
)
parser.add_argument(
	"--epsilon_noise", default=1e-3, 
	help="entropic regularization parameter for noise experiment"
)
parser.add_argument(
	"--epsilon_dropout", default=1e-4, 
	help="entropic regularization parameter for dropout experiment"
)
parser.add_argument(
	"--agg_op", default="mean", 
	help="aggregation operation for wavelet coefficients scales"
)
parser.add_argument(
	"--n_scales", default=20, 
	help="number of wavelet scales to use"
)
args = parser.parse_args()

"""
Evaluation code below is largely adapted from SCOT (Ritambhara Singh, Pinar Demetci, Rebecca Santorella)
"""
def calc_frac_idx(x1_mat,x2_mat, true_dict):
	"""
	Returns fraction closer than true match for each sample (as an array)
	"""
	fracs = []
	x = []
	nsamp = x1_mat.shape[0]
	rank=0
	for row_idx in range(nsamp):
		euc_dist = torch.sqrt(torch.sum(np.square(np.subtract(x1_mat[row_idx,:], x2_mat)), axis=1))
		true_nbr = euc_dist[row_idx]
		sort_euc_dist = sorted(euc_dist)
		rank =sort_euc_dist.index(true_nbr)
		frac = float(rank)/(nsamp -1)

		fracs.append(frac)
		x.append(row_idx+1)

	return fracs,x

def calc_domainAveraged_FOSCTTM(x1_mat, x2_mat, true_dict):
	"""
	Outputs average FOSCTTM measure (averaged over both domains)
	Get the fraction matched for all data points in both directions
	Averages the fractions in both directions for each data point
	"""
	fracs1,xs = calc_frac_idx(x1_mat, x2_mat, true_dict)
	fracs2,xs = calc_frac_idx(x2_mat, x1_mat, {value: key for key, value in true_dict.items()})
	fracs = []
	for i in range(len(fracs1)):
		fracs.append((fracs1[i]+fracs2[i])/2)  
	return fracs

# CONSTANTS
N_SCALES = args.n_scales
NOISE_EPSILON = args.epsilon_noise
DROPOUT_EPSILON = args.epsilon_dropout
AGG_OP = args.agg_op
WAVELET_KERNEL = "simple_tight"

if args.noise:
	# EXPERIMENT DATA STORES
	# gw_info = {x: [] for x in np.linspace(0.0, args.max_noise, num=args.num_intervals)}
	wot_info = {x: [] for x in np.linspace(0.0, args.max_noise, num=args.num_intervals)}

	# TOY DATASETS
	clean_X1 = np.genfromtxt(os.path.join(args.datapath, "simulations/s1_mapped1.txt"))
	clean_X2 =  np.genfromtxt(os.path.join(args.datapath, "simulations/s1_mapped2.txt"))

	# EUCLIDEAN DISTANCE MATRIX
	X1_avg_dist = torch.cdist(torch.from_numpy(clean_X1), torch.from_numpy(clean_X1)).mean().item()
	X2_avg_dist = torch.cdist(torch.from_numpy(clean_X2), torch.from_numpy(clean_X2)).mean().item()

	num_trials = args.num_trials
	for i in range(num_trials):
		step = 0
		for multiplier in np.linspace(0.0, args.max_noise, num=args.num_intervals):
			X_2_permutation_indices = np.random.permutation(clean_X2.shape[0])
			X_2_permutation_matrix = np.eye(clean_X2.shape[0])[X_2_permutation_indices]
			
			X1_var = multiplier * X1_avg_dist
			X2_var = multiplier * X2_avg_dist

			X1_noise = np.random.normal(0.0, X1_var, clean_X1.shape)
			X2_noise = np.random.normal(0.0, X2_var, clean_X2.shape)

			X1 = clean_X1 + X1_noise
			X2 = clean_X2 + X2_noise

			wot = VanillaWOT(X1, X_2_permutation_matrix @ X2, n_scales=N_SCALES, w_op=WAVELET_KERNEL, dist="euclidean")
			wot.solve(epsilon=NOISE_EPSILON, agg_op=AGG_OP)

			aligned_point_X1 = wot.project(to_X2=False)
			aligned_point_X2 = wot.project()

			FOSCTTM_X1 = np.mean(calc_domainAveraged_FOSCTTM(X_2_permutation_matrix @ X1, aligned_point_X1, {}))
			FOSCTTM_X2 = np.mean(calc_domainAveraged_FOSCTTM(X2, aligned_point_X2, {}))
			print(f"-------Variance Multiplier: {multiplier}------------")
			print(f"FOSCTTM X1: {FOSCTTM_X1}")
			print(f"FOSCTTM X2: {FOSCTTM_X2}")

			wot_info[multiplier] = wot_info[multiplier] + [np.mean([FOSCTTM_X1, FOSCTTM_X2])]
			# gw_info[multiplier] = gw_info[multiplier] + [np.mean([FOSCTTM_X1, FOSCTTM_X2])]
			step += 1

	# DUMPING EXPERIMENT DATA
	pathlib.Path(os.path.join(args.outpath)).mkdir(parents=True, exist_ok=True) 
	with open(os.path.join(args.outpath, "wot-simple.pickle"), "wb") as f:
		pickle.dump(wot_info, f)

	# with open(os.path.join(args.outpath, "gw.pickle"), "wb") as f:
	# 	pickle.dump(gw_info, f)

	if args.plot:
		# CALCULATE CONFIDENCE INTERVALS
		x_wot = list(wot_info.keys())
		y_wot = [np.median(wot_info[x]) for x in x_wot]
		y_wot_upper = [np.percentile(wot_info[x], 75) for x in x_wot]
		y_wot_lower = [np.percentile(wot_info[x], 25) for x in x_wot]
		y_wot_upper = np.array(y_wot_upper) - np.array([np.median(wot_info[x]) for x in x_wot])
		y_wot_lower = np.array([np.median(wot_info[x]) for x in x_wot]) - np.array(y_wot_lower)


		# x_gw = list(gw_info.keys())
		# y_gw = [np.median(gw_info[x]) for x in x_gw]
		# y_gw_upper = [np.percentile(gw_info[x], 75) for x in x_gw]
		# y_gw_upper[3] = np.percentile(gw_info[x_gw[3]], 50)
		# y_gw_lower = [np.percentile(gw_info[x], 25) for x in x_gw]
		# y_gw_upper = np.array(y_gw_upper) - np.array([np.median(gw_info[x]) for x in x_gw])
		# y_gw_lower = np.array([np.median(gw_info[x]) for x in x_gw]) - np.array(y_gw_lower)

		# PLOTTING RESULTS

		plt.title(f"Mean FOSCTTM vs Additive Gaussian Noise")
		plt.errorbar(x_wot, y_wot, c='cyan', fmt='^', ecolor = 'darkcyan', capsize=4,yerr=[y_wot_lower, y_wot_upper], linewidth=1, linestyle='-', label="Wavelet Optimal Transport (Simple Tight)")
		# plt.errorbar(x_gw, y_gw, c='orange', fmt='o', ecolor = 'peru', capsize=4,yerr=[y_gw_lower, y_gw_upper],linewidth=1, linestyle='--', label="Gromov-Wasserstein")
		plt.xlabel("Variance of Gaussian Noise (in Fraction of Mean Distance Between Points)")
		plt.ylabel("Mean FOSCTTM")

		plt.xlim(0, args.max_noise + 0.001)
		plt.ylim(0, 0.5)
		plt.xticks(np.arange(0, args.max_noise + 0.01, 0.03))
		plt.yticks(np.arange(0, 0.51, 0.1))
		plt.legend(loc="lower right")
		plt.grid(True)
		plt.show()
		plt.savefig(os.path.join(args.outpath, "noise_vs_foscttm.png"))
		plt.close()

#########################
# DROPOUT
#########################
if args.dropout:
	# EXPERIMENT DATA STORES
	wot_info = {x: [] for x in np.linspace(0.0, args.max_dropout, num=args.num_intervals)}

	# TOY DATASETS
	clean_X1 = np.genfromtxt(os.path.join(args.datapath, "simulations/s1_mapped1.txt"))
	clean_X2 =  np.genfromtxt(os.path.join(args.datapath, "simulations/s1_mapped2.txt"))

	# EUCLIDEAN DISTANCE MATRIX
	X1_avg_dist = torch.cdist(torch.from_numpy(clean_X1), torch.from_numpy(clean_X1)).mean().item()
	X2_avg_dist = torch.cdist(torch.from_numpy(clean_X2), torch.from_numpy(clean_X2)).mean().item()

	num_trials = args.num_trials
	for i in range(num_trials):
		step = 0
		for multiplier in np.linspace(0.0, args.max_dropout, num=args.num_intervals):
			X1_dropout = multiplier * len(clean_X1)
			X2_dropout = multiplier * len(clean_X2)

			X1_selection = np.sort(np.random.choice(list(range(0, len(clean_X1))), int(X1_dropout), replace=False))
			X2_selection = np.sort(np.random.choice(list(range(0, len(clean_X2))), int(X2_dropout), replace=False))
		
			X_2_permutation_indices = np.random.permutation(clean_X2.shape[0])
			X_2_permutation_matrix = np.eye(clean_X2.shape[0])

			intersect = np.intersect1d(X1_selection, X2_selection)

			X1_intersect_indices = np.where(np.isin(X1_selection, intersect))[0]
			X2_intersect_indices = np.where(np.isin(X2_selection, intersect))[0]

			X1_mask = np.ones(len(clean_X1), dtype=bool)
			X1_mask[X1_intersect_indices] = False

			X2_mask = np.ones(len(clean_X2), dtype=bool)
			X2_mask[X2_intersect_indices] = False
			
			X1 = np.array(clean_X1)
			X2 = np.array(clean_X2)

			X1[X1_intersect_indices] = X1[X1_intersect_indices] + np.random.normal(0.0, X1_avg_dist, X1[X1_intersect_indices].shape)
			X2[X2_intersect_indices] = X2[X2_intersect_indices] + np.random.normal(0.0, X2_avg_dist, X2[X2_intersect_indices].shape)

			wot = VanillaWOT(X1, X_2_permutation_matrix @ X2, n_scales=N_SCALES, w_op=WAVELET_KERNEL, dist="euclidean")
			wot.solve(epsilon=DROPOUT_EPSILON, agg_op=AGG_OP)

			aligned_point_X1 = wot.project(to_X2=False)
			aligned_point_X2 = wot.project()


			X1_types=np.loadtxt("data/simulations/s1_label1.txt")
			X2_types=np.loadtxt("data/simulations/s1_label2.txt")
			
			if len(X2_mask) == 0:
				X1_mask = np.ones(X1.shape[0], dtype=bool)
				X2_mask = np.ones(X2.shape[0], dtype=bool)
				
			FOSCTTM_X1 = np.mean(calc_domainAveraged_FOSCTTM((X_2_permutation_matrix @ X1)[X1_mask], aligned_point_X1[X1_mask], {}))
			FOSCTTM_X2 = np.mean(calc_domainAveraged_FOSCTTM(X2[X2_mask], aligned_point_X2[X2_mask], {}))
			print(f"-------Dropout Multiplier: {multiplier}------------")
			print(f"FOSCTTM X1: {FOSCTTM_X1}")
			print(f"FOSCTTM X2: {FOSCTTM_X2}")

			wot_info[multiplier] = wot_info[multiplier] + [np.mean([FOSCTTM_X1, FOSCTTM_X2])]
			step += 1

	pathlib.Path(os.path.join(args.outpath)).mkdir(parents=True, exist_ok=True) 
	with open(os.path.join(args.outpath, "wot-dropout.pickle"), "wb") as f:
		pickle.dump(wot_info, f)

	# with open(os.path.join(args.outpath, "gw-dropout.pickle"), "wb") as f:
	# 	pickle.dump(gw_info, f)

	if args.plot:
		x_wot =list(wot_info.keys())
		y_wot = [np.median(wot_info[x]) for x in x_wot]
		y_wot_upper = [np.percentile(wot_info[x], 75) for x in x_wot]
		y_wot_lower = [np.percentile(wot_info[x], 25) for x in x_wot]
		y_wot_upper = np.array(y_wot_upper) - np.array([np.median(wot_info[x]) for x in x_wot])
		y_wot_lower = np.array([np.median(wot_info[x]) for x in x_wot]) - np.array(y_wot_lower)


		# x_gw =list(gw_info.keys())
		# y_gw = [np.median(gw_info[x]) for x in x_gw]
		# y_gw_upper = [np.percentile(gw_info[x], 75) for x in x_gw]
		# y_gw_lower = [np.percentile(gw_info[x], 25) for x in x_gw]
		# y_gw_upper = np.array(y_gw_upper) - np.array([np.median(gw_info[x]) for x in x_gw])
		# y_gw_lower = np.array([np.median(gw_info[x]) for x in x_gw]) - np.array(y_gw_lower)

		plt.title(f"Mean FOSCTTM vs Dropout")
		plt.errorbar(x_wot, y_wot, c='cyan', fmt='^', ecolor = 'darkcyan', capsize=4,yerr=[y_wot_lower, y_wot_upper], linewidth=1, linestyle='-', label="Wavelet Optimal Transport (Simple Tight)")
		plt.errorbar(x_gw, y_gw, c='orange', fmt='o', ecolor = 'peru', capsize=4,yerr=[y_gw_lower, y_gw_upper],linewidth=1, linestyle='--', label="Gromov-Wasserstein")
		plt.xlabel("Dropout Fraction")
		plt.ylabel("Mean FOSCTTM")

		plt.xlim(0, args.max_dropout + 0.01)
		plt.ylim(0, 0.5)
		plt.xticks(np.arange(0, args.max_dropout + 0.01, 0.1))
		plt.yticks(np.arange(0, 0.51, 0.1))
		plt.legend(loc="lower right")
		plt.grid(True)
		plt.show()
		plt.savefig(os.path.join(args.outpath, "dropout_vs_foscttm.png"))
		plt.close()