import os
import numpy as np
import torch as tc
import time
import toml
from math import ceil
from argparse import ArgumentParser
from tabulate import tabulate
from src.data import get_equil_r, r
from src.data.molecular_dynamics import H, grad_H, edge_index_radius, get_poisson_matrix, simulate, get_x0_and_box, animate_2D
from src.model import GNN
from src.train import sample_and_linear_solve, traditional_training
from src.utils import eval_dxdt, LinearSolveArgs, SamplingArgs, TradTrainingArgs, hamiltons_eq, flatten, mse, l2_err, memusage

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, help=".toml file for reading config of the experiment.", required=True)
argparser.add_argument("-d", "--data", required=True, type=str, help="Data path.")
argparser.add_argument("-nt", "--n_train_sims", required=True, type=int, help="Number of training trajectories.")
argparser.add_argument("-ns", "--n_snaps", required=True, type=int, help="Number of snapshots to take from trajectories.")
argparser.add_argument("-zdx", "--zero_shot_q_disp", required=True, type=float)
argparser.add_argument("-o", "--outdir", type=str, required=True,
                       help="Output directory to save the experiment results.")
argparser.add_argument("--testobj", required=True, type=int)
argparser.add_argument("--save", default=False, action="store_true", help="Saves model under root with name rfhgn_lj.pt")
args = argparser.parse_args()
config = toml.load(args.toml)
data_path = args.data

# File constants
seed = config["seed"]
rng = np.random.default_rng(seed)
get_seed = lambda: rng.integers(0, 10**9)
dtype_np = np.float32
dtype_tc = tc.float32
# lower the values below if memory is an issue
n_train_sims = args.n_train_sims
n_snaps = args.n_snaps
dof = 2
train_eval_ratio = 0.9
GB = 1024 ** 3
convert_bytes_to_GB = lambda n_bytes: n_bytes / GB

# Model constants
enc_width = config["model"]["enc_width"]
network_width = config["model"]["network_width"]
msg_width = network_width - enc_width
activ_str = config["model"]["activ_str"]

def get_traj_stats(x):
    """Returns min, max of p and r for statistics given
    x of shape (n_steps, n_obj, 2*dof)
    """
    q, p = np.split(x, 2, axis=-1)
    p_min, p_mean, p_max = np.min(p), np.mean(p), np.max(p)
    dists = []
    for q_step in q:
        for i in range(len(q_step)):
            for j in range(i + 1, len(q_step)):
                dists.append(r(q_step[i], q_step[j]))
    r_min, r_mean, r_max = np.min(dists), np.mean(dists), np.max(dists)
    return (r_min, r_mean, r_max), (p_min, p_mean, p_max)

def get_sim_args(filename):
    # e.g. filename = "data-lj/100sim_2particles_50steps_0.005deltat_2.0cutoff.npy"
    filename = os.path.basename(filename)
    parts = filename.split('_')
    n_simulations = int(parts[0].replace("sim", ""))
    n_obj = int(parts[1].replace("particles", ""))
    n_steps = int(parts[2].replace("steps", ""))
    delta_t = float(parts[3].replace("deltat", ""))
    cutoff = float(parts[4].replace("cutoff.npy", ""))
    return n_simulations, n_obj, n_steps, delta_t, cutoff

# Read lennard-jones data (created with create_lennard_jones_data.py)
n_simulations, n_obj, n_steps, delta_t, cutoff = get_sim_args(data_path)
data = np.load(data_path)
print(data.shape)
n_test_sims = n_simulations - n_train_sims
print(f"-> n_simulations    = {n_simulations}")
print(f"-> n_obj            = {n_obj}")
print(f"-> n_steps          = {n_steps}")
print(f"-> delta_t          = {delta_t}")
print(f"-> cutoff           = {cutoff}")

x = data[:min(n_simulations, n_train_sims), ::n_snaps]
x = x.reshape(-1, n_obj, 4)
ke, pe = H(x, mass=1.0, eps=1.0, sig=1.0, cutoff=cutoff, as_separate=True)
dHdx = grad_H(x, mass=1.0, eps=1.0, sig=1.0, cutoff=cutoff)
dxdt = hamiltons_eq(dHdx)

# print(f"-> computed true ke of shape {ke.shape}")
# print(f"   min = {np.min(ke):.2e}, mean = {np.mean(ke):.2e}, max = {np.max(ke):2e}")
# print(f"-> computed true pe of shape {pe.shape}")
# print(f"   min = {np.min(pe):.2e}, mean = {np.mean(pe):.2e}, max = {np.max(pe):2e}")
# print(f"-> computed true dHdx of shape {dHdx.shape}")
# print(f"-> computed true dxdt of shape {dxdt.shape}")
# print(f"   dqdt min = {np.min(dxdt[..., :2]):.2e}, mean = {np.mean(dxdt[..., 2:]):.2e}, max = {np.max(dxdt[..., 2:]):2e}")
# print(f"   dpdt min = {np.min(dxdt[..., 2:]):.2e}, mean = {np.mean(dxdt[..., :2]):.2e}, max = {np.max(dxdt[..., :2]):2e}")

del dHdx

# Compute edge index for each state
if np.isinf(cutoff):
    edge_index = edge_index_radius(x[0], np.inf)
    # print("-> cutoff=np.inf, we have the static edge index", edge_index)
else:
    edge_index = []
    node_degrees = []
    for state in x:
        edge_index_state = edge_index_radius(state, cutoff) # using no cutoff with same number of particles in every graph so the number of edges is fixed
        if edge_index_state != []:
            src, dst = edge_index_state
            all_nodes = tc.cat([src, dst])
            degree_count = tc.bincount(all_nodes)
            unique_degrees = tc.unique(degree_count)
            for unique_degree in unique_degrees:
                item = unique_degree.detach().item()
                if not item in node_degrees:
                    # print("-> New node degree:", item)
                    node_degrees.append(item)
                    # print("corresponding edge_index_state")
                    # print(edge_index_state)
        edge_index.append(edge_index_state)
    # print(f"-> cutoff={cutoff}, we have dynamic edge indexing with unique degrees", node_degrees)

def create_gnn(n_obj, edge_index, init_method, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               take_absolute_diff=True,
               msg_width=msg_width,
               enc_width=enc_width,
               activ_str=activ_str,
               init_method=init_method,
               seed=model_seed, dtype=dtype_tc
            )

def trad_train(model, x_train, x_test, L, dxdt_train, dxdt_test):
    """
    Traditionally train (using gradient-descent based optimization) FCNN or GNN
    """
    training_args = TradTrainingArgs(
        n_steps=config["train"]["n_steps"],
        batch_size=config["train"]["batch_size"],
        device=config["device"],
        weight_init=config["train"]["weight_init"],
        lr_start=config["train"]["lr_start"],
        lr_end=config["train"]["lr_end"],
        weight_decay=config["train"]["weight_decay"],
        patience=config["train"]["patience"],
        optim_type=config["train"]["optim_type"],
        sched_type=config["train"]["sched_type"],
    )
    return traditional_training(model, x_train, x_test, L, dxdt_train, dxdt_test, training_args)

def sample_fit(model, x_train, L, dxdt_train, param_sampler, sampling_seed):
    sampling_args = SamplingArgs(
        seed=sampling_seed,
        param_sampler=param_sampler,
        sample_uniformly=True,
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        driver=config["train"]["driver"],
        rcond=config["train"]["rcond"],
        device=config["device"],
        batch_size=None,    # Set batch size here as well to try batched lstsq
        # batch_size=config["train"].get("batch_size", None),
    )
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    memusage()
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, verbose=False)
    return dxdt_mse, dxdt_rel2

# create dataset
def get_indices(arr, indices):
    new_arr = []
    for index in indices:
        new_arr.append(arr[index])
    return new_arr

def train_test_split(x, edge_index, y, train_set_ratio, rng):
    indices = np.arange(len(x))
    rng.shuffle(indices)
    train_size: int = ceil(len(x) * train_set_ratio)
    train_indices, test_indices = indices[:train_size], indices[train_size:]
    x_train, x_test = x[train_indices], x[test_indices]
    edge_index_train, edge_index_test = get_indices(edge_index, train_indices), get_indices(edge_index, test_indices)
    y_train, y_test = y[train_indices], y[test_indices]
    return x_train, x_test, edge_index_train, edge_index_test, y_train, y_test

# prepare dataset
x_train, x_eval, edge_index_train, edge_index_eval, dxdt_train, dxdt_eval = \
        train_test_split(tc.from_numpy(flatten(x)), edge_index, tc.from_numpy(flatten(dxdt)), train_eval_ratio, rng)

print("prepared dataset")
print("x_train of shape", x_train.shape)
print("x_eval of shape", x_eval.shape)
print("edge_index_train of length", len(edge_index_train))
print("edge_index_eval of length", len(edge_index_eval))
print("dxdt_train of shape", dxdt_train.shape)
print("dxdt_eval of shape", dxdt_eval.shape)

if config["device"] == "cuda":
    tc.cuda.reset_peak_memory_stats(config["device"])
    tc.cuda.empty_cache()
L = tc.from_numpy(get_poisson_matrix(n_obj, 2, dtype=dtype_np))

model_seed = get_seed()
sampler_seed = get_seed()

# Train a random feature (SWIM) HGN
swim_gnn = create_gnn([n_obj], edge_index_train, init_method="none", model_seed=model_seed)
time0 = time.perf_counter()
swim_train_mse, swim_train_rel2 = sample_fit(swim_gnn, x_train, L, dxdt_train, "relu", sampling_seed=sampler_seed)
time1 = time.perf_counter()

swim_gnn_train_time = time1 - time0
print(f"(SWIM) RF-HGN finished training, took {swim_gnn_train_time} seconds")
if config["device"] == "cuda":
    mem_usage = convert_bytes_to_GB(tc.cuda.max_memory_allocated(config["device"]))
    print(f"used {mem_usage:.2f} GB memory, device={config["device"]}")
swim_gnn.edge_index = edge_index_eval
swim_eval_mse, swim_eval_rel2 = eval_dxdt(swim_gnn, x_eval, L, dxdt_eval, verbose=False)
print("dxdt")
print(f"- Train MSE         {swim_train_mse:.2e}")
print(f"- Train Rel. L2     {swim_train_rel2:.2e}")
print(f"- Eval MSE          {swim_eval_mse:.2e}")
print(f"- Eval Rel. L2      {swim_eval_rel2:.2e}")

# exit(0)

# Train a random feature (ELM) HGN
elm_gnn = create_gnn([n_obj], edge_index_train, init_method="none", model_seed=model_seed)
time0 = time.perf_counter()
elm_train_mse, elm_train_rel2 = sample_fit(elm_gnn, x_train, L, dxdt_train, "random", sampling_seed=sampler_seed)
time1 = time.perf_counter()

elm_gnn_train_time = time1 - time0
print(f"(ELM) RF-HGN finished training, took {elm_gnn_train_time} seconds")
if config["device"] == "cuda":
    mem_usage = convert_bytes_to_GB(tc.cuda.max_memory_allocated(config["device"]))
    print(f"used {mem_usage:.2f} GB memory, device={config["device"]}")
elm_gnn.edge_index = edge_index_eval
elm_eval_mse, elm_eval_rel2 = eval_dxdt(elm_gnn, x_eval, L, dxdt_eval, verbose=False)
print("dxdt")
print(f"- Train MSE         {elm_train_mse:.2e}")
print(f"- Train Rel. L2     {elm_train_rel2:.2e}")
print(f"- Eval MSE          {elm_eval_mse:.2e}")
print(f"- Eval Rel. L2      {elm_eval_rel2:.2e}")

# Train a random feature (Adam) HGN
adam_gnn = create_gnn([n_obj], edge_index_train, init_method="relu", model_seed=model_seed)
time0 = time.perf_counter()
adam_train_hist = trad_train(adam_gnn, x_train, x_eval, L, dxdt_train, dxdt_eval)
time1 = time.perf_counter()
adam_train_mse, adam_train_rel2 = eval_dxdt(adam_gnn, x_train, L, dxdt_train, verbose=False)
adam_gnn_train_time = time1 - time0
print(f"(Adam) RF-HGN finished training, took {adam_gnn_train_time} seconds")
if config["device"] == "cuda":
    mem_usage = convert_bytes_to_GB(tc.cuda.max_memory_allocated(config["device"]))
    print(f"used {mem_usage:.2f} GB memory, device={config["device"]}")
adam_gnn.edge_index = edge_index_eval
adam_eval_mse, adam_eval_rel2 = eval_dxdt(adam_gnn, x_eval, L, dxdt_eval, verbose=False)
print("dxdt")
print(f"- Train MSE         {adam_train_mse:.2e}")
print(f"- Train Rel. L2     {adam_train_rel2:.2e}")
print(f"- Eval MSE          {adam_eval_mse:.2e}")
print(f"- Eval Rel. L2      {adam_eval_rel2:.2e}")

# sanity-check
# Save models
swimgnn_modelpath = os.path.join(args.outdir, f"swim_rfhgn_{n_obj}train_lj.pt")
elmgnn_modelpath = os.path.join(args.outdir, f"elm_rfhgn_{n_obj}train_lj.pt")
adamgnn_modelpath = os.path.join(args.outdir, f"adam_hgn_{n_obj}train_lj.pt")
tc.save(swim_gnn, swimgnn_modelpath)
tc.save(elm_gnn, elmgnn_modelpath)
tc.save(adam_gnn, adamgnn_modelpath)
print(f"-> Models are saved at '{swimgnn_modelpath}', '{elmgnn_modelpath}', '{adamgnn_modelpath}'")
# exit(0)

# Evaluate simulations at 5 points linearly spaced
n_evals_in_traj = 5

# Evalute zero-shot with 4 nodes, 16 nodes, 64 nodes, 81 nodes and 100 nodes
# train 49 ==>... test 100
# n_obj_x, n_obj_y = 3, 3
# n_obj_x, n_obj_y = 6, 6
# n_obj_x, n_obj_y = 7, 7
# n_obj_x, n_obj_y = 8, 8
# n_obj_x, n_obj_y = 10, 10
n_obj_x, n_obj_y = args.testobj, args.testobj
r_eq = get_equil_r(sig=1.0)
x0, box_start, box_end = get_x0_and_box(n_obj_x, n_obj_y, 1.0, r_eq, q_noise=args.zero_shot_q_disp, p_noise=0.0, dtype=dtype_np, rng=rng)
print(f"-> Simulation box is created with {n_obj_x * n_obj_y} number of particles for testing")

# update x0 a bit
print("-> Equilibrating..")
x_traj, _, _, _, _ = simulate(
    x0, 1.0, 1.0, 1.0, cutoff,
    1000, 1e-3,
    bc="reflective", box_start=box_start, box_end=box_end
)
x0 = x_traj[-1]

# Ground truth simulation
delta_t = 1e-5
n_steps = 100_000
print("\n\n" + "-"*40 + "\n\n")
print("-> Simulating ground truth trajectory using the true Hamiltonian of the system")
print(f"  delta_t={delta_t}")
t0 = time.perf_counter()
x_traj, ke_traj, pe_traj, h_traj, dHdx_traj = simulate(
    x0, 1.0, 1.0, 1.0, cutoff,
    n_steps, delta_t,
    bc="reflective", box_start=box_start, box_end=box_end)
t1 = time.perf_counter()
print(f"-> Took {(t1 - t0):.2f} seconds")
dxdt_traj = hamiltons_eq(dHdx_traj)
print(f"-> computed x_traj {x_traj.shape}, h_traj {h_traj.shape}, dxdt_traj {dxdt_traj.shape}")
x_traj, ke_traj, pe_traj, h_traj, dxdt = x_traj[1:], ke_traj[1:], pe_traj[1:], h_traj[1:], dxdt[1:]
q_traj = x_traj[..., :dof]

def model_predict_traj(gnn, filename_prefix):
    # Model prediction
    print("\n\n" + "-"*40 + "\n\n")
    print("-> Predicting with the trained model")
    x_traj_pred, ke_traj_pred, pe_traj_pred, h_traj_pred, dHdx_traj_pred = simulate(
        x0, 1.0, 1.0, 1.0, cutoff,
        n_steps, delta_t, model=gnn,
        bc="reflective", box_start=box_start, box_end=box_end)
    t1 = time.perf_counter()
    print(f"-> Took {(t1 - t0):.2f} seconds")
    dxdt_traj_pred = hamiltons_eq(dHdx_traj_pred)
    print(f"-> computed x_traj_pred {x_traj_pred.shape}, h_traj_pred {h_traj_pred.shape}, dxdt_traj_pred {dxdt_traj_pred.shape}")

    # separate q, energy and dxdt, and remove the initial point where the error is 0
    # note that h_traj_pred uses the model's Hamiltonian-like conserved quantity but ke_traj_pred and pe_traj_pred
    # is evaluated using ground truth Hamiltonian just for inspecting from the rollout trajectory snapshots
    # This way we can compare the Model's conserved quantity which should be conserved when used with
    # a symplectic integrator, and we will be able to compare how well the model performs by
    x_traj_pred, ke_traj_pred, pe_traj_pred, h_traj_pred, dxdt_traj_pred = x_traj_pred[1:], ke_traj_pred[1:], pe_traj_pred[1:], h_traj_pred[1:], dxdt_traj_pred[1:]
    q_traj_pred = x_traj_pred[..., :dof]
    if args.save:
        testpath = os.path.join(args.outdir, filename_prefix + ".npz")
        np.savez(
            testpath,
            q_traj=q_traj, q_traj_pred=q_traj_pred,
            ke_traj=ke_traj, ke_traj_pred=ke_traj_pred,
            pe_traj=pe_traj, pe_traj_pred=pe_traj_pred,
            h_traj=h_traj, h_traj_pred=h_traj_pred
        )
        print(f"-> Test trajectory results are saved at '{testpath}'")

    test_indices = np.linspace(0, n_steps - 2, num=n_evals_in_traj, dtype=np.int64)
    q_mse, q_rel2 = np.zeros(n_evals_in_traj, dtype=dtype_np), np.zeros(n_evals_in_traj, dtype=dtype_np)
    h_true = np.zeros(n_evals_in_traj, dtype=dtype_np)
    h_pred = np.zeros(n_evals_in_traj, dtype=dtype_np)
    # h_mse, h_rel2 = np.zeros(n_evals_in_traj, dtype=dtype_np), np.zeros(n_evals_in_traj, dtype=dtype_np)
    # dxdt_mse, dxdt_rel2 = np.zeros(n_evals_in_traj, dtype=dtype_np), np.zeros(n_evals_in_traj, dtype=dtype_np)
    for idx, test_index in enumerate(test_indices):
        q_mse[idx] = mse(q_traj[test_index], q_traj_pred[test_index], verbose=False)
        q_rel2[idx] = l2_err(q_traj[test_index], q_traj_pred[test_index], verbose=False)
        h_true[idx] = h_traj[test_index]
        h_pred[idx] = h_traj_pred[test_index]

    print("Zero-shot test results on the evaluated trajectory at time steps:", test_indices, "with simulation delta_t =", delta_t, "num. steps", n_steps, "and n_obj", n_obj_x*n_obj_y)
    print(f"-> q mse        {q_mse}")
    print(f"-> q rel2       {q_rel2}")
    print(f"-> real hamil   {h_true}")
    print(f"-> pred hamil   {h_pred}")

# rmse_traj = np.sqrt(((q_traj- q_traj_pred)**2).mean(axis=(1, 2)))

# framing_length = 10
# framing_length = 100
# framing_length = 1000
# animate_2D(q_traj[::framing_length], ke_traj, pe_traj, h_traj,
           # q_pred=q_traj_pred[::framing_length], ke_pred=ke_traj_pred, pe_pred=pe_traj_pred, h_pred=h_traj_pred,
           # rmse=rmse_traj, delta_t=delta_t, framing_length=framing_length, filename="main_lennard_jones.mp4")

    # Save tabulated data
    table_title = "\nTable: Predicted trajectory evaluation, error values on positions (q) are displayed with true and predicted conserved (energy) values."
    arr_columns = [ f"T={step_idx+1}" for step_idx in test_indices ]
    results_table = tabulate(
        headers=[""] + arr_columns,
        tabular_data=[
            ["q MSE"] + list(q_mse),
            ["q L2 rel."] + list(q_rel2),
            ["True H"] + list(h_true),
            ["Pred H"] + list(h_pred),
        ],
        floatfmt=".3e"
    )

    print(results_table)

    filename = os.path.join(args.outdir, filename_prefix + ".txt")
    with open(filename, 'w') as f:
        f.write(table_title + '\n' + results_table)
    print(f"-> Experiment results are saved at {filename}")

model_predict_traj(swim_gnn, f"swim_gnn_{n_obj}train_{n_obj_x*n_obj_y}test")
model_predict_traj(elm_gnn, f"elm_gnn_{n_obj}train_{n_obj_x*n_obj_y}test")
model_predict_traj(adam_gnn, f"adam_gnn_{n_obj}train_{n_obj_x*n_obj_y}test")

exit(0)
