from __future__ import annotations
import numpy as np
import torch
from torch.utils.data import Dataset
from ..gridworld.env import LongCorridor
from ..gridworld.agent import sample_trajectory

class SeqPredDataset(Dataset):
    def __init__(self, env: LongCorridor, n_traj: int = 200, T: int = 50, k: int = 3, seed: int = 123):
        self.env = env
        self.k = k
        self.trajs = [sample_trajectory(env, T=T, seed=seed + i) for i in range(n_traj)]
        self.examples = []
        for (O, A, HDF, P) in self.trajs:
            Tsteps = O.shape[0]
            for t in range(Tsteps - k - 1):
                feats = np.concatenate([A[t:t+k], HDF[t:t+k]], axis=1)  # [k, 10]
                self.examples.append((O[t], feats, O[t+1:t+k+1]))
        self.obs_dim = self.trajs[0][0].shape[1]
        self.feat_dim = 10

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx: int):
        o0, feats, targets = self.examples[idx]
        return (
            torch.from_numpy(o0).float(),
            torch.from_numpy(feats).float(),
            torch.from_numpy(targets).float(),
        )
