import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

base_path = "03_results/paper_results/unimodal"

def load_all_seeds(exp_type):
	pattern = os.path.join(base_path, f"unimodal_param_robustness_exp-{exp_type}_seed-*.csv")
	files = glob.glob(pattern)
	dfs = [pd.read_csv(f) for f in files]
	if dfs:
		return pd.concat(dfs, ignore_index=True)
	else:
		return pd.DataFrame()

def main():
	# Load all seeds for each experiment type
	df_sample = load_all_seeds('samples')
	df_data = load_all_seeds('data')
	df_noise = load_all_seeds('noise')

	# Set font sizes
	fs = 24
	plt.rcParams.update({'font.size': fs, 'axes.titlesize': fs, 'axes.labelsize': fs, 'xtick.labelsize': fs, 'ytick.labelsize': fs, 'legend.fontsize': fs})

	# Create grid with 3 columns, all subplots same width
	fig = plt.figure(figsize=(21, 7))
	gs = fig.add_gridspec(1, 3, wspace=0.25)
	ax1 = fig.add_subplot(gs[0, 0])
	ax2 = fig.add_subplot(gs[0, 1])
	ax3 = fig.add_subplot(gs[0, 2])

	# --- Subplot 1: Rank vs. Sample Size (mean ± SEM over seeds) ---
	sample_data = df_sample[df_sample['param_name'] == 'n_samples']
	sample_data = sample_data.sort_values('param_value')
	grouped = sample_data.groupby('param_value')['final_ranks']
	mean = grouped.mean()
	sem = grouped.sem()
	ax1.errorbar(mean.index, mean.values, yerr=sem.values, marker='o', color='blue', capsize=6, markersize=10, linewidth=3)
	ax1.set_xscale('log')
	ax1.set_xlabel('Sample Size')
	ax1.set_ylabel('Final Rank')
	ax1.set_title('')
	ax1.grid(True, alpha=0.3)
	ax1.axhline(5, color='red', linestyle='dotted', linewidth=2)

	# --- Subplot 2: Middle plot: choice between point plot (mean rank ± SEM) and boxplot (delta rank) ---
	middle_plot_type = "point"  # options: "point" (default), "delta"
	plot_data = df_data.dropna(subset=['param_name', 'param_value'])
	plot_data['param_value'] = plot_data['param_value'].astype(str)
	baseline_rows = df_data[df_data['param_name'].isna()]
	if not baseline_rows.empty:
		target_rank = baseline_rows['final_ranks'].mean()
	else:
		target_rank = 5
	plot_data['rank_deviation'] = np.abs(plot_data['final_ranks'] - target_rank)

	custom_labels = ['n', 'distribution', 'connectivity', 'nonlinearity depth', 'nonlinearity type']

	if middle_plot_type == "point":
		# Point plot: mean rank ± SEM for each param_name
		param_names = plot_data["param_name"].unique()
		means = []
		sems = []
		for param in param_names:
			group = plot_data[plot_data["param_name"] == param]
			means.append(group["final_ranks"].mean())
			sems.append(group["final_ranks"].sem())
		ax2.errorbar(range(len(param_names)), means, yerr=sems, fmt='o', color='C0', capsize=6, markersize=10, linewidth=3)
		ax2.set_xticks(range(len(param_names)))
		ax2.set_xticklabels(custom_labels, rotation=30)
		ax2.set_xlabel("Parameter Name")
		ax2.set_ylabel("Final Rank")
		ax2.set_title("")
		ax2.grid(True, alpha=0.3)
		ax2.legend([],[], frameon=False)
		ax2.axhline(5, color='red', linestyle='dotted', linewidth=2)
	else:
		# Boxplot of deviations from baseline rank for each param_name
		sns.boxplot(
			data=plot_data,
			x="param_name",
			y="rank_deviation",
			ax=ax2,
			palette="tab10",
			showfliers=False
		)
		ax2.set_title("")
		ax2.set_xlabel("Parameter Name")
		ax2.set_ylabel("Δ Rank")
		ax2.tick_params(axis='x', rotation=30)
		ax2.legend([],[], frameon=False)
		ax2.set_xticklabels(custom_labels, rotation=30)
		ax2.axhline(5, color='red', linestyle='dotted', linewidth=2)

	# --- Subplot 3: Rank vs. SNR (noise) and Sparsity (mean ± SEM over seeds, two x-axes) ---
	import json
	snr_path = os.path.join(base_path, "signal_to_noise_ratios.json")
	if os.path.exists(snr_path):
		with open(snr_path, "r") as f:
			snr_dict = json.load(f)
	else:
		snr_dict = None

	noise_data = df_noise[(df_noise['param_name'] == 'noise_variance') | (df_noise['param_name'] == 'data_sparsity')]
	noise_pivot = noise_data[noise_data['param_name'] == 'noise_variance'].groupby('param_value')['final_ranks'].agg(['mean','sem']).reset_index()
	sparsity_pivot = noise_data[noise_data['param_name'] == 'data_sparsity'].groupby('param_value')['final_ranks'].agg(['mean','sem']).reset_index()

	# Prepare SNRs for noise x-axis
	snr_x = []
	snr_labels = []
	for val in noise_pivot['param_value']:
		snr = None
		if snr_dict:
			snr = snr_dict.get(str(val), None)
			if snr is None:
				snr = snr_dict.get(val, None)
		if val == 0.0 or (snr is not None and np.isnan(snr)):
			snr_x.append(np.nan)
			snr_labels.append('inf')
		else:
			snr_x.append(snr)
			snr_labels.append(f'{snr:.2f}')

	# Remove NaN SNRs for plotting, but keep for ticks
	plot_x = [x for x in snr_x if not np.isnan(x)]
	plot_y = [y for x, y in zip(snr_x, noise_pivot['mean']) if not np.isnan(x)]
	plot_sem = [s for x, s in zip(snr_x, noise_pivot['sem']) if not np.isnan(x)]

	ln1 = ax3.errorbar(plot_x, plot_y, yerr=plot_sem, marker='o', color='purple', label='Noise', capsize=6, markersize=10, linewidth=3)
	ax3.set_xlabel('Signal-to-Noise Ratio')
	ax3.set_ylabel('Final Rank')
	ax3.set_title('')
	ax3.grid(True, alpha=0.3)
	ax3.axhline(5, color='red', linestyle='dotted', linewidth=2)
	# Set log scale and invert axis (large to small)
	ax3.set_xscale('log')
	ax3.invert_xaxis()
	# Only use SNRs that are not inf/NaN
	# Select largest 3 and smallest SNRs for ticks
	sorted_x = sorted(plot_x, reverse=True)
	tick_x = sorted_x[:3] + [sorted_x[-1]] if len(sorted_x) > 3 else sorted_x
	tick_labels = [f'{x:.2f}' for x in tick_x]
	ax3.set_xticks(tick_x)
	ax3.set_xticklabels(tick_labels, rotation=30)

	# Create top x-axis for sparsity
	ax3_top = ax3.twiny()
	ln2 = ax3_top.errorbar(sparsity_pivot['param_value'], sparsity_pivot['mean'], yerr=sparsity_pivot['sem'], marker='s', color='orange', label='Sparsity', capsize=6, markersize=10, linewidth=3)
	ax3_top.set_xlabel('Dropout')
	ax3_top.xaxis.set_label_position('top')
	ax3_top.xaxis.set_ticks_position('top')
	ax3_top.grid(False)

	# Legends for both axes
	lines = [ln1, ln2]
	labels = ['Noise', 'Dropout']
	ax3.legend(lines, labels, loc='center left', bbox_to_anchor=(1.02, 0.5), borderaxespad=0)

	fig.subplots_adjust(left=0.07, right=0.85, top=0.85, bottom=0.35, wspace=0.25)
	plt.savefig(os.path.join(base_path, "unimodal_param_robustness_summary.png"), dpi=300)
	plt.show()

	# --- Supplementary Figure: Ranks over param_value for each param_name ---
	param_names = plot_data["param_name"].unique()
	n_params = len(param_names)
	fig_supp, axes_supp = plt.subplots(1, n_params, figsize=(6*n_params, 6), sharey=True)
	if n_params == 1:
		axes_supp = [axes_supp]
	for i, param in enumerate(param_names):
		ax = axes_supp[i]
		param_df = plot_data[plot_data["param_name"] == param].copy()
		# For data_dim, sort param_value numerically
		if param == "data_dim":
			param_df["param_value"] = pd.to_numeric(param_df["param_value"], errors="coerce")
			param_df = param_df.sort_values("param_value")
		else:
			param_df = param_df.sort_values("param_value")
		grouped = param_df.groupby("param_value")
		mean = grouped["final_ranks"].mean()
		sem = grouped["final_ranks"].sem()
		if param == "data_dim":
			ax.errorbar(mean.index, mean.values, yerr=sem.values, marker='o', linestyle='-', capsize=5, markersize=8, linewidth=2)
			ax.set_xscale('log')
		elif param == "hidden_connectivity":
			ax.errorbar(mean.index, mean.values, yerr=sem.values, marker='o', linestyle='-', capsize=5, markersize=8, linewidth=2)
		else:
			ax.errorbar(mean.index, mean.values, yerr=sem.values, marker='o', linestyle='', capsize=5, markersize=8, linewidth=0)
		ax.set_title(param)
		ax.set_xlabel("Param Value")
		if i == 0:
			ax.set_ylabel("Final Rank")
		ax.grid(True, alpha=0.3)
		ax.tick_params(axis='x', rotation=30)
		ax.axhline(5, color='red', linestyle='dotted', linewidth=2)
	fig_supp.tight_layout()
	fig_supp.subplots_adjust(bottom=0.38)
	plt.savefig(os.path.join(base_path, "unimodal_param_robustness_data_supp.png"), dpi=300)

if __name__ == "__main__":
	main()
