import numpy as np
from matplotlib import pyplot as plt


# download pipe/Pipe_Q.npy
# from GeoFNO datasets (https://drive.google.com/drive/folders/1JUkPbx0-lgjFHPURH_kp1uqjfRn3aw9-)
# to Forward_Problem/datas

if __name__ == "__main__":
    Q = np.expand_dims(np.load("data/Pipe_Q.npy")[:,0], axis=-1)
    X = np.expand_dims(np.load("data/Pipe_X.npy"), axis=-1)
    Y = np.expand_dims(np.load("data/Pipe_Y.npy"), axis=-1)
    
    x = []
    for x1 in np.linspace(0, 1, 129):
        for x2 in np.linspace(0, 1, 129):
            x.append([x1, x2])
            
    x = np.reshape(np.array(x), (129, 129, 2))
    x = np.expand_dims(x, axis=0)
    x = np.repeat(x, 2310, axis=0)
    
    y1 = np.concatenate((X, Y), axis=-1)
    y2 = Q
    
    train_data = {"x":x[:1000,...], "y2":y2[:1000,...], "y1":y1[:1000]}
    val_data = {"x":x[-200:], "y2":y2[-200:], "y1":y1[-200:]}
    np.save("data/PIPE_train", train_data)
    np.save("data/PIPE_val", val_data)

    # plt.figure(figsize=(9, 4))
    # plt.scatter(x, y, c=q)
    # plt.savefig("pipe.png")
