import torch
import numpy as np
from dataloaders.sequential_dataset import SequentialDataSet

class DummyDataset(SequentialDataSet):
    def __init__(self, 
                 shape, # Shape should be defined according to the actual data dimensions you expect
                 num_samples=100, # Number of samples in the dataset
                 if_test=False,
    ):
        """
        Initializes a dummy dataset.
        
        :param shape: tuple representing the shape of the data (batch, spatial dimensions..., time, channels)
        :param num_samples: total number of samples in the dataset
        :param n_time_steps: number of time steps in each sample
        :param n_channels: number of channels (variables like density, pressure etc.)
        """
        self.data = torch.rand(num_samples, *shape) # Generate random data
        self.grid = torch.rand(shape[0],1) # Grid dimensions should exclude batch and channels

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        Returns initial condition, full data, and grid corresponding to the index.
        """
        return self.data[idx], self.data[idx], self.grid
    
    def input_shape(self):
        """
        Returns the input shape of the dataset.
        """
        return self.data.shape[1:]
