import h5py
import torch

from torch.utils.data import Dataset
import numpy as np

__all__ = [
    'AdvectionDataset'
]

class AdvectionDataset(Dataset):
    def __init__(self,
        file_name,
        reduced_resolution=1,
        reduced_resolution_t=1,
        reduced_batch=1,
        if_test=False,
        test_ratio=0.1,
        num_samples_max=-1,
    ):
        file_path = file_name
        with h5py.File(file_path, 'r') as f:
            _data = np.array(f['tensor'], dtype=np.float32)
            nt=min(_data.shape[1],f['t-coordinate'].shape[0]) if f.get('t-coordinate',None) else 1

            _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution]

            _data = np.transpose(_data[:, :, :], (0, 2, 1))
            self.data = _data[:, :, :, None]  # batch, x, t, ch
            x = np.array(f["x-coordinate"], dtype='f')
            t = np.array(f["t-coordinate"], dtype='f')[:nt] if f.get('t-coordinate',None) else np.array([0],dtype='f')
            x = torch.tensor(x, dtype=torch.float)
            t = torch.tensor(t, dtype=torch.float)
            X, T = torch.meshgrid((x, t),indexing='ij')
            self.grid = torch.stack((X,T),axis=-1)[::reduced_resolution,::reduced_resolution_t]

        self.dx=x[reduced_resolution]-x[0]
        self.dt=t[reduced_resolution_t]-t[0] if t.shape[0]>1  else None
        self.tmax=t[-1] if t.shape[0]>1  else None

        if num_samples_max > 0:
            num_samples_max = min(num_samples_max, self.data.shape[0])
        else:
            num_samples_max = self.data.shape[0]

        test_idx = int(num_samples_max * (1-test_ratio))
        if if_test:
            self.data = self.data[test_idx:num_samples_max]
        else:
            self.data = self.data[:test_idx]

        self.data = torch.tensor(self.data)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.grid

