import torch
from typing import Tuple, Optional


class ReplayBuffer:
    class Transition:
        def __init__(self):
            self.observations: torch.Tensor = None # type: ignore
            self.critic_observations: torch.Tensor = None # type: ignore
            self.next_observations: torch.Tensor = None # type: ignore
            self.next_critic_observations: torch.Tensor = None # type: ignore
            self.actions: torch.Tensor = None # type: ignore
            self.rewards: torch.Tensor = None # type: ignore
            self.dones: torch.Tensor = None # type: ignore
            self.hidden_states: torch.Tensor = None # type: ignore

        def clear(self):
            self.__init__()

    def __init__(
        self,
        num_envs: int,
        capacity_per_env: int,
        obs_shape: Tuple[int, ...],
        privileged_obs_shape: Tuple[int, ...],
        action_shape: Tuple[int, ...],
        device: str = "cpu",
    ):

        # Replay buffer data shape
        self.obs_shape = obs_shape
        self.privileged_obs_shape = privileged_obs_shape
        self.action_shape = action_shape

        # Replay buffer capacity size
        self.num_envs = num_envs
        self.capacity = capacity_per_env * num_envs
        self.capacity_per_env = capacity_per_env

        # Replay buffer device
        self.device = device

        # Replay buffer data initialization
        self._init_storage()
        # Replay buffer internal state
        self.next_p = 0
        self.is_full = False
        self.cur_capacity_per_env = 0

    def _init_storage(self):
        # Replay buffer data initialization
        self.observations = torch.zeros(self.num_envs, self.capacity_per_env, *self.obs_shape, device=self.device)
        self.actions = torch.zeros(self.num_envs, self.capacity_per_env, *self.action_shape, device=self.device)
        self.rewards = torch.zeros(self.num_envs, self.capacity_per_env, 1, device=self.device)
        if self.privileged_obs_shape is not None:
            self.privileged_observations = torch.zeros(
                self.num_envs, self.capacity_per_env, *self.privileged_obs_shape, device=self.device
            )
            self.next_critic_observations = torch.zeros(
                self.num_envs, self.capacity_per_env, *self.privileged_obs_shape, device=self.device
            )
        else:
            self.privileged_observations = None
            self.next_critic_observations = torch.zeros(
                self.num_envs, self.capacity_per_env, *self.obs_shape, device=self.device
            )
        self.dones = torch.zeros(self.num_envs, self.capacity_per_env, 1, device=self.device).byte()

    @torch.no_grad()
    def add_transitions(self, transition: Transition):
        if self.next_p >= self.capacity_per_env:
            self.next_p = 0
            self.is_full = True
        self.observations[:, self.next_p].copy_(transition.observations)
        self.actions[:, self.next_p].copy_(transition.actions)
        self.rewards[:, self.next_p].copy_(transition.rewards[:, None])
        if self.privileged_observations is not None:
            self.privileged_observations[:, self.next_p].copy_(transition.critic_observations)
        self.next_critic_observations[:, self.next_p].copy_(transition.next_critic_observations)
        self.dones[:, self.next_p].copy_(transition.dones[:, None])
        self.next_p += 1
        self.cur_capacity_per_env = self.capacity_per_env if self.is_full else self.next_p

    @torch.no_grad()
    def reccurent_mini_batch_generator(self, batch_size: int, num_epochs=1):
        raise NotImplementedError("Recurrent mini batch generator is not implemented yet")

    @torch.no_grad()
    def mini_batch_generator(self, batch_size: int, num_epochs=1):
        batch_size = batch_size
        indices = torch.randint(
            0,
            self.cur_capacity_per_env * self.num_envs,
            (1, num_epochs * batch_size),
            device=self.device,
        )

        observations = self.observations[:, : self.cur_capacity_per_env, :].flatten(0, 1)
        next_critic_observations = self.next_critic_observations[:, : self.cur_capacity_per_env, :].flatten(0, 1)
        if self.privileged_observations is not None:
            critic_observations = self.privileged_observations[:, : self.cur_capacity_per_env, :].flatten(0, 1)
        else:
            critic_observations = observations

        actions = self.actions[:, : self.cur_capacity_per_env, :].flatten(0, 1)
        rewards = self.rewards[:, : self.cur_capacity_per_env, :].flatten(0, 1)
        dones = self.dones[:, : self.cur_capacity_per_env, :].flatten(0, 1)

        for epoch in range(num_epochs):
            start = epoch * batch_size
            end = (epoch + 1) * batch_size
            yield (
                observations[indices[:, start:end]].flatten(0, 1),
                critic_observations[indices[:, start:end]].flatten(0, 1),
                next_critic_observations[indices[:, start:end]].flatten(0, 1),
                actions[indices[:, start:end]].flatten(0, 1),
                rewards[indices[:, start:end]].flatten(0, 1),
                dones[indices[:, start:end]].flatten(0, 1),
                None,
            )

class NStepReplayBuffer(ReplayBuffer):
    def __init__(
        self,
        num_envs: int,
        obs_shape: Tuple[int, ...],
        privileged_obs_shape: Tuple[int, ...],
        action_shape: Tuple[int, ...],
        gamma: float = 0.99,
        nstep: int = 1,
        device: str = "cpu",
    ):
        self.nstep = nstep
        self.nstep_count = 0
        super().__init__(
            num_envs,
            nstep,
            obs_shape,
            privileged_obs_shape,
            action_shape,
            device,
        )
        self.gamma = gamma
        self.gamma_array = torch.tensor([self.gamma**i for i in range(self.nstep)]).to(device).view(-1, 1)

    @torch.no_grad()
    def fifo_shift(self, transition):
        self.observations = torch.cat((self.observations[:, 1:], transition.observations.unsqueeze(1)), dim=1)
        self.actions = torch.cat((self.actions[:, 1:], transition.actions.unsqueeze(1)), dim=1)
        self.rewards = torch.cat((self.rewards[:, 1:], transition.rewards[:, None, None]), dim=1)
        if self.privileged_observations is not None:
            self.privileged_observations = torch.cat(
                (self.privileged_observations[:, 1:], transition.critic_observations.unsqueeze(1)), dim=1
            )
        self.next_critic_observations = torch.cat(
            (self.next_critic_observations[:, 1:], transition.next_critic_observations.unsqueeze(1)), dim=1
        )
        self.dones = torch.cat((self.dones[:, 1:], transition.dones[:, None, None]), dim=1)
        self.hidden_states = None

    @torch.no_grad()
    def add_transitions(self, transition: ReplayBuffer.Transition):
        if self.nstep > 1:
            self.fifo_shift(transition)
            self.nstep_count += 1 if self.nstep_count < self.nstep else 0

            reward, next_critic_obs, done = compute_nstep_return(
                nstep_buf_next_obs=self.next_critic_observations,
                nstep_buf_done=self.dones,
                nstep_buf_reward=self.rewards,
                gamma_array=self.gamma_array,
            )
        else:
            raise NotImplementedError("To use NStep Replay Buffer, set nstep > 1")
        return reward, next_critic_obs, done

    def is_ready(self):
        return self.nstep_count >= self.nstep


@torch.jit.script
def compute_nstep_return(nstep_buf_next_obs, nstep_buf_done, nstep_buf_reward, gamma_array):
    buf_done = nstep_buf_done.squeeze(-1)
    buf_done_ids = torch.where(buf_done)  # now shape (N_env, N_step)
    buf_done_envs = torch.unique_consecutive(buf_done_ids[0])
    buf_done_steps = buf_done.argmax(dim=1)

    done = nstep_buf_done[:, -1].clone()
    done[buf_done_envs] = True

    next_obs = nstep_buf_next_obs[:, -1].clone()
    next_obs[buf_done_envs] = nstep_buf_next_obs[buf_done_envs, buf_done_steps[buf_done_envs]].clone()

    mask = torch.ones(buf_done.shape, device=buf_done.device, dtype=torch.bool)
    mask[buf_done_envs] = torch.arange(mask.shape[1], device=buf_done.device) <= buf_done_steps[buf_done_envs][:, None]
    discounted_rewards = nstep_buf_reward * gamma_array
    discounted_rewards = (discounted_rewards * mask.unsqueeze(-1)).sum(1)
    return discounted_rewards, next_obs, done
