import argparse
import os
import sys
from functools import wraps

import fire
import matplotlib.pyplot as plt

from psystems.nsprings import (chain)

MAINPATH = ".."  # nopep8
sys.path.append(MAINPATH)  # nopep8

import src
from jax.config import config
from src.md import *
from src.nve import NVEStates
from src.utils import *

config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser(description="Provide number of nodes and number (--nodes) of epochs (--epochs).")
parser.add_argument("-N", "--nodes", type=int, default=3)
parser.add_argument("-T", "--training", type=bool, default=True)
parser.add_argument("-S", "--seed", type=int, default=42)
parser.add_argument("-C", "--n_config", type=int, default=100)
args = parser.parse_args()


def main(N1=5, N2=1, dim=2, runs=2000, nconfig=100, ifdrag=0):
    if N2 is None:
        N2 = N1

    N = N1 * N2

    tag = f"{N}-Spring-data"
    seed = 420
    out_dir = f"../results"
    rname = False
    rstring = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") if rname else "1"  # + str(nconfig * (runs - 1))
    filename_prefix = f"{out_dir}/{tag}/{rstring}/"

    def _filename(name):
        file = f"{filename_prefix}/{name}"
        os.makedirs(os.path.dirname(file), exist_ok=True)
        filename = f"{filename_prefix}/{name}".replace("//", "/")
        print("===", filename, "===")
        return filename

    def shift(R, dR, V):
        return R + dR, V

    def OUT(f):
        @wraps(f)
        def func(file, *args, **kwargs):
            return f(_filename(file), *args, **kwargs)

        return func

    savefile = OUT(src.io.savefile)
    save_ovito = OUT(src.io.save_ovito)

    np.random.seed(seed)

    init_confs = [chain(N)[:2]
                  for i in range(nconfig)]

    _, _, senders, receivers = chain(N)

    print("Saving init configs...")
    savefile(f"initial-configs.pkl",
             init_confs, metadata={"N1": N1, "N2": N2})

    masses = jnp.ones(N)

    dt = 1.0e-3
    stride = 100

    def pot_energy_orig(x):
        dr = jnp.square(x[senders, :] - x[receivers, :]).sum(axis=1)
        return vmap(partial(lnn.SPRING, stiffness=1.0, length=1.0))(dr).sum()

    kin_energy = partial(lnn._T, mass=masses)

    def Lactual(x, v, params):
        return kin_energy(v) - pot_energy_orig(x)

    def drag(x, v, params):
        return 0.0

    acceleration_fn_orig = lnn.accelerationFull(N, dim,
                                                lagrangian=Lactual,
                                                non_conservative_forces=drag,
                                                constraints=None,
                                                external_force=None)

    def force_fn_orig(R, V, params, mass=None):
        if mass is None:
            return acceleration_fn_orig(R, V, params)
        else:
            return acceleration_fn_orig(R, V, params) * mass.reshape(-1, 1)

    @jit
    def forward_sim(R, V):
        return predition(R, V, None, force_fn_orig, shift, dt, masses, stride=stride, runs=runs)

    @jit
    def v_forward_sim(init_conf):
        return vmap(lambda x: forward_sim(x[0], x[1]))(init_conf)

    print("Data generation ...")
    ind = 0
    dataset_states = []
    for R, V in init_confs:
        ind += 1
        print(f"{ind}/{len(init_confs)}", end='\r')
        model_states = forward_sim(R, V)
        dataset_states += [model_states]

    print("Saving datafile...")
    savefile(f"model_states_test.pkl", dataset_states)

    def cal_energy(states):
        KE = vmap(kin_energy)(states.velocity)
        PE = vmap(pot_energy_orig)(states.position)
        L = vmap(Lactual, in_axes=(0, 0, None))(
            states.position, states.velocity, None)
        return jnp.array([PE, KE, L, KE + PE]).T

    print("plotting energy...")
    ind = 0
    for states in dataset_states:
        ind += 1
        Es = cal_energy(states)

        plt.plot(Es, label=["PE", "KE", "L", "TE"], lw=6, alpha=0.5)
        plt.legend(bbox_to_anchor=(1, 1))
        plt.ylabel("Energy")
        plt.xlabel("Time step")

        title = f"{N}-Spring random state {ind}"
        plt.title(title)
        plt.savefig(
            _filename(title.replace(" ", "_") + ".png"), dpi=300)
        save_ovito(f"dataset_{ind}.data", [
            state for state in NVEStates(states)], lattice="")

        if ind >= 10:
            break


fire.Fire(main)
