# This code is for the preparation of dataset.
# The folder structure of dataset.

import os
import sys
import math
import pandas as pd

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)
sys.path.append("..")

from system import *
from config import *
from dataloader import *
args = parser.parse_args()


# Dataset directory structure
DATASET_DIRECTORIES = ["/full", \
    "/train/initial", "/train/boundary", \
    "/train/collocation/shot/interpolation", "/train/collocation/shot/extrapolation", \
    "/train/collocation/unshot/interpolation", "/train/collocation/unshot/extrapolation",\
    "/test/interpolation", "/test/extrapolation", \
    "/val"]

# Domain and mesh array
DOMAIN = {"spatial":[[(0, 2*math.pi)]], "temporal":[(0, 1)]} if args.is_2D != 1 else \
    {"spatial":[[(0, 1)], [(0, 1)]], "temporal":[(0, 1)]}
MESH = {"spatial":[args.xgrid], "temporal":[args.nt]} if args.is_2D != 1 else \
    {"spatial":[20, 20], "temporal":[args.nt]}

# Dataset directory generation
def generate_dataset_directory():
    
    # Check the directory already exists.
    # assert not os.path.isdir(f"{FILE_DIR_PATH}/dataset"), "There is already a dataset structure. Please remove before implement this code."
    
    # Generate the directories.
    root_path = ""
    path = ""
    for system_elem in SYSTEM:
        root_path = f"{FILE_DIR_PATH}/dataset/{system_elem}"
        if not os.path.isdir(root_path): os.makedirs(root_path)        
        for init_cond_elem in INIT_COND:            
            path = f"{FILE_DIR_PATH}/dataset/{system_elem}/{init_cond_elem}"
            if not os.path.isdir(path): os.makedirs(path)
            for dir_elem in DATASET_DIRECTORIES:
                if not os.path.isdir(f"{path}/{dir_elem}"): os.makedirs(f"{path}/{dir_elem}")

# Generating mesh function.
def make_mesh(spatial_temporal_linspace):
    mesh = np.meshgrid(*spatial_temporal_linspace)
    grid_flat = [m.flatten()[:, None] for m in mesh]
    return np.hstack(grid_flat)

# Generating domain function.
def generate_domain():
 
    # Make a linspace.
    print(len(DOMAIN["spatial"]))
    spatial_linspace = []
    temporal_linspace = []
    for i_spatial in range(len(DOMAIN["spatial"])):
        for i_interval in range(len(DOMAIN["spatial"][i_spatial])):
            left, right = DOMAIN["spatial"][i_spatial][i_interval]
            spatial_linspace.append(np.linspace(left, right, MESH["spatial"][i_spatial]).reshape(-1, 1))
    before, after = DOMAIN["temporal"][0]
    temporal_linspace.append(np.linspace(before, after, MESH["temporal"][0]).reshape(-1, 1))
    spatial_temporal_linspace = spatial_linspace + temporal_linspace
    
    X_star = make_mesh(spatial_temporal_linspace)

    # Generate a collocation mesh.
    x_noboundary = [sp[1:-1].flatten() for sp in spatial_linspace]
    t_noinitial = [temporal_linspace[0][1:]]
    spatial_temporal_linspace = x_noboundary + t_noinitial
    X_star_noboundary_noinitial = make_mesh(spatial_temporal_linspace)

    # Generate a boundary mesh.
    X_star_boundary = []
    for i in range(len(spatial_linspace)):
        x_boundary = []
        for j in range(len(spatial_linspace)):
            if i == j:
                x_boundary.append(np.concatenate((spatial_linspace[j][0:1], spatial_linspace[j][-1:]), axis=0))
            else:
                x_boundary.append(spatial_linspace[j])
        spatial_temporal_linspace = x_boundary + temporal_linspace
        seg_mesh = make_mesh(spatial_temporal_linspace)
        X_star_boundary = np.vstack((X_star_boundary, seg_mesh)) if len(X_star_boundary) else seg_mesh
    X_star_boundary = np.unique(X_star_boundary, axis=0)

    # Generate an initial mesh.
    spatial_temporal_linspace = spatial_linspace + [temporal_linspace[0][0:1]]
    X_star_initial = make_mesh(spatial_temporal_linspace)
    X_star_initial = np.unique(X_star_initial, axis=0)
    
    return torch.tensor(X_star), torch.tensor(X_star_noboundary_noinitial), torch.tensor(X_star_initial), torch.tensor(X_star_boundary)

def calculate_loss():
    
    # Set steps.
    beta_step = args.beta_step
    nu_step = args.nu_step
    rho_step = args.rho_step
    epsilon_step = args.epsilon_step
    theta_step = args.theta_step

    # According to steps, determine intervals.
    beta_interval = int((args.beta_max - args.beta_min) // beta_step)
    nu_interval = int((args.nu_max - args.nu_min) // nu_step)
    rho_interval = int((args.rho_max - args.rho_min) // rho_step)

    # Loss list.
    L2_RE_loss_list = []

    for i in range(beta_interval + 1):
        for j in range(nu_interval + 1):
            for k in range(rho_interval + 1):
                
                beta = args.beta_min + i
                nu = args.nu_min + j
                rho = args.rho_min + k
                
                
                system = system_determinator(beta, nu, rho, 0, 0)
                # Load prior.
                PRIOR_LOAD_PATH = f"{FILE_DIR_PATH}/dataset/{system}/{args.u0_str}/full/{beta}_{nu}_{rho}_0.0_0.0_PINN_based.csv"
                prior_test_df = pd.read_csv(PRIOR_LOAD_PATH)
                prior_test_collocation_data = cdr_DATA(prior_test_df).get_all()
                
                # Load solution.
                SOLUTION_LOAD_PATH = f"{FILE_DIR_PATH}/dataset/{system}/{args.u0_str}/full/{beta}_{nu}_{rho}_0.0_0.0_analytical.csv"
                solution_full_df = pd.read_csv(SOLUTION_LOAD_PATH)
                solution_full_collocation_data = cdr_DATA(solution_full_df).get_all()
                
                
                # Estimate L2 relative error.
                L2_error_norm = torch.linalg.norm(prior_test_collocation_data[:, 2]-solution_full_collocation_data[:, 2], 2, dim = 0)
                L2_true_norm = torch.linalg.norm(solution_full_collocation_data[:, 2], 2, dim = 0)

                L2_absolute_error = torch.mean(torch.abs(prior_test_collocation_data[:, 2]-solution_full_collocation_data[:, 2]))
                L2_relative_error = L2_error_norm / L2_true_norm
                
                
                L2_RE_loss_list.append(L2_relative_error.item())
    
    print("=================================")
    print(f"beta range : {args.beta_min} ~ {args.beta_max}")
    print(f"nu range : {args.nu_min} ~ {args.nu_max}")
    print(f"rho range : {args.rho_min} ~ {args.rho_max}")
    print("List:", L2_RE_loss_list)
    print("Result:", sum(L2_RE_loss_list) / len(L2_RE_loss_list))
    print("=================================")

if __name__ == "__main__":
    generate_dataset_directory()
    