"""Biased replay buffer for StableBaselines3 OffPolicy algorithms.

It stores the `info["ground_truth_mode"]` of each experience, and uses
it to perform biased sampling.
"""
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.buffers import ReplayBufferSamples
from stable_baselines3.common.vec_env import VecNormalize
from collections import defaultdict
from itertools import product
from typing import List
from typing import Any
from typing import Dict
from typing import Optional
import random
import numpy as np


class BiasedModeReplayBuffer(ReplayBuffer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert not self.optimize_memory_usage, "Optimized memory usage not supported!"

        # Initialize mode
        self.modes = np.zeros(
            (self.buffer_size, self.n_envs,),
            dtype=np.int32,
        )

    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        infos: List[Dict[str, Any]],
    ) -> None:
        # The original function adds to the current `self.pos`. So the
        # new data should be saved to that location first, then the
        # original function should be called

        # Extract modes
        modes = [
            info["ground_truth_mode"] if "ground_truth_mode" in info.keys() else info["active_mode"]
            for info in infos
        ]

        # Save modes
        self.modes[self.pos] = np.array(modes, dtype=np.int32)

        # Call the original function
        super().add(obs, next_obs, action, reward, done, infos)

    def sample(
        self,
        batch_size: int,
        env: Optional[VecNormalize],
    ) -> None:
        """This function returns a sample of experiences from the replay buffer.
        The distribution over experiences is independent of the mode.

        That is, in the limit, each mode represented in the buffer is sampled
        an equal amount of times.
        """
        # Notes on the implementation of the original function:
        # - The replay buffer saves experiences in a collection of array mostly
        # as one would expect from theory. The only difference is that
        # the implementation is written with `VecEnv` in mind; so instead
        # of accessing the arrays with a single index as one would
        # expect (`arr[experience_i]`) we have to access with two
        # indices, one for step index and the second one for environment
        # instance (`arr[batch_i, env_i]`).
        # - `ReplayBuffer.sample` simply calls `BaseBuffer.sample` because the
        # optimize flag is `false`
        # - `BaseBuffer.sample` samples a set of `batch_i` and calls
        # `ReplayBuffer._get_sample`, which samples a set of `env_i`.
        # We will call a (`batch_i, env_i`) tuple an "experience index".

        # Organize experience indices by mode
        experiences = defaultdict(list)
        batch_upper_bound = self.buffer_size if self.full else self.pos
        all_batch_inds = list(range(batch_upper_bound))
        all_env_indices = list(range(self.n_envs))
        for batch_i, env_i in product(all_batch_inds, all_env_indices):
            experience_idx = (batch_i, env_i)
            mode_i = self.modes[experience_idx]
            experiences[mode_i].append(experience_idx)

        # Compute experience weights
        # Formula:
        # sample weight = (total_n - mode_size)/total_n
        # We don't normalize because it will be normalized down the line anyway
        exp_total_n = sum(
            len(experience_idxs)
            for experience_idxs in experiences.values()
        )
        mode_weights = {
            mode_i: exp_total_n - len(experience_idxs)
            for mode_i, experience_idxs in experiences.items()
        }

        # Create list with all experiences
        population = list()
        weights = list()
        for mode_i, experience_idxs in experiences.items():
            mode_weight = mode_weights[mode_i]
            # Special case: there is only one mode. In that case,
            # the only mode weight is zero. We replace that with one.
            if mode_weight == 0.0:
                mode_weight = 1
            for experience_idx in experience_idxs:
                population.append(experience_idx)
                weights.append(mode_weight)

        # Choose experiences
        choices = random.choices(population, weights, k=batch_size)
        batch_inds = [batch_i for batch_i, _ in choices]
        env_indices = [env_i for _, env_i in choices]

        # Packaging logic. Exactly as in `ReplayBuffer._get_samples`
        next_obs = self._normalize_obs(
            self.next_observations[batch_inds, env_indices, :], env
        )
        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            next_obs,
            # Only use dones that are not due to timeouts
            # deactivated by default (timeouts is initialized as an array of False)
            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
        )
        return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
