# %%
import os
import pandas as pd
import argparse
import matplotlib.pyplot as plt


def parse_args():
    parser = argparse.ArgumentParser(description="Plot cumulative regret curves for different methods.")
    parser.add_argument("--nrepl", type=int, default=20, help="number of replications")
    parser.add_argument("--TH", type=int, default=30000, help="time horizon (main run)")
    parser.add_argument("--dim", type=int, default=5)
    parser.add_argument("--save_dir", type=str, default="exp_result", help="directory to save the experimental results")
    parser.add_argument("--load_dir", type=str, default="result", help="directory under each method containing results")
    parser.add_argument("--ngrid_list", type=int, nargs="+", default=[10,100,1000,30000], help="list of grid sizes")
    parser.add_argument("--cdf0str_list", nargs="+", default=["twomode","tnorm2"], help="cdf0 variants")
    parser.add_argument("--covstr_list", nargs="+", default=["ph"])
    parser.add_argument("--xstr_list", nargs="+", default=["unif"], help="covariate distribution")
    return vars(parser.parse_args())


if __name__ == "__main__":
    args = parse_args()

    # Define methods, legend labels, colors
    class_name_list = ['BayesCoxCP', 'CoxCP']
    line_name_list = ['Choi et al. (2023)' if class_name == 'CoxCP' else class_name for class_name in class_name_list]
    color_list = ['red' if class_name == 'BayesCoxCP' else 'blue' for class_name in class_name_list]

    for cdf0str in args['cdf0str_list']:
        for covstr in args['covstr_list']:
            for xstr in args['xstr_list']:
                # Prepare mean and std of cumulative regret
                cum_reg_mean_dict = {class_name: {} for class_name in class_name_list}
                cum_reg_std_dict = {class_name: {} for class_name in class_name_list}

                for class_name in class_name_list:
                    if class_name == 'CoxCP':
                        opt_dict = {}
                        unk_dict = {}
                        reg_dict = {}
                        cum_reg_dict = {}

                        print('Reading csv files...')

                        for ngrid in args['ngrid_list']:
                            result_dir = os.path.join(os.getcwd(), class_name, args['load_dir'])
                            name_flag = f"dim_{args['dim']}_cdf0_{cdf0str}_cov_{covstr}_xdist_{xstr}_T_{args['TH']}_grid_{ngrid}_nrepl_{args['nrepl']}"

                            opt_results_temp = pd.read_csv(f"{result_dir}/cumRev_unk_opt_coxph_{name_flag}.csv")
                            unk_results_temp = pd.read_csv(f"{result_dir}/cumRev_unk_coxph_{name_flag}.csv")
                            opt_dict[ngrid] = opt_results_temp.iloc[:,:args['TH']]
                            unk_dict[ngrid] = unk_results_temp.iloc[:,:args['TH']]
                            reg_dict[ngrid] = opt_dict[ngrid] - unk_dict[ngrid]
                            cum_reg_dict[ngrid] = reg_dict[ngrid].cumsum(axis=1)

                            cum_reg_mean_dict[class_name][ngrid] = cum_reg_dict[ngrid].mean().to_numpy()
                            cum_reg_std_dict[class_name][ngrid] = cum_reg_dict[ngrid].std().to_numpy()
                    else:
                        reg_dict = {}
                        cum_reg_dict = {}

                        print('Reading csv files...')

                        for ngrid in args['ngrid_list']:
                            result_dir = os.path.join(os.getcwd(), class_name, args['load_dir'])
                            name_flag = f"dim_{args['dim']}_cdf0_{cdf0str}_cov_{covstr}_xdist_{xstr}_T_{args['TH']}_ngrid_{ngrid}_nrepl_{args['nrepl']}_{class_name}"

                            reg_dict[ngrid] = pd.read_csv(f"{result_dir}/cumRev_unk_reg_coxph_{name_flag}.csv")
                            cum_reg_dict[ngrid] = reg_dict[ngrid].cumsum(axis=1)

                            cum_reg_mean_dict[class_name][ngrid] = cum_reg_dict[ngrid].mean().to_numpy()
                            cum_reg_std_dict[class_name][ngrid] = cum_reg_dict[ngrid].std().to_numpy()

                # plotting
                num_figures = len(args['ngrid_list'])
                columns = 4
                rows = (num_figures + columns - 1) // columns

                fig, axes = plt.subplots(rows, columns, figsize=(5 * columns, 4 * rows), sharex=True, sharey=True)
                axes = axes.flatten()

                fig.text(-0.01, 0.5, 'Cumulative Regret', va='center', rotation='vertical', fontsize=16)
                fig.text(0.5, -0.02, 'Times', ha='center', fontsize=16)

                # plotting each subplot
                for i, ngrid in enumerate(args['ngrid_list']):
                    line_list = []

                    for j, class_name in enumerate(class_name_list):
                        line, = axes[i].plot(range(len(cum_reg_mean_dict[class_name][ngrid])), cum_reg_mean_dict[class_name][ngrid], linestyle='-', color=color_list[j], alpha=1)
                        line_list.append(line)

                        axes[i].fill_between(
                            range(len(cum_reg_mean_dict[class_name][ngrid])), 
                            cum_reg_mean_dict[class_name][ngrid] - cum_reg_std_dict[class_name][ngrid], 
                            cum_reg_mean_dict[class_name][ngrid] + cum_reg_std_dict[class_name][ngrid], 
                            facecolor=color_list[j], 
                            alpha=0.2)

                    axes[i].set_title(f'K = {ngrid}', fontsize=18)
                    axes[i].grid(True)

                fig.legend(line_list, line_name_list, loc='upper center', bbox_to_anchor=(0.5, -0.04), ncol=len(line_name_list), fontsize=18)

                # hiding blank subplot
                for k in range(num_figures, len(axes)):
                    fig.delaxes(axes[k])

                plt.tight_layout(rect=[0, 0, 1, 0.98])

                # save
                save_dir = os.path.join(os.getcwd(), args['save_dir'])
                os.makedirs(save_dir, exist_ok=True)

                name_flag = f"dim_{args['dim']}_cdf0_{cdf0str}_cov_{covstr}_xdist_{xstr}_T_{args['TH']}_nrepl_{args['nrepl']}"

                plt.savefig(f'{save_dir}/{name_flag}.png', format='png', dpi=300, bbox_inches='tight')

                plt.show()


