from torch.utils.data import Dataset, DataLoader
import numpy as np

class PDE_Dataset(Dataset):
    def __init__(self, y, b, G, iter_num, f=None):
        self.y = y
        self.b = b
        self.G = G
        self.iter_num = iter_num
        self.f = f

    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        y = self.y[idx]
        b = self.b[idx]
        iter_num = self.iter_num[idx]
        if self.f is not None:
            f = self.f[idx]
        else:
            f = 0
        return y, b, self.G, iter_num, f


def get_dataloader(opt):
    data_dir = "data_generation/"
    train_data = np.load(data_dir + "train_data.npy")
    train_b = np.load(data_dir + "train_b.npy")
    train_iter_num = np.load(data_dir + "train_iter_num.npy")
    train_G = np.load(data_dir + "train_G.npy")
    train_dataset = PDE_Dataset(train_data, train_b, train_G, train_iter_num)
    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=8)

    test_data1 = np.load(data_dir + "test_data1.npy")
    test_b1 = np.load(data_dir + "test_b1.npy")
    test_iter_num1 = np.load(data_dir + "test_iter_num1.npy")
    test_G1 = np.load(data_dir + "test_G1.npy")
    test_dataset1 = PDE_Dataset(test_data1, test_b1, test_G1, test_iter_num1)
    test_loader1 = DataLoader(test_dataset1, batch_size=1, shuffle=True, num_workers=8)

    test_data2 = np.load(data_dir + "test_data2.npy")
    test_b2 = np.load(data_dir + "test_b2.npy")
    test_iter_num2 = np.load(data_dir + "test_iter_num2.npy")
    test_G2 = np.load(data_dir + "test_G2.npy")
    test_dataset2 = PDE_Dataset(test_data2, test_b2, test_G2, test_iter_num2)
    test_loader2 = DataLoader(test_dataset2, batch_size=1, shuffle=True, num_workers=8)

    test_data3 = np.load(data_dir + "test_data3.npy")
    test_b3 = np.load(data_dir + "test_b3.npy")
    test_iter_num3 = np.load(data_dir + "test_iter_num3.npy")
    test_G3 = np.load(data_dir + "test_G3.npy")
    test_dataset3 = PDE_Dataset(test_data3, test_b3, test_G3, test_iter_num3)
    test_loader3 = DataLoader(test_dataset3, batch_size=1, shuffle=True, num_workers=8)

    test_data4 = np.load(data_dir + "test_data4.npy")
    test_b4 = np.load(data_dir + "test_b4.npy")
    test_iter_num4 = np.load(data_dir + "test_iter_num4.npy")
    test_G4 = np.load(data_dir + "test_G4.npy")
    test_f4 = np.load(data_dir + "test_all_f4.npy")
    test_dataset4 = PDE_Dataset(test_data4, test_b4, test_G4, test_iter_num4, test_f4)
    test_loader4 = DataLoader(test_dataset4, batch_size=1, shuffle=True, num_workers=8)

    test_loaders = [test_loader1, test_loader2, test_loader3, test_loader4]
    return train_loader, test_loaders


