import collections
from typing import Tuple, Union

import numpy as np
from tqdm import tqdm

Batch = collections.namedtuple(
    "Batch",
    ["observations", "actions", "rewards", "masks", "next_observations"],
)


def split_into_trajectories(
    observations, actions, rewards, masks, dones_float, next_observations
):
    trajs = [[]]

    for i in tqdm(range(len(observations))):
        trajs[-1].append(
            (
                observations[i],
                actions[i],
                rewards[i],
                masks[i],
                dones_float[i],
                next_observations[i],
            )
        )
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs


def merge_trajectories(trajs):
    observations = []
    actions = []
    rewards = []
    masks = []
    dones_float = []
    next_observations = []

    for traj in trajs:
        for obs, act, rew, mask, done, next_obs in traj:
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            masks.append(mask)
            dones_float.append(done)
            next_observations.append(next_obs)

    return (
        np.stack(observations),
        np.stack(actions),
        np.stack(rewards),
        np.stack(masks),
        np.stack(dones_float),
        np.stack(next_observations),
    )


class Dataset(object):
    def __init__(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        masks: np.ndarray,
        dones_float: np.ndarray,
        next_observations: np.ndarray,
        size: int,
    ):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.size = size

    def sample(self, batch_size: int) -> Batch:
        indx = np.random.randint(self.size, size=batch_size)
        return Batch(
            observations=self.observations[indx],
            actions=self.actions[indx],
            rewards=self.rewards[indx],
            masks=self.masks[indx],
            next_observations=self.next_observations[indx],
        )

    def sample_multiple(self, batch_size: int, n_batches: int) -> Tuple[Batch]:
        indx = np.random.randint(self.size, size=batch_size * n_batches)
        batches = []
        for i in range(n_batches):
            batches.append(
                Batch(
                    observations=self.observations[
                        indx[i * batch_size : (i + 1) * batch_size]
                    ],
                    actions=self.actions[
                        indx[i * batch_size : (i + 1) * batch_size]
                    ],
                    rewards=self.rewards[
                        indx[i * batch_size : (i + 1) * batch_size]
                    ],
                    masks=self.masks[
                        indx[i * batch_size : (i + 1) * batch_size]
                    ],
                    next_observations=self.next_observations[
                        indx[i * batch_size : (i + 1) * batch_size]
                    ],
                )
            )
        return tuple(batches)

    def get_initial_states(
        self, and_action: bool = False
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        states = []
        if and_action:
            actions = []
        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.dones_float,
            self.next_observations,
        )

        def compute_returns(traj):
            episode_return = 0
            for _, _, rew, _, _, _ in traj:
                episode_return += rew

            return episode_return

        trajs.sort(key=compute_returns)

        for traj in trajs:
            states.append(traj[0][0])
            if and_action:
                actions.append(traj[0][1])

        states = np.stack(states, 0)
        if and_action:
            actions = np.stack(actions, 0)
            return states, actions
        else:
            return states

    def get_monte_carlo_returns(self, discount) -> np.ndarray:
        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.dones_float,
            self.next_observations,
        )
        mc_returns = []
        for traj in trajs:
            mc_return = 0.0
            for i, (_, _, reward, _, _, _) in enumerate(traj):
                mc_return += reward * (discount**i)
            mc_returns.append(mc_return)

        return np.asarray(mc_returns)

    def take_top(self, percentile: float = 100.0):
        assert percentile > 0.0 and percentile <= 100.0

        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.dones_float,
            self.next_observations,
        )

        def compute_returns(traj):
            episode_return = 0
            for _, _, rew, _, _, _ in traj:
                episode_return += rew

            return episode_return

        trajs.sort(key=compute_returns)

        N = int(len(trajs) * percentile / 100)
        N = max(1, N)

        trajs = trajs[-N:]

        (
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.dones_float,
            self.next_observations,
        ) = merge_trajectories(trajs)

        self.size = len(self.observations)

    def take_random(self, percentage: float = 100.0):
        assert percentage > 0.0 and percentage <= 100.0

        trajs = split_into_trajectories(
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.dones_float,
            self.next_observations,
        )
        np.random.shuffle(trajs)

        N = int(len(trajs) * percentage / 100)
        N = max(1, N)

        trajs = trajs[-N:]

        (
            self.observations,
            self.actions,
            self.rewards,
            self.masks,
            self.dones_float,
            self.next_observations,
        ) = merge_trajectories(trajs)

        self.size = len(self.observations)
