"""
For training Hamiltonian graph networks on the chain mass-spring data with
optimizer and spring potential options
"""

import os
import toml
import numpy as np
import torch as tc
from time import time
from math import ceil
from argparse import ArgumentParser
from src.data import MassSpring
from src.model import GNN
from src.train import sample_and_linear_solve, traditional_training
from src.utils import SamplingArgs, LinearSolveArgs, eval_dxdt, TradTrainingArgs
from src.utils import Mesh
from sklearn.model_selection import train_test_split
from tabulate import tabulate

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, required=True,
                       help=".toml file for reading config of the experiment")
argparser.add_argument("-s", "--system", type=str, choices=["spring-chain", "anharmonic-chain", "morse-chain"], required=True,
                       help="System to learn")
argparser.add_argument("-o", "--outdir", type=str, required=True,
                       help="Output directory to save the integration results and models")
argparser.add_argument("-t", "--trainer", type=str, choices=["adam", "swim", "elm"], required=True,
                       help="Trainer to use to optimize the model")
args = vars(argparser.parse_args())
out_dir = args["outdir"]
trainer = args["trainer"]
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
n_obj, auto_diff_mode, device = config["n_obj"], config["auto_diff_mode"], config["device"]
print("-> Training with device:", device)

data_config, model_config, train_config = config["data"], config["model"], config["train"]
system_config = data_config[args["system"]]
dof = system_config["dof"]
n_features = np.prod(n_obj) * 2*dof
normalize_data = system_config["normalize_data"]

def prepare_spring_chain(n_obj, data_seed, system_name="chain"):
    rng = np.random.default_rng(data_seed)
    n_points = data_config["n_points"]
    n_features = np.prod(n_obj) * 2*dof
    if dof == 2:
        qx = rng.uniform(system_config["qx_min"], system_config["qx_max"], size=(n_points, n_features//4)).astype(dtype_np)
        qy = rng.uniform(system_config["qy_min"], system_config["qy_max"], size=(n_points, n_features//4)).astype(dtype_np)
        q = np.concatenate([qx, qy], axis=-1)
    elif dof == 3:
        q = rng.uniform(system_config["q_min"], system_config["q_max"], size=(n_points, n_features//2)).astype(dtype_np)
    else:
        raise ValueError("DOF can be 2 or 3, i.e. the system can be in 2D or 3D.")
    p = rng.uniform(system_config["p_min"], system_config["p_max"], size=(n_points, n_features//2)).astype(dtype_np)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], system_config["spring_constant"], Mesh(system_config["meshing"]),
                        system_config["l"], system_config["D"], system_config["a"],
                        system=system_name)
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(system.to_array(flatten=True)),
        tc.from_numpy(system.dxdt(flatten=True)),
        train_size=ceil(n_points*data_config["train_test_split"]), shuffle=True, random_state=data_seed+1
    )
    print("-> train dxdt min/max", dxdt_train.min().item(), dxdt_train.max().item())
    print("-> test  dxdt min/max", dxdt_test.min().item(), dxdt_test.max().item())
    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def prepare_data(system_name, n_obj, data_seed):
    match system_name:
        case "spring-chain":
            return prepare_spring_chain(n_obj, data_seed, system_name="chain")
        case "anharmonic-chain":
            return prepare_spring_chain(n_obj, data_seed, system_name="anharmonic")
        case "morse-chain":
            return prepare_spring_chain(n_obj, data_seed, system_name="morse")
        case _:
            raise ValueError("Unknown system", system_name)

def create_gnn(n_obj, edge_index, model_seed, normalize_data=False):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               make_translation_invariant=normalize_data,
               make_rotation_invariant=normalize_data,
               n_msg_passes=1,
               take_absolute_diff=False,
               direct=model_config["direct"],
               msg_width=model_config["width"]-model_config["enc_width"],
               enc_width=model_config["enc_width"],
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc)

def eval_and_return_err(model, x, L, dxdt):
    mse, rel2 = eval_dxdt(model, x, L, dxdt, verbose=False)
    return mse, rel2

x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(args["system"], n_obj, data_config["data_seed"])
rf_hgn = create_gnn(n_obj, edge_index, model_config["model_seed"], normalize_data)

if trainer in ["swim", "elm"]:
    print("==== Training (SWIM) RF-HGN")
    sampling_args = SamplingArgs(
        param_sampler="random" if trainer == "elm" else train_config["param_sampler"],
        seed=train_config["sampling_seed"],
        sample_uniformly=True,
        resample_duplicates=train_config["resample_duplicates"],
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode="forward",
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device
    )
    time0 = time()
    sample_and_linear_solve(rf_hgn, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    time1 = time()
    mse_train, rel2_train = eval_and_return_err(rf_hgn, x_train, L, dxdt_train)
    mse_test, rel2_test = eval_and_return_err(rf_hgn, x_test, L, dxdt_test)
    training_time = time1 - time0
    print(f"took {training_time:.2f} seconds")

    model_path = os.path.join(out_dir, f"{trainer}_rf_hgn_chain_{n_obj}_{dof}DOF_{args['system']}.pt")
    results_table_path = f"{trainer}_rf_hgn_chain_{n_obj}_{dof}DOF_{args['system']}.txt"
    results_table_path = os.path.join(out_dir, results_table_path)
    tc.save(rf_hgn, model_path)
    print(f"-> Model is saved at '{model_path}'")
elif trainer == "adam":
    print("==== Training (Adam) HGN")
    training_args = TradTrainingArgs(
        n_steps=train_config["n_steps"],
        batch_size=train_config["batch_size"],
        device=device,
        weight_init=train_config["weight_init"],
        lr_start=train_config["lr_start"],
        lr_end=train_config["lr_end"],
        weight_decay=train_config["weight_decay"],
        patience=train_config["patience"],
        optim_type=train_config["optim_type"],
        sched_type=train_config["sched_type"],
    )
    time0 = time()
    traditional_training(rf_hgn, x_train, x_test, L, dxdt_train, dxdt_test, training_args)
    time1 = time()
    mse_train, rel2_train = eval_and_return_err(rf_hgn, x_train, L, dxdt_train)
    mse_test, rel2_test = eval_and_return_err(rf_hgn, x_test, L, dxdt_test)
    training_time = time1 - time0
    print(f"took {training_time:.2f} seconds")
    model_path = os.path.join(out_dir, f"adam_hgn_chain_{n_obj}_{dof}DOF_{args['system']}.pt")
    results_table_path = f"adam_hgn_chain_{n_obj}_{dof}DOF_{args['system']}.txt"
    results_table_path = os.path.join(out_dir, results_table_path)
    tc.save(rf_hgn, model_path)
    print(f"-> Model is saved at '{model_path}'")

# Print and save evaluation and then exit
results_table_title = "\nTable: Final model evaluation and test results."
results_table = tabulate(
    headers=["", "Train MSE", "Train (relative)", "Test MSE", "Test (relative)"],
    tabular_data=[
        ["Loss/Error", mse_train, rel2_train, mse_test, rel2_test],
    ],
    floatfmt=".2e"
)
print('\n' + results_table + '\n')

with open(results_table_path, 'w') as file:
    filestr = '\n'.join([
        results_table_title,
        str(results_table),
        f"-> Training took {training_time:.3e} or {training_time:.3f} seconds" ,
    ])
    file.write(filestr)

exit(0)
