import collections
from typing import Optional, Union

import gym
import gym.spaces
import jax
import numpy as np

from src.agents.buffers.dataset import DatasetDict, Dataset


def _init_replay_dict(
    obs_space: gym.Space, capacity: int
) -> Union[np.ndarray, DatasetDict]:
    if isinstance(obs_space, gym.spaces.Box):
        return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype)
    elif isinstance(obs_space, gym.spaces.Dict):
        data_dict = {}
        for k, v in obs_space.spaces.items():
            data_dict[k] = _init_replay_dict(v, capacity)
        return data_dict
    else:
        raise TypeError()


def _insert_recursively(
    dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int
):
    if isinstance(dataset_dict, np.ndarray):
        dataset_dict[insert_index] = data_dict
    elif isinstance(dataset_dict, dict):
        assert dataset_dict.keys() == data_dict.keys()
        for k in dataset_dict.keys():
            _insert_recursively(dataset_dict[k], data_dict[k], insert_index)
    else:
        raise TypeError()


def _insert_recursively_batch(
    dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int, size: int
):
    if isinstance(dataset_dict, np.ndarray):
        dataset_dict[insert_index : insert_index + size] = data_dict
    elif isinstance(dataset_dict, dict):
        assert dataset_dict.keys() == data_dict.keys()
        for k in dataset_dict.keys():
            _insert_recursively_batch(dataset_dict[k], data_dict[k], insert_index, size)
    else:
        raise TypeError()


class ReplayBuffer(Dataset):
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        capacity: int,
        next_observation_space: Optional[gym.Space] = None,
        seed: int = None,
        embodiment: str = False,
        include_imgs: bool = False
    ):
        if next_observation_space is None:
            next_observation_space = observation_space
            
        observation_data = _init_replay_dict(observation_space, capacity)
        next_observation_data = _init_replay_dict(next_observation_space, capacity)
        dataset_dict = dict(
            observations=observation_data,
            next_observations=next_observation_data,
            actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype),
            rewards=np.empty((capacity,), dtype=np.float32),
            masks=np.empty((capacity,), dtype=np.float32),
            dones=np.empty((capacity,), dtype=np.float32),
        )
        if embodiment:
            dataset_dict.update({'embodiment': np.empty((capacity, ), dtype=np.float32)})
            # only for jupyter
            if include_imgs:
                dataset_dict.update({'images': np.empty((capacity, 64, 64, 3), dtype=np.float32)})
        super().__init__(dataset_dict, seed=seed)

        self._size = 0
        self._capacity = capacity
        self._insert_index = 0

    def __len__(self) -> int:
        return self._size

    def insert(self, data_dict: DatasetDict):
        _insert_recursively(self.dataset_dict, data_dict, self._insert_index)
        #_rescale_rewards(self.dataset_dict)
        self._insert_index = (self._insert_index + 1) % self._capacity
        self._size = min(self._size + 1, self._capacity)

    def insert_batch(self, data_dict: DatasetDict):
        first_key = list(data_dict.keys())[0]
        batch_size = data_dict[first_key].shape[0]

        if self._insert_index + batch_size > self._capacity:
            self._insert_index = 0
        self._size = max(self._size, self._insert_index + batch_size)
        _insert_recursively_batch(
            self.dataset_dict, data_dict, self._insert_index, batch_size
        )

    def get_iterator(self, queue_size: int = 2, sample_args: dict = {}):
        # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
        # queue_size = 2 should be ok for one GPU.

        queue = collections.deque()

        def enqueue(n):
            for _ in range(n):
                data = self.sample(**sample_args)
                queue.append(jax.device_put(data))

        enqueue(queue_size)
        while queue:
            yield queue.popleft()
            enqueue(1)