from torch.utils.data import Dataset
import torch

class Basic_dataset(Dataset):
    def __init__(self, task_params,data, data_eval=None):
        """
        Basic dataset class for time series data
        Args:
            task_params (dict): dictionary of task parameters
            data (np.ndarray; T x dim_x): time series data
        """
        self.task_params=task_params
        self.data = torch.from_numpy(data)
        if data_eval is not None:
            self.data_eval = torch.from_numpy(data_eval)
        else:
            self.data_eval = self.data
        self.dur =task_params['dur']
        self.n_trials = task_params['n_trials']

    def __len__(self):
        """ Return number of trials in an epoch """
        return self.n_trials
    
    def __getitem__(self, idx):
        """
        Return a trial of length self.dur
        Args:
            idx (int): trial index, arbitrary as trials are sampled randomly
        Returns:
            trial (torch.tensor; dim_x x self.dur): trial of length self.dur
        """
        t_start = torch.randint(low=0,high=self.data.shape[0]-self.dur,size=(1,))[0]
        t_end = t_start + self.dur
        return self.data[t_start:t_end].T, torch.zeros(0,self.dur,device=self.data.device)

