import numpy as np
import torch as t
from torch.utils.data import Dataset

from ..simulators import SIMULATOR


class ProcgenDataset(Dataset):
    def __init__(self, file, n_steps):
        super().__init__()

        data = t.load(file)

        self.states = t.cat(
            [state for states in data["states"] for state in states],
            dim=0,
        ).numpy()
        self.actions = t.tensor(
            [action for actions in data["actions"] for action in actions],
            dtype=t.int,
        ).numpy()
        self.rewards = t.tensor(
            [reward for rewards in data["rewards"] for reward in rewards],
            dtype=t.float32,
        ).numpy()
        self.values = t.tensor(
            [value for values in data["values"] for value in values],
            dtype=t.float32,
        ).numpy()

        self.trajectory_ends = set(np.cumsum([len(i) for i in data["actions"]]))
        self.n_steps = n_steps + 1
        self.n_samples = len(self.actions)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        end = idx + 1
        while end < idx + self.n_steps and end not in self.trajectory_ends:
            end += 1

        states = self.states[idx:end]
        actions = self.actions[idx:end]
        rewards = self.rewards[idx:end]
        values = self.values[idx:end]

        if end < idx + self.n_steps:
            states = np.concatenate(
                (states, [self.states[end - 1]] * (idx + self.n_steps - end)),
                axis=0,
            )
            actions = np.concatenate(
                (actions, np.random.randint(SIMULATOR.n_actions, size=idx + self.n_steps - end)),
                axis=0,
            )
            rewards = np.concatenate(
                (rewards, [self.rewards[end - 1]] * (idx + self.n_steps - end)),
                axis=0,
            )
            values = np.concatenate(
                (values, [self.values[end - 1]] * (idx + self.n_steps - end)),
                axis=0,
            )

        return (
            states.astype(np.float32),
            actions.astype(int),
            rewards.astype(np.float32),
            values.astype(np.float32),
        )
