import argparse
import os
import sys
from functools import wraps

import fire
import matplotlib.pyplot as plt

from psystems.nsprings import (chain)

MAINPATH = ".."  # nopep8
sys.path.append(MAINPATH)  # nopep8

import src
from jax.config import config
from src.utils import *
from src.hamiltonian import get_zdot_lambda, ode

config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser(description="Provide number of nodes and number (--nodes) of epochs (--epochs).")
parser.set_defaults(training=True)
parser.add_argument("-N", "--nodes", type=int, default=3)
parser.add_argument('--training', action='store_true')
parser.add_argument('--test', dest='training', action='store_false')
parser.add_argument("-S", "--seed", type=int, default=42)
parser.add_argument("-C", "--n_config", type=int, default=100)
parser.add_argument("-L", "--label", type=str, default="0")
args = parser.parse_args()


def main(N=args.nodes,dim=2, nconfig=args.n_config, saveat=100, dt=1e-3, stride=100, runs=100):
    tag = f"{N}-Spring-data"
    seed = args.seed
    out_dir = f"../results"
    rstring = args.label
    filename_prefix = f"{out_dir}/{tag}/{rstring}/"

    def _filename(name):
        file = f"{filename_prefix}/{name}"
        os.makedirs(os.path.dirname(file), exist_ok=True)
        filename = f"{filename_prefix}/{name}".replace("//", "/")
        print("===", filename, "===")
        return filename

    def OUT(f):
        @wraps(f)
        def func(file, *args, **kwargs):
            return f(_filename(file), *args, **kwargs)

        return func

    np.random.seed(seed)
    init_confs = [chain(N)[:2] for i in range(nconfig)]

    _, _, senders, receivers = chain(N)


    print("Saving init configs...")
    savefile = OUT(src.io.savefile)
    savefile(f"initial-configs.pkl",
             init_confs, metadata={"N1": N, "N2": N})

    masses = jnp.ones(N)

    def pot_energy_orig(x):
        dr = jnp.square(x[senders, :] - x[receivers, :]).sum(axis=1)
        return jax.vmap(partial(src.hamiltonian.SPRING, stiffness=1.0, length=1.0))(dr).sum()

    kin_energy = partial(src.hamiltonian._T, mass=masses)

    def Hactual(x, p, params):
        return kin_energy(p) + pot_energy_orig(x)

    def drag(x, p, params):
        return 0.0

    zdot, lamda_force = get_zdot_lambda(
        N, dim, Hactual, drag=drag, constraints=None, external_force=None)

    def zdot_func(z, t):
        x, p = jnp.split(z, 2)
        return zdot(x, p, None)

    def get_z(x, p):
        return jnp.vstack([x, p])

    def zz(out, ind=None):
        if ind is None:
            x, p = jnp.split(out, 2, axis=1)
            return x, p
        else:
            return jnp.split(out, 2, axis=1)[ind]

    t = jnp.linspace(0.0, runs * stride * dt, runs * stride)

    print("Data generation ...")
    ind = 0
    dataset_states = []
    for x, p in init_confs:
        _z_out = ode.odeint(zdot_func, get_z(x, p), t)
        z_out = _z_out[0::stride]
        xout, pout = zz(z_out)
        zdot_out = jax.vmap(zdot, in_axes=(0, 0, None))(xout, pout, None)
        ind += 1
        print(f"{ind}/{len(init_confs)}", end='\r')
        model_states = z_out, zdot_out
        dataset_states += [model_states]

    print("Saving datafile...")
    if args.training is True:
        savefile(f"model_states_train.pkl", dataset_states)
    else:
        savefile(f"model_states_test.pkl", dataset_states)


    if not os.path.exists(f"../data/{N}-spring/"):
        os.makedirs(f"../data/{N}-spring/")
    if args.training:
        np.save(f"../data/{N}-spring/dataset_train",np.array(dataset_states))
    else:
        np.save(f"../data/{N}-spring/dataset_test",np.array(dataset_states))



if __name__ == "__main__":
    main(N=args.nodes, dim=2, runs=100, nconfig=args.n_config)
