import numpy as np
import torch as t
from torch.utils.data import Dataset

from ..simulators import SIMULATOR


class NavigationDataset(Dataset):
    def __init__(self, file, n_steps):
        super().__init__()

        data = t.load(file)
        self.states = np.concatenate(
            data["states"],
            axis=0,
            dtype=bool,
        )
        self.actions = np.concatenate(
            data["actions"],
            axis=0,
            dtype=int,
        )
        self.rewards = np.concatenate(
            data["rewards"],
            axis=0,
            dtype=int,
        )
        self.values = np.concatenate(
            data["values"],
            axis=0,
            dtype=int,
        )
        self.trajectory_ends = set(np.cumsum([len(i) for i in data["actions"]]))
        self.n_steps = n_steps + 1
        self.n_samples = len(self.actions)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        end = idx + 1
        while end < idx + self.n_steps and end not in self.trajectory_ends:
            end += 1

        states = self.states[idx:end]
        actions = self.actions[idx:end]
        rewards = self.rewards[idx:end]
        values = self.values[idx:end]

        if end < idx + self.n_steps:
            states = np.concatenate(
                (states, [self.states[end - 1]] * (idx + self.n_steps - end)),
                axis=0,
            )
            actions = np.concatenate(
                (actions, np.random.randint(SIMULATOR.n_actions, size=idx + self.n_steps - end)),
                axis=0,
            )
            rewards = np.concatenate(
                (rewards, [self.rewards[end - 1]] * (idx + self.n_steps - end)),
                axis=0,
            )
            values = np.concatenate(
                (values, [self.values[end - 1]] * (idx + self.n_steps - end)),
                axis=0,
            )

        return (
            states.astype(np.float32),
            actions.astype(int),
            rewards.astype(np.float32),
            values.astype(np.float32),
        )
