from typing import Generator, NamedTuple, Optional

import numpy as np
import torch as th
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize

class MaskRolloutBufferSamples(NamedTuple):
    observations: th.Tensor
    actions: th.Tensor
    is_action_valids: th.Tensor
    masks: th.Tensor
    old_values: th.Tensor
    old_log_prob: th.Tensor
    advantages: th.Tensor
    returns: th.Tensor


class MaskRolloutBuffer(RolloutBuffer):
    def __init__(self, *args, **kwargs):
        self.is_action_valids = None
        self.masks = None
        super().__init__(*args, **kwargs)

        self.is_action_valids = np.zeros(
            (self.buffer_size, self.n_envs), dtype=np.float32
        )
        self.masks = np.zeros((self.buffer_size, self.n_envs, self.action_space.n), dtype=np.float32)
        self.reset()
        super().reset()

    def reset(self) -> None:
        # self.is_action_valids = np.zeros(
        #     (self.buffer_size, self.n_envs), dtype=np.float32
        # )
        # super().reset()
        self.reset_index = self.pos

    def add(
        self,
        obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        episode_start: np.ndarray,
        value: th.Tensor,
        log_prob: th.Tensor,
        is_action_valid: Optional[np.ndarray] = None,
        masks: Optional[np.ndarray] = None,
    ) -> None:
        if is_action_valid is not None:
            self.is_action_valids[self.pos] = np.array(is_action_valid).copy()
        if masks is not None:
            self.masks[self.pos] = np.array(masks).copy()
        super().add(obs, action, reward, episode_start, value, log_prob)
        self.pos = self.pos % self.buffer_size

    def get(
        self, batch_size: Optional[int] = None, start_index = 0, buffer_size = None
    ) -> Generator[MaskRolloutBufferSamples, None, None]:
        # assert self.full, ""
        if buffer_size is None:
            buffer_size = self.buffer_size if self.full else self.pos
        indices = np.random.permutation(buffer_size * self.n_envs)
        # Prepare the data
        if not self.generator_ready:

            _tensor_names = [
                "observations",
                "actions",
                "is_action_valids",
                "masks",
                "values",
                "log_probs",
                "advantages",
                "returns",
            ]

            for tensor in _tensor_names:
                self.__dict__[tensor] = self.swap_and_flatten(
                    self.__dict__[tensor]
                )
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = buffer_size * self.n_envs

        idx = 0
        while idx < buffer_size * self.n_envs:
            yield self._get_samples((indices[idx : idx + batch_size] + start_index) % self.buffer_size)
            idx += batch_size

    def _get_samples(
        self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
    ) -> MaskRolloutBufferSamples:
        data = (
            self.observations[batch_inds],
            self.actions[batch_inds],
            self.is_action_valids[batch_inds].flatten(),
            self.masks[batch_inds],
            self.values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
        )
        return MaskRolloutBufferSamples(*tuple(map(self.to_torch, data)))

    def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarray) -> None:
        """
        Post-processing step: compute the lambda-return (TD(lambda) estimate)
        and GAE(lambda) advantage.

        Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
        to compute the advantage. To obtain Monte-Carlo advantage estimate (A(s) = R - V(S))
        where R is the sum of discounted reward with value bootstrap
        (because we don't always have full episode), set ``gae_lambda=1.0`` during initialization.

        The TD(lambda) estimator has also two special cases:
        - TD(1) is Monte-Carlo estimate (sum of discounted rewards)
        - TD(0) is one-step estimate with bootstrapping (r_t + gamma * v(s_{t+1}))

        For more information, see discussion in https://github.com/DLR-RM/stable-baselines3/pull/375.

        :param last_values: state value estimation for the last step (one for each env)
        :param dones: if the last step was a terminal step (one bool for each env).
        """
        # Convert to numpy
        last_values = last_values.clone().cpu().numpy().flatten()

        last_gae_lam = 0
        buffer_size = self.pos
        for step in reversed(range(buffer_size)):
            if step == buffer_size - 1:
                next_non_terminal = 1.0 - dones
                next_values = last_values
            else:
                next_non_terminal = 1.0 - self.episode_starts[step + 1]
                next_values = self.values[step + 1]
            delta = self.rewards[step] + self.gamma * next_values * next_non_terminal - self.values[step]
            last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
            self.advantages[step] = last_gae_lam
        # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)"
        # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA
        self.returns = self.advantages + self.values
