# This code is for the preparation of dataset.
# The folder structure of dataset.

import os
import sys
import math
import numpy as np
import torch

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)
FILE_DIR_PATH = os.path.dirname(FILE_DIR_PATH)
sys.path.append("..")

from utils.config import *
args = parser.parse_args()

if args.spatial_size == "pi":
    SPATIAL_SIZE = math.pi ; COEFF = 2
elif args.spatial_size == "1":
    SPATIAL_SIZE = 1 ; COEFF = 2*math.pi
else:
    SPATIAL_SIZE = 2*math.pi ; COEFF = 1
    
# Domain and mesh array
DOMAIN = {"spatial":[[(0, math.pi)], [(0, math.pi)]], "temporal":[(0, 1.0)]}
MESH = {"spatial":[args.xgrid, args.xgrid], "temporal":[args.nt]}



#### Define initial condition function ####
def function(u0: str):
    if u0 == 'sin':
        u0 = lambda x, y: 1 + torch.sin(COEFF*x)*torch.sin(COEFF*y)
    else:
        u0 = lambda x, y: 1
    return u0

# Generating mesh function.
def make_mesh(spatial_temporal_linspace):
    mesh = np.meshgrid(*spatial_temporal_linspace)
    grid_flat = [m.flatten()[:, None] for m in mesh]
    return np.hstack(grid_flat)

# Generating domain function.
def generate_domain():

    # Make a linspace.
    print(len(DOMAIN["spatial"]), "Dimension")
    spatial_linspace = []
    temporal_linspace = []
    for i_spatial in range(len(DOMAIN["spatial"])):
        for i_interval in range(len(DOMAIN["spatial"][i_spatial])):
            left, right = DOMAIN["spatial"][i_spatial][i_interval]
            spatial_linspace.append(np.linspace(left, right, MESH["spatial"][i_spatial]).reshape(-1, 1))
    before, after = DOMAIN["temporal"][0]
    temporal_linspace.append(np.linspace(before, after, MESH["temporal"][0]).reshape(-1, 1))
    spatial_temporal_linspace = spatial_linspace + temporal_linspace
    X_star = make_mesh(spatial_temporal_linspace)

    # Generate a collocation mesh.
    x_noboundary = [sp[1:-1].flatten() for sp in spatial_linspace]
    t_noinitial = [temporal_linspace[0][1:]]
    spatial_temporal_linspace = x_noboundary + t_noinitial
    X_star_noboundary_noinitial = make_mesh(spatial_temporal_linspace)

    # Generate a boundary mesh.
    X_star_boundary = []
    for i in range(len(spatial_linspace)):
        x_boundary = []
        for j in range(len(spatial_linspace)):
            if i == j:
                x_boundary.append(np.concatenate((spatial_linspace[j][0:1], spatial_linspace[j][-1:]), axis=0))
            else:
                x_boundary.append(spatial_linspace[j])
        spatial_temporal_linspace = x_boundary + temporal_linspace
        seg_mesh = make_mesh(spatial_temporal_linspace)
        X_star_boundary = np.vstack((X_star_boundary, seg_mesh)) if len(X_star_boundary) else seg_mesh
    X_star_boundary = np.unique(X_star_boundary, axis=0)

    # Generate an initial mesh.
    spatial_temporal_linspace = spatial_linspace + [temporal_linspace[0][0:1]]
    X_star_initial = make_mesh(spatial_temporal_linspace)
    X_star_initial = np.unique(X_star_initial, axis=0)
    
    return torch.tensor(X_star), torch.tensor(X_star_noboundary_noinitial), torch.tensor(X_star_initial), torch.tensor(X_star_boundary)


# Dataset directory generation
def generate_dataset_directory():
    
    # Check the directory already exists.
    # assert not os.path.isdir(f"{FILE_DIR_PATH}/dataset"), "There is already a dataset structure. Please remove before implement this code."
    
    # Generate the directories.
    Dataset_directory = ["/full", "/train", "/valid/extrapolation", "/valid/interpolation"
                        , "/test/extrapolation", "/test/interpolation"]
    root_path = ""
    path = ""
    root_path = f"{FILE_DIR_PATH}/dataset/2D_cdr"
    if not os.path.isdir(root_path): os.makedirs(root_path)     
    path = f"{FILE_DIR_PATH}/dataset/2D_cdr"
    if not os.path.isdir(path): os.makedirs(path)
    for dir_elem in Dataset_directory:
        if not os.path.isdir(f"{path}/{dir_elem}"): os.makedirs(f"{path}/{dir_elem}")
    for pde_name in ["SWE","CNSE","Darcy01"]:
        root_path = f"{FILE_DIR_PATH}/dataset/{pde_name}"
        if not os.path.isdir(root_path): os.makedirs(root_path)     


if __name__ == "__main__":
    generate_dataset_directory()

