
# from torch.utils.data import DataLoader, TensorDataset
from scipy.io import loadmat
import numpy as np
import torch
from h5py import File

# reading data, same as FNO: https://arxiv.org/pdf/2010.08895.pdf.
class MatReader(object):
    def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True):
        super(MatReader, self).__init__()

        self.to_torch = to_torch
        self.to_cuda = to_cuda
        self.to_float = to_float

        self.file_path = file_path

        self.data = None
        self.old_mat = None
        self._load_file()

    def _load_file(self):
        try:
            self.data = loadmat(self.file_path)
            self.old_mat = True
        except:
            self.data = File(self.file_path)
            self.old_mat = False

    def load_file(self, file_path):
        self.file_path = file_path
        self._load_file()

    def read_field(self, field):
        x = self.data[field]

        if not self.old_mat:
            x = x[()]
            x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1))

        if self.to_float:
            x = x.astype(np.float32)

        if self.to_torch:
            x = torch.from_numpy(x)

            if self.to_cuda:
                x = x.cuda()

        return x

    def set_cuda(self, to_cuda):
        self.to_cuda = to_cuda

    def set_torch(self, to_torch):
        self.to_torch = to_torch

    def set_float(self, to_float):
        self.to_float = to_float


def Loading_Data(file_path, sub_sampling_steps, resolution, num_train, num_test, batch_size):

    dataloader = MatReader(file_path)

    x_data = dataloader.read_field('a')[:,::sub_sampling_steps]
    y_data = dataloader.read_field('u')[:,::sub_sampling_steps]


    x_train = x_data[:num_train,:]
    y_train = y_data[:num_train,:]
    x_test = x_data[-num_test:,:]
    y_test = y_data[-num_test:,:]

    x_train = x_train.reshape(num_train,resolution,1)
    x_test = x_test.reshape(num_test,resolution,1)

    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

    return x_test, y_test, train_loader, test_loader


