import os
import sys
import torch

sys.path.append("..")

from config import *
from system import *
from utils import *
from data_gen.dataset_util import *
from PINN.adaptivePINN import *
from PINN.naive import *
from dataloader import *
args = parser.parse_args()

FILE_ABS_PATH = os.path.abspath(__file__)
FILE_DIR_PATH = os.path.dirname(FILE_ABS_PATH)

# CUDA support
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

def generate_dataset():
    
    # Determine parameter intervals.
    beta_interval = int((args.beta_max - args.beta_min) // args.beta_step)
    nu_interval = int((args.nu_max - args.nu_min) // args.nu_step)
    rho_interval = int((args.rho_max - args.rho_min) // args.rho_step)
    epsilon_interval = int((args.epsilon_max - args.epsilon_min) // args.epsilon_step)
    theta_interval = int((args.theta_max - args.theta_min) // args.theta_step)

    # Bring a domain.
    X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary = generate_domain()
    
    # For each beta, nu, rho, epsilon and theta, generate the prior.
    i, j, k, l, m = 0, 0, 0, 0, 0
    for i in range(beta_interval + 1):
        for j in range(nu_interval + 1):
            for k in range(rho_interval + 1):
                for l in range(epsilon_interval + 1):
                    for m in range(theta_interval + 1):
                
                        beta = args.beta_min + (i * args.beta_step)
                        nu = args.nu_min + (j * args.nu_step)
                        rho = args.rho_min + (k * args.rho_step)
                        epsilon = args.epsilon_min + (l * args.epsilon_step)
                        theta = args.theta_min + (m * args.theta_step)
                
                        # Obtain a save path and check if there is already a solution.
                        system = system_determinator(beta, nu, rho, epsilon, theta)
                        str_analytical = "analytical" if args.numerical else "PINN_based"
                        save_path_root = f"{FILE_DIR_PATH}/dataset/{system}/{args.u0_str}"
                        file_name = f"{beta}_{nu}_{rho}_{epsilon}_{theta}_{str_analytical}.csv"
                        full_save_path_dir = save_path_root + DATASET_DIRECTORIES[0]
                        full_save_path = f"{full_save_path_dir}/{file_name}"
                        #if os.path.isfile(full_save_path): continue
                            
                
                        # Generate the solution in full mesh.
                        u_vals = 0
                 
                        if args.numerical:
                            u_vals, u_v = convection_diffusion_reaction_discrete_solution(args.u0_str, beta, nu, rho, epsilon, theta, MESH["spatial"][0], MESH["temporal"][0])
                            u_vals = torch.tensor(u_vals)
                            u_vals = u_vals.unsqueeze(1)
                        else:
                            DOMAIN_LIST = [X_star, X_star_noboundary_noinitial, X_star_initial, X_star_boundary]
                            
                            # naive PINN
                            for redo in range(10):
                                pinn = NaivePINN(beta, nu, rho, epsilon, theta)
                                result = pinn.train(DOMAIN_LIST, args.PINN_epoch)
                                if result > 0: break
                            u_vals = pinn.predict(X_star[:, 0:1].to(device, dtype=torch.float), X_star[:, 1:2].to(device, dtype=torch.float))\
                                .detach().cpu()
                            
                        full_data = torch.cat((X_star, u_vals), dim=1)
                        
                        N_train_shot, N_train_unshot, N_test, N_val = 800, 200, 1000, 1000
                        X_train_shot, X_train_unshot, X_test, X_val = 0, 0, 0, 0
                        if not args.extrapolation:
                            N_train_shot, N_train_unshot, N_test, N_val = 800, 200, 1000, 1000
                            indices = torch.randperm(X_star_noboundary_noinitial.shape[0])
                            X_train_shot = X_star_noboundary_noinitial[indices[:N_train_shot], :]
                            X_train_unshot = X_star_noboundary_noinitial[indices[N_train_shot:N_train_shot+N_train_unshot], :]
                            X_test = X_star_noboundary_noinitial[indices[N_train_shot+N_train_unshot:N_train_shot+N_train_unshot+N_test], :]
                            X_val = X_star_noboundary_noinitial[indices[N_train_shot+N_train_unshot+N_test:N_train_shot+N_train_unshot+N_test+N_val], :]
                        else:
                            X_below = X_star_noboundary_noinitial[X_star_noboundary_noinitial[:, 1] <= 0.6]
                            X_above = X_star_noboundary_noinitial[X_star_noboundary_noinitial[:, 1] > 0.6]
                            below_indices = torch.randperm(X_below.shape[0])
                            above_indices = torch.randperm(X_above.shape[0])
                            X_train_shot = X_below[below_indices[:N_train_shot], :]
                            X_train_unshot = X_below[below_indices[N_train_shot:N_train_shot + N_train_unshot], :]
                            X_val = X_below[below_indices[N_train_shot+N_train_unshot:N_train_shot + N_train_unshot+N_val], :]
                            X_test = X_above[above_indices[:N_test], :]
                        
                        # Match the solution.
                        boundary_data = match(X_star_boundary, full_data)
                        initial_data = match(X_star_initial, full_data)
                        train_shot_data = match(X_train_shot, full_data)
                        train_unshot_data = match(X_train_unshot, full_data)
                        test_data = match(X_test, full_data)
                        val_data = match(X_val, full_data)
                        
                        
                        # Make .csv file.
                        column = ['x_data', 't_data', 'u_data'] if not args.is_2D else ['x_data', 'y_data', 't_data', 'u_data']
                        if not args.extrapolation:
                            dataset_collection = [full_data, initial_data, boundary_data, train_shot_data, None, train_unshot_data, None, test_data, None, val_data]
                        else:
                            dataset_collection = [None, None, None, None, train_shot_data, None, train_unshot_data, None, test_data, None]
                        assert len(dataset_collection) == len(DATASET_DIRECTORIES), "Length mismatch."
                        for dataset_elem in range(len(dataset_collection)):
                            if dataset_collection[dataset_elem] != None:
                                make_csv(dataset_collection[dataset_elem], column, f"{save_path_root}{DATASET_DIRECTORIES[dataset_elem]}/{file_name}")
                        
                        # Print the log.
                        print(beta, nu, rho, epsilon, theta)

    print("Generating dataset is done.")

if __name__ == "__main__":
    generate_dataset()