from common.imports import *
from ..dec_MAHDDRQN.buffer import Buffer, EpisodeBuffer, _pad_mask_data

class FactBuffer(Buffer):    
    def __init__(self, venv, args, agent_id, device):
        super().__init__(venv, args, agent_id, device)
        self.state_dim = len(venv.call('get_state')[0])
       
    def _set_batch(self, batch, i, ep_batch):
        batch['m_obs'][i] = ep_batch['m_obs']
        batch['m_acts'][i] = ep_batch['m_acts']
        batch['m_valids'][i] = ep_batch['m_valids']
        batch['m_rewards'][i] = ep_batch['m_rewards']
        batch['m_j_valids'][i] = ep_batch['m_j_valids']
        batch['m_j_rewards'][i] = ep_batch['m_j_rewards']
        batch['m_j_gammas'][i] = ep_batch['m_j_gammas']
        batch['next_m_obs'][i] = ep_batch['next_m_obs']
        batch['dones'][i] = ep_batch['dones']
        batch['m_hs'][i] = ep_batch['m_hs']
        batch['next_m_hs'][i] = ep_batch['next_m_hs']
        batch['states'][i] = ep_batch['states']
        batch['next_states'][i] = ep_batch['next_states']

    def _init_batch(self):
        return {
            'm_obs': th.empty(self.batch_size, self.seq_len, self.obs_dim).to(self.device),
            'm_acts': th.empty(self.batch_size, self.seq_len, 1, dtype=th.int64).to(self.device),
            'm_valids': th.empty(self.batch_size, self.seq_len, dtype=th.bool).to(self.device),
            'm_rewards': th.empty(self.batch_size, self.seq_len, 1).to(self.device),
            'm_j_valids': th.empty(self.batch_size, self.seq_len, dtype=th.bool).to(self.device),
            'm_j_rewards': th.empty(self.batch_size, self.seq_len, 1).to(self.device),
            'm_j_gammas': th.empty(self.batch_size, self.seq_len, 1).to(self.device),
            'next_m_obs': th.empty(self.batch_size, self.seq_len, self.obs_dim).to(self.device),
            'dones': th.empty(self.batch_size, self.seq_len, 1).to(self.device),
            'm_hs': th.empty(self.batch_size, self.seq_len, self.h_size).to(self.device),
            'next_m_hs': th.empty(self.batch_size, self.seq_len, self.h_size).to(self.device),
            'states': th.empty(self.batch_size, self.seq_len, self.state_dim).to(self.device),
            'next_states': th.empty(self.batch_size, self.seq_len, self.state_dim).to(self.device),
        }  
    
class FactEpisodeBuffer(EpisodeBuffer):
    def __init__(self, venv, args, agent_id, device):
        super().__init__(venv, args, agent_id, device)

        state_dim = len(venv.call('get_state')[0])
        self.b_states = th.zeros((self.capacity, state_dim)).to(device) 
        self.b_next_states = deepcopy(self.b_states)

    def _init_batches(self, h_size, device):
        # Here we init to 0 since the extra values will be used as a 0-mask
        self.b_obs = th.zeros((self.capacity, self.obs_dim)).to(device) 
        self.b_m_acts = th.zeros((self.capacity, 1), dtype=th.int32).to(device)
        self.b_m_valids = th.zeros((self.capacity), dtype=th.bool).to(device)  # has to be bool if we want to use this as a mask to index elements; this is used to 
        self.b_m_rewards = th.zeros((self.capacity, 1)).to(device) 
        self.b_m_j_valids = deepcopy(self.b_m_valids)   # This is used to get valid trajectories
        self.b_m_j_rewards = th.zeros((self.capacity, 1)).to(device)
        self.b_m_j_gammas = deepcopy(self.b_m_j_rewards)
        self.b_next_m_obs = deepcopy(self.b_obs)
        self.b_dones = deepcopy(self.b_m_j_rewards)
        self.b_m_hs = th.zeros((self.capacity, h_size)).to(device)
        self.b_next_m_hs = deepcopy(self.b_m_hs)
        
        self.idx = 0

    def store(self, m_obs, m_act, m_valid, m_reward, m_j_valid, m_j_reward, m_j_gamma, next_m_obs, done, m_h, next_m_h, state, next_state):

        self.b_obs[self.idx] = m_obs
        self.b_m_acts[self.idx] = m_act
        self.b_m_valids[self.idx] = m_valid
        self.b_m_rewards[self.idx] = m_reward
        self.b_m_j_valids[self.idx] = m_j_valid
        self.b_m_j_rewards[self.idx] = m_j_reward
        self.b_m_j_gammas[self.idx] = m_j_gamma
        self.b_next_m_obs[self.idx] = next_m_obs
        self.b_dones[self.idx] = done
        self.b_m_hs[self.idx] = m_h
        self.b_next_m_hs[self.idx] = next_m_h
        self.b_states[self.idx] = state
        self.b_next_states[self.idx] = next_state 

        self.idx = self.idx + 1
        
    def sample(self, start_idx):
        idxs = slice(start_idx, start_idx+self.seq_len, 1)

        return {
            'm_obs': self.b_obs[idxs],
            'm_acts': self.b_m_acts[idxs],
            'm_valids': self.b_m_valids[idxs],
            'm_rewards': self.b_m_rewards[idxs],
            'm_j_valids': self.b_m_j_valids[idxs],
            'm_j_rewards': self.b_m_j_rewards[idxs],
            'm_j_gammas': self.b_m_j_gammas[idxs],
            'next_m_obs': self.b_next_m_obs[idxs],
            'dones': self.b_dones[idxs],
            'm_hs': self.b_m_hs[idxs],
            'next_m_hs': self.b_next_m_hs[idxs],
            'states': self.b_states[idxs],
            'next_states': self.b_next_states[idxs]
        } 
    
def _split_mask_batch(batch):
    mask_valid_traj = batch['m_j_valids'].sum(-1) > 0     # Check if each traj has min 1 valid ma
    #if mask_valid_traj.sum() == 0: break    # if we don't have usable samples, we don't update

    n_valid_macro = batch['m_j_valids'].sum(-1)[mask_valid_traj]     # count valid m in j_valid trajs
    batch['m_obs'] = th.split_with_sizes(batch['m_obs'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]], list(n_valid_macro))
    batch['next_m_obs'] = th.split_with_sizes(batch['next_m_obs'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]], list(n_valid_macro))
    batch['m_acts'] = batch['m_acts'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]
    batch['m_rewards'] = batch['m_rewards'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]
    batch['m_j_rewards'] = batch['m_j_rewards'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]
    batch['m_j_gammas'] = batch['m_j_gammas'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]
    batch['dones'] = batch['dones'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]
    batch['m_hs'] = th.split_with_sizes(batch['m_hs'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]], list(n_valid_macro))
    batch['next_m_hs'] = th.split_with_sizes(batch['next_m_hs'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]], list(n_valid_macro))
    batch['states'] = batch['states'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]
    batch['next_states'] = batch['next_states'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]

    batch['m_valids'] = batch['m_valids'][mask_valid_traj][batch['m_j_valids'][mask_valid_traj]]
