import argparse
import os
import sys
from functools import wraps
from psystems.nsprings import (chain)

MAINPATH = ".."  # nopep8
sys.path.append(MAINPATH)  # nopep8

import src
from jax.config import config
from src.md import *
from src.nve import NVEStates
from src.utils import *

config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser(description="...")
parser.set_defaults(training=True)
parser.add_argument("-N", "--nodes", type=int, default=3)
parser.add_argument('--training', action='store_true')
parser.add_argument('--test', dest='training', action='store_false')
parser.add_argument("-S", "--seed", type=int, default=42)
parser.add_argument("-C", "--n_config", type=int, default=100)
parser.add_argument("-L", "--label", type=str, default="0")
args = parser.parse_args()


#create a new state for storing data
class Datastate:
    def __init__(self, model_states):
        self.position = model_states.position[:-1]
        self.velocity = model_states.velocity[:-1]
        self.force = model_states.force[:-1]
        self.mass = model_states.mass[:-1]
        self.index = 0
        self.change_position = model_states.position[1:]-model_states.position[:-1]
        self.change_velocity = model_states.velocity[1:]-model_states.velocity[:-1]

def main(N1=3, N2=1, dim=2, runs=100, nconfig=100):
    if N2 is None:
        N2 = N1

    N = N1 * N2

    tag = f"{N}-Spring-data"
    seed = args.seed
    out_dir = f"../results"
    rstring = args.label
    filename_prefix = f"{out_dir}/{tag}/{rstring}/"

    def _filename(name):
        file = f"{filename_prefix}/{name}"
        os.makedirs(os.path.dirname(file), exist_ok=True)
        filename = f"{filename_prefix}/{name}".replace("//", "/")
        print("===", filename, "===")
        return filename

    def shift(R, dR, V):
        return R + dR, V

    def OUT(f):
        @wraps(f)
        def func(file, *args, **kwargs):
            return f(_filename(file), *args, **kwargs)

        return func

    savefile = OUT(src.io.savefile)

    np.random.seed(seed)

    init_confs = [chain(N)[:2]
                  for i in range(nconfig)]

    _, _, senders, receivers = chain(N)

    print("Saving init configs...")
    savefile(f"initial-configs.pkl",
             init_confs, metadata={"N1": N1, "N2": N2})

    masses = jnp.ones(N)

    dt = 1.0e-3
    stride = 100

    def pot_energy_orig(x):
        dr = jnp.square(x[senders, :] - x[receivers, :]).sum(axis=1)
        return vmap(partial(lnn.SPRING, stiffness=1.0, length=1.0))(dr).sum()

    kin_energy = partial(lnn._T, mass=masses)

    def Lactual(x, v, params):
        return kin_energy(v) - pot_energy_orig(x)

    def drag(x, v, params):
        return 0.0

    acceleration_fn_orig = lnn.accelerationFull(N, dim,
                                                lagrangian=Lactual,
                                                non_conservative_forces=drag,
                                                constraints=None,
                                                external_force=None)

    def force_fn_orig(R, V, params, mass=None):
        if mass is None:
            return acceleration_fn_orig(R, V, params)
        else:
            return acceleration_fn_orig(R, V, params) * mass.reshape(-1, 1)

    @jit
    def forward_sim(R, V):
        return predition(R, V, None, force_fn_orig, shift, dt, masses, stride=stride, runs=runs)

    @jit
    def v_forward_sim(init_conf):
        return vmap(lambda x: forward_sim(x[0], x[1]))(init_conf)

    print("Data generation ...")
    ind = 0
    dataset_states = []
    for R, V in init_confs:
        ind += 1
        print(f"{ind}/{len(init_confs)}", end='\r')
        model_states = forward_sim(R, V)
        if args.label =='0':
            dataset_states += [model_states]
        if args.label =='1':
            dataset_states += [Datastate(model_states)]
        if args.label =='2':
            print("For Hamiltonian data please use Spring-data-HGNN")
            break

    print("Saving datafile...")
    if args.training is True:
        savefile(f"model_states_train.pkl", dataset_states)
    else:
        savefile(f"model_states_test.pkl", dataset_states)


    if not os.path.exists(f"../data/{N}-spring/"):
        os.makedirs(f"../data/{N}-spring/")

    if args.training:
        np.save(f"../data/{N}-spring/dataset_train",np.array(dataset_states))
    else:
        np.save(f"../data/{N}-spring/dataset_test",np.array(dataset_states))

if __name__ == "__main__":
    main(N1=args.nodes, N2=1, dim=2, runs=100, nconfig=args.n_config)
