from typing import List, Dict, Any
from stable_baselines3.common.buffers import ReplayBuffer
from gymnasium import spaces
import numpy as np
from typing import Optional
from stable_baselines3.common.vec_env import VecNormalize
from flax.struct import dataclass
import jax


@dataclass
class ReplayBufferSamples:
    observations: jax.Array
    actions: jax.Array
    next_observations: jax.Array
    dones: jax.Array
    rewards: jax.Array

@dataclass
class ReplayBufferSamplesWeight:
    observations: jax.Array
    actions: jax.Array
    next_observations: jax.Array
    dones: jax.Array
    rewards: jax.Array
    weight: jax.Array


class MultiObjectiveReplayBuffer(ReplayBuffer):
    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        num_reward: int = 1,
        n_envs: int = 1,
        optimize_memory_usage: bool = False,
        handle_timeout_termination: bool = True,
        *,
        seed: int
    ):
        self.num_reward = num_reward
        super().__init__(
            buffer_size,
            observation_space,
            action_space,
            'cpu',
            n_envs,
            optimize_memory_usage,
            handle_timeout_termination
        )
        self.np_rngs = np.random.default_rng(seed)
        self.rewards = np.zeros((self.buffer_size, self.n_envs, self.num_reward), dtype=np.float32)


    def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
        """
        Sample elements from the replay buffer.
        Custom sampling when using memory efficient variant,
        as we should not sample the element with index `self.pos`
        See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274

        :param batch_size: Number of element to sample
        :param env: associated gym VecEnv
            to normalize the observations/rewards when sampling
        :return:
        """
        if not self.optimize_memory_usage:
            return super().sample(batch_size=batch_size, env=env)
        # Do not sample the element with index `self.pos` as the transitions is invalid
        # (we use only one array to store `obs` and `next_obs`)
        if self.full:
            batch_inds = (self.np_rngs.integers(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
        else:
            batch_inds = self.np_rngs.integers(0, self.pos, size=batch_size)
        return self._get_samples(batch_inds, env=env)


    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
        # Sample randomly the env idx
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))

        if self.optimize_memory_usage:
            next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
        else:
            next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)

        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            next_obs,
            # Only use dones that are not due to timeouts
            # deactivated by default (timeouts is initialized as an array of False)
            np.logical_or(self.dones[batch_inds, env_indices], self.timeouts[batch_inds, env_indices]).reshape(-1, 1),
            self.rewards[batch_inds, env_indices, :].reshape(-1, self.num_reward, 1),
        )
        return ReplayBufferSamples(*data)


class MOReplayBufferWeight(MultiObjectiveReplayBuffer):
    def __init__(self,
                 buffer_size: int,
                 observation_space: spaces.Space,
                 action_space: spaces.Space,
                 num_reward: int = 1,
                 n_envs: int = 1,
                 optimize_memory_usage: bool = False,
                 handle_timeout_termination: bool = True,
                 *,
                 seed: int
                 ):
        super().__init__(
            buffer_size, observation_space, action_space, num_reward, n_envs, optimize_memory_usage, handle_timeout_termination, seed=seed
        )
        self.weights = np.zeros((self.buffer_size, self.n_envs, self.num_reward), dtype=np.float32)

    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        infos: List[Dict[str, Any]],
        weight: np.ndarray,
    ) -> None:
        # Reshape needed when using multiple envs with discrete observations
        # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
        if isinstance(self.observation_space, spaces.Discrete):
            obs = obs.reshape((self.n_envs, *self.obs_shape))
            next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))

        # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392
        action = action.reshape((self.n_envs, self.action_dim))

        # Copy to avoid modification by reference
        self.observations[self.pos] = np.array(obs)

        if self.optimize_memory_usage:
            self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
        else:
            self.next_observations[self.pos] = np.array(next_obs)

        self.actions[self.pos] = np.array(action)
        self.rewards[self.pos] = np.array(reward)
        self.dones[self.pos] = np.array(done)
        self.weights[self.pos] = np.array(weight)

        if self.handle_timeout_termination:
            self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])

        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True
            self.pos = 0

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamplesWeight:
        # Sample randomly the env idx
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))

        if self.optimize_memory_usage:
            next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :],
                                           env)
        else:
            next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)

        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            next_obs,
            # Only use dones that are not due to timeouts
            # deactivated by default (timeouts is initialized as an array of False)
            np.logical_or(self.dones[batch_inds, env_indices], self.timeouts[batch_inds, env_indices]).reshape(-1,
                                                                                                               1),
            self.rewards[batch_inds, env_indices, :].reshape(-1, self.num_reward, 1),
            self.weights[batch_inds, env_indices, :].reshape(-1, self.num_reward)
        )
        return ReplayBufferSamplesWeight(*data)
