import os

import h5py
import torch
from torch.utils.data import Dataset

from ._base import register_dataset


@register_dataset('rd')
class ReactionDiffusionDataset(Dataset):
    def __init__(self, root, split, data_file):
        self.root = root
        self.split = split
        self.data_file = data_file
        self.file = h5py.File(os.path.join(root, data_file), 'r')

        self.data = self.file['solution']

        # self.solution = self.file['solution']  # (n_sims, nx, nt)
        # self.initial_condition = self.file['initial_condition']  # (n_sims, nx)
        # self.bc_flux_left = self.file['bc_flux_left']  # (n_sims,)
        # self.bc_flux_right = self.file['bc_flux_right']  # (n_sims,)
        
        self.n_data, self.nx, self.nt = self.data.shape
        
        # n_train = int(self.n_sims * train_ratio)
        # if split == 'train':
        #     self.indices = list(range(n_train))
        # else:  # test
        #     self.indices = list(range(n_train, self.n_sims))
        
        # self.n_data = len(self.indices)

    def __del__(self):
        self.file.close()

    def __len__(self):
        return self.n_data

    def __getitem__(self, index):
        data = torch.from_numpy(self.data[index]).float()
        return data
