import collections
from typing import Dict, Optional, Union

import gym
import gym.spaces
import jax
import numpy as np

from expo.data.dataset import Dataset, DatasetDict


def _init_replay_dict(
    obs_space: gym.Space, capacity: int
) -> Union[np.ndarray, DatasetDict]:
    if isinstance(obs_space, gym.spaces.Box):
        return np.empty((capacity, *obs_space.shape), dtype=obs_space.dtype)
    elif isinstance(obs_space, gym.spaces.Dict):
        data_dict = {}
        for k, v in obs_space.spaces.items():
            data_dict[k] = _init_replay_dict(v, capacity)
        return data_dict
    else:
        raise TypeError()


def _insert_recursively(
    dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int
):
    if isinstance(dataset_dict, np.ndarray):
        dataset_dict[insert_index] = data_dict
    elif isinstance(dataset_dict, dict):
        assert dataset_dict.keys() == data_dict.keys()
        for k in dataset_dict.keys():
            _insert_recursively(dataset_dict[k], data_dict[k], insert_index)
    else:
        raise TypeError()


def _init_robo_replay_dict(
    example_observation: gym.Space, capacity: int
) -> Union[np.ndarray, DatasetDict]:
    # if isinstance(obs_space, gym.spaces.Box):
    return np.empty((capacity, *example_observation.shape), dtype=example_observation.dtype)


class RoboReplayBuffer(Dataset):
    def __init__(
        self,
        example_observation,
        example_action,
        capacity: int,
    ):

        observation_data = _init_robo_replay_dict(example_observation, capacity)
        next_observation_data = _init_robo_replay_dict(example_observation, capacity)
        dataset_dict = dict(
            observations=observation_data,
            next_observations=next_observation_data,
            actions=np.empty((capacity, *example_action.shape), dtype=example_action.dtype),
            rewards=np.empty((capacity,), dtype=np.float32),
            masks=np.empty((capacity,), dtype=np.float32),
            dones=np.empty((capacity,), dtype=bool),
        )

        super().__init__(dataset_dict)

        self._size = 0
        self._capacity = capacity
        self._insert_index = 0

    def __len__(self) -> int:
        return self._size

    def insert(self, data_dict: DatasetDict):
        _insert_recursively(self.dataset_dict, data_dict, self._insert_index)

        self._insert_index = (self._insert_index + 1) % self._capacity
        self._size = min(self._size + 1, self._capacity)
    

    def insert_dataset(self, dataset: Dict[str, np.ndarray]):
        """Insert an entire offline dataset at once."""
        dataset_size = len(dataset['observations'])
        
        if dataset_size <= self._capacity:
            # Dataset fits entirely in buffer
            for key in self.dataset_dict.keys():
                if key in dataset:
                    self.dataset_dict[key][:dataset_size] = dataset[key]
            
            self._size = dataset_size
            self._insert_index = min(self._size + 1, self._capacity)  # Start overwriting from beginning
        else:
            # Dataset is larger than buffer capacity - take random subset
            indices = np.random.choice(dataset_size, self._capacity, replace=False)
            
            for key in self.dataset_dict.keys():
                if key in dataset:
                    self.dataset_dict[key][:] = dataset[key][indices]
            
            self._size = self._capacity
            self._insert_index = 0


    def get_iterator(self, queue_size: int = 2, sample_args: dict = {}):
        # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
        # queue_size = 2 should be ok for one GPU.

        queue = collections.deque()

        def enqueue(n):
            for _ in range(n):
                data = self.sample(**sample_args)
                queue.append(jax.device_put(data))

        enqueue(queue_size)
        while queue:
            yield queue.popleft()
            enqueue(1)


class ReplayBuffer(Dataset):
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        capacity: int,
        next_observation_space: Optional[gym.Space] = None,
    ):
        if next_observation_space is None:
            next_observation_space = observation_space

        observation_data = _init_replay_dict(observation_space, capacity)
        next_observation_data = _init_replay_dict(next_observation_space, capacity)
        dataset_dict = dict(
            observations=observation_data,
            next_observations=next_observation_data,
            actions=np.empty((capacity, *action_space.shape), dtype=action_space.dtype),
            rewards=np.empty((capacity,), dtype=np.float32),
            masks=np.empty((capacity,), dtype=np.float32),
            dones=np.empty((capacity,), dtype=bool),
        )

        super().__init__(dataset_dict)

        self._size = 0
        self._capacity = capacity
        self._insert_index = 0

    def __len__(self) -> int:
        return self._size

    def insert(self, data_dict: DatasetDict):
        _insert_recursively(self.dataset_dict, data_dict, self._insert_index)

        self._insert_index = (self._insert_index + 1) % self._capacity
        self._size = min(self._size + 1, self._capacity)

    def insert_dataset(self, dataset: Dict[str, np.ndarray]):
        dataset_size = len(dataset['observations'])
        
        if dataset_size <= self._capacity:
            # Dataset fits entirely in buffer
            for key in self.dataset_dict.keys():
                if key in dataset:
                    self.dataset_dict[key][:dataset_size] = dataset[key]
            
            self._size = dataset_size
            self._insert_index = min(self._size + 1, self._capacity)  # Start overwriting from beginning
        else:
            # Dataset is larger than buffer capacity - take random subset
            indices = np.random.choice(dataset_size, self._capacity, replace=False)
            
            for key in self.dataset_dict.keys():
                if key in dataset:
                    self.dataset_dict[key][:] = dataset[key][indices]
            
            self._size = self._capacity
            self._insert_index = 0

    def get_iterator(self, queue_size: int = 2, sample_args: dict = {}):
        # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
        # queue_size = 2 should be ok for one GPU.

        queue = collections.deque()

        def enqueue(n):
            for _ in range(n):
                data = self.sample(**sample_args)
                queue.append(jax.device_put(data))

        enqueue(queue_size)
        while queue:
            yield queue.popleft()
            enqueue(1)
