import numpy as np
import pandas as pd
from pathlib import Path
from typing import Optional
from matplotlib import pyplot as plt

assert False, "This script is deprecated. Please use collect_baseline_csvs.py instead."

eval_path = Path('eval_results/')
data_key = 'sawtooth' # gp | sawtooth
T_to_plot = 128
mae_lim = [-0.1, 2]
mae_ticks = [0, 0.1, 0.5, 1, 2]
mse_lim = [-0.1, 2]
mse_ticks = [0, 0.1, 0.5, 1, 2]

table_save_path = eval_path / f'{data_key}_evals_{T_to_plot}tar.xlsx'
plot_save_path = eval_path / f'{data_key}_evals_{T_to_plot}tar.svg'


#
time_keys = ['sequence_time', 'sequence_ll_time']
metric_keys = ['MSE', 'MAE', 'Log Likelihood']

stats = {
    key: pd.DataFrame(
        columns=[
            ['configs']*4 + ['values']*2,
            ['method', 'num_permutation', 'num_context', 'num_target', 'mean', 'std']
        ]
    ) for key in time_keys + metric_keys
}

for path in eval_path.glob(f'{data_key}*/*/evaluation_report.txt'):
    num_permutation = int(path.parent.parent.name.split('per')[0].split('_')[-1])
    num_context = int(path.parent.parent.name.split('con')[0].split('_')[-1])
    num_target = int(path.parent.parent.name.split('tar')[0].split('_')[-1])
    method = path.parent.name

    print(data_key, num_permutation, num_context, num_target, method)

    with open(path, 'r') as f_report:
        for line in f_report:
            if any(key in line for key in time_keys):
                line_split = line.split()
                key = line_split[0]
                mean = float(line_split[1])
                std = float(line_split[2])
                stats[key].loc[len(stats[key]), :] = [
                    method, num_permutation, num_context, num_target,
                    mean, std
                ]
            if any(key in line for key in metric_keys):
                line = line.replace("Log Likelihood", "Log_Likelihood")
                line_split = line.split()
                key = line_split[0].split(':')[0].replace("Log_Likelihood", "Log Likelihood")
                mean = float(line_split[1])
                std = float(line_split[3])
                stats[key].loc[len(stats[key]), :] = [
                    method, num_permutation, num_context, num_target,
                    mean, std
                ]

# filter invalid values
for df in stats.values():
    df.replace([np.inf, -np.inf], np.nan, inplace=True)

with pd.ExcelWriter(table_save_path, mode='w') as writer:
    for key, df in stats.items():
        df.to_excel(writer, sheet_name=key)


# let's visualize the summary
## remember:
## time_keys = ['sequence_time', 'sequence_ll_time']
## metric_keys = ['MSE', 'MAE', 'Log Likelihood']

def filter_stats(
    stats,
    key: str,
    method: Optional[str]=None,
    num_permutation: Optional[int]=None,
    num_context: Optional[int]=None,
    num_target: Optional[int]=None,
    filter_nan: bool=True,
    filter_large_values: bool=False,
    large_values_threshold: float=1e3,
):
    df = stats[key]
    assert method or num_permutation or num_context or num_target, "At least one filtering condition should be provided."
    masks = np.ones(len(df), dtype=bool)
    if method:
        masks = np.logical_and(masks, (df[('configs', 'method')] == method))
    if num_permutation:
        masks = np.logical_and(masks, (df[('configs', 'num_permutation')] == num_permutation))
    if num_context:
        masks = np.logical_and(masks, (df[('configs', 'num_context')] == num_context))
    if num_target:
        masks = np.logical_and(masks, (df[('configs', 'num_target')] == num_target))
    if filter_nan:
        nan_mask = np.logical_and(
            np.logical_or(df[('values', 'mean')].isna(), df[('values', 'std')].isna()),
            np.logical_or(df[('values', 'mean')].isnull(), df[('values', 'std')].isnull())
        )
        masks = np.logical_and(masks, ~nan_mask)
    if filter_large_values:
        large_values_mask = df[('values', 'mean')] + df[('values', 'std')] > large_values_threshold
        masks = np.logical_and(masks, ~large_values_mask)
    filtered = df[masks]
    return filtered


nums_context = stats['sequence_time'][('configs', 'num_context')].unique()
nc2iter = {nc: i for i, nc in enumerate(nums_context)}
nums_target = stats['sequence_time'][('configs', 'num_target')].unique() if T_to_plot is None else [T_to_plot]
methods = stats['sequence_time'][('configs', 'method')].unique()

fig, axs = plt.subplots(2, 3, figsize=(15, 8))
# axs[0, 2].set_yscale('symlog')
for i, text in enumerate(metric_keys):
    axs[0, i].set_title(text)
    for w, m in enumerate(methods):
        for T in nums_target:
            df_filtered = filter_stats(stats, text, method=m, num_target=T, filter_nan=True, filter_large_values=False, large_values_threshold=1e4 if text == 'Log Likelihood' else 10)
            Nc = df_filtered[('configs', 'num_context')].values
            mean = df_filtered[('values', 'mean')].values
            std = df_filtered[('values', 'std')].values
            axs[0, i].errorbar(
                Nc - 0.95 + 2*w/(len(methods)+1),  # Jitter for better visibility
                mean,
                yerr=std,
                fmt='h',
                markersize=12,
                color=f'C{w}',
                alpha=0.5,
                label=f"{m}: {T} targets" if T_to_plot is None else m,
                #zorder=j+num_methods
            )
            
            # axs[0, i].bar(
            #     [nc2iter[nc] for nc in Nc],
            #     mean,
            #     yerr=std,
            #     width=1/(len(methods))
            #     label=f"{m}: {T} targets"
            # )
    axs[0, i].set_xscale('symlog')
    if i < 2:
        axs[0, i].set_yscale('symlog') # all but log likelihood
    axs[0, i].set_xlim(4, 2*max(nums_context))
    axs[0, i].set_xticks(nums_context, labels=[str(n) for n in nums_context])
    axs[0, i].set_xlabel('num_context', loc='right')
    axs[0, i].set_ylabel('mean +- std')
axs[0, 0].set_ylim(*mse_lim) # MSE
axs[0, 1].set_ylim(*mae_lim) # MAE
axs[0, 0].set_yticks(mse_ticks, labels=[str(t) for t in mse_ticks])
axs[0, 1].set_yticks(mae_ticks, labels=[str(t) for t in mae_ticks])

for i, text in enumerate(time_keys):

    axs[1, i+1].set_title(text)
    for w, m in enumerate(methods):
        for T in nums_target:
            df_filtered = filter_stats(stats, text, method=m, num_target=T, filter_nan=True, filter_large_values=False)
            if text.lower()=='mae' and m=='pfn_ar':
                print(df_filtered)
            Nc = df_filtered[('configs', 'num_context')].values
            mean = df_filtered[('values', 'mean')].values
            std = df_filtered[('values', 'std')].values
            axs[1, i+1].errorbar(
                Nc - 0.95 + 2*w/(len(methods)+1),  # Jitter for better visibility
                mean,
                yerr=std,
                fmt='h',
                markersize=12,
                color=f'C{w}',
                alpha=0.5,
                label=f"{m}: {T} targets" if T_to_plot is None else m,
                #zorder=j+num_methods
            )
    axs[1, i+1].set_xscale('symlog')
    axs[1, i+1].set_yscale('symlog')
    axs[1, i+1].set_xlim(4, 2*max(nums_context))
    axs[1, i+1].set_xticks(nums_context, labels=[str(n) for n in nums_context])
    axs[1, i+1].set_xlabel('num_context', loc='right')
    axs[1, i+1].set_ylabel('time (s), mean +- std')


handles, labels = axs[0, 0].get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=10)
fig.suptitle(data_key + (f" ({T_to_plot} targets)" if T_to_plot is not None else ""), fontsize=14)

fig.savefig( plot_save_path, format="svg" )
fig.savefig( plot_save_path.parent / (plot_save_path.name.split('.svg')[0] + '.png') )

plt.show()
