"""
Trains a Hamiltonian model for predicting molecular dynamics.
"""
import os
import math
import numpy as np
import torch as tc
from src.data.molecular_dynamics import edge_index_radius, r, H, grad_H, get_poisson_matrix, \
                                        eval_and_print_err_as_separate
from src.utils import hamiltons_eq
from src.model import GNN
from src.train import sample_and_linear_solve, traditional_training
from src.utils import LinearSolveArgs, SamplingArgs, flatten, eval_dxdt, TradTrainingArgs
from sklearn.model_selection import train_test_split, KFold
from argparse import ArgumentParser
from time import time

random_state = 482956
rng = np.random.default_rng(random_state)
tc.manual_seed(random_state)
tc.cuda.manual_seed(random_state)
tc.cuda.manual_seed_all(random_state)

argparser = ArgumentParser()
argparser.add_argument("-f", "--data", required=True, type=str,
                       help=".npy file containing trajectory data")
argparser.add_argument("-h1", "--enc_width", required=False, type=int, default=16,
                       help="encoding width, default=16")
argparser.add_argument("-h2", "--msg_width", required=False, type=int, default=256,
                       help="message width, due to architecture: network_width=enc_width+msg_width, default=256")
argparser.add_argument("-d", "--driver", required=False, default="gels", type=str,
                       help="lstsq driver of torch")
argparser.add_argument("-dev", "--device", required=False, default="gels", type=str,
                       help="default=cpu, options={cpu,cuda}")
argparser.add_argument("-prec", "--precision", type=str, default="double",
                       help="single or double, default=double")
argparser.add_argument("-r", "--rcond", type=float, default=1e-10,
                       help="regularization, default=1e-10")
argparser.add_argument("-ns", "--num_sim", required=False, type=int, default=100,
                       help="Maximum number of trajectories to read the data from")
argparser.add_argument("-k", "--datasize", required=False, default=10, type=int,
                       help="The number of snapshots to take as data from each simulation")
argparser.add_argument("-o", "--output", required=False, type=str, default="pretrained",
                       help="Output directory for the trained model to be saved.")
argparser.add_argument("-t", "--trainer", required=False, type=str, default="swim", choices=["adam", "lbfgs", "swim-lstsq", "swim-adam", "swim-lbfgs", "swim-ridge-lbfgs"],
                       help="To experiment with traditional training as well.")
argparser.add_argument("--n_gd_steps", required=False, type=int, default=100)
argparser.add_argument("--batch_size", required=False, type=int, default=32)
args = argparser.parse_args()
data = np.load(args.data)
dtype_np = np.float32 if args.precision == "single" else np.float64
dtype_tc = tc.float32 if args.precision == "single" else tc.float64

def get_traj_stats(x):
    """Returns min, max of p and r for statistics given
    x of shape (n_steps, n_obj, 2*dof)
    """
    q, p = np.split(x, 2, axis=-1)
    p_min, p_mean, p_max = np.min(p), np.mean(p), np.max(p)
    dists = []
    for q_step in q:
        for i in range(len(q_step)):
            for j in range(i + 1, len(q_step)):
                dists.append(r(q_step[i], q_step[j]))
    r_min, r_mean, r_max = np.min(dists), np.mean(dists), np.max(dists)
    return (r_min, r_mean, r_max), (p_min, p_mean, p_max)

def get_sim_args(filename):
    # e.g., filename=".../20sim_9particles_50000steps_0.005deltat_2.0cutoff.npy"
    filename = os.path.basename(filename)
    parts = filename.split('_')
    n_simulations = int(parts[0].replace("sim", ""))
    n_obj = int(parts[1].replace("particles", ""))
    n_steps = int(parts[2].replace("steps", ""))
    delta_t = float(parts[3].replace("deltat", ""))
    cutoff = float(parts[4].replace("cutoff.npy", ""))
    return n_simulations, n_obj, n_steps, delta_t, cutoff

n_simulations, n_obj, n_steps, delta_t, cutoff = get_sim_args(args.data)
print(f"-> n_simulations    = {n_simulations}")
print(f"-> n_obj            = {n_obj}")
print(f"-> n_steps          = {n_steps}")
print(f"-> delta_t          = {delta_t}")
print(f"-> cutoff           = {cutoff}")
x_trajs = []
for sim_idx in range(min(n_simulations, args.num_sim)):
    sim_data = data[sim_idx*n_steps:(sim_idx+1)*n_steps]
    if args.datasize == 1: indices = [-1]
    else: indices = np.clip(np.unique(np.linspace(0, n_steps, args.datasize).astype(np.int64)), None, n_steps-1)
    # indices = rng.integers(0, n_steps, size=args.datasize)
    # indices = [-1]
    x_trajs.append(sim_data[indices])
    # (r_min, r_mean, r_max), (p_min, p_mean, p_max) = get_traj_stats(sim_data)
    # print(f"Read simulation {sim_idx} with")
    # print(f"r_min   = {r_min:.4f}           p_min   = {p_min:.4f}")
    # print(f"r_mean  = {r_mean:.4f}           p_mean  = {p_mean:.4f}")
    # print(f"r_max   = {r_max:.4f}           p_max   = {p_max:.4f}")
    # print(f"----------------------------")
    # x_trajs.append(sim_data)
data = np.stack(x_trajs, axis=0).astype(dtype_np)
print(f"-> data of shape {data.shape} is read")

# TODO: Decide on a data gathering method for training. Do we take the equilibrated states? YES

# rng.integers(0, size=100) # 100 indices per simulation
# x = np.stack([data[:, 0, ...], data[:, n_steps // 2, ...], data[:, -1, ...], data[:, -2, ...], data[:, -3, ...]], axis=0)
# x = x.reshape(-1, n_obj, 4) # of shape (n_sim*extracted_num_graphs, n_obj, dof)
x = data.reshape(-1, n_obj, 4)
# print(f"-> Extracted last trajectory x of shape {x.shape}")
print(f"Data p min {np.min(x[..., 2:]):.5f}")
print(f"Data p max {np.max(x[..., 2:]):.5f}")

ke, pe = H(x, mass=1.0, eps=1.0, sig=1.0, cutoff=cutoff, as_separate=True)
dHdx = grad_H(x, mass=1.0, eps=1.0, sig=1.0, cutoff=cutoff)
dxdt = hamiltons_eq(dHdx)
print(f"-> computed true ke of shape {ke.shape}")
print(f"   min = {np.min(ke):.2e}, mean = {np.mean(ke):.2e}, max = {np.max(ke):2e}")
print(f"-> computed true pe of shape {pe.shape}")
print(f"   min = {np.min(pe):.2e}, mean = {np.mean(pe):.2e}, max = {np.max(pe):2e}")
print(f"-> computed true dHdx of shape {dHdx.shape}")
print(f"-> computed true dxdt of shape {dxdt.shape}")
print(f"   dqdt min = {np.min(dxdt[..., :2]):.2e}, mean = {np.mean(dxdt[..., 2:]):.2e}, max = {np.max(dxdt[..., 2:]):2e}")
print(f"   dpdt min = {np.min(dxdt[..., 2:]):.2e}, mean = {np.mean(dxdt[..., :2]):.2e}, max = {np.max(dxdt[..., :2]):2e}")

# Compute edge index for each state
if np.isinf(cutoff):
    edge_index = edge_index_radius(x[0], np.inf)
    print("-> cutoff=np.inf, we have the static edge index", edge_index)
else:
    edge_index = []
    node_degrees = []
    for state in x:
        edge_index_state = edge_index_radius(state, cutoff) # using no cutoff with same number of particles in every graph so the number of edges is fixed
        if edge_index_state != []:
            src, dst = edge_index_state
            all_nodes = tc.cat([src, dst])
            degree_count = tc.bincount(all_nodes)
            unique_degrees = tc.unique(degree_count)
            for unique_degree in unique_degrees:
                item = unique_degree.detach().item()
                if not item in node_degrees:
                    print("-> New node degree:", item)
                    node_degrees.append(item)
                    print("corresponding edge_index_state")
                    print(edge_index_state)
        edge_index.append(edge_index_state)
    print(f"-> cutoff={cutoff}, we have dynamic edge indexing with unique degrees", node_degrees)

if args.trainer in ["swim", "swim-lstsq"]:
    sampling_args = SamplingArgs(param_sampler="relu",
                                 dtype=dtype_np,
                                 seed=rng.integers(0, 10**9),
                                 sample_uniformly=True,
                                 resample_duplicates=True)
    linear_solve_args = LinearSolveArgs(driver=args.driver,
                                        rcond=args.rcond,
                                        device=args.device)
elif args.trainer in ["adam", "lbfgs"]:
    trad_train_args = TradTrainingArgs(n_steps=args.n_gd_steps,
                                       batch_size=args.batch_size, # if you want to do batching then you need to also configure edge_index properly inside the model!
                                       device=args.device,
                                       weight_init="kaiming_normal",
                                       lr_start = 1e-01 if args.trainer == "lbfgs" else 1e-02,
                                       lr_end = 1e-01 if args.trainer == "lbfgs" else 1e-05,
                                       # lr_start=0.01, lr_end=0.0001,
                                       # lr_start=0.005, lr_end=0.0001,
                                       weight_decay=1e-6, patience=10_000,
                                       optim_type=args.trainer,
                                       sched_type="exponential" if args.trainer == "adam" else "none")
elif args.trainer in ["swim-lbfgs", "swim-adam"]:
    sampling_args = SamplingArgs(param_sampler="relu",
                                 dtype=dtype_np,
                                 seed=rng.integers(0, 10**9),
                                 sample_uniformly=True,
                                 resample_duplicates=True)
    trad_train_args = TradTrainingArgs(n_steps=args.n_gd_steps,
                                       batch_size=args.batch_size, # if you want to do batching then you need to also configure edge_index properly inside the model!
                                       device=args.device,
                                       weight_init="none",
                                       lr_start = 1e-01 if args.trainer == "swim-lbfgs" else 1e-02,
                                       lr_end = 1e-01 if args.trainer == "swim-lbfgs" else 1e-05,
                                       # lr_start=0.005, lr_end=0.0001,
                                       weight_decay=1e-6, patience=10_000,
                                       optim_type=args.trainer.replace("swim-", ""),
                                       sched_type="exponential" if args.trainer == "swim-adam" else "none")
else:
    raise NotImplementedError(f"Specified trainer {args.trainer} is not yet implemented.")

L = tc.from_numpy(get_poisson_matrix(n_obj, 2, dtype=dtype_np))

# warm-up the system
tc.matmul(tc.randn(10_000, 10_000).to(args.device),
          tc.randn(10_000, 10_000).to(args.device))

# to be able to compute edge_index dynamcially
def set_edge_index(gnn, x):
    edge_index = []
    for state in x:
        edge_index_state = edge_index_radius(state, cutoff) # using no cutoff with same number of particles in every graph so the number of edges is fixed
        edge_index.append(edge_index_state)
    gnn.edge_index = edge_index
pre_op = lambda model, x: None if np.isinf(cutoff) else lambda model, x: set_edge_index(model, x)

###############################################################################

# perform k-fold cross validation for the random-feature methods
if args.trainer in ["swim", "swim-lstsq", "swim-lbfgs", "swim-adam"]:
    kfold = KFold(n_splits=5, shuffle=True, random_state=rng.integers(0, 10**9))
    total_train_time = 0.0
    train_mse_dqdt, train_rel2_dqdt = [], []
    train_mse_dpdt, train_rel2_dpdt = [], []
    test_mse_dqdt, test_rel2_dqdt = [], []
    test_mse_dpdt, test_rel2_dpdt = [], []
    for fold_idx, (train_indices, test_indices) in enumerate(kfold.split(x)):
        print(f"Fold {fold_idx + 1} / {kfold.n_splits}")

        gnn = GNN(n_obj=[n_obj],
                  dof=2,
                  edge_index=edge_index,
                  enc_width=args.enc_width,
                  msg_width=args.msg_width,
                  activ_str="softplus",
                  seed=rng.integers(0, 10**9),
                  dtype=dtype_tc,
                  )

        # training
        if args.trainer in ["swim", "swim-lstsq"]:
            pre_op(gnn, x[train_indices])
            time0 = time()
            sample_and_linear_solve(gnn,
                                    tc.from_numpy(flatten(x[train_indices])), L,
                                    tc.from_numpy(flatten(dxdt[train_indices])),
                                    sampling_args, linear_solve_args)
            time1 = time()
        elif args.trainer in ["swim-lbfgs", "swim-adam"]:
            gnn = gnn.train()
            time0 = time()
            gnn.sample_hidden(tc.from_numpy(flatten(x[train_indices])), sampling_args, e_pred=None) # Initial sampling if Approximate-SWIM

            weight_means, weight_stds, bias_means, bias_stds = [], [], [], []
            for layer in [gnn.node_encoder, gnn.edge_encoder, gnn.msg_encoder]:
                weight_means.append(layer.weight.detach().numpy().mean().item())
                weight_stds.append(layer.weight.detach().numpy().std().item())
                bias_means.append(layer.bias.detach().numpy().mean().item())
                bias_stds.append(layer.bias.detach().numpy().std().item())

            gnn.freeze_hidden_layers()

            print("calling with trainer", args.trainer.replace("swim-", ""))
            traditional_training(gnn,
                                 tc.from_numpy(flatten(x[train_indices])), None, L,
                                 tc.from_numpy(flatten(dxdt[train_indices])), None,
                                 trad_train_args, pre_op)

            weight_means_after, weight_stds_after, bias_means_after, bias_stds_after = [], [], [], []
            for layer in [gnn.node_encoder, gnn.edge_encoder, gnn.msg_encoder]:
                weight_means_after.append(layer.weight.detach().numpy().mean().item())
                weight_stds_after.append(layer.weight.detach().numpy().std().item())
                bias_means_after.append(layer.bias.detach().numpy().mean().item())
                bias_stds_after.append(layer.bias.detach().numpy().std().item())

            for lst_before, lst_after in zip([weight_means, weight_stds, bias_means, bias_stds],
                                             [weight_means_after, weight_stds_after, bias_means_after, bias_stds_after]):
                for param_before, param_after in zip(lst_before, lst_after):
                    assert math.isclose(param_before, param_after), f"Sampled parameters have changed after iterative training."

            time1 = time()
        else:
            raise NotImplementedError(f"Specified trainer {args.trainer} is not yet implemented.")
        total_train_time += (time1 - time0)
        print(f"-> Training took {(time1 - time0):.2f} seconds")

        pre_op(gnn, x[train_indices])

        # evaluation
        mse_dqdt, rel2_dqdt, mse_dpdt, rel2_dpdt = \
                eval_dxdt(gnn,
                          tc.from_numpy(flatten(x[train_indices])), L,
                          tc.from_numpy(flatten(dxdt[train_indices])),
                          as_separate=True, verbose=False)
        print(f"train:      mse_dqdt={mse_dqdt:.2e}, mse_dpdt={mse_dpdt:.2e}")

        for err, lst in zip([mse_dqdt, rel2_dqdt, mse_dpdt, rel2_dpdt], [train_mse_dqdt, train_rel2_dqdt, train_mse_dpdt, train_rel2_dpdt]):
            lst.append(err)

        pre_op(gnn, x[test_indices])

        mse_dqdt, rel2_dqdt, mse_dpdt, rel2_dpdt = \
                eval_dxdt(gnn,
                          tc.from_numpy(flatten(x[test_indices])), L,
                          tc.from_numpy(flatten(dxdt[test_indices])),
                          as_separate=True, verbose=False)
        print(f"test :      mse_dqdt={mse_dqdt:.2e}, mse_dpdt={mse_dpdt:.2e}")

        for err, lst in zip([mse_dqdt, rel2_dqdt, mse_dpdt, rel2_dpdt], [test_mse_dqdt, test_rel2_dqdt, test_mse_dpdt, test_rel2_dpdt]):
            lst.append(err)

    avg_train_time = total_train_time / kfold.n_splits

    avg_train_mse_dqdt = sum(train_mse_dqdt) / kfold.n_splits
    avg_train_rel2_dqdt = sum(train_rel2_dqdt) / kfold.n_splits
    avg_train_mse_dpdt = sum(train_mse_dpdt) / kfold.n_splits
    avg_train_rel2_dpdt = sum(train_rel2_dpdt) / kfold.n_splits

    avg_test_mse_dqdt = sum(test_mse_dqdt) / kfold.n_splits
    avg_test_rel2_dqdt = sum(test_rel2_dqdt) / kfold.n_splits
    avg_test_mse_dpdt = sum(test_mse_dpdt) / kfold.n_splits
    avg_test_rel2_dpdt = sum(test_rel2_dpdt) / kfold.n_splits

    print()
    print(f"Average training time       :       {avg_train_time:.2f} seconds")
    print(f"Average train dqdt MSE,REL2 :       {avg_train_mse_dqdt:.2e} {avg_train_rel2_dqdt:.2e}")
    print(f"Average train dpdt MSE,REL2 :       {avg_train_mse_dpdt:.2e} {avg_train_rel2_dpdt:.2e}")
    print(f"Average test  dqdt MSE,REL2 :       {avg_test_mse_dqdt:.2e} {avg_test_rel2_dqdt:.2e}")
    print(f"Average test  dpdt MSE,REL2 :       {avg_test_mse_dpdt:.2e} {avg_test_rel2_dpdt:.2e}")
elif args.trainer in ["adam", "lbfgs"]:
    gnn = GNN(n_obj=[n_obj],
              dof=2,
              edge_index=edge_index,
              enc_width=args.enc_width,
              msg_width=args.msg_width,
              activ_str="softplus",
              seed=rng.integers(0, 10**9),
              dtype=dtype_tc,
              )
    gnn.unfreeze_hidden_layers()
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(flatten(x)),
        tc.from_numpy(flatten(dxdt)),
        test_size=0.2, shuffle=True, random_state=rng.integers(0, 10**9),
    )
    time0 = time()
    traditional_training(gnn, x_train, x_test, L, dxdt_train, dxdt_test, trad_train_args, pre_op)
    time1 = time()
else:
    raise NotImplementedError(f"Specified trainer {args.trainer} is not implemented yet.")

###############################################################################

# Train and save the model for evaluation
print()
print("Final training and saving the model for evaluation..")

gnn = GNN(n_obj=[n_obj],
          dof=2,
          edge_index=edge_index,
          enc_width=args.enc_width,
          msg_width=args.msg_width,
          activ_str="softplus",
          seed=rng.integers(0, 10**9),
          dtype=dtype_tc,
          )
if args.trainer in ["swim", "swim-lstsq"]:
    pre_op(gnn, x)
    time0 = time()
    sample_and_linear_solve(gnn,
                            tc.from_numpy(flatten(x)), L,
                            tc.from_numpy(flatten(dxdt)),
                            sampling_args, linear_solve_args)
    time1 = time()
elif args.trainer in ["swim-lbfgs", "swim-adam"]:
    gnn = gnn.train()
    time0 = time()
    gnn.sample_hidden(tc.from_numpy(flatten(x)), sampling_args, e_pred=None) # Initial sampling if Approximate-SWIM
    gnn.freeze_hidden_layers()
    weight_means, weight_stds, bias_means, bias_stds = [], [], [], []
    for layer in [gnn.node_encoder, gnn.edge_encoder, gnn.msg_encoder]:
        weight_means.append(layer.weight.detach().numpy().mean().item())
        weight_stds.append(layer.weight.detach().numpy().std().item())
        bias_means.append(layer.bias.detach().numpy().mean().item())
        bias_stds.append(layer.bias.detach().numpy().std().item())
    traditional_training(gnn,
                         tc.from_numpy(flatten(x)), None, L,
                         tc.from_numpy(flatten(dxdt)), None,
                         trad_train_args, pre_op)
    weight_means_after, weight_stds_after, bias_means_after, bias_stds_after = [], [], [], []
    for layer in [gnn.node_encoder, gnn.edge_encoder, gnn.msg_encoder]:
        weight_means_after.append(layer.weight.detach().numpy().mean().item())
        weight_stds_after.append(layer.weight.detach().numpy().std().item())
        bias_means_after.append(layer.bias.detach().numpy().mean().item())
        bias_stds_after.append(layer.bias.detach().numpy().std().item())

    for lst_before, lst_after in zip([weight_means, weight_stds, bias_means, bias_stds],
                                     [weight_means_after, weight_stds_after, bias_means_after, bias_stds_after]):
        for param_before, param_after in zip(lst_before, lst_after):
            assert math.isclose(param_before, param_after), f"Sampled parameters have changed after iterative training."

    time1 = time()
elif args.trainer in ["adam", "lbfgs"]:
    time0 = time()
    traditional_training(gnn,
                         tc.from_numpy(flatten(x)), None, L,
                         tc.from_numpy(flatten(dxdt)), None,
                         trad_train_args, pre_op)
    time1 = time()
else:
    raise NotImplementedError(f"Specified trainer {args.trainer} is not yet implemented.")

print(f"-> Training took {(time1 - time0):.2f} seconds")

pre_op(gnn, x)
mse_dqdt, rel2_dqdt, mse_dpdt, rel2_dpdt = \
        eval_dxdt(gnn,
                  tc.from_numpy(flatten(x)), L,
                  tc.from_numpy(flatten(dxdt)),
                  as_separate=True, verbose=False)
print(f"train:      mse_dqdt={mse_dqdt:.2e}, mse_dpdt={mse_dpdt:.2e}")
print(f"train:      rel_dqdt={rel2_dqdt:.2e}, rel_dpdt={rel2_dpdt:.2e}")

model_path = os.path.join(args.output, f"rf_hgn_md_{n_obj}_2DOF.pt")
tc.save(gnn, model_path)
print(f"-> Model is saved at '{model_path}'")
