from torch.utils.data import Dataset, DataLoader
import numpy as np

class PDE_Dataset(Dataset):
    def __init__(self, y, b, b1, b2, G, G1, G2, f=None, f1=None, f2=None):
        self.y = y
        self.b = b
        self.b1 = b1
        self.b2 = b2
        self.G = G
        self.G1 = G1
        self.G2 = G2
        self.f = f
        self.f1 = f1
        self.f2 = f2

    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        y = self.y[idx]
        b = self.b[idx]
        b1 = self.b1[idx]
        b2 = self.b2[idx]
        if self.f is not None:
            f = self.f[idx]
            f1 = self.f1[idx]
            f2 = self.f2[idx]
        else:
            f, f1, f2 = np.zeros_like(b), np.zeros_like(b1), np.zeros_like(b2)
        return y, b, b1, b2, self.G, self.G1, self.G2, f, f1, f2


def get_dataloader(opt):
    data_dir = "data_generation/"
    train_data = np.load(data_dir + "train_data.npy")
    train_b = np.load(data_dir + "train_b.npy")
    train_b1 = np.load(data_dir + "train_b1.npy")
    train_b2 = np.load(data_dir + "train_b2.npy")
    train_G = np.load(data_dir + "train_G.npy")
    train_G1 = np.load(data_dir + "train_G1.npy")
    train_G2 = np.load(data_dir + "train_G2.npy")
    train_dataset = PDE_Dataset(train_data, train_b, train_b1, train_b2, train_G, train_G1, train_G2)
    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=8)

    test_data1 = np.load(data_dir + "test_data1.npy")
    test_b1 = np.load(data_dir + "test_b1.npy")
    test_b1_1 = np.load(data_dir + "test_b1_1.npy")
    test_b1_2 = np.load(data_dir + "test_b1_2.npy")
    test_G1 = np.load(data_dir + "test_G1.npy")
    test_G1_1 = np.load(data_dir + "test_G1_1.npy")
    test_G1_2 = np.load(data_dir + "test_G1_2.npy")
    test_dataset1 = PDE_Dataset(test_data1, test_b1, test_b1_1, test_b1_2, test_G1, test_G1_1, test_G1_2)
    test_loader1 = DataLoader(test_dataset1, batch_size=1, shuffle=False, num_workers=8)

    test_data2 = np.load(data_dir + "test_data2.npy")
    test_b2 = np.load(data_dir + "test_b2.npy")
    test_b2_1 = np.load(data_dir + "test_b2_1.npy")
    test_b2_2 = np.load(data_dir + "test_b2_2.npy")
    test_G2 = np.load(data_dir + "test_G2.npy")
    test_G2_1 = np.load(data_dir + "test_G2_1.npy")
    test_G2_2 = np.load(data_dir + "test_G2_2.npy")
    test_dataset2 = PDE_Dataset(test_data2, test_b2, test_b2_1, test_b2_2, test_G2, test_G2_1, test_G2_2)
    test_loader2 = DataLoader(test_dataset2, batch_size=1, shuffle=False, num_workers=8)

    test_data3 = np.load(data_dir + "test_data3.npy")
    test_b3 = np.load(data_dir + "test_b3.npy")
    test_b3_1 = np.load(data_dir + "test_b3_1.npy")
    test_b3_2 = np.load(data_dir + "test_b3_2.npy")
    test_G3 = np.load(data_dir + "test_G3.npy")
    test_G3_1 = np.load(data_dir + "test_G3_1.npy")
    test_G3_2 = np.load(data_dir + "test_G3_2.npy")
    test_dataset3 = PDE_Dataset(test_data3, test_b3, test_b3_1, test_b3_2, test_G3, test_G3_1, test_G3_2)
    test_loader3 = DataLoader(test_dataset3, batch_size=1, shuffle=False, num_workers=8)

    test_data4 = np.load(data_dir + "test_data4.npy")
    test_b4 = np.load(data_dir + "test_b4.npy")
    test_b4_1 = np.load(data_dir + "test_b4_1.npy")
    test_b4_2 = np.load(data_dir + "test_b4_2.npy")
    test_G4 = np.load(data_dir + "test_G4.npy")
    test_G4_1 = np.load(data_dir + "test_G4_1.npy")
    test_G4_2 = np.load(data_dir + "test_G4_2.npy")
    test_f4 = np.load(data_dir + "test_all_f4.npy")
    test_f4_1 = np.load(data_dir + "test_all_f4_1.npy")
    test_f4_2 = np.load(data_dir + "test_all_f4_2.npy")
    test_dataset4 = PDE_Dataset(test_data4, test_b4, test_b4_1, test_b4_2, test_G4, test_G4_1, test_G4_2, test_f4, test_f4_1, test_f4_2)
    test_loader4 = DataLoader(test_dataset4, batch_size=1, shuffle=False, num_workers=8)

    test_loaders = [test_loader1, test_loader2, test_loader3, test_loader4]
    return train_loader, test_loaders


