
from torch.utils.data import DataLoader, TensorDataset
from scipy.io import loadmat
import numpy as np
import torch
from h5py import File


def Loading_Data(file_path, sub_sampling_steps, resolution, num_train, num_test, batch_size, input_time_slices_num, next_time_slices_num, inverse = False):

    train_data = torch.load('your/path/data/train_rayleigh_taylor_instability_At_25.pt')
    train_data = train_data[:, :, ::sub_sampling_steps,::sub_sampling_steps,::sub_sampling_steps, :]

    x_train = train_data[:, :input_time_slices_num, :, :, :, :]
    x_train = x_train.reshape(-1, x_train.size(2), x_train.size(3), x_train.size(4), x_train.size(5))

    y_train = train_data[:, input_time_slices_num:next_time_slices_num+input_time_slices_num, :, :, :, :]
    y_train = y_train.reshape(-1, y_train.size(2), y_train.size(3), y_train.size(4), y_train.size(5))

    test_data = torch.load('your/path/data/test_rayleigh_taylor_instability_At_25.pt')
    test_data = test_data[:, :, ::sub_sampling_steps,::sub_sampling_steps,::sub_sampling_steps, :]
    x_test = test_data[:, :input_time_slices_num, :, :, :, :]
    x_test = x_test.reshape(-1, x_test.size(2), x_test.size(3), x_test.size(4), x_test.size(5))

    y_test = test_data[:, input_time_slices_num:next_time_slices_num+input_time_slices_num, :, :, :, :]
    y_test = y_test.reshape(-1, y_test.size(2), y_test.size(3), y_test.size(4), y_test.size(5))


    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

    return x_test, y_test, train_loader, test_loader