import random
import numpy as np
import pickle as pkl
from collections import namedtuple
import os
from data.data_utils import clinician_tx, clinician_policy_func, target_policy_func
from sklearn.preprocessing import OneHotEncoder

#### experience replay repo ####
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))
Transition_nst = namedtuple('Transition_nst',
                            ('state', 'action', 'next_state', 'reward', 'time'))


def load_diabetes_data(env, seed, speed=1 * 60, dir_name='./data/diabetes/'):
    print(os.path.join(dir_name, 'diabetes_data_disc_try_target.pickle'))
    with open(os.path.join(dir_name, 'diabetes_data_disc_try_target.pickle'), 'rb') as f:
        data_dict = pkl.load(f)

    data_df = data_dict['data']

    H_tk_obs = np.zeros((data_df.shape[0], 5))  # sarst

    data_x = np.zeros((data_df.shape[0], 1))
    k = 0
    AGG = np.all(data_df['AGG'] == 1)
    if ~AGG:
        data_df['hour'] = data_df['hour_agg'].dt.hour.astype(int)
        T = data_df[data_df['patient'] == 1].shape[0] / speed
        ep_length = int(T)
    else:
        T = data_df[data_df['patient'] == 1].shape[0]
        ep_length = T - 1

    print('ep length', ep_length)
    for h in range(ep_length):
        # collect time points and estimate dynamics, learn non stationary policy.
        row_idx = np.where(data_df['hour'] == h)[0]
        if h == 0 and AGG == 0:
            row_idx = [row_idx[x] for x in range(len(row_idx)) if (x + 1) % 61 != 0]
        states = np.array(data_df.iloc[row_idx]['Discrete BG'])
        actions = np.array(data_df.iloc[row_idx]['discrete_CHO_insulin'])
        next_states = np.array(data_df.iloc[np.array(row_idx) + 1]['Discrete BG'])
        if h == 0:
            rewards = - np.array(data_df.iloc[row_idx]['Risk'])
        else:
            rewards = np.array(data_df.iloc[np.array(row_idx) - 1]['Risk']) - np.array(data_df.iloc[row_idx]['Risk'])
        assert np.all(
            np.array(data_df.iloc[row_idx]['patient']) == np.array(data_df.iloc[np.array(row_idx) + 1]['patient']))
        for ee, s in enumerate(states):
            if np.isnan(s) or np.isnan(actions[ee]) or np.isnan(next_states[ee]) or np.isnan(rewards[ee]):
                continue
            H_tk_obs[k, 0] = int(s)
            H_tk_obs[k, 1] = int(actions[ee])
            H_tk_obs[k, 2] = rewards[ee]
            H_tk_obs[k, 3] = int(next_states[ee])
            H_tk_obs[k, 4] = int(h)
            data_x[k, 0] = int(s)

            k += 1
    print("reward stats", np.mean(H_tk_obs[:, 2]), np.std(H_tk_obs[:, 2]))
    S_space = data_dict['state_space']
    A_space = data_dict['action_space']
    # load policy
    with open(os.path.join(dir_name, 'diabetes_simulator_policies.p'), 'rb') as f:
        policies = pkl.load(f)

    clinician_policy = policies['physSoftNonMyopic']
    candidate_policy = policies['candNonMyopic']

    with open(os.path.join(dir_name, 'diab_nst.pkl'), 'rb') as f:
        mat_dict = pkl.load(f)

    true_tx = mat_dict['tx_mat']
    r_mat = mat_dict['r_mat']

    data_x_unencoded = data_x
    enc = OneHotEncoder(sparse=False)
    data_x = np.zeros((data_x.shape[0], true_tx[0].shape[1]))
    for i in range(data_x.shape[0]):
        data_x[i, int(data_x_unencoded[i])] = 1
    enc.fit(data_x)

    idx = np.where(H_tk_obs[:, -1] == 0)[0]
    unique, counts = np.unique(H_tk_obs[idx, 0], return_counts=True)
    counts_states = np.zeros(len(S_space))
    # print(unique)
    for u, c in zip(unique, counts):
        counts_states[int(u)] = c
    counts_states = counts_states / np.sum(counts_states)

    return H_tk_obs, ep_length, A_space, S_space, clinician_policy, candidate_policy, true_tx, r_mat, data_x, data_x_unencoded, enc, counts_states


def load_discrete_data(env, seed, p=1, test_opt=False, baseline="sltd", dir_name='./data/disc_example_'):
    dir_name = dir_name + str(p)
    # load clinician data for behavior policy
    with open(os.path.join(dir_name, 'clinician_state_discrete.pkl'), 'rb') as f:
        clinician_cf_state = pkl.load(f)[:, :, 0]
    with open(os.path.join(dir_name, 'x_state_discrete.pkl'), 'rb') as f:
        x_state = pkl.load(f)[:, :, 0]
    # for partially observed - only see y not x - not used for now
    with open(os.path.join(dir_name, 'y_state.pkl'), 'rb') as f:
        y_state = pkl.load(f)[:, :, 0]
    with open(os.path.join(dir_name, 'clinician_reward.pkl'), 'rb') as f:
        reward = pkl.load(f)[:, :, 0]
    with open(os.path.join(dir_name, 'clinician_action.pkl'), 'rb') as f:
        clinician_action = pkl.load(f)[:, :, 0]
    with open(os.path.join(dir_name, 'action.pkl'), 'rb') as f:
        x_action = pkl.load(f)[:, :, 0]
    with open(os.path.join(dir_name, 'discrete_state.pkl'), 'rb') as f:
        state_space = pkl.load(f)
    with open(os.path.join(dir_name, 'clinician_policy.pkl'), 'rb') as f:
        clinician_policy = pkl.load(f)
    with open(os.path.join(dir_name, 'target_policy.pkl'), 'rb') as f:
        candidate_policy = pkl.load(f)
    if baseline == "gold":
        with open(os.path.join(dir_name, 'gold_policy.pkl'), 'rb') as f:
            candidate_policy = pkl.load(f)
    with open(os.path.join(dir_name, 'true_tx.pkl'), 'rb') as f:
        true_tx = pkl.load(f)
    with open(os.path.join(dir_name, 'reward_mat.pkl'), 'rb') as f:
        r_mat = pkl.load(f)

    n_states = len(state_space)
    n_actions = 2
    ep_length = clinician_cf_state.shape[1] - 1
    S_space = [i for i in range(n_states)]
    A_space = [i for i in range(n_actions)]

    episodes_repo_obs, H_tk_obs = generate_dataset_nst(clinician_x=clinician_cf_state,
                                                       clinician_action=clinician_action,
                                                       reward_mat=reward,
                                                       episodes=clinician_cf_state.shape[0], seed=seed,
                                                       env_name=env)

    tx_mat = np.zeros((n_actions, n_states, n_states)) + 1e-16

    for i in range(clinician_cf_state.shape[0]):
        for t in range(clinician_cf_state.shape[1] - 1):
            s0 = int(clinician_cf_state[i, t])
            s1 = int(clinician_cf_state[i, t + 1])
            a = int(clinician_action[i, t + 1])
            tx_mat[a, s0, s1] += 1

    tx_mat /= tx_mat.sum(axis=-1, keepdims=True)
    if test_opt:
        candidate_policy = clinician_policy

    idx = np.where(H_tk_obs[:, -1] == 0)[0]
    unique, counts = np.unique(H_tk_obs[idx, 0], return_counts=True)
    counts_states = np.zeros(len(S_space))
    # print(unique)
    for u, c in zip(unique, counts):
        counts_states[int(u)] = c
    counts_states = counts_states / np.sum(counts_states)

    # !!! HACK will not work for other discrete data
    # I am appropriately using the right subpopulation here by sending true_tx indexed
    return episodes_repo_obs, H_tk_obs, ep_length, A_space, S_space, true_tx[
        1.], tx_mat, r_mat, clinician_policy, candidate_policy, counts_states


class experienceRepository_nst(object):

    def __init__(self, limit):
        self.limit = limit
        self.memory = []
        self.position = 0

    def store(self, *args):
        """Stores a transition."""
        if len(self.memory) < self.limit:
            self.memory.append(None)
        self.memory[self.position] = Transition_nst(*args)
        self.position = (self.position + 1) % self.limit

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


def build_H_tk_nst(samp, epi_no):
    tt = 0
    H_tk = np.hstack((samp.state[tt], samp.action[tt], samp.reward[tt], samp.next_state[tt], samp.time[tt]))
    for tt in range(1, epi_no):
        next_state = samp.next_state[tt]
        sars_pairs = np.hstack((samp.state[tt], samp.action[tt], samp.reward[tt], next_state, samp.time[tt]))
        H_tk = np.vstack((H_tk, sars_pairs))

    return H_tk


def generate_dataset_nst(clinician_x, clinician_action, reward_mat, episodes,
                         env_name="randomwalk", seed=0):
    obs_episodes_repo = experienceRepository_nst(episodes)
    ep_length = clinician_x.shape[1]
    for k in range(int(episodes)):
        state_seq = clinician_x[k, 0]
        reward_seq = -100
        action_seq = -100
        ep_seq = 0
        for t in range(ep_length - 1):
            next_state = clinician_x[k, t + 1]
            reward = reward_mat[k, t + 1]
            action = clinician_action[k, t + 1]
            reward_seq = np.vstack((reward_seq, reward))
            action_seq = np.vstack((action_seq, action))
            state = next_state
            state_seq = np.vstack((state_seq, state))
            ep_seq = np.vstack((ep_seq, t + 1))
        obs_episodes_repo.store(state_seq[:-1], action_seq[1:], state_seq[1:], reward_seq[1:], ep_seq[:-1])

    obs_episodes = obs_episodes_repo.memory
    obs_sample = Transition_nst(*zip(*obs_episodes))
    obs_H_tk = build_H_tk_nst(obs_sample, int(episodes))
    return obs_episodes_repo, obs_H_tk
