import collections
from typing import Tuple, Union

import numpy as np

Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'rewards', 'masks', 'next_observations', 'mc_returns']
)




def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations, mc_returns=None):
    '''
    Split the entire dataset into trajectories

    '''
    trajs = [[]]
    for i in range(len(observations)):
        if mc_returns is None:
            trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                              dones_float[i], next_observations[i]))
        else:
            trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i], mc_returns[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs


def merge_trajectories(trajs):
    observations = []
    actions = []
    rewards = []
    masks = []
    dones_float = []
    next_observations = []
    mc_returns = []

    for traj in trajs:
        for (obs, act, rew, mask, done, next_obs, mc_return) in traj:
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            masks.append(mask)
            dones_float.append(done)
            next_observations.append(next_obs)
            mc_returns.append(mc_return)

    return np.stack(observations), np.stack(actions), np.stack(
        rewards), np.stack(masks), np.stack(dones_float), np.stack(
            next_observations), np.stack(mc_returns)


class Dataset(object):

    def __init__(self, observations: np.ndarray, actions: np.ndarray,
                 rewards: np.ndarray, masks: np.ndarray,
                 dones_float: np.ndarray, next_observations: np.ndarray,
                 size: int, mc_returns=None):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.size = size

        # init monte-carlo returns
        if mc_returns is None:
            self.get_monte_carlo_returns_per_transition(0.99)
        else:
            self.mc_returns = mc_returns

    def sample(self, batch_size: int) -> Batch:
        indx = np.random.randint(self.size, size=batch_size)
        return Batch(observations=self.observations[indx],
                     actions=self.actions[indx],
                     rewards=self.rewards[indx],
                     masks=self.masks[indx],
                     next_observations=self.next_observations[indx],
                     mc_returns=self.mc_returns[indx])

    def split_into_trajectories(self):
        return split_into_trajectories(self.observations, self.actions,
                                       self.rewards, self.masks,
                                       self.dones_float,
                                       self.next_observations,
                                       self.mc_returns)

    def merge_trajectories(self, trajs):
        return merge_trajectories(trajs)

    def get_initial_states(
        self,
        and_action: bool = False
        ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        '''
        Get the initial states of the trajectories in the dataset.
        Args:
            and_action: If True, return the initial actions as well.
        '''
        states = []
        if and_action:
            actions = []
        trajs = split_into_trajectories(self.observations, self.actions,
                                        self.rewards, self.masks,
                                        self.dones_float,
                                        self.next_observations,
                                        self.mc_returns)
        def compute_returns(traj):
            episode_return = 0
            for _, _, rew, _, _, _, _, _ in traj:
                episode_return += rew

            return episode_return

        trajs.sort(key=compute_returns)

        for traj in trajs:
            states.append(traj[0][0])
            if and_action:
                actions.append(traj[0][1])

        states = np.stack(states, 0)
        if and_action:
            actions = np.stack(actions, 0)
            return states, actions
        else:
            return states

    def get_monte_carlo_returns_per_transition(self, discount) -> np.ndarray:
        '''
        Compute the monte-carlo returns per transition.
        Args:
            discount: The discount factor.
        Returns:
            The monte-carlo returns per transition as a numpy array.
        '''
        trajs = split_into_trajectories(self.observations, self.actions,
                                        self.rewards, self.masks,
                                        self.dones_float,
                                        self.next_observations
                                        )

        mc_returns_per_transition = []
        for tidx, traj in enumerate(trajs):
            per_episode = []
            mc_return = 0.0
            for step, (_, _, reward, _, _, _) in enumerate(reversed(traj)):
                mc_return = reward + (discount * mc_return)
                per_episode.append(mc_return)
            mc_returns_per_transition.append(list(reversed(per_episode)))
        self.mc_returns = np.hstack(mc_returns_per_transition)


    def take_frac(self, percentage: float = 100.0, selection_criteria: str = 'random'):
        '''
        Take a fraction of the dataset
        Args:
            percentage: percentage of the dataset to take
            selection_criteria: 'random' or 'top'
        '''
        assert percentage > 0.0 and percentage <= 100.0

        trajs = split_into_trajectories(self.observations, self.actions,
                                        self.rewards, self.masks,
                                        self.dones_float,
                                        self.next_observations,
                                        self.mc_returns)

        if selection_criteria == 'random':
            np.random.shuffle(trajs)
        elif selection_criteria == 'top':
            def compute_returns(traj):
                episode_return = 0
                for _, _, rew, _, _, _, _ in traj:
                    episode_return += rew

                return episode_return

            trajs.sort(key=compute_returns)

        N = int(len(trajs) * percentage / 100)
        N = max(1, N)

        trajs = trajs[-N:]

        self.observations, self.actions, self.rewards, self.masks,\
         self.dones_float, self.next_observations, self.mc_returns = merge_trajectories(trajs)

        self.size = len(self.observations)

    def get_train_validation_split(self, train_fraction: float = 0.8
                               ) -> Tuple['Dataset', 'Dataset']:
        '''
        Split the dataset into a training and validation set.
        Args:
            train_fraction: The fraction of the dataset to use for training.
        Returns:
            A Dataset object of the training and validation set.
        '''
        trajs = split_into_trajectories(self.observations, self.actions,
                                        self.rewards, self.masks,
                                        self.dones_float,
                                        self.next_observations,
                                        self.mc_returns)
        print(f'Number of trajectories: {len(trajs)}')
        train_size = int(train_fraction * len(trajs))

        traj_lengths = []
        for traj in trajs:
            traj_lengths.append(len(traj))
        print(f'Mean length of trajectories: {np.mean(traj_lengths)} and Std is: {np.std(traj_lengths)}.')


        np.random.shuffle(trajs)

        print(f'Number of trajectories for training: {train_size} and validation: {len(trajs) - train_size}')

        (train_observations, train_actions, train_rewards, train_masks,
         train_dones_float, train_next_observations, train_mc_returns) = merge_trajectories(trajs[:train_size])

        (valid_observations, valid_actions, valid_rewards, valid_masks,
         valid_dones_float, valid_next_observations, valid_mc_returns) = merge_trajectories(trajs[train_size:])

        train_dataset = Dataset(train_observations,
                                train_actions,
                                train_rewards,
                                train_masks,
                                train_dones_float,
                                train_next_observations,
                                size=len(train_observations),
                                mc_returns=train_mc_returns)
        valid_dataset = Dataset(valid_observations,
                                valid_actions,
                                valid_rewards,
                                valid_masks,
                                valid_dones_float,
                                valid_next_observations,
                                size=len(valid_observations),
                                mc_returns=valid_mc_returns)

        return train_dataset, valid_dataset