from typing import Optional, Union
from functools import partial

import collections
import gym
import numpy as np
from numba import njit

from koopman.data.dataset import Dataset

Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'next_observations', 'rewards', 'masks', 'seq_masks']
)

# uncomment to get the numba_sample_seq() fpr the ReplayBuffer class
@njit
def numba_sample_seq(all_observations, all_actions, all_rewards, 
                    all_next_observations, all_masks,
                    start_indexes, traj_ends, max_seq_length, capacity):
    '''
    Sample a sequence of length max_seq_length from any trajectory.
    Args:
        all_observations:  all observations in the dataset of shape (N, *obs_shape)
        all_actions: all actions in the dataset of shape (N, *act_shape)
        all_rewards: all rewards in the dataset of shape (N, 1)
        start_indexes: the start index of each sequence to be sampled of shape (batch_size, 1)
        traj_ends: the end index of each trajectory of shape (N, 1)
        max_seq_length: the maximum length of the sequence to be sampled
    Returns:
        observations: the sampled observations of shape (batch_size, max_seq_length, *obs_shape)
        actions: the sampled actions of shape (batch_size, max_seq_length, *act_shape)
        rewards: the sampled rewards of shape (batch_size, max_seq_length, 1)
        seq_masks: the mask for each sequence of shape (batch_size, max_seq_length, 1)

    '''

    batch_size = start_indexes.shape[0]
    observations = np.zeros(
        (batch_size, max_seq_length, *all_observations[0].shape), dtype=all_observations.dtype
    )
    actions = np.zeros(
        (batch_size, max_seq_length, *all_actions[0].shape), dtype=all_actions.dtype
    )
    rewards = np.zeros(
        (batch_size, max_seq_length), dtype=np.float32
    )
    masks = np.zeros(
        (batch_size, max_seq_length), dtype=np.float32
    )

    seq_masks = np.zeros(
        (batch_size, max_seq_length), dtype=np.int32
    )
    for i in range(batch_size):
        for j in range(max_seq_length):
            cur = (start_indexes[i] + j) % capacity
            if cur <= traj_ends[i]:
                observations[i, j] = all_observations[cur]
                actions[i, j] = all_actions[cur]
                rewards[i, j] = all_rewards[cur]
                seq_masks[i, j] = 1
                masks[i, j] = all_masks[cur]
            elif cur == traj_ends[i] + 1:
                observations[i, j] = all_next_observations[cur - 1]
                seq_masks[i, j] = 1
            else:
                break

    return observations, actions, rewards, masks, seq_masks


class ReplayBuffer(Dataset):
    '''
    Replay buffer is a special case of dataset to make it a dynamic dataset.
    Useful methods for dynamic dataset are implemented here like:
    1. sample_seq(): sample a batch of sequences from the data
    2. initialize_with_dataset(): initialize the replay buffer with a dataset
    3. insert(): insert a new transition into the replay buffer
    so we can use the efficient sequence sample function
    '''
    def __init__(self, observation_space: gym.spaces.Box,
                 action_space: Union[gym.spaces.Discrete,
                                     gym.spaces.Box],
                 capacity: int
                 ):

        observations = np.empty((capacity, *observation_space.shape),
                                dtype=observation_space.dtype)
        actions = np.empty((capacity, *action_space.shape),
                           dtype=action_space.dtype)
        rewards = np.empty((capacity, ), dtype=np.float32)
        masks = np.empty((capacity, ), dtype=np.float32)
        dones_float = np.empty((capacity, ), dtype=np.float32)
        next_observations = np.empty((capacity, *observation_space.shape),
                                     dtype=observation_space.dtype)

        super().__init__(observations=observations,
                         actions=actions,
                         rewards=rewards,
                         masks=masks,
                         dones_float=dones_float,
                         next_observations=next_observations,
                         size=0)
        self.size = 0

        self.insert_index = 0
        self.capacity = capacity
        self.traj_info = np.empty((self.capacity, 2), dtype=np.int32)
        self.traj_index_map = np.empty((self.capacity), dtype=np.int32)
        self.traj_insert_index = 0
        self.traj_info[0, 0] = 0
        self.observation_space = observation_space
        self.action_space = action_space


    def initialize_with_dataset(self, dataset, num_samples=None):
        """initialize the replay buffer with a dataset"""

        assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.'

        dataset_size = len(dataset.observations)
        print("Initializing with a dataset of size ", dataset_size)
        if num_samples is None:
            num_samples = dataset_size
        else:
            num_samples = min(dataset_size, num_samples)
        assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.'

        if num_samples < dataset_size:
            perm = np.random.permutation(dataset_size)
            indices = perm[:num_samples]
        else:
            indices = np.arange(num_samples)

        self.observations[:num_samples] = dataset.observations[indices]
        self.actions[:num_samples] = dataset.actions[indices]
        self.rewards[:num_samples] = dataset.rewards[indices]
        self.masks[:num_samples] = dataset.masks[indices]
        self.dones_float[:num_samples] = dataset.dones_float[indices]
        self.next_observations[:num_samples] = dataset.next_observations[indices]
        self.insert_index = num_samples
        self.size = num_samples
        for i in range(len(self.observations)):
            self.traj_index_map[i] = self.traj_insert_index
            if self.dones_float[i] == 1.0 and i + 1 < len(self.observations):
                self.traj_info[self.traj_insert_index, 1] = i
                self.traj_insert_index += 1
                self.traj_info[self.traj_insert_index, 0] = i + 1


    def insert(self, observation: np.ndarray, action: np.ndarray,
               reward: float, mask: float, done_float: float,
               next_observation: np.ndarray):
        ''' Insert a new transition into the replay buffer '''

        self.observations[self.insert_index] = observation
        self.actions[self.insert_index] = action
        self.rewards[self.insert_index] = reward
        self.masks[self.insert_index] = mask
        self.dones_float[self.insert_index] = done_float
        self.next_observations[self.insert_index] = next_observation

        self.traj_index_map[self.insert_index] = self.traj_insert_index

        if done_float == 1.0:
            self.traj_info[self.traj_insert_index, 1] = self.insert_index
            self.traj_insert_index += 1
            self.traj_info[self.traj_insert_index, 0] = self.insert_index + 1

        self.insert_index = (self.insert_index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample_seq(self, batch_size: int, seq_length: int) -> Batch:
        """sample a batch of sequences from the buffer """

        start_indexes = np.random.randint(self.size - 1, size=batch_size).astype(dtype=np.int32)
        traj_indexes = self.traj_index_map[start_indexes]

        observations, actions, rewards, masks, seq_masks = numba_sample_seq(
            self.observations, self.actions, self.rewards, self.next_observations,
            self.masks, start_indexes, self.traj_info[traj_indexes][:, 1], seq_length + 1,
            self.capacity
        )

        return Batch(observations=observations[:, :-1],
                     actions=actions[:, :-1],
                     next_observations=observations[:, 1:],
                     rewards=rewards[:, :-1],
                     masks=masks[:, :-1],
                     seq_masks=seq_masks[:, :-1])
