import numpy as np
from torch.utils.data import Dataset
import pickle
import h5py
from PDEs import PDE

def get_dataset(args):
    base_data_path = f"data/{args['pde']}_train_1024_default.h5"
    valid_data_path = f"data/{args['pde']}_valid_1024_default.h5"
    test_data_path = f"data/{args['pde']}_test_4096_default.h5"
    pde_path = f"data/{args['pde']}_default.pkl"

    with open(pde_path, 'rb') as f:
        pde = pickle.load(f)

    nt = pde.nt_effective
    nx = pde.nx

    train_dataset = PDEDataset(path=base_data_path, mode='train', nt=nt, nx=nx, n_data=args['train_samples'], pde=pde, noise=args['noise'])
    valid_dataset = PDEDataset(path=valid_data_path, mode='valid', nt=nt, nx=nx, n_data=args['train_samples'], pde=pde)
    test_dataset = PDEDataset(path=test_data_path, mode='test', nt=nt, nx=nx, n_data=args['train_samples'], pde=pde)
    return train_dataset, valid_dataset, test_dataset

class PDEDataset(Dataset):
    def __init__(self, path: str, mode: str, nt: int, nx: int, pde: PDE=None, n_data: int=-1, noise: float=0.0):
        super().__init__()
        f = h5py.File(path, 'r')
        self.mode = mode
        self.data = f[self.mode]
        self.dataset = f'pde_{nt}-{nx}'
        self.pde = PDE() if pde is None else pde

        load_all = True
        if load_all:
            dataset_size = self.data[self.dataset].shape[0]
            n_data = self.data[self.dataset].shape[0] if n_data == -1 else n_data
            ind = np.random.randint(dataset_size, size=n_data)
            data = {k: np.array(self.data[k])[ind] for k in self.data.keys()}
            f.close()
            self.data = data

        t = np.linspace(pde.tmax - (pde.nt_effective - 1) * pde.dt, pde.tmax, pde.nt_effective)
        x = np.linspace(0.0, (pde.nx - 1) * pde.dx, pde.nx)
        u = self.data[self.dataset]
        u *= (1 + np.random.randn(*u.shape) * noise)
        dudt = (u[:, 2:, :] - u[:, :-2, :]) / (2 * pde.dt)
        dudx = (u[:, :, 2:] - u[:, :, :-2]) / (2 * pde.dx)
        dudxdx = (u[:, :, 2:] - 2 * u[:, :, 1:-1] + u[:, :, :-2]) / pde.dx / pde.dx
        dudxdxdx = (u[:, :, 4:] - 2 * u[:, :, 3:-1] + 2 * u[:, :, 1:-3] - u[:, :, :-4]) / 2 / pde.dx / pde.dx / pde.dx
        dudxdxdxdx = (u[:, :, 4:] - 4 * u[:, :, 3:-1] + 6 * u[:, :, 2:-2] - 4 * u[:, :, 1:-3] + u[:, :, :-4]) / pde.dx / pde.dx / pde.dx / pde.dx

        self.data = {}
        self.data['t'] = np.tile(t[None, 1:-1, None], (n_data, 1, pde.nx - 4)).reshape(-1).astype(np.float32)
        self.data['x'] = np.tile(x[None, None, 2:-2], (n_data, pde.nt_effective - 2, 1)).reshape(-1).astype(np.float32)
        self.data['u'] = u[:, 1:-1, 2:-2].reshape(-1).astype(np.float32)
        self.data['dudt'] = dudt[:, :, 2:-2].reshape(-1).astype(np.float32)
        self.data['dudx'] = dudx[:, 1:-1, 1:-1].reshape(-1).astype(np.float32)
        self.data['dudxdx'] = dudxdx[:, 1:-1, 1:-1].reshape(-1).astype(np.float32)
        self.data['dudxdxdx'] = dudxdxdx[:, 1:-1, :].reshape(-1).astype(np.float32)
        self.data['dudxdxdxdx'] = dudxdxdxdx[:, 1:-1, :].reshape(-1).astype(np.float32)

    def __len__(self):
        return len(self.data['u'])
    
    def __getitem__(self, idx):
        return {key: self.data[key][idx] for key in self.data}
