from torch.utils.data import DataLoader, TensorDataset
from scipy.io import loadmat
import numpy as np
import torch
from h5py import File

# reading data, same as FNO: https://arxiv.org/pdf/2010.08895.pdf.
class MatReader(object):
    def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True):
        super(MatReader, self).__init__()

        self.to_torch = to_torch
        self.to_cuda = to_cuda
        self.to_float = to_float

        self.file_path = file_path

        self.data = None
        self.old_mat = None
        self._load_file()

    def _load_file(self):
        try:
            self.data = loadmat(self.file_path)
            self.old_mat = True
        except:
            self.data = File(self.file_path)
            self.old_mat = False

    def load_file(self, file_path):
        self.file_path = file_path
        self._load_file()

    def read_field(self, field):
        x = self.data[field]

        if not self.old_mat:
            x = x[()]
            x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1))

        if self.to_float:
            x = x.astype(np.float32)

        if self.to_torch:
            x = torch.from_numpy(x)

            if self.to_cuda:
                x = x.cuda()

        return x

    def set_cuda(self, to_cuda):
        self.to_cuda = to_cuda

    def set_torch(self, to_torch):
        self.to_torch = to_torch

    def set_float(self, to_float):
        self.to_float = to_float



def Loading_Data(file_path, sub_sampling_steps, resolution, num_train, num_test, batch_size, input_time_slices_num, next_time_slices_num, inverse = False):

    if not inverse:
        reader = MatReader(file_path)
        x_train = reader.read_field('u')[:num_train,::sub_sampling_steps,::sub_sampling_steps,:input_time_slices_num]
        y_train = reader.read_field('u')[:num_train,::sub_sampling_steps,::sub_sampling_steps,input_time_slices_num:next_time_slices_num+input_time_slices_num]

        reader = MatReader(file_path)
        x_test = reader.read_field('u')[-num_test:,::sub_sampling_steps,::sub_sampling_steps,:input_time_slices_num]
        y_test = reader.read_field('u')[-num_test:,::sub_sampling_steps,::sub_sampling_steps,input_time_slices_num:next_time_slices_num+input_time_slices_num]

    else:
        reader = MatReader(file_path)
        x_train = reader.read_field('u')[:num_train,::sub_sampling_steps,::sub_sampling_steps,:input_time_slices_num]
        y_train = reader.read_field('u')[:num_train,::sub_sampling_steps,::sub_sampling_steps,input_time_slices_num:next_time_slices_num+input_time_slices_num]

        reader = MatReader(file_path)
        x_test = reader.read_field('u')[-num_test:,::sub_sampling_steps,::sub_sampling_steps,:input_time_slices_num]
        y_test = reader.read_field('u')[-num_test:,::sub_sampling_steps,::sub_sampling_steps,input_time_slices_num:next_time_slices_num+input_time_slices_num]

    print(y_train.shape)
    print(y_test.shape)
    assert (resolution == y_train.shape[-2])
    assert (next_time_slices_num == y_train.shape[-1])

    x_train = x_train.reshape(num_train,resolution,resolution,input_time_slices_num)
    x_test = x_test.reshape(num_test,resolution,resolution,input_time_slices_num)

    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

    return x_test, y_test, train_loader, test_loader