import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
import matplotlib._color_data as mcd

sigma1 = 1.0
sigma0 = 0.3
N = 500
step_n = 15
n_trials = 1
n_states = 8
n_actions = 2
DISCOUNT = 0.999
EPS = 0.05

x_data_state = np.zeros((N, step_n, n_trials))
y_data_state = np.zeros((N, step_n, n_trials))
clinician_state = np.zeros((N, step_n, n_trials))
ground_truth_defer = np.zeros((N, step_n, n_trials))
target_action = np.zeros((N, step_n, n_trials))
clinician_action = np.zeros((N, step_n, n_trials))
reward = np.zeros((N, step_n, n_trials))
clinician_reward = np.zeros((N, step_n, n_trials))
p_type_list = np.zeros(N)

tx_mat = {0: np.zeros((n_actions, n_states, n_states)), 1: np.zeros((n_actions, n_states, n_states))}
tx_mat[0][0, 0] = np.array([0.5, 0, 0.5, 0, 0, 0, 0, 0]) + 0 * 1e-5  # good action
tx_mat[0][1, 0] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 0 * 1e-5  # bad action
tx_mat[0][0, 1] = np.array([0, 0.5, 0, 0.5, 0, 0, 0, 0]) + 0 * 1e-4  # doesn't matter for this patient type
tx_mat[0][1, 1] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 0 * 1e-4  # doesn't matter for this patient type
tx_mat[0][0, 2] = np.array([0, 0, 0.4, 0.6, 0, 0, 0, 0]) + 0 * 1e-3  # good action
tx_mat[0][1, 2] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 0 * 1e-3  # bad action
tx_mat[0][0, 3] = np.array([0, 0.0, 0.5, 0.5, 0, 0, 0, 0]) + 0 * 1e-2  # absorbing state
tx_mat[0][1, 3] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 0 * 1e-2  # absorbing state
tx_mat[0][0, 4] = np.array([0, 0, 0.2, 0.0, 0.8, 0, 0, 0]) + 1e-3 / 2  # absorbing state
tx_mat[0][1, 4] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 1e-3 / 2  # absorbing state
tx_mat[0][0, 5] = np.array([0, 0, 0.2, 0.0, 0.8, 0, 0, 0]) + 1e-1  # absorbing state
tx_mat[0][1, 5] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 1e-1  # absorbing state
tx_mat[0][0, 6] = np.array([0, 0, 0.2, 0.0, 0.8, 0, 0, 0]) + 2 * 1e-1  # absorbing state
tx_mat[0][1, 6] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 2 * 1e-1  # absorbing state
tx_mat[0][0, 7] = np.array([0, 0, 0.2, 0.0, 0.8, 0, 0, 0]) + 2 * 1e-1  # absorbing state
tx_mat[0][1, 7] = np.array([0, 0, 0, 0, 1, 0, 0, 0]) + 2 * 1e-1  # absorbing state
# ----
nst = 0
tx_mat[1][0, 0] = np.array([0.5, 0.5, 0, 0, 0, 0, 0, 0]) + nst * 1e-4  # good action
tx_mat[1][1, 0] = np.array([0, 0.2, 0.2, 0, 0, 0, 0, 0]) + nst * 1e-4  # bad action
tx_mat[1][0, 1] = np.array([0, 0.5, 0, 0.0, 0, 0.5, 0, 0]) + nst * 1e-3  # doesn't matter for this patient type
tx_mat[1][1, 1] = np.array([0, 0, 0.5, 0.0, 0.5, 0, 0, 0]) + nst * 1e-3  # doesn't matter for this patient type
tx_mat[1][0, 2] = np.array([0, 0.2, 0.0, 0.4, 0, 0.4, 0, 0]) + nst * 1e-2  # good action
tx_mat[1][1, 2] = np.array([0, 0, 0.0, 0.0, 0.5, 0, 0.5, 0]) + nst * 1e-2  # bad action
tx_mat[1][0, 3] = np.array([0, 0.0, 0.0, 0.5, 0, 0.5, 0, 0]) + nst * 1e-3 / 2  # absorbing state
tx_mat[1][1, 3] = np.array([0, 0, 0.0, 0, 0.5, 0.2, 0.3, 0]) + nst * 3 * 1e-3 / 2  # absorbing state
#######
tx_mat[1][0, 4] = np.array([0, 0.2, 0.0, 0.0, 0.0, 0.8, 0, 0]) + nst * 1e-3  # absorbing state
tx_mat[1][1, 4] = np.array([0, 0, 0, 0, 0.4, 0.1, 0.5, 0]) + nst * 1 * 1e-1  # absorbing state
tx_mat[1][0, 5] = np.array([0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 1.0]) + nst * 1 * 1e-3  # absorbing state
tx_mat[1][1, 5] = np.array([0, 0, 0.0, 0, 0.3, 0.0, 0.5, 0.2]) + nst * 0.1 * 1e-2  # absorbing state
tx_mat[1][0, 6] = np.array([0, 0, 0.0, 0.0, 0.0, 0.0, 0, 1.0]) + nst * 1 * 1e-2  # absorbing state
tx_mat[1][1, 6] = np.array([0, 0, 0, 0, 0.4, 0, 0.5, 0.1]) + nst * 0 * 1e-1  # absorbing state
tx_mat[1][0, 7] = np.array([0, 0, 0, 0.0, 0, 0, 0, 1]) + nst * 0 * 1e-1  # absorbing state
tx_mat[1][1, 7] = np.array([0, 0, 0, 0, 0.2, 0, 0.7, 0.1]) + nst * 0 * 1e-1  # absorbing state

tx_mat[0] /= tx_mat[0].sum(axis=-1, keepdims=True)
tx_mat[1] /= tx_mat[1].sum(axis=-1, keepdims=True)

t_0 = 5
t_1 = 12
target_policy = {}
clinician_policy = {}
gold_policy = {}
for t in range(step_n):
    if t_0 <= t <= t_1:
        target_policy[t] = [0.7] * n_states
        target_policy[t][2] = 0.
        target_policy[t][3] = 0.
        target_policy[t][4] = 0.
    else:
        target_policy[t] = [0.1] * n_states
        target_policy[t][2] = 1
        target_policy[t][3] = 1
        target_policy[t][4] = 1

    if t_0 <= t <= t_1:
        gold_policy[t] = [1] * n_states
    else:
        gold_policy[t] = [0] * n_states
    clinician_policy[t] = [0.1] * n_states
    if t_0 <= t <= t_1:
        clinician_policy[t] = [0.9] * n_states
        clinician_policy[t][2] = 0.7
        clinician_policy[t][3] = 0.7
        clinician_policy[t][4] = 0.7

reward_vec = [1, 1, 1, 1, 1, 1, -5, 1]
r_mat = np.zeros((len(reward_vec), n_actions))
for a in range(n_actions):
    if a == 0:
        r_mat[:, a] = reward_vec
    else:
        r_mat[:, a] = reward_vec
S_card = list(range(n_states))
bin_id_c = 2
bin_id_p = 4


def plot(n_list, x_state, y_state, state_clinician, action, rewards, clinician_rewards, defer, patient_type_list):
    _, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5, 1, sharex=True, figsize=(21, 21))
    for nn, plot_id in enumerate(n_list):
        for ii in range(n_trials):
            ax1.plot(range(step_n), x_state[plot_id, :, ii].flatten(), marker="o", linestyle=":",
                     color=mcd.XKCD_COLORS["xkcd:teal"])

            ax1.plot(range(step_n), state_clinician[plot_id, :, ii].flatten(), marker="o",
                     color="k")

            ax1.set_ylabel('X')

            ax2.plot(range(step_n), y_state[plot_id, :, ii].flatten(), marker="v", color="g")
            ax2.set_ylabel('Y')

            ax3.plot(range(step_n), action[plot_id, :, ii].flatten(), marker="p", color="m")
            ax3.set_ylabel('A')

            ax4.plot(range(step_n), np.cumsum(rewards[plot_id, :, ii].flatten()), color='k')
            ax4.plot(range(step_n), np.cumsum(clinician_rewards[plot_id, :, ii].flatten()),
                     color=mcd.XKCD_COLORS["xkcd:teal"])
            ax4.set_ylabel('reward')

            ax5.plot(range(step_n), defer[plot_id, :, ii].flatten(), marker="*", color=mcd.XKCD_COLORS["xkcd:olive"])
            ax5.set_ylabel('defer')
        plt.savefig(os.path.join(dir_name, 'random_walk_' + '_' + str(nn) + '.pdf'), dpi=300)


# determines fraction of subpopulation
tx_dict = {}
for pp in [1.0]:
    tx_dict[pp] = {}
    for n in range(N):
        # Sample type:
        p_type = np.random.binomial(1, p=pp, size=1)[0]
        p_type_list[n] = p_type

        x0 = 0
        x_data_state[n, 0, :] = x0
        clinician_state[n, 0, :] = x0

        for t in range(1, step_n):
            clinician_action[n, t, 0] = \
                np.random.binomial(1, p=clinician_policy[t][int(clinician_state[n, t - 1, 0])], size=1)[0]
            target_action[n, t, 0] = np.random.binomial(1, p=target_policy[t][int(x_data_state[n, t - 1, 0])], size=1)[
                0]

            if t_0 <= t <= t_1:
                tx = np.copy(tx_mat[p_type])
                tx[1, :, :] = tx[0, :, :]
                tx[0, :, :] = 0.00
                tx[0, :, [2, 4, 6]] = 0 * tx[1, :, [2, 4, 6]] + (t - t_0 + 3) * DISCOUNT * EPS

            else:

                tx = np.copy(tx_mat[p_type])

            tx /= tx.sum(axis=-1, keepdims=True)
            tx_dict[p_type][t - 1] = tx
            if t >= t_0 and n == 0:
                print('normalized', t - 1, tx_dict[p_type][t - 1][1, :, :])

            x_data_state[n, t, 0] = np.random.choice(S_card, p=tx_dict[p_type][t - 1][int(target_action[n, t, 0]),
                                                                                      int(x_data_state[n, t - 1, 0])],
                                                     size=1)[0]
            clinician_state[n, t, 0] = np.random.choice(S_card, p=tx_dict[p_type][t - 1][int(clinician_action[n, t, 0]),
                                                                                         int(clinician_state[
                                                                                                 n, t - 1, 0])],
                                                        size=1)[0]
            reward[n, t, 0] = reward_vec[int(x_data_state[n, t - 1, 0])]
            clinician_reward[n, t, 0] = reward_vec[int(clinician_state[n, t - 1, 0])]

    dir_name = './data/disc_example_' + str(pp)
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)
    with open(os.path.join(dir_name, 'x_state.pkl'), 'wb') as f:
        pickle.dump(x_data_state, f)
    with open(os.path.join(dir_name, 'discrete_state.pkl'), 'wb') as f:
        pickle.dump(S_card, f)
    with open(os.path.join(dir_name, 'y_state.pkl'), 'wb') as f:
        pickle.dump(y_data_state, f)
    with open(os.path.join(dir_name, 'action.pkl'), 'wb') as f:
        pickle.dump(target_action, f)
    with open(os.path.join(dir_name, 'defer.pkl'), 'wb') as f:
        pickle.dump(ground_truth_defer, f)
    with open(os.path.join(dir_name, 'clinician.pkl'), 'wb') as f:
        pickle.dump(clinician_state, f)
    with open(os.path.join(dir_name, 'clinician_reward.pkl'), 'wb') as f:
        pickle.dump(clinician_reward, f)
    with open(os.path.join(dir_name, 'clinician_action.pkl'), 'wb') as f:
        pickle.dump(clinician_action, f)
    with open(os.path.join(dir_name, 'policy_reward.pkl'), 'wb') as f:
        pickle.dump(reward, f)
    with open(os.path.join(dir_name, 'x_state_discrete.pkl'), 'wb') as f:
        pickle.dump(x_data_state, f)
    with open(os.path.join(dir_name, 'clinician_state_discrete.pkl'), 'wb') as f:
        pickle.dump(clinician_state, f)
    with open(os.path.join(dir_name, 'clinician_policy.pkl'), 'wb') as f:
        pickle.dump(clinician_policy, f)
    with open(os.path.join(dir_name, 'target_policy.pkl'), 'wb') as f:
        pickle.dump(target_policy, f)
    with open(os.path.join(dir_name, 'gold_policy.pkl'), 'wb') as f:
        pickle.dump(gold_policy, f)
    with open(os.path.join(dir_name, 'true_tx.pkl'), 'wb') as f:
        pickle.dump(tx_dict, f)
    with open(os.path.join(dir_name, 'reward_mat.pkl'), 'wb') as f:
        pickle.dump(r_mat, f)
    with open(os.path.join(dir_name, 'bins.pkl'), 'wb') as f:
        pickle.dump({'xp': bin_id_p, 'xc': bin_id_c}, f)

    # plot 1 of each type
    a_type = [i for i in range(N) if int(p_type_list[i]) == 0]
    if len(a_type) > 0:
        n_easy_model = list(np.random.choice(a_type, size=5))
    else:
        n_easy_model = []
    b_type = [i for i in range(N) if int(p_type_list[i]) == 1]
    if len(b_type) > 0:
        n_easy_defer = list(np.random.choice(b_type, size=10))
    else:
        n_easy_defer = []
    plot(n_easy_model + n_easy_defer, x_data_state, y_data_state, clinician_state, target_action, reward,
         clinician_reward, ground_truth_defer, p_type_list)
