"""
For experimenting
- scaling of nodes in a graph: random-feature fcnn vs. random-feature gnn       (FCNN cannot scale with the number of nodes, while GNN error remains almost constant)
- comparison of data-agnostic sampling (ELM) vs. data-driven sampling (SWIM)    (Data-driven random sampling of parameters)
- Training a GNN with N number of nodes and testing it with >>N number of nodes (GNN generalization)
"""

import toml
import json
import os
from time import time
import numpy as np
import torch as tc
from math import ceil
from sklearn.model_selection import train_test_split
from argparse import ArgumentParser
from src.utils import Mesh, LinearSolveArgs, SamplingArgs, eval_dxdt, TradTrainingArgs
from src.data import MassSpring
from src.model import FCNN, GNN
from src.train import sample_and_linear_solve, traditional_training

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, help=".toml file for reading config of the experiment", required=True)
argparser.add_argument("-o", "--out", type=str, help="Output directory to save the results of the experiment", required=True)
args = vars(argparser.parse_args())
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
dof, auto_diff_mode, device = config["dof"], config["auto_diff_mode"], config["device"]
Nx_start    = config["Nx_start"]        # Scaling number of nodes starting point
Nx_gnn      = config["Nx_gnn"]          # specifies when to not train new GNNs and use last one for inference
Nx_end_gnn  = config["Nx_end_gnn"]      # scaling end for GNN
Nx_end_fcnn = config["Nx_end_fcnn"]     # scaling end for FCNN due to training memory overhead
assert Nx_gnn > Nx_start, "Nx_gnn must be larger than Nx_start for the GNN to be trained at least once."
assert Nx_start < Nx_end_fcnn, "Nx_end_fcnn must be larger than Nx_start for FCNN to be at least trained once for comparison, ideally more."
assert Nx_gnn < Nx_end_gnn, "Nx_end_gnn must be larger than Nx_gnn for the GNN to be trained at least once before being fixed."
assert Nx_end_gnn > Nx_end_fcnn, "Nx_end_gnn must be greater than Nx_end_fcnn because GNN inference costs less memory than FCNN training."
out_dir_path= os.path.join(args["out"], f"chain_start{Nx_start}_end{Nx_end_fcnn}fcnn_{Nx_end_gnn}gnn")
if not os.path.exists(out_dir_path):
    os.makedirs(out_dir_path)
data_config, model_config, train_config = config["data"], config["model"], config["train"]

# Experiment results will be gathered here
results = { "n_nodes": [], "elm-rf-hnn": [], "swim-rf-hnn": [], "adam-hnn": [],
                           "elm-rf-hgn": [], "swim-rf-hgn": [], "adam-hgn": [] }

def prepare_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    n_points = data_config["n_points"]
    n_features = np.prod(n_obj) * 2*dof
    x = rng.uniform(data_config["x_min"], data_config["x_max"], size=(n_points, n_features)).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], data_config["spring_constant"], Mesh(data_config["meshing"]))
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(system.to_array(flatten=True)),
        tc.from_numpy(system.dxdt(flatten=True)),
        train_size=ceil(n_points*data_config["train_test_split"]), shuffle=True, random_state=data_seed+1
    )
    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_fcnn(n_obj, model_seed):
    return FCNN(dof=dof, n_obj=n_obj, n_features=np.prod(n_obj) * 2*dof,
                width=model_config["width"], activ_str=model_config["activ_str"],
                init_method=model_config["init_method"], seed=model_seed, dtype=dtype_tc)

def create_gnn(n_obj, edge_index, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               direct=model_config["direct"],
               msg_width=model_config["width"]-model_config["enc_width"],
               enc_width=model_config["enc_width"],
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly, sampling_seed):
    sampling_args = SamplingArgs(
        param_sampler=param_sampler,
        seed=sampling_seed,
        sample_uniformly=sample_uniformly,
        elm_bias_start=train_config["elm_bias_start"],
        elm_bias_end=train_config["elm_bias_end"],
        resample_duplicates=train_config["resample_duplicates"],
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode=auto_diff_mode,
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device,
    )
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)
    return dxdt_mse, dxdt_rel2

def trad_train(model, x_train, x_test, L, dxdt_train, dxdt_test):
    """
    Traditionally train (using gradient-descent based optimization) FCNN or GNN
    """
    training_args = TradTrainingArgs(
        n_steps=train_config["n_steps"],
        batch_size=train_config["batch_size"],
        device=device,
        weight_init=train_config["weight_init"],
        lr_start=train_config["lr_start"],
        lr_end=train_config["lr_end"],
        weight_decay=train_config["weight_decay"],
        patience=train_config["patience"],
        optim_type=train_config["optim_type"],
        sched_type=train_config["sched_type"],
    )
    return traditional_training(model, x_train, x_test, L, dxdt_train, dxdt_test, training_args)

def train_and_save(model, x_train, x_test, L, dxdt_train, dxdt_test, param_sampler, model_name, sampling_seed, is_random_feature_model):
    if is_random_feature_model:
        time0 = time()
        train_mse, train_rel2 = sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly=True, sampling_seed=sampling_seed)
        time1 = time()
        train_time = time1 - time0

        time0 = time()
        test_mse, test_rel2 = eval_dxdt(model, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
        time1 = time()
        infer_time = time1 - time0
        results[model_name].append({ "mse_train": train_mse, "relative_l2_train": train_rel2,
                                    "mse_test": test_mse, "relative_l2_test": test_rel2,
                                    "train_time": train_time, "infer_time": infer_time })
    else:
        time0 = time()
        trad_train(model, x_train, x_test, L, dxdt_train, dxdt_test)
        time1 = time()
        train_time = time1 - time0
        train_mse, train_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)

        time0 = time()
        test_mse, test_rel2 = eval_dxdt(model, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
        time1 = time()
        infer_time = time1 - time0
        results[model_name].append({ "mse_train": train_mse, "relative_l2_train": train_rel2,
                                    "mse_test": test_mse, "relative_l2_test": test_rel2,
                                    "train_time": train_time, "infer_time": infer_time })

    print(f"\t{model_name}\t\t\t{results[model_name][-1]}")
    tc.save(model, os.path.join(out_dir_path, f"{model_name}_Nx{Nx}.pt"))
    del model
    tc.cuda.empty_cache()

def load_and_save(Nx, edge_index, model_name, x_test, dxdt_test):    # only for gnn
    model = tc.load(os.path.join(out_dir_path, f"{model_name}_Nx{Nx_gnn}.pt"), weights_only=False)
    model.n_obj = [Nx]
    model.edge_index = edge_index

    time0 = time()
    test_mse, test_rel2= eval_dxdt(model, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
    time1 = time()
    infer_time = time1 - time0
    results[model_name].append({ "mse_test": test_mse, "relative_l2_test": test_rel2, "infer_time": infer_time })
    print(f"\t{model_name}\t\t\t{results[model_name][-1]}")
    del model

data_seed = data_config["data_seed"]
model_seed = model_config["model_seed"]
sampling_seed = train_config["sampling_seed"]
print("Warming up...")
x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data([Nx_start], data_seed)
model = create_fcnn([Nx_start], model_seed)
_, _= sample_fit(model, x_train, L, dxdt_train, param_sampler="random", sample_uniformly=True, sampling_seed=sampling_seed)
del model, x_train, x_test, L, edge_index, dxdt_train, dxdt_test
tc.cuda.empty_cache()
print("Warm-up complete.")

# Scale Nx exponentially [2,4,8,16,32,64,128,256...]
Nx = Nx_start
while Nx <= Nx_end_gnn:
    print(f"-> training for n_obj = [{Nx}]")
    # ==== Prepare data
    n_obj = [Nx]
    results["n_nodes"].append(n_obj)
    x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_seed)

    if Nx < Nx_end_fcnn:
        # (ELM) Random Feature Hamiltonian Neural Network (rf-hgn)
        model = create_fcnn(n_obj, model_seed)
        train_and_save(model, x_train, x_test, L, dxdt_train, dxdt_test,
                       param_sampler="random", model_name="elm-rf-hnn", sampling_seed=sampling_seed,
                       is_random_feature_model=True)

        # (SWIM) Random Feature Hamiltonian Neural Network (rf-hgn)
        model = create_fcnn(n_obj, model_seed)
        train_and_save(model, x_train, x_test, L, dxdt_train, dxdt_test,
                       param_sampler="relu", model_name="swim-rf-hnn", sampling_seed=sampling_seed,
                       is_random_feature_model=True)

        # (Adam) Hamiltonian Neural Network (hnn)
        model = create_fcnn(n_obj, model_seed)
        train_and_save(model, x_train, x_test, L, dxdt_train, dxdt_test,
                       param_sampler=None, model_name="adam-hnn", sampling_seed=sampling_seed,
                       is_random_feature_model=False)

    if Nx > Nx_gnn: # Do not train new GNNs, use the last one
        # Load the ELM-GNN which is already trained with Nx_gnn number of nodes and save results
        # (ELM) Random-Feature Hamiltonian Graph Network (RF-HGN)
        load_and_save(Nx, edge_index, "elm-rf-hgn", x_test, dxdt_test)
        # (SWIM) Random-Feature Hamiltonian Graph Network (RF-HGN)
        load_and_save(Nx, edge_index, "swim-rf-hgn", x_test, dxdt_test)
        # (Adam) Hamiltonian Graph Network (RF-HGN)
        load_and_save(Nx, edge_index, "adam-hgn", x_test, dxdt_test)
    else:
        # (ELM) Random-Feature Hamiltonian Graph Network (RF-HGN)
        model = create_gnn(n_obj, edge_index, model_seed)
        train_and_save(model, x_train, x_test, L, dxdt_train, dxdt_test,
                       param_sampler="random", model_name="elm-rf-hgn", sampling_seed=sampling_seed,
                       is_random_feature_model=True)
        # (SWIM) Random-Feature Hamiltonian Graph Network (RF-HGN)
        model = create_gnn(n_obj, edge_index, model_seed)
        train_and_save(model, x_train, x_test, L, dxdt_train, dxdt_test,
                       param_sampler="relu", model_name="swim-rf-hgn", sampling_seed=sampling_seed,
                       is_random_feature_model=True)
        # (Adam) Hamiltonian Graph Network (RF-HGN)
        model = create_gnn(n_obj, edge_index, model_seed)
        train_and_save(model, x_train, x_test, L, dxdt_train, dxdt_test,
                       param_sampler=None, model_name="adam-hgn", sampling_seed=sampling_seed,
                       is_random_feature_model=False)

    Nx *= 2 # scale exponentially the number of nodes in the system, effectively also scaling the total number of data points, of course.

# Save results
results_file_path = os.path.join(out_dir_path, "results.json")
results_file = open(results_file_path, "w")
json.dump(results, results_file, indent=4)
results_file.close()

# Save configuration for this experiment
config_file_path = os.path.join(out_dir_path, "config.toml")
config_file = open(config_file_path, "w")
toml.dump(config, config_file)
config_file.close()
exit(0)
