import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import sys
sys.path.append("..")
from data import get_data

from sklearn import preprocessing, metrics

import itertools
import argparse

def get_args():
	parser = argparse.ArgumentParser()
	parser.add_argument("--result_folder", type=str, default=".")
	parser.add_argument("--path_to_data", type=str, default="../data/datasets")
	parser.add_argument("--output_csv", type=str, default="model_selection.csv")
	parser.add_argument("--all_kernels", type=str, nargs="+")
	parser.add_argument("--all_distances", type=str, nargs="+")
	parser.add_argument("--methods", type=str, nargs="+")
	parser.add_argument("--all_datasets", type=str, nargs="+")
	parser.add_argument("--n_bootstrap", type=int, default=20)

	return parser.parse_args()


def compute_wcss(X, y, affinity="euclidean"):
	wcss = 0
	if affinity == "euclidean" or affinity=="sqeuclidean":
		for i in np.unique(y):
			cluster_indices, = np.where(y==i)
			wcss += metrics.pairwise_distances(X[cluster_indices], metric="sqeuclidean").sum()/(2*len(cluster_indices))
	else:
		data_kernel = metrics.pairwise_kernels(X, metric=affinity)
		wcss = np.diag(data_kernel).sum()
		for i in np.unique(y):
			cluster_indices, = np.where(y==i)

			cluster_kernel = data_kernel[cluster_indices][:,cluster_indices]
			wcss -= cluster_kernel.sum()/len(cluster_indices)
	return wcss


def main():
	args = get_args()

	datasets = dict()
	targets = dict()
	for name in args.all_datasets:
		# Re-apply standard scaler
		X, y = get_data(name, args.path_to_data)
		datasets[name] = preprocessing.StandardScaler().fit_transform(X)
		targets[name] = y

	results = []
	for affinity, method, dataset in tqdm(itertools.product(args.all_kernels+args.all_distances, args.methods, args.all_datasets)):

		for k in range(2,16):
			# Iterate over number of tasked clusters
			file_name = os.path.join(args.result_folder, affinity, f"{dataset}_{method}_k{k}.csv")
			if os.path.exists(file_name):
				clustering_csv = pd.read_csv(file_name)
				y_pred = clustering_csv.to_numpy().reshape(-1)

				tmp_result = {"Method":method, "Affinity":affinity, "Dataset":dataset, "K":k}

				# Compute the silhouette score
				silhouette = metrics.silhouette_score(datasets[dataset], y_pred)

				# Compute the davies-bouldin score
				davies_bouldin = metrics.davies_bouldin_score(datasets[dataset], y_pred)

				# Compute the within-cluster sum of squares
				wcss = compute_wcss(datasets[dataset], y_pred, affinity)

				tmp_result["WCSS"] = np.log(wcss)
				tmp_result["DB"] = davies_bouldin
				tmp_result["Silhouette"] = silhouette

				for bootstrap in range(args.n_bootstrap):
					bootstrap_file_name = file_name.replace(".csv", f"_bootstrap{bootstrap}.csv")
					if os.path.exists(bootstrap_file_name):
						add_bootstrap = True
						bootstrap_csv = pd.read_csv(bootstrap_file_name)

						data_columns = [x for x in bootstrap_csv.columns if x!="y"]

						X_bootstrap = bootstrap_csv[data_columns].to_numpy()

						bootstrap_gap = compute_wcss(X_bootstrap, bootstrap_csv["y"].to_numpy().reshape(-1), affinity)

						tmp_result[f"BootWCSS_{bootstrap}"] = np.log(bootstrap_gap)

				results += [tmp_result]
	result_df = pd.DataFrame(results)

	result_df.to_csv(args.output_csv, index=False)


if __name__ == "__main__":
	main()
