import h5py
import torch

from torch.utils.data import Dataset
import numpy as np

import math

__all__ = [
    'DarcyDataset'
]

class DarcyDataset(Dataset):
    def __init__(self, file_name,
                 initial_step=1,
                 reduced_resolution=1,
                 reduced_resolution_t=1,
                 reduced_batch=1,
                 if_test=False,
                 test_ratio=0.1,
                 num_samples_max=-1,
                ):
        file_path = file_name

        with h5py.File(file_path, 'r') as f:
            if 'tensor' in f.keys():  # scalar equations
                ## data dim = [n, t, x1, ..., xd]
                _data = np.array(f['tensor'], dtype=np.float32)  # batch, time, x,...
                # recorrect the nt
                nt=min(_data.shape[1],f['t-coordinate'].shape[0]) if f.get('t-coordinate',None) else 1

                if len(_data.shape) == 3:  # 1D
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution]
                    ## convert to [n, x1, ..., xd, t]
                    _data = np.transpose(_data[:, :, :], (0, 2, 1))
                    self.data = _data[:, :, :, None]  # batch, x, t, ch
                    x = np.array(f["x-coordinate"], dtype='f')
                    t = np.array(f["t-coordinate"], dtype='f')[:nt] if f.get('t-coordinate',None) else np.array([0],dtype='f')
                    x = torch.tensor(x, dtype=torch.float)
                    t = torch.tensor(t, dtype=torch.float)
                    X, T = torch.meshgrid((x, t),indexing='ij')
                    self.grid = torch.stack((X,T),axis=-1)[::reduced_resolution,::reduced_resolution_t]


                if len(_data.shape) == 4:  
                    if nt==1:  # 2D Darcy flow
                        # u: label
                        _data = _data[::reduced_batch,:,::reduced_resolution,::reduced_resolution]
                        ## convert to [n, x1, ..., xd, t]
                        _data = np.transpose(_data[:, :, :, :], (0, 2, 3, 1))
                        self.data = _data
                        # nu: input
                        _data = np.array(f['nu'], dtype=np.float32)  # batch, time, x,...
                        _data = _data[::reduced_batch, None,::reduced_resolution,::reduced_resolution]
                        ## convert to [n, x1, ..., xd, t]
                        _data = np.transpose(_data[:, :, :, :], (0, 2, 3, 1))
                        self.data = np.concatenate([_data, self.data], axis=-1)
                        self.data = self.data[:, :, :, None, :]  # batch, x, y, t, ch
                        x = np.array(f["x-coordinate"], dtype='f')
                        y = np.array(f["y-coordinate"], dtype='f')
                        t = np.array(f["t-coordinate"], dtype='f')[:nt] if f.get('t-coordinate',None) else np.array([0],dtype='f')
                        x = torch.tensor(x, dtype=torch.float)
                        y = torch.tensor(y, dtype=torch.float)
                        t = torch.tensor(t, dtype=torch.float)
                        X, Y, T = torch.meshgrid((x, y, t),indexing='ij')
                        self.grid = torch.stack((X, Y, T), axis=-1)[::reduced_resolution, ::reduced_resolution,::reduced_resolution_t]
                    else:
                        _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution,::reduced_resolution]
                        _data = np.transpose(_data[:, :, :, :], (0, 2, 3, 1))
                        self.data=_data[:,:,:,:,None]
                        x = np.array(f["x-coordinate"], dtype='f')
                        y = np.array(f["y-coordinate"], dtype='f')
                        t = np.array(f["t-coordinate"], dtype='f')[:nt] if f.get('t-coordinate',None) else np.array([0],dtype='f')
                        x = torch.tensor(x, dtype=torch.float)
                        y = torch.tensor(y, dtype=torch.float)
                        t = torch.tensor(t, dtype=torch.float)
                        X, Y, T = torch.meshgrid((x, y, t),indexing='ij')
                        self.grid = torch.stack((X, Y, T), axis=-1)[::reduced_resolution, ::reduced_resolution,::reduced_resolution_t]

            else:  # NS equation
                _data = np.array(f['density'], dtype=np.float32)  # density: [batch, time, x1,..,xd]
                idx_cfd = _data.shape
                # recorrect the nt
                nt=min(_data.shape[1],f['t-coordinate'].shape[0])
                if len(idx_cfd)==3:  # 1D
                    self.data = np.zeros([idx_cfd[0]//reduced_batch,
                                          idx_cfd[2]//reduced_resolution,
                                          math.ceil(idx_cfd[1]/reduced_resolution_t),
                                          3],
                                        dtype=np.float32)
                    #density
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution]
                    ## convert to [x1, ..., xd, t, v]
                    _data = np.transpose(_data[:, :, :], (0, 2, 1))
                    self.data[...,0] = _data   # batch, x, t, ch
                    # pressure
                    _data = np.array(f['pressure'], dtype=np.float32)  # batch, time, x,...
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution]
                    ## convert to [x1, ..., xd, t, v]
                    _data = np.transpose(_data[:, :, :], (0, 2, 1))
                    self.data[...,1] = _data   # batch, x, t, ch
                    # Vx
                    _data = np.array(f['Vx'], dtype=np.float32)  # batch, time, x,...
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution]
                    ## convert to [x1, ..., xd, t, v]
                    _data = np.transpose(_data[:, :, :], (0, 2, 1))
                    self.data[...,2] = _data   # batch, x, t, ch

                    x = np.array(f["x-coordinate"], dtype='f')
                    t = np.array(f["t-coordinate"], dtype='f')[:nt]
                    x = torch.tensor(x, dtype=torch.float)
                    t = torch.tensor(t, dtype=torch.float)[:nt]
                    X, T = torch.meshgrid((x, t),indexing='ij')
                    self.grid = torch.stack((X,T),axis=-1)[::reduced_resolution,::reduced_resolution_t]

                if len(idx_cfd)==4:  # 2D
                    self.data = np.zeros([idx_cfd[0]//reduced_batch,
                                          idx_cfd[2]//reduced_resolution,
                                          idx_cfd[3]//reduced_resolution,
                                          math.ceil(idx_cfd[1]/reduced_resolution_t),
                                          4],
                                         dtype=np.float32)
                    # density
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution,::reduced_resolution]
                    ## convert to [x1, ..., xd, t, v]
                    _data = np.transpose(_data, (0, 2, 3, 1))
                    self.data[...,0] = _data   # batch, x, t, ch
                    # pressure
                    _data = np.array(f['pressure'], dtype=np.float32)  # batch, time, x,...
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution,::reduced_resolution]
                    ## convert to [x1, ..., xd, t, v]
                    _data = np.transpose(_data, (0, 2, 3, 1))
                    self.data[...,1] = _data   # batch, x, t, ch
                    # Vx
                    _data = np.array(f['Vx'], dtype=np.float32)  # batch, time, x,...
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution,::reduced_resolution]
                    ## convert to [x1, ..., xd, t, v]
                    _data = np.transpose(_data, (0, 2, 3, 1))
                    self.data[...,2] = _data   # batch, x, t, ch
                    # Vy
                    _data = np.array(f['Vy'], dtype=np.float32)  # batch, time, x,...
                    _data = _data[::reduced_batch,::reduced_resolution_t,::reduced_resolution,::reduced_resolution]
                    ## convert to [x1, ..., xd, t, v]
                    _data = np.transpose(_data, (0, 2, 3, 1))
                    self.data[...,3] = _data   # batch, x, t, ch

                    x = np.array(f["x-coordinate"], dtype='f')
                    y = np.array(f["y-coordinate"], dtype='f')
                    t = np.array(f["t-coordinate"], dtype='f')[:nt]
                    x = torch.tensor(x, dtype=torch.float)
                    y = torch.tensor(y, dtype=torch.float)
                    t = torch.tensor(t, dtype=torch.float)[:nt]
                    X, Y, T = torch.meshgrid((x, y, t),indexing='ij')
                    self.grid = torch.stack((X, Y, T), axis=-1)[::reduced_resolution, ::reduced_resolution,::reduced_resolution_t]

        self.dx=x[reduced_resolution]-x[0]
        self.dt=t[reduced_resolution_t]-t[0] if t.shape[0]>1  else None
        self.tmax=t[-1] if t.shape[0]>1  else None

        # Define the max number of samples
        if num_samples_max > 0:
            num_samples_max = min(num_samples_max, self.data.shape[0])
        else:
            num_samples_max = self.data.shape[0]

        # Construct train/test dataset
        test_idx = int(num_samples_max * (1-test_ratio))
        if if_test:
            self.data = self.data[test_idx:num_samples_max]
        else:
            self.data = self.data[:test_idx]

        # time steps used as initial conditions
        self.initial_step = initial_step
        self.data = torch.tensor(self.data)


    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx,...,:self.initial_step,:], self.data[idx], self.grid

