"""
This file conducts the optimization comparison study, comparing iterative gradient-based optimization
algorithms such as Adam, RMSProp, LBFGS, ... against random-feature training using the SWIM algorithm.
"""

import toml
import json
import os
from time import time
import numpy as np
import torch as tc
from math import ceil
from sklearn.model_selection import train_test_split
from argparse import ArgumentParser
from src.utils import Mesh, LinearSolveArgs, SamplingArgs, eval_dxdt, TradTrainingArgs
from src.data import MassSpring
from src.model import FCNN, GNN
from src.train import sample_and_linear_solve, traditional_training

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, help=".toml file for reading config of the experiment", required=True)
argparser.add_argument("-o", "--out", type=str, help="Output directory to save the results of the experiment", required=True)
args = vars(argparser.parse_args())
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
dof, auto_diff_mode, device = config["dof"], config["auto_diff_mode"], config["device"]

# node and edge feature dimensions for the GNN
N = config["N"]                 # [Nx,Ny]
repeat = config["repeat"]
optim_types = config["optim_types"]

out_dir_path= os.path.join(args["out"], f"lattice{N}_{device}")
if not os.path.exists(out_dir_path):
    os.makedirs(out_dir_path)
model_out_dir_path = os.path.join(out_dir_path, "models")
if not os.path.exists(model_out_dir_path):
    os.makedirs(model_out_dir_path)
data_config, model_config, train_config = config["data"], config["model"], config["train"]

# Experiment results will be gathered here
all_results = []

def prepare_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    n_points = data_config["n_points"]
    n_features = np.prod(n_obj) * 2*dof
    x = rng.uniform(data_config["x_min"], data_config["x_max"], size=(n_points, n_features)).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], data_config["spring_constant"], Mesh(data_config["meshing"]))
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(system.to_array(flatten=True)),
        tc.from_numpy(system.dxdt(flatten=True)),
        train_size=ceil(n_points*data_config["train_test_split"]), shuffle=True, random_state=data_seed+1
    )
    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               direct=model_config["direct"],
               msg_width=model_config["width"]-model_config["enc_width"],
               enc_width=model_config["enc_width"],
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc)

def create_fcnn(n_obj, model_seed):
    return FCNN(dof=dof, n_obj=n_obj, n_features=np.prod(n_obj) * 2*dof,
                width=model_config["width"], activ_str=model_config["activ_str"],
                init_method=model_config["init_method"], seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly, sampling_seed):
    sampling_args = SamplingArgs(
        param_sampler=param_sampler,
        seed=sampling_seed,
        sample_uniformly=sample_uniformly,
        resample_duplicates=train_config["resample_duplicates"],
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode=auto_diff_mode,
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device,
    )
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)
    return dxdt_mse, dxdt_rel2

def trad_train(model, x_train, x_test, L, dxdt_train, dxdt_test, optim_type):
    """
    Traditionally train (using gradient-descent based optimization) FCNN or GNN
    """
    if optim_type == "lbfgs":
        # calcualate corresponding number of full-batch updates given the number of mini-batch updates
        # n_steps = (train_config["n_steps"] * train_config["batch_size"]) // len(x_train)
        n_steps = 100                                   # for lbfgs, only do 100 steps because it already converges
        batch_size = len(x_train)                       # full batch updates
        sched = None                                    # disable scheduler
        lr_start = train_config["lr_lbfgs_adadelta"]    # fix in the config file (scheduler is set to None)
    elif optim_type == "rprop":
        # calcualate corresponding number of full-batch updates given the number of mini-batch updates
        n_steps = (train_config["n_steps"] * train_config["batch_size"]) // len(x_train)
        batch_size = len(x_train)                       # full batch updates
        sched = None                                    # disable scheduler
        lr_start = train_config["lr_start"]             # fix in the config file (scheduler is set to None)
    elif optim_type == "adadelta":
        n_steps = train_config["n_steps"]               # number of mini-batch updates
        batch_size = train_config["batch_size"]         # use batched-updates
        sched = None                                    # disable scheduler
        lr_start = train_config["lr_lbfgs_adadelta"]    # fix in the config file (scheduler is set to None)
    elif optim_type == "sgd" or optim_type == "sgdmomentum" or optim_type == "asgd":
        n_steps = train_config["n_steps"]               # number of mini-batch updates
        batch_size = train_config["batch_size"]         # use batched-updates
        sched = train_config["sched_type"]              # use the specified scheduler
        lr_start = train_config["lr_start_sgd"]         # use the specified learning rate start
    else:
        n_steps = train_config["n_steps"]               # number of mini-batch updates
        batch_size = train_config["batch_size"]         # use batched-updates
        sched = train_config["sched_type"]              # use the specified scheduler
        lr_start = train_config["lr_start"]             # use the specified learning rate start
    training_args = TradTrainingArgs(
        n_steps=n_steps,
        batch_size=batch_size,
        device=device,
        weight_init=train_config["weight_init"],
        lr_start=lr_start,
        lr_end=train_config["lr_end"],
        weight_decay=train_config["weight_decay"],
        patience=train_config["patience"],
        optim_type=optim_type,
        sched_type=sched,
    )
    return traditional_training(model, x_train, x_test, L, dxdt_train, dxdt_test, training_args)
def train(model, x_train, x_test, L, dxdt_train, dxdt_test, param_sampler, sampling_seed, run_id, optim_type):
    results = {}
    if optim_type == "swim":
        time0 = time()
        train_mse, train_rel2 = sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly=True, sampling_seed=sampling_seed)
        time1 = time()
        train_time = time1 - time0
    else:
        time0 = time()
        hist = trad_train(model, x_train, x_test, L, dxdt_train, dxdt_test, optim_type)
        time1 = time()
        train_time = time1 - time0
        results[f"{optim_type}-train-loss"] = hist["loss"]
        results[f"{optim_type}-test-loss"] = hist["test_loss"]
        results[f"{optim_type}-lr"] = hist["lr"]
        results[f"{optim_type}-train-steps"] = train_config["n_steps"]
        results[f"{optim_type}-best-test-loss"] = hist["best_test_loss"]
        results[f"{optim_type}-best-test-error"] = hist["best_test_error"]
        train_mse, train_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)

    test_mse, test_rel2 = eval_dxdt(model, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
    results[f"{optim_type}-train-mse"] = train_mse
    results[f"{optim_type}-train-rel2"] = train_rel2
    results[f"{optim_type}-test-mse"] = test_mse
    results[f"{optim_type}-test-rel2"] = test_rel2
    results[f"{optim_type}-train-time"] = train_time

    # save the model
    tc.save(model, os.path.join(model_out_dir_path, f"{optim_type}_N{N}_run{run_id}.pt"))
    del model
    tc.cuda.empty_cache()

    return results

data_seed = data_config["data_seed"]
model_seed = model_config["model_seed"]
sampling_seed = train_config["sampling_seed"]

n_obj = N

# ==== Warm-up system
print("Warming up...")
x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_seed)
model = create_fcnn(n_obj, model_seed)
_, _= sample_fit(model, x_train, L, dxdt_train, param_sampler="random", sample_uniformly=True, sampling_seed=sampling_seed)
del model, x_train, x_test, L, edge_index, dxdt_train, dxdt_test
tc.cuda.empty_cache()
print("Warm-up complete.")

print("The results of each run for each model will be printed in format {L2_RELATIVE_TRAIN, MSE_TRAIN, L2_RELATIVE_TEST, MSE_TEST, TRAIN_TIME} with sufficient precisions")

for run in range(repeat):
    print(f"-> Running experiment:  {run+1}/{repeat}")
    # ==== PREPARE DATA
    x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_seed)

    save_result = { "run_id": run, "data_seed": data_seed, "model_seed": model_seed, "sampling_seed": sampling_seed }
    for optim_type in optim_types:
        print("="*40)
        print(f"-> Using training algorithm: {optim_type}")
        print("="*40)
        model = create_gnn(n_obj, edge_index, model_seed)                                       # Initialize a GNN
        result = train(model, x_train, x_test, L, dxdt_train, dxdt_test, run_id=run,
                       param_sampler="relu", sampling_seed=sampling_seed, optim_type=optim_type)     # Fit data-agnostic ELM-GNN and save results
        train_mse = result[f"{optim_type}-train-mse"]
        train_rel2 = result[f"{optim_type}-train-rel2"]
        test_mse = result[f"{optim_type}-test-mse"]
        test_rel2 = result[f"{optim_type}-test-rel2"]
        train_time = result[f"{optim_type}-train-time"]
        print(f"Train MSE {train_mse:.2e}, Train REL2 {train_rel2:.2e}, Test MSE {test_mse:.2e}, Test REL2 {test_rel2:.2e}, Train time: {train_time:.2f}")
        save_result[optim_type] = result

    all_results.append(save_result)

    # ==== Update seeds
    data_seed += 1
    model_seed += 1
    sampling_seed += 1

print("========== RESULTS")
for result in all_results:
    for optim_type in optim_types:
        optim_type_result = result[optim_type]
        test_mse  = optim_type_result[f"{optim_type}-test-mse"]
        test_rel2 = optim_type_result[f"{optim_type}-test-rel2"]
        train_time = optim_type_result[f"{optim_type}-train-time"]
        print(f"-> {optim_type:<12} MSE {test_mse:>8.2e} REL2 {test_rel2:>8.2e}   TRAIN TIME {train_time:>7.4f}")

# save the results to results.json and the config to config.toml for reproducability

# Save results
results_file_path = os.path.join(out_dir_path, "results.json")
results_file = open(results_file_path, "w")
json.dump(all_results, results_file, indent=4)
results_file.close()
print(f"-> Results are saved at '{results_file_path}'")

# Save configuration for this experiment
config_file_path = os.path.join(out_dir_path, "config.toml")
config_file = open(config_file_path, "w")
toml.dump(config, config_file)
config_file.close()
print(f"-> config.toml is saved at '{config_file_path}'")

exit(0)
