import math
import numpy as np
import pandas as pd
import argparse
from pathlib import Path
from typing import Optional
from matplotlib import pyplot as plt

METHODS_LIST = [
    'fast_buf_np_KT',
    'fast_buf_np_K1',
    'fast_buf_np_K4',
    'fast_buf_np_K16',
    'fast_buf_np_sg_KT',
    'fast_buf_np_sg_K1',
    'fast_buf_np_sg_K4',
    'fast_buf_np_sg_K16',
    'pfn_ar',
    'pfn_independent',
    'tnpd_ar',
    'tnpd_independent',
    'tnpdmg_ar',
    'tnpdmg_independent',
    'tnpa',
    'tnpamg',
    'tnpnd'
]

def filter_invalid_values(df, filter_large_values=False, large_values_threshold=1e3):
    filtered = df.replace([np.inf, -np.inf], np.nan)
    filtered = filtered.dropna(axis=0, how='any')
    if filter_large_values:
        large_values_mask = (filtered > large_values_threshold).any(axis=1)
        filtered = filtered[~large_values_mask]
    return filtered

def time_mapping(time_in_s: np.ndarray, unit: str) -> np.ndarray:
    if unit == 's':
        return time_in_s
    elif unit == 'ms':
        return time_in_s * 1e3
    elif unit == 'us':
        return time_in_s * 1e6
    else:
        raise ValueError(f"Unsupported time unit: {unit}")

def need_y_normalize(metric: str) -> bool:
    return metric in ['MC Log Likelihood', 'Log MC Likelihood']

def filter_stats(
    stats,
    key: str,
    method: Optional[str]=None,
    num_permutation: Optional[int]=None,
    num_context: Optional[int]=None,
    num_target: Optional[int]=None,
):
    df = stats[key]
    assert method or num_permutation or num_context or num_target, "At least one filtering condition should be provided."
    masks = np.ones(len(df), dtype=bool)
    if method:
        masks = np.logical_and(masks, (df[('configs', 'method')] == method))
    if num_permutation:
        masks = np.logical_and(masks, (df[('configs', 'num_permutation')] == num_permutation))
    if num_context:
        masks = np.logical_and(masks, (df[('configs', 'num_context')] == num_context))
    if num_target:
        masks = np.logical_and(masks, (df[('configs', 'num_target')] == num_target))
    filtered = df[masks]
    return filtered

def is_method_ours(method: str) -> bool:
    return method.startswith('fast_buf_np')

def dimension(data_key: str) -> int:
    if data_key in ['gp', 'sawtooth', 'bav']:
        return 1
    elif data_key in ['eeg_data', 'eeg_forecasting_data']:
        return 7
    else:
        raise ValueError(f"Unsupported data key: {data_key}")

def get_stats(df: pd.DataFrame, col: str):
    mean = df[col].mean().item()
    std = df[col].std().item()
    sem = df[col].sem().item()
    return mean, std, sem

def main(args):
    eval_path = Path(args.eval_path)
    data_key = args.data_key
    methods_to_collect = args.methods_to_collect
    have_seeds = args.have_seeds
    min_num_context, max_num_context = args.number_contexts_to_collect_range
    min_num_target, max_num_target = args.number_targets_to_collect_range
    do_not_plot = args.do_not_plot
    #
    T_to_plot = args.number_targets_to_plot
    xscale = args.xscale
    x_jitter_size = args.x_jitter_size
    donot_plot_time = args.hide_time_plot
    # if we don't plot time, the below args are not used
    time_unit = args.time_unit
    time_scale = args.time_scale
    mae_scale = args.mae_scale
    mse_scale = args.mse_scale
    mae_lim = args.mae_lim
    mse_lim = args.mse_lim

    # do some checks
    assert donot_plot_time or min_num_target <= T_to_plot <= max_num_target, "T_to_plot should be within the range of number of targets collected."

    # prepare paths
    table_save_path = eval_path / f'{data_key}_evals.xlsx' # the xlsx store all target numbers
    summary_table_save_path = eval_path / f'{data_key}_evals_summary.xlsx'
    plot_save_path_svg = eval_path / f'{data_key}_evals_{T_to_plot}tar.svg'
    plot_save_path_png = eval_path / f'{data_key}_evals_{T_to_plot}tar.png'

    # prepare stats tables
    time_keys = ['sequence_time', 'sequence_ll_time']
    metric_keys = ['MSE', 'MAE', 'MC Log Likelihood', 'Log MC Likelihood'] ### DO NOT change the order
    key2summaryCol = {
        'sequence_time': 'sequence_time',
        'sequence_ll_time': 'sequence_ll_time',
        'MSE': 'mse',
        'MAE': 'mae',
        'MC Log Likelihood': 'mean_log_likelihood',
        'Log MC Likelihood': 'log_mean_likelihood',
    }

    stats = {
        key: pd.DataFrame(
            columns=[
                ['configs']*6 + ['values']*6,
                [
                    'method', 'ours', 'num_permutation', 'num_context', 'num_target', 'sample_size',
                    'mean', 'std', 'sem', 'normalized_mean', 'normalized_std', 'normalized_sem']
            ] if need_y_normalize(key) else [
                ['configs']*6 + ['values']*3,
                [
                    'method', 'ours', 'num_permutation', 'num_context', 'num_target', 'sample_size',
                    'mean', 'std', 'sem'
                ]
            ]
        ) for key in time_keys + metric_keys
    }

    # decipher the folder name and obtain num_permutation, num_context, num_target
    def get_nums_per_con_tar_from_path_name(path_name: str):
        try:
            num_permutation = int(path_name.split('per')[0].split('_')[-1])
            num_context = int(path_name.split('con')[0].split('_')[-1])
            num_target = int(path_name.split('tar')[0].split('_')[-1])
            return num_permutation, num_context, num_target
        except Exception as e:
            raise ValueError(f"Invalid path name format: {path_name}") from e

    # collect results
    for path in eval_path.glob(f'{data_key}*'):
        if not path.is_dir():
            continue
        num_permutation, num_context, num_target = get_nums_per_con_tar_from_path_name(path.name)
        if not (min_num_context <= num_context <= max_num_context) or not (min_num_target <= num_target <= max_num_target):
            continue

        methods_to_collect = methods_to_collect if not methods_to_collect is None else METHODS_LIST
        for method in methods_to_collect:
            if have_seeds:
                # path: .../{data_key}*_{num_permutation}per_{num_context}con_{num_target}tar/{method}_{seed}
                time_df = [] # accumulated over seeds
                ppp_df = [] # accumulated over seeds
                # ll_sum = None # log mean likelihood need logsumexp per row accross seeds, treat separately
                num_seeds = 0 # accumulate number of seeds found
                for method_path in path.glob(f'{method}_[0-9]'):

                    print(f"loading {data_key}, {num_permutation} per, {num_context} con, {num_target} tar, {method_path.name}")

                    time_path = method_path / 'timing_data.csv'
                    ppp_path = method_path / 'performance_data.csv'

                    if not time_path.exists() or not ppp_path.exists():
                        continue

                    time_df.append(pd.read_csv(time_path))
                    ppp_df.append(pd.read_csv(ppp_path))
                    num_seeds += 1
                if num_seeds == 0: # no files found and the above loop is never entered
                    continue
                # pd.DataFrame support sum([df1, ...])
                time_df = sum(time_df) / num_seeds
                ppp_df = sum(ppp_df) / num_seeds
                # so in this case, results on each function is averaged over seeds
            else:
                # path: .../{data_key}*_{num_permutation}per_{num_context}con_{num_target}tar/{method}
                time_path = path / method / 'timing_data.csv'
                ppp_path = path / method / 'performance_data.csv'

                if not time_path.exists() or not ppp_path.exists():
                    continue

                print(f"loading {data_key}, {num_permutation} per, {num_context} con, {num_target} tar, {method}")

                time_df = pd.read_csv(time_path)
                ppp_df = pd.read_csv(ppp_path)

            # store stats from tables
            flag_ours = is_method_ours(method) # mark if methods are ours, convenient for method sorting
            # store time metrics from time_df
            for key in time_keys:
                col = key2summaryCol[key]
                mean, std, sem = get_stats(time_df, col)
                stats[key].loc[len(stats[key]), :] = [
                    method, flag_ours, num_permutation, num_context, num_target, time_df.shape[0],
                    mean, std, sem
                ]
            # store performance metrics from ppp_df
            # notice here that normalization might be needed
            for key in metric_keys:
                col = key2summaryCol[key]
                mean, std, sem = get_stats(ppp_df, col)
                if need_y_normalize(key):
                    mean_norm = mean / dimension(data_key) / num_target
                    std_norm = std / dimension(data_key) / num_target
                    sem_norm = sem / dimension(data_key) / num_target
                    stats[key].loc[len(stats[key]), :] = [
                        method, flag_ours, num_permutation, num_context, num_target, ppp_df.shape[0],
                        mean, std, sem, mean_norm, std_norm, sem_norm
                    ]
                else:
                    stats[key].loc[len(stats[key]), :] = [
                        method, flag_ours, num_permutation, num_context, num_target, ppp_df.shape[0],
                        mean, std, sem
                    ]

    print(f"Saving summary table to {table_save_path}")
    with pd.ExcelWriter(table_save_path, mode='w') as writer:
        for key, df in stats.items():
            df.to_excel(writer, sheet_name=key)

    print("Creating summary table (merge num_context...)")
    stats_summary = dict()
    for key in stats.keys():
        # also create a summary table
        df = stats[key]
        df_summary = pd.DataFrame( columns=df.columns.drop(('configs', 'num_context')) )
        # the previous tables collect mean, std, sem over different functions
        # now we need mean, std, sem over different num_contexts as well
        # we need to compute variance as if this is variance over all results accross functions, num_contexts
        methods = df[('configs', 'method')].unique()
        num_per = df[('configs', 'num_permutation')].unique()
        num_target = df[('configs', 'num_target')].unique()
        for m in methods:
            for p in num_per:
                for t in num_target:
                    df_filtered = filter_stats(stats, key, method=m, num_permutation=p, num_target=t)
                    if len(df_filtered) == 0:
                        continue
                    assert df_filtered[('configs', 'sample_size')].nunique() == 1
                    n = df_filtered[('configs', 'sample_size')].unique().item()
                    mean_i = df_filtered[('values', 'mean')]
                    var_i = df_filtered[('values', 'std')].pow(2)
                    # now we have the mean and variance of each num_context
                    # we need to compute the overall mean and variance, taking functions and num_contexts together
                    mean = mean_i.mean().item()
                    var = (var_i + (mean_i - mean).pow(2)).mean().item()
                    std = math.sqrt(var)
                    sem = math.sqrt(var / n / df_filtered.shape[0])
                    if ('values', 'normalized_mean') in df.columns:
                        mean_i_norm = df_filtered[('values', 'normalized_mean')]
                        var_i_norm = df_filtered[('values', 'normalized_std')].pow(2)

                        mean_norm = mean_i_norm.mean().item()
                        var_norm = (var_i_norm + (mean_i_norm - mean_norm).pow(2)).mean().item()
                        std_norm = math.sqrt(var_norm)
                        sem_norm = math.sqrt(var_norm / n / df_filtered.shape[0])
                        df_summary.loc[len(df_summary), :] = [
                            m, is_method_ours(m), p, t, f'{n}x{df_filtered.shape[0]}',
                            mean, std, sem, mean_norm, std_norm, sem_norm
                        ]
                    else:
                        df_summary.loc[len(df_summary), :] = [
                            m, is_method_ours(m), p, t, f'{n}x{df_filtered.shape[0]}',
                            mean, std, sem
                        ]
        stats_summary[key] = df_summary

    print(f"Saving merged summary table to {summary_table_save_path}")
    with pd.ExcelWriter(summary_table_save_path, mode='w') as writer:
        for key, df_summary in stats_summary.items():
            # save summary table
            df_summary.to_excel(writer, sheet_name=f"{key}")

    if do_not_plot:
        return

    ##########################################
    ##########################################
    ##########################################
    # let's visualize the summary
    ## remember:
    ## time_keys = ['sequence_time', 'sequence_ll_time']
    ## metric_keys = ['MSE', 'MAE', 'MC Log Likelihood', 'Log MC Likelihood']
    ##########################################
    ##########################################
    ##########################################

    # prepare for plotting
    print("Plotting...")
    nums_context = stats['sequence_time'][('configs', 'num_context')].unique()
    nums_target = stats['sequence_time'][('configs', 'num_target')].unique() if T_to_plot is None else [T_to_plot]
    methods = stats['sequence_time'].sort_values(by=[('configs', 'ours'), ('configs', 'method')], ascending=[False, True], axis=0)[('configs', 'method')].unique()

    xticks = np.asarray(sorted(nums_context), dtype=float)
    xtick_labels = [str(int(x)) for x in xticks]
    xlim = (min(nums_context)-4, max(nums_context)+4) if xscale == 'linear' else \
        (min(nums_context)//2, max(nums_context)*2)

    # plot

    fig, axs = plt.subplots(2, 4, figsize=(25, 10), sharex='all')

    for i, text in enumerate(metric_keys):
        axs[0, i].set_title(text)
        for w, m in enumerate(methods):
            for T in nums_target:
                df_filtered = filter_stats(stats, text, method=m, num_target=T)
                df_filtered = df_filtered.sort_values(by=[('configs', 'num_context')], axis=0, ascending=True)
                Nc = df_filtered[('configs', 'num_context')].values
                mean = df_filtered[('values', 'mean')].values
                std = df_filtered[('values', 'std')].values
                sem = df_filtered[('values', 'sem')].values
                if len(mean) < 1:
                    print(f"No data for {text}, {m}, {T} targets; skip")
                    continue

                # jitter x for better visibility
                if xscale == 'linear':
                    x_jittered = Nc - x_jitter_size + 2*x_jitter_size*w/(len(methods)+1)
                else: # log scale
                    _r = x_jitter_size / 100
                    x_jittered = Nc * (1 - _r + 2*_r*w/(len(methods)+1))

                axs[0, i].errorbar(
                    x_jittered,
                    mean,
                    yerr=1.96*sem,
                    fmt='s',
                    markersize=6,
                    color=f'C{w}',
                    alpha=0.7 if is_method_ours(m) else 0.4,
                    label=f"{m}: {T} targets" if T_to_plot is None else m,
                )

        axs[0, i].set_xscale(xscale)
        axs[0, i].set_xlim(*xlim)
        axs[0, i].set_xticks(xticks, labels=xtick_labels, minor=False)
        axs[0, i].set_xticks([], minor=True)
        if donot_plot_time:
            axs[0, i].tick_params(labelbottom=True)
            axs[0, i].set_xlabel('num_context (jitter for visual)', loc='right')
        axs[0, i].set_ylabel('mean +- 1.96 SEM')
    if not donot_plot_time:
        axs[0, 0].tick_params(labelbottom=True)
        axs[0, 3].tick_params(labelbottom=True)
    axs[0, 0].set_yscale(mse_scale) # MSE
    axs[0, 1].set_yscale(mae_scale) # MAE 
    axs[0, 0].set_ylim(*mse_lim) # MSE
    axs[0, 1].set_ylim(*mae_lim) # MAE

    for i, text in enumerate(time_keys):

        if donot_plot_time:
            axs[1, i+1].axis('off')
            continue
        axs[1, i+1].set_title(text)
        for w, m in enumerate(methods):
            for T in nums_target:
                df_filtered = filter_stats(stats, text, method=m, num_target=T)
                df_filtered = df_filtered.sort_values(by=[('configs', 'num_context')], axis=0, ascending=True)
                Nc = df_filtered[('configs', 'num_context')].values
                mean = time_mapping(
                    df_filtered[('values', 'mean')].values, time_unit
                )
                std = time_mapping(
                    df_filtered[('values', 'std')].values, time_unit
                )
                sem = time_mapping(
                    df_filtered[('values', 'sem')].values, time_unit
                )

                # jitter x for better visibility
                if xscale == 'linear':
                    x_jittered = Nc - 0.5 + w/(len(methods)+1)
                else: # log scale
                    x_jittered = Nc * (0.9 + 0.2*w/(len(methods)+1))

                axs[1, i+1].errorbar(
                    x_jittered,
                    mean,
                    yerr=1.96*sem,
                    fmt='s',
                    markersize=6,
                    color=f'C{w}',
                    alpha=0.7 if is_method_ours(m) else 0.4,
                    label=f"{m}: {T} targets" if T_to_plot is None else m,
                )

        axs[1, i+1].set_xscale(xscale)
        axs[1, i+1].set_yscale(time_scale)
        axs[1, i+1].set_xlim(*xlim)
        axs[1, i+1].set_xticks(xticks, labels=xtick_labels, minor=False)
        axs[1, i+1].set_xticks([], minor=True)
        axs[1, i+1].set_xlabel('num_context (jitter for visual)')
        axs[1, i+1].set_ylabel(f'time ({time_unit}), mean +- 1.96 SEM')

    handles, labels = axs[0, 0].get_legend_handles_labels()
    #axs[1,0].legend(handles, labels, loc='upper right', ncol=1)
    axs[1,0].axis('off')
    axs[1,3].legend(handles, labels, loc='upper left', ncol=1)
    axs[1,3].axis('off')
    fig.suptitle(data_key + (f" ({T_to_plot} targets)" if T_to_plot is not None else ""), fontsize=14)
    fig.tight_layout()

    print(f"Saving plot to {plot_save_path_svg.parent}/{plot_save_path_svg.name} and {plot_save_path_png.name}")
    fig.savefig( plot_save_path_png )
    fig.savefig( plot_save_path_svg, format="svg" )

    plt.show()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Visualize baseline evaluation results")
    parser.add_argument(
        "--eval-path", type=str, default="eval_results/",
        help="Root Path to the evaluation results"
    )
    parser.add_argument(
        "--data-key", type=str, default="gp",
        choices=['gp', 'sawtooth', 'eeg_data', 'eeg_forecasting_data', 'bav'],
        help="Data key to use for filtering the results"
    )
    parser.add_argument(
        "--methods-to-collect", type=str, default=None, nargs='+',
        choices=METHODS_LIST,
        help="Methods to collect (all if None)"
    )
    parser.add_argument(
        "--have-seeds", action='store_true',
        help="Whether the experiments have seeds (methods end with _[0-9])"
    )
    parser.add_argument(
        "--number-contexts-to-collect-range", type=int, default=[8, 192], nargs=2,
        help="Range of number of contexts to collect (min, max)"
    )
    parser.add_argument(
        "--number-targets-to-collect-range", type=int, default=[1, 128], nargs=2,
        help="Range of number of targets to collect (min, max)"
    )
    parser.add_argument(
        "--do-not-plot", action='store_true',
        help="Whether to plot the results"
    )
    parser.add_argument(
        "--number-targets-to-plot", type=int, default=16,
        help="Number of target points to plot (all if None)"
    )
    parser.add_argument(
        "--xscale", type=str, default='linear',
        choices=['linear', 'log'],
        help="X-axis scale (linear or log)"
    )
    parser.add_argument(
        "--x-jitter-size", type=float, default=2,
        help="Size of the jitter applied to the x-axis. For log scale, the number is in percentage, for linear scale, the number is in x axis value."
    )
    parser.add_argument(
        "--hide-time-plot", action='store_true',
        help="Whether to hide the time plots"
    )
    parser.add_argument(
        "--time-unit", type=str, default='ms',
        choices=['s', 'ms', 'us'],
        help="Time unit for plotting (s, ms, us); the collected table xlsx is in seconds."
    )
    parser.add_argument(
        "--time-scale", type=str, default='linear',
        choices=['linear', 'log'],
        help="Time scale for plotting (linear or log)"
    )
    parser.add_argument(
        "--mae-scale", type=str, default='log',
        choices=['linear', 'log'],
        help="MAE scale for plotting (linear or log)"
    )
    parser.add_argument(
        "--mae-lim", type=float, nargs=2, default=[1e-4, 2],
        help="Y-axis limits for MAE plot"
    )
    parser.add_argument(
        "--mse-scale", type=str, default='log',
        choices=['linear', 'log'],
        help="MSE scale for plotting (linear or log)"
    )
    parser.add_argument(
        "--mse-lim", type=float, nargs=2, default=[1e-4, 2],
        help="Y-axis limits for MSE plot"
    )
    args = parser.parse_args()

    main(args)
