import os
import toml
import torch as tc
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from argparse import ArgumentParser
from src.data import Pendulum, MassSpring, grad_H_mass_spring, H_mass_spring
from src.utils import Mesh, animate_2D, infer_gradient, flatten_TC, unflatten_TC
from src.model import GNN

# Read arguments: System to simulate, system configuration for e.g. initial condition (toml file), model to predict the dynamics
argparser = ArgumentParser()
argparser.add_argument("--system", "-s", type=str, required=True, choices=["pendulum-chain", "spring-chain", "anharmonic-chain", "morse-chain"])
argparser.add_argument("--toml", "-t", type=str, required=True)
argparser.add_argument("--model", "-m", type=str, required=False, default="")
argparser.add_argument("--outdir", "-o", type=str, required=True)
args = vars(argparser.parse_args())

# print(args["system"])
# print(args["toml"])
# print(args["model"])

config = toml.load(args["toml"])

dtype, device, data_config, n_obj = config["dtype"], config["device"], config["data"], config["n_obj"]
data_seed = data_config["data_seed"]
system_config = data_config[args["system"]]
simulate_config = config["simulation"]
phase_space_config = config["phase-space"]
model_filepath = args["model"]
dtype = config["dtype"]

if dtype == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif dtype == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")

if model_filepath:
    os.path.isfile(model_filepath)
    assert (
        model_filepath[-3:] == ".pt"
    ), f"-> Input model '{model_filepath}' must have .pt extension"

def prepare_spring_chain(n_obj, dof, n_steps, system_name):
    # Prepare data
    rng = np.random.default_rng(5561947)
    n_features = n_obj * 2*dof
    # Create mass spring data to store the trajectory, where n_points = len_traj
    q = np.zeros((n_steps, n_features//2), dtype=dtype_np)
    p = np.zeros((n_steps, n_features//2), dtype=dtype_np)
    system = MassSpring(n_points=n_steps, n_features=n_features, q=q, p=p,
                        mass=system_config["mass"], spring_constant=system_config["spring_constant"],
                        n_obj=[n_obj], dof=dof, meshing=Mesh("rectangular"),
                        l=system_config["l"], D=system_config["D"], a=system_config["a"],
                        system=system_name)
    return system

def simulate_spring_chain(system, n_steps, delta_t, grad_H, H):
    # Integrate
    # time0 = time()
    system.integrate(n_steps, delta_t, grad_H, simulate_config["fixed_indices"],
                     np.array(simulate_config["gravity"], dtype=dtype_np),
                     np.array(simulate_config["force_node_idx"], dtype=np.int64),
                     np.array(simulate_config["force_vec"], dtype=np.float64),
                     np.array(simulate_config["force_snap_idx"], dtype=np.int64),
                     integration_method=simulate_config["integration_method"])
    # time1 = time()
    # print(f"Integration took: {(time1-time0):.2f} seconds")
    x = system.to_array()
    print("-- Trajectory Information --")
    print(f"- q limits: [{np.min(x[..., :system.dof])}, {np.max(x[..., :system.dof])}]")
    print(f"- p limits: [{np.min(x[..., system.dof:])}, {np.max(x[..., system.dof:])}]")
    qx, qy, px, py = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
    hamiltonian = H(x)
    print(f"- hamiltonian limits: [{np.min(hamiltonian):.2f}, {np.max(hamiltonian):.2f}]")
    return qx, qy, px, py, hamiltonian

def get_system(system_name):
    match system_name:
        case "spring-chain":
            system = prepare_spring_chain(n_obj[0], system_config["dof"], simulate_config["n_steps"], "chain")
            H = lambda x: H_mass_spring(
                x, system.n_obj, system.meshing, system.mass,
                spring_constant=system.spring_constant, system="chain"
            )
            integration_grad_H = lambda x: grad_H_mass_spring(
                x, system.dof, system.meshing, system.mass,
                spring_constant=system.spring_constant, system="chain"
            )
            return system, integration_grad_H, H
        case "anharmonic-chain":
            system = prepare_spring_chain(n_obj[0], system_config["dof"], simulate_config["n_steps"], "anharmonic")
            H = lambda x: H_mass_spring(
                x, system.n_obj, system.meshing, system.mass,
                spring_constant=system.spring_constant, l=system.l, system="anharmonic"
            )
            integration_grad_H = lambda x: grad_H_mass_spring(
                x, system.dof, system.meshing, system.mass,
                spring_constant=system.spring_constant, l=system.l, system="anharmonic"
            )
            return system, integration_grad_H, H
        case "morse-chain":
            system = prepare_spring_chain(n_obj[0], system_config["dof"], simulate_config["n_steps"], "morse")
            H = lambda x: H_mass_spring(
                x, system.n_obj, system.meshing, system.mass,
                D=system.D, a=system.a, system="morse"
            )
            integration_grad_H = lambda x: grad_H_mass_spring(
                x, system.dof, system.meshing, system.mass,
                D=system.D, a=system.a, system="morse"
            )
            return system, integration_grad_H, H
        case "pendulum-chain":
            return Pendulum(system_config["link_length"], system_config["mass"], system_config["gravity"],
                            # origin=0)
                            origin=np.prod(n_obj)*system_config["link_length"])
        case _:
            raise ValueError("Unknown system")

system, integration_grad_H, H = get_system(args["system"])

if model_filepath:
    pred_sim_path = f"pred_simulation_{args['system']}.npz"
    pred_sim_path = os.path.join(args["outdir"], pred_sim_path)
    if os.path.exists(pred_sim_path):
        raise FileExistsError(f"{pred_sim_path} already exists, remove file or point to another directory")

    system_pred = deepcopy(system)
    model = tc.load(model_filepath, weights_only=False)
    print(f"-> Model is loaded from '{model_filepath}' width {model.node_encoder.weight.shape[0] + model.msg_encoder.weight.shape[0]}")
    model.n_obj = n_obj
    model.edge_index = system.edge_index()
    print("-> Edge index and number of nodes is adjusted in model, number of nodes is", n_obj[0])

    def H_model(x):
        x = tc.from_numpy(x)
        x = flatten_TC(x)
        energy_pred = model(x)
        energy_pred = energy_pred.detach().cpu().numpy()
        return energy_pred

    def integration_grad_H_model(x):
        x = tc.from_numpy(x)
        x = flatten_TC(x)
        dxdt_pred = infer_gradient(model, x)
        dxdt_pred = unflatten_TC(dxdt_pred, n_obj, system.dof)
        dxdt_pred = dxdt_pred.detach().cpu().numpy()
        return dxdt_pred

    # integration_grad_H_model = lambda x: infer_gradient(model, tc.from_numpy(x)).detach().cpu().numpy()
    qx_pred, qy_pred, px_pred, py_pred, e_pred= simulate_spring_chain(
        system_pred, simulate_config["n_steps"], simulate_config["delta_t"], integration_grad_H_model, H_model
    )
    q_pred = np.stack([qx_pred, qy_pred], axis=-1)
    p_pred = np.stack([px_pred, py_pred], axis=-1)
    np.savez(
        pred_sim_path,
        q=q_pred,
        p=p_pred,
        e=e_pred,
    )
    print(f"-> Model simulation saved at '{pred_sim_path}'")
    q_pred[..., 0] = q_pred[..., 0] + np.arange(n_obj[0]).reshape(1, -1) * simulate_config["spring_length"]
    # print("- Last position model:", q_pred[-1])
    indices = np.linspace(0, len(e_pred) - 1, num=5, dtype=np.int64)
    print("-> Positions (model) at:", indices.tolist())
    print(q_pred[indices, phase_space_config["node_idx"]].tolist())
    print("-> Energy (model) at:", indices.tolist())
    print(e_pred[indices].tolist())
    # animation_model_filename = str(args["system"]) + "_pred_" + ".mp4"
    # animate_2D(q_pred, system_pred.edge_index().cpu().numpy(), framing_length=simulate_config["framing_length"], filename=animation_model_filename)
    # print(f"-> Model animation saved at '{animation_model_filename}'")
else:
    q_pred = None

true_sim_path = f"true_simulation_{args['system']}.npz"
true_sim_path = os.path.join(args["outdir"], true_sim_path)
if os.path.exists(true_sim_path):
    print(f"-> '{true_sim_path}' exists, skipping ground truth simulation")
    exit(0)

qx_true, qy_true, px_true, py_true, e_true = simulate_spring_chain(
    system, simulate_config["n_steps"], simulate_config["delta_t"], integration_grad_H, H
)
q_true = np.stack([qx_true, qy_true], axis=-1)
p_true = np.stack([px_true, py_true], axis=-1)
np.savez(
    true_sim_path,
    q=q_true,
    p=p_true,
    e=e_true,
)
print(f"-> Ground truth simulation saved at '{true_sim_path}'")
q_true[..., 0] = q_true[..., 0] + np.arange(n_obj[0]).reshape(1, -1) * simulate_config["spring_length"]
indices = np.linspace(0, len(e_true) - 1, num=5, dtype=np.int64)
print("-> Positions (true) at:", indices.tolist())
print(q_true[indices, phase_space_config["node_idx"]].tolist())
print("-> Energy (true) at:", indices.tolist())
print(e_true[indices].tolist())

exit(0)

animation_filename = str(args["system"]) + ".mp4"
animation_filename = os.path.join(args["outdir"], animation_filename)
animate_2D(
    q_true,
    system.edge_index().cpu().numpy(),
    framing_length=simulate_config["framing_length"],
    filename=animation_filename,
    q_pred=q_pred
)
# TODO: plot position-energy-mse plot
print(f"-> Animation saved at '{animation_filename}'")

fig, ax = plt.subplots(1, 1, figsize=(5, 3), dpi=100)
ax.plot(e_true, c="red", label="True")
if model_filepath:
    # fix integration constant
    integration_const = e_true[0] - e_pred[0]
    e_pred = e_pred + integration_const.item()
    ax.plot(e_pred, c="tab:blue", linestyle="dashed", label="Pred")
ax.set_yscale("symlog", linthresh=0.015)
ax.set_title(r"Hamiltonian")
ax.set_xlabel("Time step")
# ax.set_xticks()
ax.grid()
ax.legend()
hamil_plot_filename = f"hamil-{args['system']}.pdf"
hamil_plot_filename = os.path.join(args["outdir"], hamil_plot_filename)
fig.savefig(hamil_plot_filename)
print(f"-> System hamil plot saved at '{hamil_plot_filename}'")

# TODO: plot the phase space for a node
if system_config["dof"] == 2:
    node_idx = phase_space_config["node_idx"]
    q_min, q_max = phase_space_config["q_min"], phase_space_config["q_max"]

    gridsize = 200
    q = np.zeros((gridsize**2, *n_obj, system_config["dof"]), dtype=dtype_np)
    p = np.zeros((gridsize**2, *n_obj, system_config["dof"]), dtype=dtype_np)
    q_node = np.linspace(q_min, q_max, 200, dtype=dtype_np)
    qx_node, qy_node = np.meshgrid(q_node, q_node, indexing="ij")
    q[:, node_idx, 0] = qx_node.flatten()
    q[:, node_idx, 1] = qy_node.flatten()
    x = np.concatenate([q, p], axis=-1)
    energy = H(x)
    print(f"- Energy ndarray shape {energy.shape}")
    print(f"- Maximum Energy {energy.max():.2e}")
    print(f"- Minimum Energy {energy.min():.2e}")
    fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=100)
    im = ax.contourf(q_node, q_node, energy.reshape(q_node.size, q_node.size).T)
    fig.colorbar(im, ax=ax)
    ax.set_title(r"$\mathcal{H}$")
    ax.set_xlabel(fr"$q_x$ of node {node_idx}")
    ax.set_ylabel(fr"$q_y$ of node {node_idx}")
    # ax.set_xticks()
    phase_plot_filename = f"phase_plot_{args['system']}.pdf"
    fig.savefig(phase_plot_filename)
    print(f"-> System phase plot saved at '{phase_plot_filename}'")

# Just plotting an overview of different potentials
def spring_potential(r, k=1.0):
    return 0.5 * k * r**2

def anharmonic_potential(r, k=1.0, l=1.0):
    return 0.5 * k * r**2 + l * r**4

def morse_potential(r, k=1.0, D=1.0):
    a = np.sqrt(k/ (2 * D))
    return D * (1 - np.exp(-a / r))**2

# plot side by side all potentials
fig, axes = plt.subplots(1, 3, figsize=(15, 4), dpi=100)
r = np.linspace(1e-5, 2, num=1000)
for ax, f, title in zip(axes, [spring_potential, anharmonic_potential, morse_potential],
                        ["spring", "anharmonic", "morse"]):
    ax.plot(r, f(r))
    ax.grid(True)
    ax.set_title(title)
    ax.set_xlabel("Distance")
potentials_filename = "potentials.pdf"
fig.savefig(potentials_filename)
print(f"-> System potentials saved at '{potentials_filename}'")
