import numpy as np
import torch

from environments import Env


class DreamerMemory:
    def __init__(self, capacity, sequence_length, action_size, obs_size, state_size, device, env_type, n_agents):
        self.capacity = capacity
        self.sequence_length = sequence_length
        self.action_size = action_size
        self.obs_size = obs_size
        self.state_size = state_size
        self.device = device
        self.env_type = env_type
        self.n_agents = n_agents


        self.observations = np.empty((capacity, self.obs_size), dtype=np.float32)
        self.actions = np.empty((capacity, self.action_size), dtype=np.float32)
        self.rewards = np.empty((capacity,1), dtype=np.float32)
        self.dones = np.empty((capacity,1), dtype=np.float32) 
        self.fake = np.empty((capacity,1), dtype=np.float32)
        self.last = np.empty((capacity,1), dtype=np.float32) 
        self.av_actions = np.empty((capacity, self.action_size), dtype=np.float32) 

        self.next_idx = 0
        self.full = False


    def init_buffer(self, new_capacity, env_type):
        self.capacity = new_capacity
        self.next_idx = 0
        self.env_type = env_type
        
        self.observations = np.empty((new_capacity, self.obs_size), dtype=np.float32) 
        self.actions = np.empty((new_capacity, self.action_size), dtype=np.float32) 
        self.rewards = np.empty((new_capacity,1), dtype=np.float32)
        self.dones = np.empty((new_capacity,1), dtype=np.float32)
        self.fake = np.empty((new_capacity,1), dtype=np.float32)
        self.last = np.empty((new_capacity,1), dtype=np.float32) 
        self.av_actions = np.empty((new_capacity, self.action_size), dtype=np.float32) 
        self.full = False


    def append(self, agent_id, obs, action, reward, done, fake, last, av_action):
            
        for i in range(len(obs)):
            self.observations[self.next_idx] = obs[i][agent_id]
            self.actions[self.next_idx] = action[i][agent_id]
            self.rewards[self.next_idx] = reward[i][agent_id]
            self.dones[self.next_idx] = done[i][agent_id]
            self.fake[self.next_idx] = fake[i][agent_id]
            self.last[self.next_idx] = last[i][agent_id]
            
            if av_action is not None:
                self.av_actions[self.next_idx] = av_action[i][agent_id]
            
            self.next_idx = (self.next_idx + 1) % self.capacity
            self.full = self.full or self.next_idx == 0

    def tenzorify(self, nparray):
        return torch.from_numpy(nparray).float()

    
    def sample(self, batch_size):
        return self.get_transitions(self.sample_positions(batch_size))

    def process_batch(self, val, idxs, batch_size):

        return torch.as_tensor(val[idxs].reshape(self.sequence_length, batch_size, 1, -1)).to(self.device)

    def process_global_state(self, val, idxs, batch_size):

        return torch.as_tensor(val[idxs].reshape(self.sequence_length, batch_size, 1, -1)).to(self.device)

    def get_transitions(self, idxs):
        batch_size = len(idxs)
        vec_idxs = idxs.transpose().reshape(-1)
        observation = self.process_batch(self.observations, vec_idxs, batch_size)[1:]
        reward = self.process_batch(self.rewards, vec_idxs, batch_size)[:-1]
        action = self.process_batch(self.actions, vec_idxs, batch_size)[:-1]
        av_action = self.process_batch(self.av_actions, vec_idxs, batch_size)[1:] if self.env_type == Env.STARCRAFT else None
        done = self.process_batch(self.dones, vec_idxs, batch_size)[:-1]
        fake = self.process_batch(self.fake, vec_idxs, batch_size)[1:]
        last = self.process_batch(self.last, vec_idxs, batch_size)[1:]
        
        return {'observation': observation, 'action': action, 'av_action': av_action,
                'reward': reward, 'done': done, 'fake': fake, 'last': last}


    def sample_position(self):
        valid_idx = False
        while not valid_idx:
            idx = np.random.randint(0, self.capacity if self.full else self.next_idx - self.sequence_length)
            idxs = np.arange(idx, idx + self.sequence_length) % self.capacity
            valid_idx = self.next_idx not in idxs[1:]  
        return idxs

    def sample_positions(self, batch_size):
        return np.asarray([self.sample_position() for _ in range(batch_size)])

    def __len__(self):
        return self.capacity if self.full else self.next_idx

    
    
    def clean(self):
        self.memory = list()
        self.position = 0
