from sltd_utils import *
import argparse
from data_loading_utils import *
from baselines.vi_shaping import vi_shaping
from baselines.ltd_supervised import mozannar_action_ltd, madras_action_ltd
from sklearn.model_selection import train_test_split
import uncertainty_decomposition.utils as udutils


if __name__ == "__main__":
    # Load parameters
    parser = argparse.ArgumentParser()

    # Data-loading settings
    # env = {randomwalk, discrete_toy}
    parser.add_argument("--env", default="discrete_toy")  # environment name
    parser.add_argument("--p", default=1.0, type=float)  # parameter for discrete data

    # These configs are for the learning settings
    parser.add_argument("--seed", default=0, type=int)  # Sets Gym, PyTorch and Numpy seeds
    parser.add_argument("--partial", action='store_false')  # by default learning will not assume PO
    parser.add_argument("--non_stationary",
                        action='store_false')  # by default learning dynamics will be assumed stationary for now!!
    parser.add_argument("--non_stationary_policy",
                        action='store_false')  # by default the target policy will be assumed to be stationary for now!!
    # baselines = {sltd, vi, mozannar, madras} Note that vi only works with discrete states
    parser.add_argument("--baseline", default="sltd", type=str)  # Method to use to learn to defer
    parser.add_argument("--learn_dynamics",
                        action='store_true', default=False)  # For all data, this flag is for whether to estimate

    # Method specific arguments
    parser.add_argument("--risk_aversion", default=0.05,
                        type=float)  # Risk aversion parameter for SLTD if using hypothesis testing
    parser.add_argument("--defer_cost", default=0.3,
                        type=float)  # Deferral penalty/cost for VI and SLTD
    parser.add_argument("--defer_th", default=0.00,
                        type=float)
    parser.add_argument("--alpha_mozannar", default=1,
                        type=float)  # Only used for Mozannar
    parser.add_argument("--MDP_samples_train", default=200, type=int)  # Number of MDP samples to use for training
    parser.add_argument("--learn_policy_class", action="store_false")  # Whether to learn policy class
    parser.add_argument("--policy_class", default='decision_tree',
                        type=str)  # parameter for discrete data, used only if learn flag is true
    parser.add_argument("--test_opt", default=False, type=bool)  # if this flag is true, target policy is modified to be
    # optimal policy, if exists (usually clinicians)
    parser.add_argument("--one_step", default=0, type=int)
    parser.add_argument("--defer_method", default="hypothesis",
                        type=str)  # options: mean or hypothesis testing. default is mean. This option is only for
    # our method
    parser.add_argument("--defer_protocol", default="once", type=str)  # {once, mixture}
    parser.add_argument("--results_path", default="./", type=str)
    parser.add_argument("--data_path", default="./data/disc_example_", type=str)

    # Uncertainty decomposition method
    parser.add_argument("--decomp_method", default='explain_policy',
                        type=str)
    parser.add_argument("--decomp_only", default=False, type=str)
    args = parser.parse_args()

    env, p = args.env, args.p
    seed, partial, non_stationary, non_stationary_policy, baseline, learn_dynamics = args.seed, args.partial, \
                                                                                     args.non_stationary, \
                                                                                     args.non_stationary_policy, \
                                                                                     args.baseline, args.learn_dynamics
    alpha, defer_cost, train_K_no = args.risk_aversion, args.defer_cost, args.MDP_samples_train
    learn_policy_class, policy_class_str, defer_method, defer_protocol, test_opt, one_step = args.learn_policy_class, \
                                                                                             args.policy_class, \
                                                                                             args.defer_method, \
                                                                                             args.defer_protocol, \
                                                                                             args.test_opt, \
                                                                                             args.one_step
    alpha_mozannar = args.alpha_mozannar

    decomp_method = args.decomp_method
    decomp_only = args.decomp_only
    defer_th = args.defer_th

    # Create necessary folders
    data_path = args.data_path
    results_path = os.path.join(args.results_path, env)
    if not os.path.exists(results_path):
        os.makedirs(results_path)

    if not os.path.exists(os.path.join(results_path, "models")):
        os.makedirs(os.path.join(results_path, "models"))

    if not os.path.exists(os.path.join(results_path, "buffers")):
        os.makedirs(os.path.join(results_path, "buffers"))

    if not os.path.exists(os.path.join(results_path, "plots")):
        os.makedirs(os.path.join(results_path, "plots"))

    if not os.path.exists(os.path.join(results_path, "ckpts")):
        os.makedirs(os.path.join(results_path, "ckpts"))

    print("Running expts for %s, %s" % (env, baseline))

    if env == "discrete_toy":
        episodes_repo_obs, H_tk_obs, ep_length, A_space, S_space, true_tx, tx_mat, r_mat, clinician_policy, candidate_policy, init_prior = \
            load_discrete_data(
                env, seed, p, test_opt=test_opt, baseline=baseline, dir_name=data_path)

        if not decomp_only:
            if baseline == "sltd" or baseline == "gold":
                g_s, Q, g_s_1st, Q_1st, dirich_alpha_dict, defer_vec, P_sas_dict = SLTD_stochastic_nst_multistep(
                    H_T=H_tk_obs, alpha=alpha, tau=ep_length,
                    K_no=train_K_no,
                    pi_st=clinician_policy,
                    target_policy=candidate_policy,
                    S_space=S_space,
                    A_space=A_space, tx_mat=None,
                    defer_method=defer_method,
                    defer_cost=defer_cost, true_tx=true_tx,
                    learn_dynamics=True,
                    non_stationary=True, non_stationary_policy=True,
                    one_step=one_step, env=env)

                with open(os.path.join(results_path, 'models',
                                       '%s_policy_%d_cost_%f_K_%d_true_dynamics.pkl' % (baseline, seed, defer_cost,
                                                                                        train_K_no)),
                          'wb') as f:
                    pkl.dump(
                        {'defer_policy': g_s, 'value': Q, 'online_value_defer': None, 'online_value_tar': None,
                         'defer_policy_1st': g_s_1st, 'value_1st': Q_1st, 'online_value_defer_1st': None}, f)

                with open(os.path.join(results_path, 'models', '%s_policy_%d_cost_%f_K_%d_uncertainty_raw.pkl' %
                                                               (baseline, seed, defer_cost, train_K_no)),
                          'wb') as f:
                    pkl.dump({'defer_policy': g_s, 'P_sas_dict': P_sas_dict,
                              'defer_policy_1st': g_s_1st, 'H_tk': H_tk_obs, 'A_space': A_space,
                              'dirch': dirich_alpha_dict,
                              'true_tx': true_tx, 'reward_mat': r_mat, 'defer_cost': defer_cost}, f)

            elif baseline == "sltd_st":

                g_s, Q, dirich_alpha_dict, defer_vec, P_sas_dict = SLTD_stochastic(H_T=H_tk_obs, alpha=alpha,
                                                                                   tau=ep_length,
                                                                                   K_no=train_K_no,
                                                                                   pi_st=clinician_policy,
                                                                                   target_policy=candidate_policy,
                                                                                   S_space=S_space,
                                                                                   A_space=A_space, tx_mat=tx_mat,
                                                                                   defer_method=defer_method,
                                                                                   defer_cost=defer_cost, true_tx=None,
                                                                                   learn_dynamics=True,
                                                                                   non_stationary=False,
                                                                                   non_stationary_policy=True,
                                                                                   one_step=one_step, env=env)

                with open(
                        os.path.join(results_path, 'models',
                                     '%s_policy_%d_cost_%f_K_%d_true_dynamics.pkl' % (
                                     baseline, seed, defer_cost, train_K_no)),
                        'wb') as f:
                    pkl.dump(
                        {'defer_policy': g_s, 'value': Q, 'online_value_defer': None, 'online_value_tar': None},
                        f)
            elif baseline == "reward_shaping":
                vi_obj = vi_shaping(tx_mat=true_tx, reward_vec=r_mat)
                vi_obj.augment_state_action(deferral_cost=-defer_cost)
                vi_policy, V = vi_obj.run()
                n_actions = true_tx[0].shape[0]
                vi_policy_dict = {}
                for st in range(vi_policy.shape[0]):
                    for tt in range(vi_policy.shape[1]):
                        vi_policy_dict[(st, tt)] = vi_policy[st, tt]

                with open(os.path.join(results_path, 'models',
                                       '%s_policy_%d_cost_%f_true_dynamics.pkl' % (baseline, seed, defer_cost)),
                          'wb') as f:
                    pkl.dump(
                        {'defer_policy': vi_policy_dict, 'value': V, 'augmented_tx': vi_obj.tx_dict,
                         'augmented_reward': vi_obj.reward}, f)

                # print(np.mean(V_vec))

            elif baseline == "mozannar_ltd":
                data_x = H_tk_obs[:, 0].reshape(-1, 1)
                data_y = H_tk_obs[:, 1].reshape(-1, 1)
                data_t = H_tk_obs[:, -1].reshape(-1, 1)
                data_x, data_x_test, data_y, data_y_test, data_t, data_t_test = train_test_split(data_x, data_y, data_t,
                                                                                                 test_size=0.2,
                                                                                                 stratify=data_y)
                ltd_obj = mozannar_action_ltd(input_dim=true_tx[0].shape[1], output_dim=len(np.unique(data_y)),
                                              clinician_policy=clinician_policy, alpha=alpha_mozannar, env=env,
                                              nst=True)
                ltd_obj.encode_data(data_x=data_x, data_y=data_y, data_x_test=data_x_test, data_y_test=data_y_test,
                                    data_t=data_t, data_t_test=data_t_test,
                                    one_hot=True)
                ltd_obj.train_model(lr=0.05, n_epochs=50)
                V, Cost, Freq = ltd_obj.eval(true_tx=true_tx, reward_mat=r_mat, one_hot=True, func=False, env=env,
                                             defer_cost=defer_cost,
                                             n_trials=500, results_path=results_path, baseline=baseline, seed=seed,
                                             alpha_mozannar=alpha_mozannar)
                print("Value estimate:", np.mean(V - Cost), np.std(V - Cost))
                with open(os.path.join(results_path, 'models', '%s_policy_%d_defer_cost_%f_alpha_%f_true_dynamics.pkl' %
                                                               (baseline, seed, defer_cost, alpha_mozannar)),
                          'wb') as f:
                    pkl.dump({'defer_policy': V, 'defer_freq': Freq, 'defer_cost': Cost}, f)

            elif baseline == "madras_ltd":
                data_x = H_tk_obs[:, 0].reshape(-1, 1)
                data_y = H_tk_obs[:, 1].reshape(-1, 1)
                data_t = H_tk_obs[:, -1].reshape(-1, 1)
                data_x, data_x_test, data_y, data_y_test, data_t, data_t_test = train_test_split(data_x, data_y, data_t,
                                                                                                 test_size=0.2,
                                                                                                 stratify=data_y)
                ltd_obj = madras_action_ltd(input_dim=true_tx[0].shape[1], output_dim=2,
                                            clinician_policy=clinician_policy, target_policy=candidate_policy, env=env,
                                            train_net=False, nst=True, defer_cost=defer_cost)
                ltd_obj.encode_data(data_x=data_x, data_y=data_y, data_x_test=data_x_test, data_y_test=data_y_test,
                                    data_t=data_t, data_t_test=data_t_test,
                                    one_hot=True)
                ltd_obj.train_model(lr=0.01, n_epochs=100)
                V, Cost, Freq = ltd_obj.eval(true_tx=true_tx, reward_mat=r_mat, one_hot=True,
                                             func=False, env=env, defer_cost=defer_cost,
                                             n_trials=1000, results_path=results_path, seed=seed, baseline=baseline)
                print("Value estimate", np.mean(V - Cost), np.std(V - Cost), np.mean(V), np.std(V))
                with open(os.path.join(results_path, 'models',
                                       '%s_policy_%d_cost_%f_true_dynamics.pkl' % (baseline, seed, defer_cost)),
                          'wb') as f:
                    pkl.dump({'defer_policy': V, 'defer_cost': Cost, 'defer_freq': Freq}, f)
            else:
                raise ValueError('%s method is not implemented for %s environment!' % (baseline, env))
        else:
            if decomp_method == "explain_policy" and (
                    baseline == "sltd" or baseline == "sltd_st" or baseline == "sltd_1st" or baseline == "uniform" or baseline == "gold" or baseline == "reward_shaping"):
                baseline0 = baseline
                if baseline0 == "sltd_1st":
                    baseline = "sltd"

                _, _, ep_length, A_space, S_space, true_tx, tx_mat, r_mat, clinician_policy, candidate_policy, init_prior = \
                    load_discrete_data(
                        env, seed, p, test_opt=test_opt, baseline=baseline, dir_name=data_path)

                if baseline != "uniform":
                    results_path0 = results_path
                    if baseline != "reward_shaping":
                        with open(os.path.join(results_path, 'models',
                                               '%s_policy_%d_cost_%f_K_%d_true_dynamics.pkl' % (
                                               baseline, seed, defer_cost, train_K_no)),
                                  'rb') as f:
                            dt = pkl.load(f)
                    else:
                        with open(os.path.join(results_path, 'models',
                                               '%s_policy_%d_cost_%f_true_dynamics.pkl' % (baseline, seed, defer_cost)),
                                  'rb') as f:
                            dt = pkl.load(f)
                    results_path = results_path0
                    g_s = dt['defer_policy']

                    if baseline0 == "sltd_1st":
                        g_s = dt['defer_policy_1st']
                    else:
                        g_s_1st = None
                elif baseline0 == "uniform":
                    T = len(true_tx.keys())
                    g_s = {(s, tt): 0.5 for s in S_space for tt in range(0, T)}

                udutils.collect_trajectories_defer(env=env, baseline=baseline0, seed=seed,
                                                   target_policy=candidate_policy,
                                                   clinician_policy=clinician_policy, g_s=g_s,
                                                   A_space=A_space, true_tx=true_tx, reward_mat=r_mat,
                                                   defer_cost=defer_cost,
                                                   n_trials=10000, init_prior=init_prior,
                                                   results_path=results_path, nst_defer_policy=True,
                                                   th_sltd=defer_th, th_sltd_st=defer_th, th_1st=defer_th, K=train_K_no)

                if baseline0 == "sltd" and train_K_no > 1:
                    udutils.collect_trajectories(env=env, baseline=baseline, seed=seed, policy=clinician_policy,
                                                 A_space=A_space, true_tx=true_tx, reward_mat=r_mat,
                                                 defer_cost=0,
                                                 n_trials=10000, init_prior=init_prior, policy_name='clinician',
                                                 results_path=results_path)

                    udutils.collect_trajectories(env=env, baseline=baseline, seed=seed, policy=candidate_policy,
                                                 A_space=A_space, true_tx=true_tx, reward_mat=r_mat,
                                                 defer_cost=0,
                                                 n_trials=10000, init_prior=init_prior, policy_name='target',
                                                 results_path=results_path)

            else:
                raise ValueError("%s decomposition is not implemented for method %s" % (decomp_method, baseline))

    else:
        raise ValueError('%s environment not defined!' % env)
