"""
This file trains Hamiltonian Graph Network (HGN) on 4-chain masss-spring
data created around a fixed geometry and then tested in a similar geometry of 8-chain mass-spring
but rotated and translated in order to see the effect of translation and rotation invariant encoding
of our graph network architecture. Trained and tested models are:
    - (Adam) HNN trained on 8-chain mass-spring because it can't be trained with 4-chain mass-spring
      and then be tested on 8-chain mass-spring (zero-shot) for comparison against the graph architectures
    - (SWIM) RF-HGN trained on 4-chain mass-sprign (and tested as explained above using zero-shot generalization)
    - Plain non-invariant (SWIM) RF-HGN is the same as RF-HGN as explained above but NOT translation-rotation invariant to
      show the importance of incorporating geometrical biases into the graph network models.
TLDR: We experiment translation and rotation invariance of our graph network HGN against previous work HNN.
"""
import os
import toml
import numpy as np
import torch as tc
from argparse import ArgumentParser
from time import time
from src.data import MassSpring, grad_H_mass_spring
from src.model import GNN
from src.train import sample_and_linear_solve
from src.utils import SamplingArgs, LinearSolveArgs, eval_dxdt, unflatten_TC, flatten_TC, grad
from src.utils import unflatten, rotate_2d, Mesh
from sklearn.model_selection import train_test_split

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, help=".toml file for reading config of the experiment", required=True)
argparser.add_argument("-o", "--outdir", type=str, help="Output directory to save the integration results and models")
args = vars(argparser.parse_args())
out_dir = args["outdir"]
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
n_obj_train, n_obj_test, dof, auto_diff_mode, device = config["n_obj_train"], config["n_obj_test"], config["dof"], config["auto_diff_mode"], config["device"]
# node and edge feature dimensions for the GNN
n_features_train = np.prod(n_obj_train) * 2*dof
n_features_test = np.prod(n_obj_test) * 2*dof

# for integration with zero-shot test system
len_traj = 100
delta_t = 1e-4

data_config, model_config, train_config = config["data"], config["model"], config["train"]

########################################
##### Prepare data: 4-chain angles #####
########################################
def prepare_data(n_points, n_obj, n_features, translation=0.0, rotation_degree=0.0):
    rng = np.random.default_rng(data_config["data_seed"])
    x = np.zeros((n_points, n_features), dtype=dtype_np)
    x = unflatten(x, n_obj, dof)
    q, p = np.split(x, 2, axis=-1)

    # zig-zags nodes to fix a geometry
    for idx_obj in range(1, n_obj[0]):
        if idx_obj % 2 == 1:
            direction = +1.0        # towards positive x-axis direction for the nodes that have an odd id
            q[:, idx_obj, 0] = direction * rng.uniform(data_config["q_min"], data_config["q_max"], size=(n_points)).astype(dtype_np)
    p = rng.uniform(data_config["p_min"], data_config["p_max"], size=(n_points, *n_obj, dof)).astype(dtype_np) # p can be sampled uniformly

    # translate globally (only the position) along x and y
    q += translation

    # rotate with the given degree around q and the opposite direction (just to have a different rotation to make it more complex) around p
    q = rotate_2d(q, rotation_degree, dtype_np)
    p = rotate_2d(p, -rotation_degree, dtype_np)

    q = q.reshape(n_points, -1)
    p = p.reshape(n_points, -1)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], data_config["spring_constant"], Mesh(data_config["meshing"]))
    return system

def create_gnn(n_obj, edge_index, model_seed, skip_normalization=False):
    if skip_normalization:
        make_translation_invariant = False
        make_rotation_invariant = False
    else:
        make_translation_invariant = True
        make_rotation_invariant = True
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               make_translation_invariant=make_translation_invariant,
               make_rotation_invariant=make_rotation_invariant,
               direct=model_config["direct"],
               msg_width=model_config["width"]-model_config["enc_width"],
               enc_width=model_config["enc_width"],
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc)

print("Preparing 4-chain train/test data with fixed 'noisy' geometry")
system = prepare_data(data_config["n_points"], n_obj_train, n_features_train) # create data with fixed geometry
x_train, x_test, dxdt_train, dxdt_test = train_test_split(
    tc.from_numpy(system.to_array(flatten=True)),
    tc.from_numpy(system.dxdt(flatten=True)),
    train_size=data_config["train_test_split"], shuffle=True, random_state=data_config["data_seed"]+1,
)
L = tc.from_numpy(system.L())
edge_index = system.edge_index()

########################################
##### Fit using invariant RF-HGN   #####
########################################
sampling_args = SamplingArgs(
    param_sampler=train_config["param_sampler"],
    seed=train_config["sampling_seed"],
    sample_uniformly=True,
    resample_duplicates=train_config["resample_duplicates"],
    dtype=dtype_np,
)
linear_solve_args = LinearSolveArgs(
    mode="forward",
    driver=train_config["driver"],
    rcond=train_config["rcond"],
    device=device
)
rf_hgn = create_gnn(n_obj_train, edge_index, model_config["model_seed"])
time0 = time()
sample_and_linear_solve(rf_hgn, x_train, L, dxdt_train, sampling_args, linear_solve_args)
time1 = time()
def eval_and_print_err(model, label, x, L, dxdt):
    mse, rel2 = eval_dxdt(model, x, L, dxdt, verbose=False)
    print(f"\n{label}")
    print(f"-> mse : {mse:.2e}")
    print(f"-> rel2: {rel2:.2e}")

eval_and_print_err(rf_hgn, "Translation-Rotation Invariant RF-HGN Train", x_train, L, dxdt_train)
eval_and_print_err(rf_hgn, "Translation-Rotation Invariant RF-HGN Test", x_test, L, dxdt_test)
print(f"took {(time1-time0):.2f} seconds")

plain_rf_hgn = create_gnn(n_obj_train, edge_index, model_config["model_seed"], skip_normalization=True)
time0 = time()
sample_and_linear_solve(plain_rf_hgn, x_train, L, dxdt_train, sampling_args, linear_solve_args)
time1 = time()
eval_and_print_err(plain_rf_hgn, "Not-Translation-Rotation Invariant RF-HGN Train", x_train, L, dxdt_train)
eval_and_print_err(plain_rf_hgn, "Not-Translation-Rotation Invariant RF-HGN Test", x_test, L, dxdt_test)
print(f"took {(time1-time0):.2f} seconds")

#########################################################################
##### Test with transl+rotation (zero-shot) generalization for HGNs #####
#########################################################################
# 8-chain (zig-zag) data
print("Preparing 8-chain train/test data with fixed 'noisy' geometry")
system = prepare_data(data_config["n_points"], n_obj_test, n_features_test) # create data with fixed geometry
x_train, x_test, dxdt_train, dxdt_test = train_test_split(
    tc.from_numpy(system.to_array(flatten=True)),
    tc.from_numpy(system.dxdt(flatten=True)),
    train_size=data_config["train_test_split"], shuffle=True, random_state=data_config["data_seed"]+1,
)
L = tc.from_numpy(system.L())
edge_index = system.edge_index()

######################################################
##### Test with transl+rotation (zero-shot HGNs) #####
######################################################
system = prepare_data(data_config["n_points"], n_obj_test, n_features_test,
                      translation=1_000.0, rotation_degree=np.pi/3) # create data with fixed geometry
L = tc.from_numpy(system.L())
x_test = tc.from_numpy(system.to_array(flatten=True))
dxdt_test = tc.from_numpy(system.dxdt(flatten=True))
print("x_test mean", x_test.mean())
plain_rf_hgn.edge_index = system.edge_index()
plain_rf_hgn.n_obj = n_obj_test
eval_and_print_err(plain_rf_hgn, "Plain RF-HGN (8-chain data translated+rotated)", x_test, L, dxdt_test)
rf_hgn.edge_index = system.edge_index()
rf_hgn.n_obj = n_obj_test
eval_and_print_err(rf_hgn, "RF-HGN (8-chain data translated+rotated)", x_test, L, dxdt_test)

########################################################################################
####### Integrate and plot MSE for invariante and non-invariant versions of RF-HGN #####
########################################################################################
system = prepare_data(len_traj, n_obj_test, n_features_test,
                      translation=1_000.0, rotation_degree=np.pi/3) # create data with fixed geometry
def integrate(system, integration_grad_H):
    time0 = time()
    system.integrate(len_traj, delta_t, integration_grad_H, integration_method="stormer_verlet")
    time1 = time()
    print(f"Integration took: {(time1-time0):.2f} seconds")
    x = system.to_array()
    print(f"q limits: [{np.min(x[..., :dof])}, {np.max(x[..., :dof])}]")
    print(f"p limits: [{np.min(x[..., dof:])}, {np.max(x[..., dof:])}]")
    qx, qy = x[..., 0], x[..., 1]
    hamiltonian = system.H()
    print(f"hamiltonian limits: [{np.min(hamiltonian):.2f}, {np.max(hamiltonian):.2f}]")
    return qx, qy, hamiltonian

integration_grad_H = lambda x: grad_H_mass_spring(x, system.dof, system.meshing, system.mass, system.spring_constant, "chain")
print("Before integrating true, x[0]=", system.to_array()[0])
qx_true, qy_true, hamil_true = integrate(system, integration_grad_H)
print("After integrating true, x[0]=", system.to_array()[0])
qx_true_path = os.path.join(out_dir, f"qx_true_{n_obj_test}_{len_traj}_{delta_t:.2e}.npy")
qy_true_path = os.path.join(out_dir, f"qy_true_{n_obj_test}_{len_traj}_{delta_t:.2e}.npy")
hamil_true_path = os.path.join(out_dir, f"energy_true_{n_obj_test}_{len_traj}_{delta_t:.2e}.npy")
for data, path in zip([qx_true, qy_true, hamil_true], [qx_true_path, qy_true_path, hamil_true_path]):
    with open(path, "wb") as f:
        np.save(f, data)
        print(f"-> saved data at '{path}'")

for label, model in zip(["plain-rf-hgn", "rf-hgn"], [ plain_rf_hgn, rf_hgn ]):
    model_grad_H = lambda x: unflatten_TC(grad(model, flatten_TC(tc.from_numpy(x))), model.n_obj, dof).detach().numpy()
    print("-> Integrating", label)
    qx_pred, qy_pred, hamil_pred = integrate(system, model_grad_H)
    qx_pred_path = os.path.join(out_dir, f"qx_{n_obj_test}_{label}_{len_traj}_{delta_t:.2e}.npy")
    qy_pred_path = os.path.join(out_dir, f"qy_{n_obj_test}_{label}_{len_traj}_{delta_t:.2e}.npy")
    hamil_pred_path = os.path.join(out_dir, f"energy_{n_obj_test}_{label}_{len_traj}_{delta_t:.2e}.npy")
    for data, path in zip([qx_pred, qy_pred, hamil_pred], [qx_pred_path, qy_pred_path, hamil_pred_path]):
        with open(path, "wb") as f:
            np.save(f, data)
            print(f"-> saved data at {path}")
    mse = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1)
    rel = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1) / np.mean(qx_true**2 + qy_true**2, axis=1)
    print(f"=trajectory mean of all (q) MSE: {mse.mean():.2e}")
    print(f"=trajectory mean of all (q) REL: {rel.mean():.2e}")

    # save the models
    tc.save(model, os.path.join(out_dir, f"{label}_{n_obj_test}_chain.pt"))

exit(0)
