import os
import toml
import numpy as np
import torch as tc
from time import time
from math import ceil
from argparse import ArgumentParser
from src.data import MassSpring
from src.model import GNN
from src.train import sample_and_linear_solve, traditional_training
from src.utils import SamplingArgs, LinearSolveArgs, eval_dxdt, TradTrainingArgs
from src.utils import Mesh
from sklearn.model_selection import train_test_split

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, required=True,
                       help=".toml file for reading config of the experiment")
argparser.add_argument("-o", "--outdir", type=str, required=True,
                       help="Output directory to save the integration results and models")
argparser.add_argument("-t", "--trainer", type=str, choices=["adam", "swim"], required=True)
args = vars(argparser.parse_args())
out_dir = args["outdir"]
trainer = args["trainer"]
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
n_obj, dof, auto_diff_mode, device = config["n_obj"], config["dof"], config["auto_diff_mode"], config["device"]
print("-> Training with device:", device)
n_features = np.prod(n_obj) * 2*dof

data_config, model_config, train_config = config["data"], config["model"], config["train"]

def prepare_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    n_points = data_config["n_points"]
    n_features = np.prod(n_obj) * 2*dof
    if dof == 2:
        qx = rng.uniform(data_config["qx_min"], data_config["qx_max"], size=(n_points, n_features//4)).astype(dtype_np)
        qy = rng.uniform(data_config["qy_min"], data_config["qy_max"], size=(n_points, n_features//4)).astype(dtype_np)
        q = np.concatenate([qx, qy], axis=-1)
    elif dof == 3:
        q = rng.uniform(data_config["q_min"], data_config["q_max"], size=(n_points, n_features//2)).astype(dtype_np)
    else:
        raise ValueError("DOF can be 2 or 3, i.e. the system can be in 2D or 3D.")
    p = rng.uniform(data_config["p_min"], data_config["p_max"], size=(n_points, n_features//2)).astype(dtype_np)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], data_config["spring_constant"], Mesh(data_config["meshing"]))
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(system.to_array(flatten=True)),
        tc.from_numpy(system.dxdt(flatten=True)),
        train_size=ceil(n_points*data_config["train_test_split"]), shuffle=True, random_state=data_seed+1
    )
    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed, skip_normalization=False):
    if skip_normalization:
        make_translation_invariant = False
        make_rotation_invariant = False
    else:
        make_translation_invariant = True
        make_rotation_invariant = True
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               make_translation_invariant=make_translation_invariant,
               make_rotation_invariant=make_rotation_invariant,
               direct=model_config["direct"],
               msg_width=model_config["width"]-model_config["enc_width"],
               enc_width=model_config["enc_width"],
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc)

def eval_and_print_err(model, label, x, L, dxdt):
    mse, rel2 = eval_dxdt(model, x, L, dxdt, verbose=False)
    print(f"\n{label}")
    print(f"-> mse : {mse:.2e}")
    print(f"-> rel2: {rel2:.2e}")

x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_config["data_seed"])
rf_hgn = create_gnn(n_obj, edge_index, model_config["model_seed"])

if trainer == "swim":
    print("==== Training (SWIM) RF-HGN")
    sampling_args = SamplingArgs(
        param_sampler=train_config["param_sampler"],
        seed=train_config["sampling_seed"],
        sample_uniformly=True,
        resample_duplicates=train_config["resample_duplicates"],
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode="forward",
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device
    )
    time0 = time()
    sample_and_linear_solve(rf_hgn, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    time1 = time()
    eval_and_print_err(rf_hgn, "RF-HGN Train", x_train, L, dxdt_train)
    eval_and_print_err(rf_hgn, "RF-HGN Test", x_test, L, dxdt_test)
    print(f"took {(time1-time0):.2f} seconds")

    model_path = os.path.join(out_dir, f"rf_hgn_chain_{n_obj}_{dof}DOF.pt")
    tc.save(rf_hgn, model_path)
elif trainer == "adam":
    print("==== Training (Adam) HGN")
    training_args = TradTrainingArgs(
        n_steps=train_config["n_steps"],
        batch_size=train_config["batch_size"],
        device=device,
        weight_init=train_config["weight_init"],
        lr_start=train_config["lr_start"],
        lr_end=train_config["lr_end"],
        weight_decay=train_config["weight_decay"],
        patience=train_config["patience"],
        optim_type=train_config["optim_type"],
        sched_type=train_config["sched_type"],
    )
    time0 = time()
    traditional_training(rf_hgn, x_train, x_test, L, dxdt_train, dxdt_test, training_args)
    time1 = time()
    eval_and_print_err(rf_hgn, "HGN Train", x_train, L, dxdt_train)
    eval_and_print_err(rf_hgn, "HGN Test", x_test, L, dxdt_test)
    print(f"took {(time1-time0):.2f} seconds")
    model_path = os.path.join(out_dir, f"hgn_chain_{n_obj}_{dof}DOF.pt")
    tc.save(rf_hgn, model_path)

exit(0)
