import numpy as np
import scipy.io as scio
import os
from matplotlib import pyplot as plt


# download plasticity/plas_N987_T20.mat
# from GeoFNO datasets (https://drive.google.com/drive/folders/1JUkPbx0-lgjFHPURH_kp1uqjfRn3aw9-)
# to Forward_Problem/datas

def load_PLA(path):
    matdata = scio.loadmat(path)
    output0 = matdata['output'][0]
    output1 = matdata['output'][1]
    np.save("_try/output0", output0)
    np.save("_try/output1", output1)
    return None

if __name__ == "__main__":
    x = []
    for x1 in np.linspace(0, 1, 101):
        for x2 in np.linspace(0, 1, 31):
            for x3 in np.linspace(0, 1, 20):
                x.append([x1, x2, x3])

    x = np.reshape(np.array(x), (101, 31, 20, 3))
    x = np.expand_dims(x, axis=0)
    x = np.repeat(x, 987, axis=0)
    print(x.shape)
    
    matdata = scio.loadmat("data/plas_N987_T20.mat")
    input = matdata["input"]
    output = matdata["output"]
    
    y1 = np.expand_dims(input, axis=-1)
    y1 = np.repeat(y1, 31, axis=-1)
    y1 = np.expand_dims(y1, axis=-1)
    y1 = np.repeat(y1, 20, axis=-1)
    y1 = np.expand_dims(y1, axis=-1)
    print(y1.shape)

    y2 = output
    print(y2.shape)
    
    np.save("data/PLA_train", {"x":x[:900,...], "y2":y2[:900,...], "y1":y1[:900,...]})
    np.save("data/PLA_val", {"x":x[900:,...], "y2":y2[900:,...], "y1":y1[900:,...]})
