import torch
import torch.jit as jit


class ReplayBuffer(jit.ScriptModule):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, episode_limit, buffer_size):
        super().__init__()
        self.buffer_size = buffer_size
        self.episode_limit = episode_limit
        self.n_agents = n_agents
        self.ac_dim = ac_dim
        self.obs = torch.zeros((buffer_size, episode_limit + 1, n_agents, ob_dim))
        self.states = torch.zeros((buffer_size, episode_limit + 1, st_dim))
        self.avails = torch.ones((buffer_size, episode_limit + 1, n_agents, ac_dim), dtype=torch.bool)
        self.actions = torch.zeros((buffer_size, episode_limit, n_agents), dtype=torch.long)
        self.rewards = torch.zeros((buffer_size, episode_limit, 1))
        self.dones = torch.ones((buffer_size, episode_limit, 1), dtype=torch.bool)
        self.actives = torch.zeros((buffer_size, episode_limit, 1), dtype=torch.bool)
        self.episode_len = torch.zeros((self.buffer_size,), dtype=torch.long)
        self.step = 0
        self.current_size = 0
        self.total_reward = 0.0
        self.dead_allies = 0.0
        self.dead_enemies = 0.0
        self.battle_won = 0.0

    @jit.script_method
    def store_transition(self, eps_id: int, obs, state, avails, actions, reward: float, done: bool):
        self.obs[self.step][eps_id] = obs
        self.states[self.step][eps_id] = state
        self.avails[self.step][eps_id] = avails
        self.actions[self.step][eps_id] = actions
        self.rewards[self.step][eps_id] = reward
        self.dones[self.step][eps_id] = done
        self.actives[self.step][eps_id] = True

    @jit.script_method
    def store_last_step(self, eps_id: int, obs, state, avails):
        self.obs[self.step][eps_id] = obs
        self.states[self.step][eps_id] = state
        self.avails[self.step][eps_id] = avails
        self.episode_len[self.step] = eps_id
        self.step = (self.step + 1) % self.buffer_size
        self.current_size = min(self.current_size + 1, self.buffer_size)

    @jit.script_method
    def sample(self, batch_size:int):
        batch_size = min(batch_size, self.current_size)
        ids = torch.randperm(self.current_size)[:batch_size]
        mb_onehots = torch.zeros((batch_size, self.episode_limit+1, self.n_agents, self.ac_dim))
        for step in ids:
            for episode_step in range(self.episode_limit):
                if self.actives[step][episode_step][0]:
                    actions = self.actions[step][episode_step]
                    mb_onehots[step][episode_step+1] = torch.eye(self.ac_dim)[actions]
        mb_obs = self.obs[ids]
        mb_states = self.states[ids]
        mb_avails = self.avails[ids]
        mb_actions = self.actions[ids]
        mb_rewards = self.rewards[ids]
        mb_dones = self.dones[ids]
        mb_actives = self.actives[ids]
        mb_agent_ids = torch.eye(self.n_agents).unsqueeze(0).unsqueeze(0).repeat(batch_size, self.episode_limit+1, 1, 1)
        mb_obs = torch.cat((mb_obs, mb_onehots, mb_agent_ids), -1)
        return mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives
    
    @jit.script_method
    def rstrip(self):
        max_episode_len = int(self.episode_len.max())
        self.obs = self.obs[:, :max_episode_len+1]
        self.states = self.states[:, :max_episode_len+1]
        self.avails = self.avails[:, :max_episode_len+1]
        self.actions = self.actions[:, :max_episode_len]
        self.rewards = self.rewards[:, :max_episode_len]
        self.dones = self.dones[:, :max_episode_len]
        self.actives = self.actives[:, :max_episode_len]
        self.episode_limit = max_episode_len
    
    @jit.script_method
    def limit(self, n_episodes: int):
        self.obs = self.obs[:n_episodes]
        self.states = self.states[:n_episodes]
        self.avails = self.avails[:n_episodes]
        self.actions = self.actions[:n_episodes]
        self.rewards = self.rewards[:n_episodes]
        self.dones = self.dones[:n_episodes]
        self.actives = self.actives[:n_episodes]
        self.buffer_size = n_episodes
        self.current_size = n_episodes
    
    @jit.script_method
    def sample_with(self, batch_size:int, obs, states, avails, actions, rewards, dones, actives):
        mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives = self.sample(batch_size)
        mb_obs = torch.cat((obs, mb_obs))
        mb_states = torch.cat((states, mb_states))
        mb_avails = torch.cat((avails, mb_avails))
        mb_actions = torch.cat((actions, mb_actions))
        mb_rewards = torch.cat((torch.zeros_like(rewards, dtype=torch.bool), torch.ones_like(mb_rewards, dtype=torch.bool)))
        mb_dones = torch.cat((dones, mb_dones))
        mb_actives = torch.cat((actives, mb_actives))
        return mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives
