import numpy as np
import pandas as pd
import argparse
from pathlib import Path
from typing import Optional
from matplotlib import pyplot as plt

tick_fontsize = 16
label_fontsize = 18
title_fontsize = 18

METHODS_LIST = [
    'fast_buf_np_KT',
    'fast_buf_np_K1',
    'fast_buf_np_K4',
    'fast_buf_np_K16',
    'pfn_ar',
    'pfn_independent',
    'tnpd_ar',
    'tnpd_independent',
    'tnpdmg_ar',
    'tnpdmg_independent',
    'tnpa',
    'tnpamg',
    'tnpnd'
]

METRICS_LIST = [
    'sequence_time',
    'sequence_ll_time',
    'MSE',
    'MAE',
    'MC Log Likelihood',
    'Log MC Likelihood'
]

def method_display_name(method: str) -> str:
    if method == 'pfn_ar':
        return 'PFN-AR'
    elif method == 'pfn_independent':
        return 'PFN-Ind'
    elif method == 'tnpd_ar':
        return 'TNP-D-AR (Gaussian Head)'
    elif method == 'tnpd_independent':
        return 'TNP-D-Ind (Gaussian Head)'
    elif method == 'tnpdmg_ar':
        return 'TNP-D-AR'
    elif method == 'tnpdmg_independent':
        return 'TNP-D-Ind'
    elif method == 'tnpa':
        return 'TNP-A (Gaussian Head)'
    elif method == 'tnpamg':
        return 'TNP-A'
    elif method == 'tnpnd':
        return 'TNP-ND'
    elif method.startswith('fast_buf_np'):
        return method.replace('fast_buf_np_', 'TNP-B-')
    else:
        return method

def find_xlim(axis_values, scale:str):
    if scale == 'linear':
        return (min(axis_values)-4, max(axis_values)+4)
    else: # log scale
        return (min(axis_values)//2, max(axis_values)*2)

def time_mapping(time_in_s: np.ndarray, unit: str) -> np.ndarray:
    if unit == 's':
        return time_in_s
    elif unit == 'ms':
        return time_in_s * 1e3
    elif unit == 'us':
        return time_in_s * 1e6
    else:
        raise ValueError(f"Unsupported time unit: {unit}")

def need_y_normalize(metric: str) -> bool:
    return False
    # return metric in ['MC Log Likelihood', 'Log MC Likelihood']

def metric_ylabel(metric: str, timeunit: Optional[str]=None) -> str:
    if metric == 'sequence_time':
        return f'time ({timeunit})'
    elif metric == 'sequence_ll_time':
        return f'time ({timeunit})'
    elif metric == 'MSE':
        return 'MSE'
    elif metric == 'MAE':
        return 'MAE'
    elif metric == 'MC Log Likelihood':
        if need_y_normalize(metric):
            return 'Normalized LL'
        return 'LL'
    elif metric == 'Log MC Likelihood':
        if need_y_normalize(metric):
            return 'Normalized LL'
        return 'LL'
    else:
        raise ValueError(f"Unsupported metric: {metric}")

def filter_stats(
    stats,
    key: str,
    method: Optional[str]=None,
    num_permutation: Optional[int]=None,
    num_context: Optional[int]=None,
    num_target: Optional[int]=None,
):
    df = stats[key]
    assert method or num_permutation or num_context or num_target, "At least one filtering condition should be provided."
    masks = np.ones(len(df), dtype=bool)
    if method:
        masks = np.logical_and(masks, (df[('configs', 'method')] == method))
    if num_permutation:
        masks = np.logical_and(masks, (df[('configs', 'num_permutation')] == num_permutation))
    if num_context:
        masks = np.logical_and(masks, (df[('configs', 'num_context')] == num_context))
    if num_target:
        masks = np.logical_and(masks, (df[('configs', 'num_target')] == num_target))
    filtered = df[masks]
    return filtered

def is_method_ours(method: str) -> bool:
    return method.startswith('fast_buf_np')

def main(args):
    eval_path = Path(args.eval_path)
    data_key = args.data_key
    methods_to_plot = args.methods_to_plot
    methods_to_plot = METHODS_LIST if methods_to_plot is None else methods_to_plot
    metrics_to_plot = args.metrics_to_plot
    metrics_to_plot = METRICS_LIST if metrics_to_plot is None else metrics_to_plot
    T_to_plot = args.number_targets_to_plot
    x_jitter_size = args.x_jitter_size
    time_unit = args.time_unit
    xscale = args.x_scale
    yscale = args.y_scale
    y_lim = args.y_lim
    show_legend = args.show_legend

    assert len(metrics_to_plot) == len(xscale) == len(yscale), f"Length of metrics_to_plot, x_scale and y_scale must be the same. Get {metrics_to_plot}, {xscale}, {yscale}."

    table_save_path = eval_path / f'{data_key}_evals.xlsx'
    assert table_save_path.exists(), f"Evaluation path {table_save_path} does not exist. Please run collect_baseline_csvs.py first."
    plot_save_path_svg = eval_path / f'{data_key}_evals_{T_to_plot}tar.svg'
    plot_save_path_pdf = eval_path / f'{data_key}_evals_{T_to_plot}tar.pdf'
    plot_save_path_png = eval_path / f'{data_key}_evals_{T_to_plot}tar.png'

    with pd.ExcelFile(table_save_path) as reader:        
        stats = {
            key: pd.read_excel(
                reader,
                sheet_name=key,
                header=[0,1],
                index_col=0,
                skiprows=[2],  # Skip the 3rd row (0-indexed), which is the empty row after headers
            ) for key in metrics_to_plot
        }


    # prepare for plotting
    print("Plotting...")
    nums_context = stats[metrics_to_plot[0]][('configs', 'num_context')].unique()
    nums_target = stats[metrics_to_plot[0]][('configs', 'num_target')].unique() if T_to_plot is None else [T_to_plot]

    xticks = np.asarray(sorted(nums_context), dtype=float)
    xtick_labels = [str(int(x)) for x in xticks]

    # plot

    fig, axs = plt.subplots(1, len(metrics_to_plot), figsize=(8*len(metrics_to_plot), 5), sharex='all', squeeze=False)

    for i, metric in enumerate(metrics_to_plot):
        for w, m in enumerate(methods_to_plot):
            for T in nums_target:
                method_name = f'fast_buf_np_K{T}' if m == 'fast_buf_np_KT' else m # when K=T, K={T} folder should not exist.
                df_filtered = filter_stats(stats, metric, method=m, num_target=T)
                df_filtered = df_filtered.sort_values(by=[('configs', 'num_context')], axis=0, ascending=True)
                Nc = df_filtered[('configs', 'num_context')].values
                mean = df_filtered[('values', 'mean')].values
                std = df_filtered[('values', 'std')].values
                sem = df_filtered[('values', 'sem')].values
                if need_y_normalize(metric):
                    mean = mean/T
                    std = std/T
                    sem = sem/T

                if len(mean) < 1:
                    print(f"No data for {metric}, {m}, {T} targets; skip")
                    continue

                if metric in ['sequence_time', 'sequence_ll_time']:
                    mean = time_mapping(mean, time_unit)
                    std = time_mapping(std, time_unit)
                    sem = time_mapping(sem, time_unit)

                # jitter x for better visibility
                if xscale == 'linear':
                    x_jittered = Nc - x_jitter_size + 2*x_jitter_size*w/(len(methods_to_plot)+1)
                else: # log scale
                    _r = x_jitter_size / 100
                    x_jittered = Nc * (1 - _r + 2*_r*w/(len(methods_to_plot)+1))

                axs[0, i].errorbar(
                    x_jittered,
                    mean,
                    yerr=1.96*sem,
                    fmt='s',
                    markersize=6,
                    color=f'C{w}',
                    alpha=0.7 if is_method_ours(m) else 0.4,
                    label=f"{method_display_name(method_name)}: {T} targets" if T_to_plot is None else method_display_name(method_name),
                )

        x_lim = find_xlim(nums_context, xscale[i])

        axs[0, i].set_title(f'N_t={T_to_plot}', fontsize=title_fontsize)
        axs[0, i].set_xscale(xscale[i])
        axs[0, i].set_xticks(xticks, labels=xtick_labels, minor=False, fontsize=tick_fontsize)
        axs[0, i].set_xticks([], minor=True)
        axs[0, i].set_xlabel('N_c (jitter for visual)', fontsize=label_fontsize)
        axs[0, i].set_xlim(*x_lim)

        axs[0, i].set_yscale(yscale[i])
        axs[0, i].set_ylabel(metric_ylabel(metric, timeunit=time_unit) + ' (mean +- 1.96 SEM)', fontsize=label_fontsize)
        axs[0, i].yaxis.set_tick_params(labelsize=tick_fontsize)
        if y_lim[0] is not None or y_lim[1] is not None:
            axs[0, i].set_ylim(*y_lim)
        if show_legend:
            axs[0, i].legend(loc=(0.655, 0), ncol=1, fontsize=label_fontsize)

    #handles, labels = axs[0, 0].get_legend_handles_labels()
    #fig.legend(handles, labels, loc='lower center', ncol=10)
    #fig.suptitle(data_key + (f" ({T_to_plot} targets)" if T_to_plot is not None else ""), fontsize=14)
    #fig.tight_layout()

    print(f"Saving plot to {plot_save_path_svg.parent}/[{plot_save_path_svg.name} | {plot_save_path_pdf.name} | {plot_save_path_png.name}]")
    fig.savefig( plot_save_path_png )
    fig.savefig( plot_save_path_svg, format="svg" )
    fig.savefig( plot_save_path_pdf, bbox_inches="tight")

    plt.show()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Visualize baseline evaluation results")
    parser.add_argument(
        "--eval-path", type=str, default="eval_results/",
        help="Root Path to the evaluation results"
    )
    parser.add_argument(
        "--data-key", type=str, default="gp",
        choices=['gp', 'sawtooth', 'eeg_data', 'eeg_forecasting_data'],
        help="Data key to use for filtering the results"
    )
    parser.add_argument(
        "--methods-to-plot", type=str, default=None, nargs='+',
        choices=METHODS_LIST,
        help="Methods to plot (all if None)"
    )
    parser.add_argument(
        "--metrics-to-plot", type=str, default=['Log MC Likelihood'], nargs='+',
        choices=METRICS_LIST,
        help="Metrics to plot (all if None)"
    )
    parser.add_argument(
        "--number-targets-to-plot", type=int, default=16,
        help="Number of target points to plot (all if None)"
    )
    parser.add_argument(
        "--x-jitter-size", type=float, default=2,
        help="Size of the jitter applied to the x-axis. For log scale, the number is in percentage, for linear scale, the number is in x axis value."
    )
    parser.add_argument(
        "--time-unit", type=str, default='ms',
        choices=['s', 'ms', 'us'],
        help="Time unit for plotting (s, ms, us); the collected table xlsx is in seconds."
    )
    parser.add_argument(
        "--x-scale", type=str, nargs='+', default=['linear']*len(METRICS_LIST),
        choices=['linear', 'log'],
        help="X scale for plotting (linear or log)"
    )
    parser.add_argument(
        "--y-scale", type=str, nargs='+', default=['log']*len(METRICS_LIST),
        choices=['linear', 'log'],
        help="Y scale for plotting (linear or log)"
    )
    parser.add_argument(
        "--y-lim", type=float, nargs=2, default=[None, None],
        help="Y-axis limits."
    )
    parser.add_argument(
        "--show-legend", action='store_true',
        help="Whether to show the legend."
    )
    args = parser.parse_args()

    main(args)
