import scipy.io as scio
import numpy as np
import os
from matplotlib import pyplot as plt


# download and unzip Darcy_241.zip 
# from FNO datasets (https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-)
# to Forward_Problem/datas

RES = 241
ROOT = "data"
SRC1 = "piececonst_r{}_N1024_smooth1.mat".format(RES)
SRC2 = "piececonst_r{}_N1024_smooth2.mat".format(RES)
N = RES
M = RES
L = 1024


def draw_2D(x, y, u, x_label, y_label, u_label, filename=None):
    c = plt.pcolormesh(x, y, u, cmap='rainbow', shading='gouraud')
    plt.colorbar(c, label=u_label)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.savefig(filename, dpi=300)
    plt.clf()
    

def show_DF(u, filename):
    u = np.reshape(u, (N, M))
    x = np.array([np.linspace(0, u.shape[0], u.shape[0])])
    x = np.repeat(x, [u.shape[1]], axis=0)
    y = np.array([np.linspace(0, u.shape[1], u.shape[1])])
    y = np.repeat(y, [u.shape[0]], axis=0).T
    draw_2D(x, y, u, "x", "y", "u", filename)
    

def load_DF(path, verbose=False, idx=0):
    matdata = scio.loadmat(path)
    f = matdata['coeff']
    y = matdata['sol']
    x = np.reshape(np.dstack(np.meshgrid(np.linspace(0, 1, N), np.linspace(0, 1, M))), (N, M, 2))
    x = np.expand_dims(x, axis=0)
    x = np.repeat(x, L, axis=0)
    y = np.expand_dims(y, axis=3)
    f = np.expand_dims(f, axis=3)
    if verbose:
        print(x.shape)
        print(y.shape)
        print(f.shape)
        show_DF(f[idx], "test_f.png")
        show_DF(y[idx], "test_y.png")
    return x, y, f


def split_train_val(data, ratio=0.5):
    train_data = data[:int(data.shape[0]*ratio)]
    val_data = data[int(data.shape[0]*ratio):]
    return train_data, val_data


if __name__ == "__main__":
    x1, y1, f1 = load_DF(os.path.join(ROOT, SRC1), verbose=False)
    x2, y2, f2 = load_DF(os.path.join(ROOT, SRC2), verbose=False)
    x = np.concatenate((x1, x2), axis=0)
    y = np.concatenate((y1, y2), axis=0)
    f = np.concatenate((f1, f2), axis=0)
    train_x, val_x = split_train_val(x)
    train_y, val_y = split_train_val(y)
    train_f, val_f = split_train_val(f)
    train_data = {"x":train_x, "y2":train_y, "y1":train_f}
    val_data = {"x":val_x, "y2":val_y, "y1":val_f}
    np.save(os.path.join(ROOT, "DF_train".format(RES)), train_data)
    np.save(os.path.join(ROOT, "DF_val".format(RES)), val_data)
    
    print(train_data["x"].shape, train_data["y1"].shape, train_data["y2"].shape)
    print(val_data["x"].shape, val_data["y1"].shape, val_data["y2"].shape)
