import collections
from typing import Optional, Union

import gym
import numpy as np

from jaxrl.datasets.dataset import Dataset


class ReplayBuffer(Dataset):

    def __init__(self, observation_space: gym.spaces.Box,
                 action_space: Union[gym.spaces.Discrete, gym.spaces.Box], 
                 capacity: int):

        observations = np.empty((capacity, *observation_space.shape),
                                dtype=observation_space.dtype)
        actions = np.empty((capacity, *action_space.shape),
                           dtype=action_space.dtype)
        rewards = np.empty((capacity, ), dtype=np.float32)
        masks = np.empty((capacity, ), dtype=np.float32)
        dones_float = np.empty((capacity, ), dtype=np.float32)
        next_observations = np.empty((capacity, *observation_space.shape),
                                     dtype=observation_space.dtype)
        super().__init__(observations=observations,
                         actions=actions,
                         rewards=rewards,
                         masks=masks,
                         dones_float=dones_float,
                         next_observations=next_observations,
                         size=0)
        self.size = 0

        self.insert_index = 0
        self.capacity = capacity

    def initialize_with_dataset(self, dataset: Dataset,
                                num_samples: Optional[int]):
        assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.'

        dataset_size = len(dataset.observations)

        if num_samples is None:
            num_samples = dataset_size
        else:
            num_samples = min(dataset_size, num_samples)
        assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.'

        if num_samples < dataset_size:
            perm = np.random.permutation(dataset_size)
            indices = perm[:num_samples]
        else:
            indices = np.arange(num_samples)

        self.observations[:num_samples] = dataset.observations[indices]
        self.actions[:num_samples] = dataset.actions[indices]
        self.rewards[:num_samples] = dataset.rewards[indices]
        self.masks[:num_samples] = dataset.masks[indices]
        self.dones_float[:num_samples] = dataset.dones_float[indices]
        self.next_observations[:num_samples] = dataset.next_observations[
            indices]

        self.insert_index = num_samples
        self.size = num_samples

    def insert(self, observation: np.ndarray, action: np.ndarray,
               reward: float, mask: float, done_float: float,
               next_observation: np.ndarray):
        self.observations[self.insert_index] = observation
        self.actions[self.insert_index] = action
        self.rewards[self.insert_index] = reward
        self.masks[self.insert_index] = mask
        self.dones_float[self.insert_index] = done_float
        self.next_observations[self.insert_index] = next_observation

        self.insert_index = (self.insert_index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)
    

Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'rewards', 'masks', 'next_observations'])


class NStepReplayBuffer(Dataset):

    def __init__(self, observation_space: gym.spaces.Box,
                 action_space: Union[gym.spaces.Discrete, gym.spaces.Box], 
                 capacity: int, discount: float = 0.99, n_step_trgt: int = 5):

        observations = np.empty((capacity, *observation_space.shape),
                                dtype=observation_space.dtype)
        actions = np.empty((capacity, *action_space.shape),
                           dtype=action_space.dtype)
        rewards = np.empty((capacity, ), dtype=np.float32)
        masks = np.empty((capacity, ), dtype=np.float32)
        dones_float = np.empty((capacity, ), dtype=np.float32)
        next_observations = np.empty((capacity, *observation_space.shape),
                                     dtype=observation_space.dtype)
        super().__init__(observations=observations,
                         actions=actions,
                         rewards=rewards,
                         masks=masks,
                         dones_float=dones_float,
                         next_observations=next_observations,
                         size=0)
        
        self.n_step_returns = np.empty((capacity, ), dtype=np.float32)
        self.n_step_trgt = int(n_step_trgt)
        self.discount = discount

        self.size = 0

        self.insert_index = 0
        self.capacity = capacity

    def initialize_with_dataset(self, dataset: Dataset,
                                num_samples: Optional[int]):
        assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.'

        dataset_size = len(dataset.observations)

        if num_samples is None:
            num_samples = dataset_size
        else:
            num_samples = min(dataset_size, num_samples)
        assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.'

        if num_samples < dataset_size:
            perm = np.random.permutation(dataset_size)
            indices = perm[:num_samples]
        else:
            indices = np.arange(num_samples)

        self.observations[:num_samples] = dataset.observations[indices]
        self.actions[:num_samples] = dataset.actions[indices]
        self.rewards[:num_samples] = dataset.rewards[indices]
        self.masks[:num_samples] = dataset.masks[indices]
        self.dones_float[:num_samples] = dataset.dones_float[indices]
        self.next_observations[:num_samples] = dataset.next_observations[
            indices]

        self.insert_index = num_samples
        self.size = num_samples

    def insert(self, observation: np.ndarray, action: np.ndarray,
               reward: float, mask: float, done_float: float,
               next_observation: np.ndarray):
        self.observations[self.insert_index] = observation
        self.actions[self.insert_index] = action
        self.rewards[self.insert_index] = reward
        self.masks[self.insert_index] = mask
        self.dones_float[self.insert_index] = done_float
        self.next_observations[self.insert_index] = next_observation
        if self.size >= self.n_step_trgt:
            start_idx, end_idx = self.insert_index - self.n_step_trgt + 1, self.insert_index + 1
            if start_idx < 0:
                ture_start_idx = (self.insert_index - self.n_step_trgt + 1) % self.capacity
                if (1.0 not in self.dones_float[start_idx:]) and \
                    (1.0 not in self.dones_float[0: start_idx+1]):
                    self.n_step_returns[ture_start_idx] = np.sum(self.rewards[ture_start_idx:] * 
                                                                np.array([self.discount**p for p in range(self.capacity - ture_start_idx)])) + \
                                                        np.sum(self.rewards[0: end_idx] * 
                                                                np.array([self.discount**p for p in range(end_idx)]))
            else:
                if 1.0 not in self.dones_float[start_idx: end_idx]:
                    self.n_step_returns[start_idx] = np.sum(self.rewards[start_idx: self.insert_index+1] * 
                                                            np.array([self.discount**p for p in range(self.n_step_trgt)]))
                    
        self.insert_index = (self.insert_index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int) -> Batch:
        indx = np.random.randint(self.size - self.n_step_trgt + 1, size=batch_size)
        def replace_invalid_indx(indexes):
            for i, idx in enumerate(indexes):
                if 1.0 in self.dones_float[idx: idx + self.n_step_trgt]:
                    indexes[i] = np.random.randint(self.size - self.n_step_trgt + 1)
            return indexes
        for step in range(self.n_step_trgt):
            if 1.0 in self.dones_float[indx + step]:
                indx = replace_invalid_indx(indx)
                break
        return Batch(observations=self.observations[indx + self.n_step_trgt - 1],
                     actions=self.actions[indx + self.n_step_trgt - 1],
                     rewards=self.n_step_returns[indx],
                     masks=self.masks[indx + self.n_step_trgt - 1],
                     next_observations=self.next_observations[indx + self.n_step_trgt - 1])
