import numpy.random
from tqdm import tqdm
import pickle as pkl
import numpy as np
import os

CLIP_CNST = 1e-5
def collect_trajectories(env, baseline, seed, policy,
                         A_space, true_tx, reward_mat, defer_cost=0.01,
                         n_trials=500, init_prior=None, policy_name='target',
                         results_path='./results', init_s0=None, T=None):
    if env == "discrete_toy" or env == "sepsis_diabetes" or env == "diabetes":
        if env == "sepsis_diabetes":
            n_starts = 1
            if init_prior is not None:
                init_s0 = np.random.choice(a=range(true_tx[0].shape[1]), p=init_prior, size=n_starts)
            else:
                init_s0 = np.random.choice(a=range(true_tx[0].shape[1]), size=n_starts)
        else:
            n_starts = 1
            init_s0 = np.zeros(n_starts)

    if T is None:
        T = int(len(true_tx.keys()))

    result_dict = {t: [] for t in range(T)}
    cost_dict = {t: [] for t in range(T)}
    state_dict = {t: [] for t in range(T)}
    action_dict = {t: [] for t in range(T)}
    value_dict = {t: [] for t in range(T)}

    for jj in tqdm(range(n_trials)):
        s_t = int(init_s0)
        defer_cost_total = 0
        cumulative_reward = 0
        for tn, tt in enumerate(range(0, T)):
            if env == "discrete_toy" or env == "diabetes":
                if len(A_space) == 2:
                    act = np.random.binomial(n=1, p=policy[tt][s_t], size=1)[0]
                else:
                    act = np.random.choice(a=range(true_tx[tt].shape[0]), p=policy[tt][s_t],
                                           size=1)[0]
            elif env == "randomwalk":
                if len(A_space) == 2:
                    act = np.random.binomial(n=1, p=policy(s_t), size=1)[0]
                else:
                    act = np.random.choice(a=A_space, p=policy(s_t),
                                           size=1)[0]
            elif env == "hiv":
                if len(A_space) == 2:
                    act = np.random.binomial(n=1, p=policy(s_t, tt), size=1)[0]
                else:
                    act = np.random.choice(a=A_space, p=policy(s_t, tt),
                                           size=1)[0]

            if env == "diabetes":
                reward = reward_mat[tt][s_t, act]
            elif env == "discrete_toy":
                reward = reward_mat[s_t, act]
            elif env == "randomwalk":
                reward = reward_mat(s_t, act)
            elif env == "hiv":
                reward = reward_mat(s_t, act)

            if policy_name == 'clinician':
                defer_cost_total += defer_cost

            cumulative_reward += reward

            state_dict[tt].append(s_t)
            action_dict[tt].append(act)
            result_dict[tt].append(reward)
            value_dict[tt].append(cumulative_reward)
            cost_dict[tt].append(0)

            if env == "discrete_toy" or env == "diabetes":
                s_t = int(np.random.choice(a=true_tx[tt].shape[1], size=1,
                                           p=true_tx[tt][act, s_t])[0])
            elif env == "randomwalk":
                s_t, _, _ = true_tx(s_t, act, tt)
                s_t = s_t[0]
            elif env == "hiv":
                s_t = true_tx(s_t, act)

    with open(os.path.join(results_path, 'models',
                           '%s_policy_%d_%s_cost_%f_trajectories.pkl' % (baseline, seed, policy_name, defer_cost)),
              'wb') as f:
        pkl.dump({'reward': result_dict, 'states': state_dict, 'actions': action_dict, 'cost': cost_dict,
                  'value': value_dict}, f)

    value_array = np.array(value_dict[T - 1])
    print(policy_name, 'value:', value_array.mean())


def collect_trajectories_defer(env, baseline, seed, clinician_policy, target_policy, data_policy, g_s,
                               A_space, true_tx, reward_mat, H_tk_obs=None, defer_cost=0.01,
                               n_trials=500, init_prior=None,
                               results_path='./results',
                               nst_defer_policy=False, th_sltd=0.2, th_sltd_st=0.2, th_1st=0.02, K=200, init_s0=None,
                               T=None, N=500):
    if env == "discrete_toy" or env == "sepsis_diabetes" or env == "diabetes":
        if init_prior is not None:
            init_s0 = np.random.choice(a=range(len(init_prior)), p=init_prior, size=1)
        else:
            init_s0 = np.random.choice(a=range(true_tx[0].shape[1]), size=1)

    if T is None:
        T = int(len(true_tx.keys()))

    result_dict = {t: [] for t in range(T)}
    cost_dict = {t: [] for t in range(T)}
    state_dict = {t: [] for t in range(T)}
    action_dict = {t: [] for t in range(T)}
    defer_dict = {t: [] for t in range(T)}
    value_dict = {t: [] for t in range(T)}
    defer_freq_dict = {t: [] for t in range(T)}
    defer_times_dict = {jj: [] for jj in range(n_trials)}

    result_dict_det = {t: [] for t in range(T)}
    cost_dict_det = {t: [] for t in range(T)}
    state_dict_det = {t: [] for t in range(T)}
    action_dict_det = {t: [] for t in range(T)}
    defer_dict_det = {t: [] for t in range(T)}
    value_dict_det = {t: [] for t in range(T)}
    defer_freq_dict_det = {t: [] for t in range(T)}
    defer_times_dict_det = {jj: [] for jj in range(n_trials)}

    result_is_dict = {t: [] for t in range(T)}
    cost_is_dict = {t: [] for t in range(T)}
    value_is_dict = {t: [] for t in range(T)}
    is_dict = {t: [] for t in range(T)}

    min_t_defer = []
    min_t_defer_det = []
    for jj in tqdm(range(n_trials)):
        shaping_flag = 0
        s_t = int(init_s0)
        s_t_det = int(init_s0)
        defer_cost_total = 0
        defer_cost_total_det = 0
        min_t_found = False
        min_t_found_det = False
        cumulative_reward = 0
        cumulative_reward_det = 0
        cumulative_reward_is = 0
        is_weight = 1
        for tn, tt in enumerate(range(0, T)):
            if nst_defer_policy:
                if (s_t, tt) in g_s.keys():
                    if baseline == "reward_shaping":
                        defer_act = float(len(A_space))  # defer is the last action
                        epsilon_vec = g_s[(s_t, tt)]
                        if epsilon_vec == defer_act or shaping_flag == 1:  # defer
                            epsilon_vec = 1
                            shaping_flag = 1
                        else:
                            epsilon_vec = 0
                    else:
                        if baseline == "sltd" or baseline == "gold":
                            epsilon_vec = np.random.binomial(n=1, p=g_s[(s_t, tt)], size=1)[0]
                        elif baseline == "sltd_st":
                            epsilon_vec = np.random.binomial(n=1, p=g_s[(s_t, tt)], size=1)[0]
                        else:
                            epsilon_vec = np.random.binomial(n=1, p=g_s[(s_t, tt)], size=1)[0]

                    if epsilon_vec == 1:
                        defer_times_dict[jj].append(tt)
                    if epsilon_vec == 1 and not min_t_found:
                        min_t_defer.append(tt)
                        min_t_found = True
                else:
                    epsilon_vec = 0

                if (s_t_det, tt) in g_s.keys():
                    if baseline == "sltd" or baseline == "gold":
                        epsilon_vec_det = g_s[(s_t_det, tt)] >= th_sltd
                    elif baseline == "sltd_st":
                        epsilon_vec_det = g_s[(s_t_det, tt)] >= th_sltd_st
                    else:
                        epsilon_vec_det = g_s[(s_t_det, tt)] >= th_1st

                    if epsilon_vec_det == 1:
                        defer_times_dict_det[jj].append(tt)
                    if epsilon_vec_det == 1 and not min_t_found:
                        min_t_defer_det.append(tt)
                        min_t_found_det = True
                else:
                    epsilon_vec_det = 0
            else:
                if s_t in g_s.keys():
                    if baseline == "sltd":
                        epsilon_vec = np.random.binomial(n=1, p=g_s[s_t], size=1)[0]
                    elif baseline == "sltd_st":
                        epsilon_vec = np.random.binomial(n=1, p=g_s[s_t], size=1)[0]
                    else:
                        epsilon_vec = np.random.binomial(n=1, p=g_s[s_t], size=1)[0]
                    if epsilon_vec == 1:
                        defer_times_dict[jj].append(tt)
                    if epsilon_vec == 1 and not min_t_found:
                        min_t_defer.append(tt)
                        min_t_found = True
                else:
                    epsilon_vec = 0

                if s_t_det in g_s.keys():
                    if baseline == "sltd":
                        epsilon_vec_det = g_s[s_t_det] >= th_sltd
                    elif baseline == "sltd_st":
                        epsilon_vec_det = g_s[s_t_det] >= th_sltd_st
                    else:
                        epsilon_vec_det = g_s[s_t_det] >= th_1st
                    if epsilon_vec_det == 1:
                        defer_times_dict_det[jj].append(tt)
                    if epsilon_vec_det == 1 and not min_t_found_det:
                        min_t_defer_det.append(tt)
                        min_t_found_det = True
                else:
                    epsilon_vec_det = 0

            if env == "diabetes" or env == "discrete_toy":
                t_idx = np.where(H_tk_obs[:, -1] == tt)[0]
                s_tt = H_tk_obs[t_idx, 0]  # no of trajectories in the dataset
                if jj < len(s_tt):
                    s_tt = int(s_tt[jj])
                    a_tt = H_tk_obs[t_idx, 1]
                    a_tt = int(a_tt[jj])

                if len(A_space) == 2:
                    cp = np.random.binomial(n=1, p=clinician_policy[tt][s_t], size=1)[0]
                    tp = np.random.binomial(n=1, p=target_policy[tt][s_t], size=1)[0]
                    act = epsilon_vec * cp + (1 - epsilon_vec) * tp
                    act_det = epsilon_vec_det * cp + (1 - epsilon_vec_det) * tp
                    if jj < len(t_idx):
                        is_weight *= (((1 - g_s[(s_tt, tt)]) * (
                                a_tt * target_policy[tt][s_tt] + (1 - a_tt) * (1 - target_policy[tt][s_tt])) + (
                                           g_s[s_tt, tt]) *
                                       (a_tt * clinician_policy[tt][s_tt] + (1 - a_tt) * (
                                               1 - clinician_policy[tt][s_tt]))) / np.clip(
                            (a_tt * data_policy[tt][s_tt] + (1 - a_tt) * (1 - data_policy[tt][s_tt])),
                            CLIP_CNST,
                            1))
                        print('jj', jj, 'tt', tt, 'is_weight', is_weight)
                else:
                    cp = np.random.choice(a=range(true_tx[tt].shape[0]), p=clinician_policy[tt][s_t],
                                          size=1)[0]
                    tp = np.random.choice(a=range(true_tx[tt].shape[0]), p=target_policy[tt][s_t],
                                          size=1)[0]
                    act = epsilon_vec * cp + (1 - epsilon_vec) * tp
                    act_det = epsilon_vec_det * cp + (1 - epsilon_vec_det) * tp

                    if jj < len(t_idx):
                        is_weight *= (((1 - g_s[(s_tt, tt)]) *
                                       target_policy[tt][s_tt][a_tt] + (
                                           g_s[s_tt, tt]) *
                                       clinician_policy[tt][s_tt][a_tt]) / np.clip(
                            data_policy[tt][s_tt][a_tt], CLIP_CNST, 1))
                        print(target_policy[tt][s_tt], clinician_policy[tt][s_tt], data_policy[tt][s_tt])
                        print(a_tt, s_tt, g_s[(s_tt, tt)], target_policy[tt][s_tt][a_tt], clinician_policy[tt][s_tt][a_tt], data_policy[tt][s_tt][a_tt])
                        print('jj', jj, 'tt', tt, 'is_weight', is_weight)
                        exit(1)

            elif env == "randomwalk":
                t_idx = np.where(H_tk_obs[:, -1] == tt)[0]
                s_tt = H_tk_obs[t_idx, 0]
                if jj < len(s_tt):
                    s_tt = s_tt[jj]
                    a_tt = H_tk_obs[t_idx, 1]
                    a_tt = int(a_tt[jj])
                    is_weight *= ((1 - g_s[(s_tt, tt)]) * (
                            a_tt * target_policy(s_tt) + (1 - a_tt) * (1 - target_policy(s_tt))) + (
                                      g_s[s_tt, tt]) *
                                  (a_tt * clinician_policy(s_tt) + (1 - a_tt) * (
                                          1 - clinician_policy(s_tt)))) / np.clip(
                        (a_tt * data_policy(s_tt) + (1 - a_tt) * data_policy(s_tt)), CLIP_CNST, 1)
                    print('jj', jj, 'tt', tt, 'is_weight', is_weight)

                if len(A_space) == 2:
                    cp = np.random.binomial(n=1, p=clinician_policy(s_t), size=1)[0]
                    tp = np.random.binomial(n=1, p=target_policy(s_t), size=1)[0]
                    act = epsilon_vec * cp + (1 - epsilon_vec) * tp
                else:
                    cp = np.random.choice(a=A_space, p=clinician_policy(s_t),
                                          size=1)[0]
                    tp = np.random.choice(a=A_space, p=target_policy(s_t),
                                          size=1)[0]
                    act = epsilon_vec * cp + (1 - epsilon_vec) * tp

                if len(A_space) == 2:
                    cp = np.random.binomial(n=1, p=clinician_policy(s_t_det), size=1)[0]
                    tp = np.random.binomial(n=1, p=target_policy(s_t_det), size=1)[0]
                    act_det = epsilon_vec_det * cp + (1 - epsilon_vec_det) * tp
                else:
                    cp = np.random.choice(a=A_space, p=clinician_policy(s_t_det),
                                          size=1)[0]
                    tp = np.random.choice(a=A_space, p=target_policy(s_t_det),
                                          size=1)[0]
                    act_det = epsilon_vec_det * cp + (1 - epsilon_vec_det) * tp
            elif env == "hiv":
                t_idx = np.where(H_tk_obs[:, -1] == tt)[0]
                s_tt = H_tk_obs[t_idx, 0]
                if jj < len(s_tt):
                    s_tt = s_tt[jj]
                    a_tt = H_tk_obs[t_idx, 1]
                    a_tt = int(a_tt[jj])
                    is_weight *= ((1 - g_s[(s_tt, tt)]) * target_policy(s_tt, tt)[a_tt] + (g_s[s_tt, tt]) *
                                  clinician_policy(s_tt, tt)[a_tt]) / np.clip(data_policy(s_tt, tt)[a_tt],
                                                                              CLIP_CNST,
                                                                              1)
                    print('jj', jj, 'tt', tt, 'is_weight', is_weight)

                if len(A_space) == 2:
                    cp = np.random.binomial(n=1, p=clinician_policy(s_t, tt), size=1)[0]
                    tp = np.random.binomial(n=1, p=target_policy(s_t, tt), size=1)[0]
                    act = epsilon_vec * cp + (1 - epsilon_vec) * tp
                    act_det = epsilon_vec_det * cp + (1 - epsilon_vec_det) * tp
                else:
                    cp = np.random.choice(a=A_space, p=clinician_policy(s_t, tt),
                                          size=1)[0]
                    tp = np.random.choice(a=A_space, p=target_policy(s_t, tt),
                                          size=1)[0]
                    act = epsilon_vec * cp + (1 - epsilon_vec) * tp
                    act_det = epsilon_vec_det * cp + (1 - epsilon_vec_det) * tp

            if env == "diabetes":
                reward = reward_mat[tt][s_t, act]
                reward_det = reward_mat[tt][s_t_det, act_det]
            elif env == "discrete_toy":
                reward = reward_mat[s_t, act]
                reward_det = reward_mat[s_t_det, act_det]
            elif env == "randomwalk":
                reward = reward_mat(s_t, act)
                reward_det = reward_mat(s_t_det, act_det)
            elif env == "hiv":
                reward = reward_mat(s_t, act)
                reward_det = reward_mat(s_t_det, act_det)

            if epsilon_vec == 1:
                defer_cost_total += defer_cost
                cumulative_reward += (reward - defer_cost)
            else:
                cumulative_reward += reward

            if epsilon_vec_det == 1:
                defer_cost_total_det += defer_cost
                cumulative_reward_det += (reward_det - defer_cost)
            else:
                cumulative_reward_det += reward_det

            state_dict[tt].append(s_t)
            action_dict[tt].append(act)
            result_dict[tt].append(reward)
            defer_dict[tt].append(epsilon_vec)
            value_dict[tt].append(cumulative_reward)
            defer_freq_dict[tt].append(epsilon_vec)

            state_dict_det[tt].append(s_t_det)
            action_dict_det[tt].append(act_det)
            result_dict_det[tt].append(reward_det)
            defer_dict_det[tt].append(epsilon_vec_det)
            value_dict_det[tt].append(cumulative_reward_det)
            defer_freq_dict_det[tt].append(epsilon_vec_det)

            t_idx = np.where(H_tk_obs[:, -1] == tt)[0]
            r_tt = H_tk_obs[t_idx, 2]
            if len(t_idx) > jj:
                result_is_dict[tt].append(is_weight * r_tt[jj])
                cumulative_reward_is += (is_weight * r_tt[jj])
                value_is_dict[tt].append(cumulative_reward_is)
                print('is weight', is_weight, value_is_dict[tt][-1])
                is_dict[tt].append(is_weight)
                if epsilon_vec == 1:
                    cost_is_dict[tt].append(defer_cost)
                else:
                    cost_is_dict[tt].append(0)

            if epsilon_vec == 1:
                cost_dict[tt].append(defer_cost)
            else:
                cost_dict[tt].append(0)

            if epsilon_vec_det == 1:
                cost_dict_det[tt].append(defer_cost)
            else:
                cost_dict_det[tt].append(0)

            if env == "discrete_toy" or env == "diabetes":
                s_t = int(np.random.choice(a=true_tx[tt].shape[1], size=1,
                                           p=true_tx[tt][act, s_t])[0])
                s_t_det = int(np.random.choice(a=true_tx[tt].shape[1], size=1,
                                               p=true_tx[tt][act_det, s_t_det])[0])
            elif env == "randomwalk":
                s_t, _, _ = true_tx(s_t, act, tt)
                s_t = s_t[0]
                s_t_det, _, _ = true_tx(s_t_det, act_det, tt)
                s_t_det = s_t_det[0]
            elif env == "hiv":
                s_t = true_tx(s_t, act)
                s_t_det = true_tx(s_t_det, act_det)

    for tt in range(T):
        is_mean = np.mean(is_dict[tt])
        value_is_dict[tt] = [x / is_mean for x in value_is_dict[tt]]

    value_array = np.array(value_dict[T - 1])
    value_is_array = np.array(value_is_dict[T - 1])
    defer_cost_array = np.asarray(list(cost_dict.values())).sum(0)
    print(defer_cost_array.shape)
    defer_freq_mat = np.array(list(defer_freq_dict.values()))
    mean_defer_time = np.nanmean(np.array([np.nanmean(defer_times_dict[jj]) for jj in range(n_trials)]))

    print(baseline, " min_t:", np.median(min_t_defer), np.mean(min_t_defer), 'value:',
          (value_array + defer_cost_array).mean(), 'value_is:', value_is_array.mean(),
          'defer frequency:', np.sum(defer_freq_mat) / np.prod(np.shape(defer_freq_mat)), 'mean defer time:',
          mean_defer_time)

    with open(os.path.join(results_path, 'models', '%s_policy_%d_%s_cost_%f_th_%f_K_%d_N_%d_trajectories.pkl' % (
            baseline, seed, baseline, defer_cost, th_sltd, K, N)),
              'wb') as f:
        pkl.dump({'reward': result_dict, 'states': state_dict, 'actions': action_dict, 'cost': cost_dict,
                  'value': value_dict, 'defer_freq': defer_freq_dict, 'defer_time': defer_times_dict,
                  'is_reward': result_is_dict, 'is_cost': cost_is_dict, 'is_value': value_is_dict,
                  'reward_det': result_dict_det, 'states_det': state_dict_det, 'actions_det': action_dict_det,
                  'cost_det': cost_dict_det, 'value_det': value_dict_det, 'defer_freq_det': defer_freq_dict_det,
                  'defer_time_det': defer_times_dict_det}, f)
