
import gym
#from d3rlpy.dataset import MDPDataset
import joblib
from os.path import join
import os
from tqdm import tqdm

import sys
DATAPATH=''
sys.path.append(join(DATAPATH, 'dataprep'))

from sepsisSimDiabetes.MDP import MDP
from sepsisSimDiabetes.Action import Action
from sepsisSimDiabetes.State import State, sepsis_state_from_sa
from gym import spaces
import numpy as np

def _to_one_hot(arr, dim):
    if type(arr) in [list, np.ndarray]:
        arr = np.array(arr)
        b = np.zeros((arr.size, dim))
        b[np.arange(arr.size), arr] = 1
    else:
        b = np.zeros((dim,))
        b[arr] = 1
    return b

def states_to_one_hot(states, dim):
    states_one_hot = []
    for traj in tqdm(states):
        states_one_hot.append(_to_one_hot(traj, dim))
    return np.array(states_one_hot)


def sepsis_reward_fn(state, action):
    input_shape = state.shape
    state = state.reshape(-1)
    action = action.reshape(-1)
    rewards = []

    def get_reward(sepsis_state):
            num_abnormal = sepsis_state.get_num_abnormal()
            if num_abnormal >= 3:
                return -1
            elif num_abnormal == 0 and not sepsis_state.on_treatment():
                return 1
            return 0


    for s, a in zip(state, action):
        nondiab_state = sepsis_state_from_sa(state_idx=s, action_idx=a, diabetic_idx=0)
        diab_state = sepsis_state_from_sa(state_idx=s, action_idx=a, diabetic_idx=1)
        rewards.append(0.2*get_reward(diab_state) + 0.8*get_reward(nondiab_state))

    return np.array(rewards).reshape(input_shape)




class SepsisSimulatedEnv(gym.Env):
    def __init__(self, dataset = "eps_0_1-100k",
                     data_path = join(DATAPATH, 'data'),
                     observe_diabetes = False):
        super().__init__()
        
        self.observe_diabetes = observe_diabetes
        obs_low = [0]*4
        obs_high = [3,3,2,5]
        if observe_diabetes:
            obs_low.extend([0])
            obs_high.extend([1])
        self.obs_dim = len(obs_low)
        self.observation_space = spaces.Box(low = np.array(obs_low), high = np.array(obs_high))
        self.action_dim = 8
        self.action_space = spaces.Discrete(self.action_dim)

        self.max_num_steps = 20
        self._elapsed_steps = 0

        self.p_diabetes = 0.2        
        self.MDP = MDP(p_diabetes=self.p_diabetes,
                        data_path = data_path) # Random initial state

        self.datapath = join(DATAPATH, f"datagen/{dataset}/1-alldata.joblib")
        if not os.path.exists(self.datapath):
            raise NotImplementedError("This version of sepsis_env has not been implemented")

    def get_obs_from_state(self, state, info=False):
        if isinstance(state, int) or isinstance(state, np.ndarray):
            state = State(state_idx = state, idx_type = 'full')
        # remove action and diabetes
        diab = state.get_state_vector(idx_type='full')[-1]
        state = state.get_state_vector(idx_type='full')[:-4]
        if self.observe_diabetes:
            state = np.concatenate([state, [diab]])
        if info:
            return state, {'diabetic': diab} 
        return state
               
    def reset(self):
        self.MDP.state = self.MDP.get_new_state()
        self._elapsed_steps = 0
        return self.get_obs_from_state(state = self.MDP.state)

    def step(self, action):
        #print(action)
        action = Action(action_idx=action)
        #print(action, self.MDP.state.get_state_vector(idx_type='full'))
        reward = float(self.MDP.transition(action))
        next_state = self.MDP.state
        next_obs = self.get_obs_from_state(state = next_state)
        
        self._elapsed_steps+=1
        done = (reward != 0) or (self._elapsed_steps >= self.max_num_steps)
        info = {}
        return next_obs, reward, done, info
    
    def get_dataset(self): # d4rl format
        states, actions, lengths, rewards, diab, emp_tx_totals, emp_r_totals = \
            joblib.load(self.datapath) 

        terminals = []
        for l in lengths:
            l = l[0]
            if l > 1:
                terminals += [0]*(l-1)
            terminals += [1]
        terminals = np.array(terminals)

        obs, acts, rews, infos = [], [], [], []
        for traj_s, traj_a, traj_r, l in zip(states, actions, rewards, lengths):
            max_len = l[0]
            for s,a,r in zip(traj_s[:max_len], traj_a[:max_len], traj_r[:max_len]):
                s, info = self.get_obs_from_state(s, info=True)
                obs.append(s)
                acts.append(a)
                rews.append(r)
                infos.append(info)
        obs =  np.array(obs, dtype=np.uint8)
        actions =  np.array(acts, dtype=np.uint8)
        rewards =  np.array(rews, dtype=np.uint8)
        obs, actions, rewards = np.squeeze(obs), np.squeeze(actions), np.squeeze(rewards)

        dataset = {'observations': obs, 'actions': actions, 'rewards': rewards,
                    'infos': infos,
                    'terminals': terminals, 'timeouts': np.zeros_like(terminals),
                    'metadata_action_dim': self.action_dim, 'metadata_obs_dim': self.obs_dim}

        return dataset


    def get_lp_solution(self, return_value=True, policy_init = None):
        # calls stable-baselines3 PPO to update policy
        return self.MDP.get_optimal_policy(return_value=return_value, policy_init=policy_init)