# %%
import os
import numpy as np
import pandas as pd
import math
import time
import argparse

import DGP
import BayesCoxCP_simul

def parse_args():
    parser = argparse.ArgumentParser(description="BayesCoxCP Simulation")
    parser.add_argument("--nrepl", type=int, default=20, help="number of replications")
    parser.add_argument("--TH", type=int, default=30000, help="time horizon (main run)")
    parser.add_argument("--TH_tuning", type=int, default=3000, help="time horizon for tuning")
    parser.add_argument("--tuning", action="store_true", help="run hyperparameter tuning")
    parser.add_argument("--no_load_opt", dest="load_opt_hyper", action="store_false",help="do not load previously tuned hyperparameters")
    parser.add_argument("--no_fitting", action="store_true", help="only do tuning, skip final fitting")
    parser.add_argument("--pmin", type=int, default=1)
    parser.add_argument("--pmax", type=int, default=10)
    parser.add_argument("--dim", type=int, default=5)
    parser.add_argument("--save_dir", type=str, default="result")
    parser.add_argument("--ngrid_list", type=int, nargs="+", default=[10,100,1000,30000], help="list of grid sizes")
    parser.add_argument("--cdf0str_list", nargs="+", default=["twomode","tnorm2"], help="cdf0 variants")
    parser.add_argument("--covstr_list", nargs="+", default=["ph"])
    parser.add_argument("--xstr_list", nargs="+", default=["unif"], help="covariate distribution")
    return vars(parser.parse_args())


if __name__ == "__main__":
    args = parse_args()

    save_dir = os.path.join(os.getcwd(), args['save_dir'])
    os.makedirs(save_dir, exist_ok=True)

    # set algorithm
    class_name = 'BayesCoxCP'


    for cdf0str in args['cdf0str_list']:
        for covstr in args['covstr_list']:
            for xstr in args['xstr_list']:
                for ngrid in args['ngrid_list']:
                    # set cdf0
                    cdf0 = getattr(DGP, f"cdf0_{cdf0str}")

                    # set cdf_true
                    cdf_true = getattr(DGP, f"cdf_{covstr}")

                    # set the effect of x
                    beta0 = np.full(args['dim'], 4/math.sqrt(args['dim']))

                    # set the distribution of x
                    gen_x = getattr(DGP, f"gen_x_{xstr}")

                    # tuning
                    if args['tuning']:
                        hyperpara_list = pd.DataFrame(np.array(np.meshgrid(2**(np.arange(-4, 1, dtype=float)/3), 2**(np.arange(-12, -8, dtype=float)/2), [64, 128, 256], [10**5])).T.reshape(-1, 4), columns=["deg_gam", "deg_gam_2", "ini", "a0l_tune"])

                        hyperpara_list["ini"] = hyperpara_list["ini"].astype(int)
                        performance_list = pd.DataFrame(np.zeros(len(hyperpara_list)), columns=["unk"])
                        TIME_HORIZON = args['TH_tuning']

                        for tunInd in range(len(performance_list)):
                            deg_gam = hyperpara_list.at[tunInd, "deg_gam"]
                            deg_gam_2 = hyperpara_list.at[tunInd, "deg_gam_2"]
                            ini = hyperpara_list.at[tunInd, "ini"]
                            a0l_tune = hyperpara_list.at[tunInd, "a0l_tune"]
                            
                            print(f"Evaluating the {tunInd}-th hyperparameter, gamma={deg_gam}, gamma_2={deg_gam_2}, initial length={ini}, a0l_tune={a0l_tune}")
                
                            start_time = time.time()

                            BayesCoxCP_class = getattr(BayesCoxCP_simul, class_name)
                            BayesCox_tuning = BayesCoxCP_class(ini, deg_gam, deg_gam_2, args['pmin'], args['pmax'], ngrid, gen_x, args['dim'], beta0, cdf0, cdf_true, TIME_HORIZON, args['nrepl'], args['tuning'], save_dir, a0l_tune=a0l_tune)

                            BayesCox_tuning.fitting()

                            end_time = time.time()
                            print(f"Time for tuning...: {end_time - start_time:.6f}")
                            
                            performance_list.iloc[tunInd] = BayesCox_tuning.get_unk_cum_real()

                        # save performance
                        save_name = f"tuningResult_dim_{args['dim']}_cdf0_{cdf0str}_cov_{covstr}_xdist_{xstr}_T_{TIME_HORIZON}_grid_{ngrid}_nrepl_{args['nrepl']}_{class_name}.csv"
                        performance_list.to_csv(os.path.join(save_dir, save_name), index=False)
                        
                        # load the optimal hyperparameter
                        tuning_table = pd.read_csv(os.path.join(save_dir, save_name))
                        unk_maxTunInd = tuning_table.idxmax().iloc[0]

                        ini = hyperpara_list.at[unk_maxTunInd, 'ini']
                        deg_gam = hyperpara_list.at[unk_maxTunInd, 'deg_gam']
                        deg_gam_2 = hyperpara_list.at[unk_maxTunInd, 'deg_gam_2']
                        a0l_tune = hyperpara_list.at[unk_maxTunInd, 'a0l_tune']
                        
                        print("----------------")
                        print(f"Select (unk.alpha, unk.gamma, unk.gamma_2, a0l_tune): {ini}, {deg_gam}, {deg_gam_2}, {a0l_tune}")
                        print("----------------")

                        # save the optimal hyperparameter
                        opt_hyper = pd.DataFrame(np.array([deg_gam, deg_gam_2, ini, a0l_tune]).T.reshape(-1,4), columns=["deg_gam", "deg_gam_2", "ini", "a0l_tune"])

                        save_name_opt = f"OptHyper_dim_{args['dim']}_cdf0_{cdf0str}_cov_{covstr}_xdist_{xstr}_T_{TIME_HORIZON}_grid_{ngrid}_nrepl_{args['nrepl']}_{class_name}.csv"

                        opt_hyper.to_csv(os.path.join(save_dir, save_name_opt), index=False)
                        
                        if ini is None:
                            print("Warning: ini not found. Run with default value alpha = 64")
                            ini = 64
                        if deg_gam is None:
                            print("Warning: deg_Gam not found. Run with default value l0 = 0.5")
                            deg_gam = 2**(-1)
                        if deg_gam_2 is None:
                            print("Warning: deg_Gam_2 not found. Run with default value l0 = 0.5")
                            deg_gam_2 = 2**(-1)
                        if a0l_tune is None:
                            print("Warning: a0l_tune not found. Run with default value a0l_tune = 10^5")
                            a0l_tune = 10**5
                    else:
                        if args['load_opt_hyper']:
                            # loading optimal hyperparameter
                            TIME_HORIZON = args['TH_tuning']
                            
                            # test
                            save_name_opt = f"OptHyper_dim_{args['dim']}_cdf0_{cdf0str}_cov_{covstr}_xdist_{xstr}_T_{TIME_HORIZON}_grid_{ngrid}_nrepl_{args['nrepl']}_{class_name}.csv"
                            save_path = os.path.join(save_dir, save_name_opt)

                            try:
                                opt_hyper = pd.read_csv(save_path)
                            except FileNotFoundError:
                                print(f"Optimal hyperparameter file not found: {save_path}")
                                ini = 64
                                deg_gam = 2**(-1)
                                deg_gam_2 = 2**(-1)
                                a0l_tune = 10**5
                            else:
                                if isinstance(opt_hyper['ini'].values[0], (int, float, str)):
                                    ini = int(opt_hyper['ini'].values[0])
                                else:
                                    raise ValueError("Cannot convert 'ini' to integer.")
                                deg_gam = opt_hyper['deg_gam'].values[0]
                                deg_gam_2 = opt_hyper['deg_gam_2'].values[0]
                                a0l_tune = opt_hyper['a0l_tune'].values[0]

                                print("----------------")
                                print(f"Load opt hyper (ini, deg_gam, deg_gam_2, a0l_tune): {ini}, {deg_gam}, {deg_gam_2}, {a0l_tune}")
                                print("----------------")
                        else:
                            # default
                            ini = 64
                            deg_gam = 2**(-1)
                            deg_gam_2 = 2**(-1)
                            a0l_tune = 10**5

                    if args['no_fitting']:
                        print('End tuning and no fitting.')
                    else:
                        # run with the optimal hyperparameter
                        TIME_HORIZON = args['TH']
                        tuning_here = False
                        name_flag = f"dim_{args['dim']}_cdf0_{cdf0str}_cov_{covstr}_xdist_{xstr}_T_{args['TH']}_ngrid_{ngrid}_nrepl_{args['nrepl']}_{class_name}"

                        start_time = time.time()

                        BayesCoxCP_class = getattr(BayesCoxCP_simul, class_name)
                        BayesCox_model = BayesCoxCP_class(ini, deg_gam, deg_gam_2, args['pmin'], args['pmax'], ngrid, gen_x, args['dim'], beta0, cdf0, cdf_true, TIME_HORIZON, args['nrepl'], tuning_here, save_dir, name_flag, a0l_tune=a0l_tune)

                        BayesCox_model.fitting()

                        end_time = time.time()
                        print(f"Time for model fitting...: {end_time - start_time:.6f}")




