import gym
from gym import spaces
import numpy as np
from icu_benchmarks.data.loader import *
from tqdm import tqdm
from collections import defaultdict
from . import colnames

DATAPATH=''


class HiridSimulatedEnv(gym.Env):
    def __init__(self,
        version,
        discrete_treatments,
        obs_cols,
        gin_config_file,
        data_path = DATAPATH,
        hidden=[]):
        super().__init__()

        self.version = version

        if self.version == 'v0':
            self.resampling = 1 # 5 min sampling
            self.max_len = -1
        elif self.version == 'v1':
            self.resampling = 12 # 1 hour
            self.max_len = 20 # 20 hours


        self.columns = None
        self.obs_cols = obs_cols.copy()        
        self.hidden = hidden
        self.obs_hidden_idx = []
        for el in self.hidden:
            self.obs_hidden_idx.append(self.obs_cols.index(el))

        for el in self.hidden:
            self.obs_cols = [col for col in self.obs_cols if col != el]

        self.obs_col_idxs = None
        self.treatment_cols_idx = None
        
        self.gin_config_file = gin_config_file
        self.data_path = data_path

        self.observation_space = spaces.Box(
            low = np.array([-1]*len(self.obs_cols)),
            high = np.array([1]*len(self.obs_cols)))
        if isinstance(discrete_treatments, tuple):
            self.action_space = spaces.MultiDiscrete(discrete_treatments)
        else:
            self.action_space = spaces.Discrete(discrete_treatments)


    def reset(self):
        return

    def step(self, action):
        return 

    def _convert_to_reward(self, seq_events):
        seq_events = 1-2*seq_events
        rewards = [seq_events[0]] + \
                [seq_events[i] +
                 int( (seq_events[i-1] == -1.) and (seq_events[i] == 1.)) 
                 for i in range(1, len(seq_events))]
        return rewards

    def get_actions(self, sample):
        raise NotImplementedError
    
    #def get_info(self, sample):
        #return np.full(len(sample), {})

    def get_info(self, sample):
        return np.array([{'hidden': el} for el in sample[:,self.hidden_idx]])
    
    def set_col_idxs(self,):
        assert self.columns is not None
        self.obs_col_idxs = [np.where(el == self.columns)[0][0] for el in self.obs_cols]
        self.hidden_idx = [np.where(el == self.columns)[0] for el in self.hidden]        


    def get_dataset(self):
        import gin
        gin.parse_config_file(self.gin_config_file)

        dataset = ICUVariableLengthDataset(self.data_path,
                                   split='train')

        dataloader = dataset.h5_loader
        self.columns = np.array(dataloader.columns)
        self.set_col_idxs()

        states, actions, rewards, terminals, infos = [], [], [], [], [] #defaultdict(list)
        for start, stop, id_ in tqdm(dataloader.patient_windows[dataset.split]):
            labels = dataloader.labels[dataset.split][start:stop][::self.resampling]
            sequence = dataloader.lookup_table[dataset.split][start:stop][::self.resampling][~np.isnan(labels)] #[:dataset.maxlen][~np.isnan(label)]
            labels = labels[~np.isnan(labels)]

            if labels.shape[0] == 0:
                continue

            def chunks(l1, l2, n):
                assert len(l1) ==len(l2)
                if n ==-1:
                    yield [l1, l2]
                for i in range(0, len(l1), n):
                    yield l1[i:min(i+n, len(l1))], l2[i:min(i+n, len(l1))]

            for sample, event in chunks(sequence, labels, self.max_len):
                states.append(sample[:,self.obs_col_idxs])
                actions.append(self.get_actions(sample)) 
                rewards.append(self._convert_to_reward(event))
                terminals.append(np.concatenate([np.zeros(sample.shape[0]-1), [1]]))
                #for key, value in self.get_info(sample).items():
                #    infos[key].append(value)
                infos.append(self.get_info(sample))
                      
        states, actions, rewards, terminals = np.concatenate(states), np.concatenate(actions), np.concatenate(rewards), np.concatenate(terminals)
        infos = np.concatenate(infos)
        dataset = {'observations': states, 'actions': actions, 'rewards': rewards,
                    'terminals': terminals, 'timeouts': np.zeros_like(terminals), 'infos': infos}

        return dataset
    
    def recover_original_obs(self, obs, info):
        to_insert = np.squeeze(np.stack([el['hidden'] for el in info], 0), -1)
        if to_insert.shape[:-1] != obs.shape[:-1]:
            # deal with 3D: TODO

            #deal with 2D
            to_insert = np.concatenate([to_insert] + 
                [to_insert[-1:, :]] * (len(obs)-len(to_insert)) , axis=0)
            
        for i in range(len(self.hidden)):
            obs = np.insert(obs, self.obs_hidden_idx[i], to_insert[:, i], axis=-1)
        return obs
        


class HiridSimulatedEnv_Circ(HiridSimulatedEnv):
    def __init__(self, 
        version='v1',
        data_path = DATAPATH,
        hidden = []):
        gin_config_file = '../configs/config_circ.gin'
        obs_cols = colnames.circ_obs_cols
        super().__init__(
            data_path=data_path,
            version = version,
            discrete_treatments = 2,
            obs_cols = obs_cols,
            gin_config_file = gin_config_file,
            hidden=hidden
            )
            
    def set_col_idxs(self,):
        super().set_col_idxs()
        self.treatment_cols_idx = [
            [np.where(el == self.columns)[0][0] for el in group]
            for group in  colnames.circ_act_cols
        ]

    def get_actions(self, sample):
        return np.stack(
                    [np.any(np.array(sample[:,treatcol] > 0), -1).astype(int)
                    for treatcol in self.treatment_cols_idx], -1)



class HiridSimulatedEnv_Fluids(HiridSimulatedEnv):
    def __init__(self, 
        version='v1',
        data_path = DATAPATH,
        hidden = []):
        gin_config_file = '../configs/config_circ.gin'
        obs_cols = colnames.circ_obs_cols
        super().__init__(
            data_path=data_path,
            version = version,
            discrete_treatments = 2,
            obs_cols = obs_cols,
            gin_config_file = gin_config_file,
            hidden=hidden
            )
            
    
    def set_col_idxs(self,):
        super().set_col_idxs()
        self.treatment_cols_idx = [
            np.where(el == self.columns)[0][0] for el in colnames.circ_act_cols[0]
        ]
    
    def get_actions(self, sample):
        return np.any(np.array(sample[:, self.treatment_cols_idx] > 0), -1).astype(int)
    
    



class HiridSimulatedEnv_Vaso(HiridSimulatedEnv):
    def __init__(self, 
        version='v1',
        data_path = DATAPATH,
        hidden=[]):
        gin_config_file = '../configs/config_circ.gin'
        obs_cols = colnames.circ_obs_cols
        super().__init__(
            data_path=data_path,
            version = version,
            discrete_treatments = 2,
            obs_cols = obs_cols,
            gin_config_file = gin_config_file,
            hidden=hidden
            )

    def set_col_idxs(self,):
        super().set_col_idxs()
        self.treatment_cols_idx = [
            np.where(el == self.columns)[0][0] for el in colnames.circ_act_cols[1]
        ]

    def get_actions(self, sample):
        return np.any(np.array(sample[:, self.treatment_cols_idx] > 0), -1).astype(int)


