import scipy.io as scio
import numpy as np
import os
from matplotlib import pyplot as plt

# download and unzip NavierStokes_V1e-5_N1200_T20.zip 
# from FNO datasets (https://drive.google.com/drive/folders/1UnbQh2WWc6knEHbLn-ZaXrKUZhp7pjt-)
# to Forward_Problem/datas

ROOT = "data"
SRC = "NavierStokes_V1e-5_N1200_T20"
N = 64
M = 64
L = 1200


def draw_2D(x, y, u, x_label, y_label, u_label, filename=None):
    c = plt.pcolormesh(x, y, u, cmap='rainbow', shading='gouraud')
    plt.colorbar(c, label=u_label)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.savefig(filename, dpi=300)
    plt.clf()
    

def show_NS(u, filename):
    u = np.reshape(u, (N, M))
    x = np.array([np.linspace(0, u.shape[0], u.shape[0])])
    x = np.repeat(x, [u.shape[1]], axis=0)
    y = np.array([np.linspace(0, u.shape[1], u.shape[1])])
    y = np.repeat(y, [u.shape[0]], axis=0).T
    draw_2D(x, y, u, "x", "y", "u", filename)
    

def load_NS(path, verbose=False, idx=0):
    n_frame = 10
    matdata = scio.loadmat(path)        
    y0 = matdata['u'][:,:,:,:n_frame]
    y = matdata['u'][:,:,:,-n_frame:]
    x = np.reshape(np.dstack(np.meshgrid(np.linspace(0, 1, N), np.linspace(0, 1, M))), (N, M, 2))
    x = np.expand_dims(x, axis=0)
    x = np.repeat(x, L, axis=0)
    # y = np.expand_dims(y, axis=3)
    # y0 = np.expand_dims(y0, axis=3)
    # y0 = np.concatenate((x, y0), axis=3)
    if verbose:
        print(x.shape)
        print(y.shape)
        print(y0.shape)
        show_NS(y0[idx], "test_y0.png")
        show_NS(y[idx], "test_y.png")
    return x, y, y0


def split_train_val(data):
    # train_num = data.shape[0]*0.5
    train_num = 1000
    train_data = data[:train_num]
    val_data = data[train_num:]
    return train_data, val_data


if __name__ == "__main__":
    x, y, y0 = load_NS(os.path.join(ROOT, SRC), verbose=False)

    train_x, val_x = split_train_val(x)
    train_y, val_y = split_train_val(y)
    train_y0, val_y0 = split_train_val(y0)
    
    np.save(os.path.join(ROOT, "NS_train"), {"x":train_x, "y2":train_y, "y1":train_y0})
    np.save(os.path.join(ROOT, "NS_val"), {"x":val_x, "y2":val_y, "y1":val_y0})
