"""
For experimenting evaluation of Adam trained RF-HGNs against SWIM sampled RF-HGNs
- #inputs: how accuracy and training time changes with the number of input samples
- Accuracy: MSE and Relative L2
- Training time: Time to solution
- Generalization: Evaluation beyond training set (test set error)
- Stability: Models are trained with different random seeds.

(training time, accuracy, #hyperparams, universal-approximation)

Target system: 3DOF Lattice [Nx,Ny] where
- Nx:   Number of nodes along x-axis
- Ny:   Number of nodes along y-axis
"""

import toml
import json
import os
from time import time
import numpy as np
import torch as tc
from math import ceil
from sklearn.model_selection import train_test_split
from argparse import ArgumentParser
from src.utils import Mesh, LinearSolveArgs, SamplingArgs, eval_dxdt, TradTrainingArgs
from src.data import MassSpring
from src.model import FCNN, GNN
from src.train import sample_and_linear_solve, traditional_training

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, help=".toml file for reading config of the experiment", required=True)
argparser.add_argument("-o", "--out", type=str, help="Output directory to save the results of the experiment", required=True)
args = vars(argparser.parse_args())
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
dof, auto_diff_mode, device = config["dof"], config["auto_diff_mode"], config["device"]

# node and edge feature dimensions for the GNN
N = config["N"]                 # [Nx,Ny]
repeat = config["repeat"]

out_dir_path= os.path.join(args["out"], f"lattice{N}_{device}")
if not os.path.exists(out_dir_path):
    os.makedirs(out_dir_path)
model_out_dir_path = os.path.join(out_dir_path, "models")
if not os.path.exists(model_out_dir_path):
    os.makedirs(model_out_dir_path)
data_config, model_config, train_config = config["data"], config["model"], config["train"]

# Experiment results will be gathered here
all_results = []

def prepare_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    n_points = data_config["n_points"]
    n_features = np.prod(n_obj) * 2*dof
    x = rng.uniform(data_config["x_min"], data_config["x_max"], size=(n_points, n_features)).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], data_config["spring_constant"], Mesh(data_config["meshing"]))
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(system.to_array(flatten=True)),
        tc.from_numpy(system.dxdt(flatten=True)),
        train_size=ceil(n_points*data_config["train_test_split"]), shuffle=True, random_state=data_seed+1
    )
    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               direct=model_config["direct"],
               msg_width=model_config["width"]-model_config["enc_width"],
               enc_width=model_config["enc_width"],
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc)

def create_fcnn(n_obj, model_seed):
    return FCNN(dof=dof, n_obj=n_obj, n_features=np.prod(n_obj) * 2*dof,
                width=model_config["width"], activ_str=model_config["activ_str"],
                init_method=model_config["init_method"], seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly, sampling_seed):
    sampling_args = SamplingArgs(
        param_sampler=param_sampler,
        seed=sampling_seed,
        sample_uniformly=sample_uniformly,
        elm_bias_start=train_config["elm_bias_start"],
        elm_bias_end=train_config["elm_bias_end"],
        resample_duplicates=train_config["resample_duplicates"],
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode=auto_diff_mode,
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device,
    )
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)
    return dxdt_mse, dxdt_rel2

def trad_train(model, x_train, x_test, L, dxdt_train, dxdt_test):
    """
    Traditionally train (using gradient-descent based optimization) FCNN or GNN
    """
    training_args = TradTrainingArgs(
        n_steps=train_config["n_steps"],
        batch_size=train_config["batch_size"],
        device=device,
        weight_init=train_config["weight_init"],
        lr_start=train_config["lr_start"],
        lr_end=train_config["lr_end"],
        weight_decay=train_config["weight_decay"],
        patience=train_config["patience"],
        optim_type=train_config["optim_type"],
        sched_type=train_config["sched_type"],
    )
    return traditional_training(model, x_train, x_test, L, dxdt_train, dxdt_test, training_args)
def train(model, x_train, x_test, L, dxdt_train, dxdt_test, param_sampler, model_name, sampling_seed, run_id):
    results = {}
    if model_name == "swim":
        time0 = time()
        train_mse, train_rel2 = sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly=True, sampling_seed=sampling_seed)
        time1 = time()
        train_time = time1 - time0
    else:
        time0 = time()
        hist = trad_train(model, x_train, x_test, L, dxdt_train, dxdt_test)
        time1 = time()
        train_time = time1 - time0
        results["adam-train-loss"] = hist["loss"]
        results["adam-test-loss"] = hist["test_loss"]
        results["adam-lr"] = hist["lr"]
        results["adam-train-steps"] = train_config["n_steps"]
        results["adam-best-test-loss"] = hist["best_test_loss"]
        results["adam-best-test-error"] = hist["best_test_error"]
        train_mse, train_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)

    test_mse, test_rel2 = eval_dxdt(model, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
    results[f"{model_name}-train-mse"] = train_mse
    results[f"{model_name}-train-rel2"] = train_rel2
    results[f"{model_name}-test-mse"] = test_mse
    results[f"{model_name}-test-rel2"] = test_rel2
    results[f"{model_name}-train-time"] = train_time

    # save the model
    tc.save(model, os.path.join(model_out_dir_path, f"{model_name}_N{N}_run{run_id}.pt"))
    del model
    tc.cuda.empty_cache()

    return results

data_seed = data_config["data_seed"]
model_seed = model_config["model_seed"]
sampling_seed = train_config["sampling_seed"]

n_obj = N

# ==== Warm-up system
print("Warming up...")
x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_seed)
model = create_fcnn(n_obj, model_seed)
_, _= sample_fit(model, x_train, L, dxdt_train, param_sampler="random", sample_uniformly=True, sampling_seed=sampling_seed)
del model, x_train, x_test, L, edge_index, dxdt_train, dxdt_test
tc.cuda.empty_cache()
print("Warm-up complete.")

print("The results of each run for each model will be printed in format {L2_RELATIVE_TRAIN, MSE_TRAIN, L2_RELATIVE_TEST, MSE_TEST, TRAIN_TIME} with sufficient precisions")

for run in range(repeat):
    print(f"-> Running experiment:  {run+1}/{repeat}")
    # ==== PREPARE DATA
    x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_seed)
    # ==== TRAIN (SWIM) RF-HGN
    print("==== Training (SWIM) RF-HGN")
    model = create_gnn(n_obj, edge_index, model_seed)                                       # Initialize a GNN
    swim_result = train(model, x_train, x_test, L, dxdt_train, dxdt_test, run_id=run,
                   param_sampler="relu", model_name="swim", sampling_seed=sampling_seed)     # Fit data-agnostic ELM-GNN and save results
    print(f"{swim_result['swim-train-rel2']:.2e}, {swim_result['swim-train-mse']:.2e}, {swim_result['swim-test-rel2']:.2e}, {swim_result['swim-test-mse']:.2e}, {swim_result['swim-train-time']:.2f}")

    # ==== TRAIN (Adam) RF-HGN
    print("==== Training Adam-HGN")
    model = create_gnn(n_obj, edge_index, model_seed)                                       # Initialize a GNN
    adam_result = train(model, x_train, x_test, L, dxdt_train, dxdt_test, run_id=run,
                        param_sampler="relu", model_name="adam", sampling_seed=sampling_seed)     # Fit data-agnostic ELM-GNN and save results
    print(f"{adam_result['adam-train-rel2']:.2e}, {adam_result['adam-train-mse']:.2e}, {adam_result['adam-test-rel2']:.2e}, {adam_result['adam-test-mse']:.2e}, {adam_result['adam-train-time']:.2f}")

    # ==== Append to all experiment results list
    results = { "run_id": run, "data_seed": data_seed, "model_seed": model_seed, "sampling_seed": sampling_seed,
               **swim_result, **adam_result}
    all_results.append(results)

    # ==== Update seeds
    data_seed += 1
    model_seed += 1
    sampling_seed += 1

# save the results to results.json and the config to config.toml for reproducability

# Save results
results_file_path = os.path.join(out_dir_path, "results.json")
results_file = open(results_file_path, "w")
json.dump(all_results, results_file, indent=4)
results_file.close()

# Save configuration for this experiment
config_file_path = os.path.join(out_dir_path, "config.toml")
config_file = open(config_file_path, "w")
toml.dump(config, config_file)
config_file.close()
exit(0)
