################################################
################## IMPORT ######################
################################################
import os
import sys
from datetime import datetime
from functools import wraps

import matplotlib.pyplot as plt

from psystems.nsprings import (chain, edge_order)

MAINPATH = ".."  # nopep8
sys.path.append(MAINPATH)  # nopep8

import jraph
import src
from jax.config import config
from src.graph1 import *
from src.md import *
from src.models import MSE, initialize_mlp
from src.utils import *

import time

config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)


def namestr(obj, namespace):
    return [name for name in namespace if namespace[name] is obj]


def pprint(*args, namespace=globals()):
    for arg in args:
        print(f"{namestr(arg, namespace)[0]}: {arg}")

import argparse
parser = argparse.ArgumentParser(description="Provide number of nodes and number (--nodes) of epochs (--epochs).")
parser.add_argument("-N", "--nodes", type=int, default=3)
parser.add_argument("-E", "--epochs", type=int, default=100)
args = parser.parse_args()


N = args.nodes
epochs = args.epochs
seed = 42
rname = True
dt = 1.0e-3
ifdrag = 0
stride = 100
trainm = 1
# mpass=1
lr = 0.001
withdata = None
datapoints = None
batch_size = 100
ifDataEfficiency = 0
if_noisy_data = 0

# def main(N=3, epochs=100, seed=42, rname=True, saveat=1,
#          dt=1.0e-5, ifdrag=0, stride=100, trainm=1, grid=False, mpass=1, lr=0.001, withdata=None, datapoints=None, batch_size=1000):

if (ifDataEfficiency == 1):
    data_points = int(sys.argv[1])
    batch_size = int(data_points / 100)

print("Configs: ")
pprint(N, epochs, seed, rname,
       dt, stride, lr, batch_size,
       namespace=locals())

randfilename = datetime.now().strftime(
    "%m-%d-%Y_%H-%M-%S") + f"_{datapoints}"

PSYS = f"{N}-Spring"
TAG = f"gnode"

if (ifDataEfficiency == 1):
    out_dir = f"../data-efficiency"
elif (if_noisy_data == 1):
    out_dir = f"../noisy_data"
else:
    out_dir = f"../results"


def _filename(name, tag=TAG):
    rstring = "0"
    if (ifDataEfficiency == 1):
        rstring = "0_" + str(data_points)

    if (tag == "data"):
        filename_prefix = f"../results/{PSYS}-{tag}/{0}/"
    else:
        filename_prefix = f"{out_dir}/{PSYS}-{tag}/{rstring}/"

    file = f"{filename_prefix}/{name}"
    os.makedirs(os.path.dirname(file), exist_ok=True)
    filename = f"{filename_prefix}/{name}".replace("//", "/")
    print("===", filename, "===")
    return filename


def displacement(a, b):
    return a - b


def shift(R, dR, V):
    return R + dR, V


def OUT(f):
    @wraps(f)
    def func(file, *args, tag=TAG, **kwargs):
        return f(_filename(file, tag=tag), *args, **kwargs)

    return func


loadmodel = OUT(src.models.loadmodel)
savemodel = OUT(src.models.savemodel)

loadfile = OUT(src.io.loadfile)
savefile = OUT(src.io.savefile)
save_ovito = OUT(src.io.save_ovito)

################################################
################## CONFIG ######################
################################################
np.random.seed(seed)
key = random.PRNGKey(seed)

try:
    dataset_states = loadfile(f"model_states_train.pkl", tag="data")[0]
except:
    raise Exception("Generate dataset first.")

if datapoints is not None:
    dataset_states = dataset_states[:datapoints]

model_states = dataset_states[0]

print(
    f"Total number of data points: {len(dataset_states)}x{model_states.position.shape[0]}")

N, dim = model_states.position.shape[-2:]
species = jnp.zeros(N, dtype=int)
masses = jnp.ones(N)

Rs, Vs, Fs = States().fromlist(dataset_states).get_array()
Rs = Rs.reshape(-1, 1, N, dim)
Vs = Vs.reshape(-1, 1, N, dim)
Fs = Fs.reshape(-1, 1, N, dim)

if (if_noisy_data == 1):
    Rs = np.array(Rs)
    Fs = np.array(Fs)
    Vs = np.array(Vs)

    np.random.seed(100)
    for i in range(len(Rs)):
        Rs[i] += np.random.normal(0, 1, 1)
        Vs[i] += np.random.normal(0, 1, 1)
        Fs[i] += np.random.normal(0, 1, 1)

    Rs = jnp.array(Rs)
    Fs = jnp.array(Fs)
    Vs = jnp.array(Vs)

mask = np.random.choice(len(Rs), len(Rs), replace=False)
allRs = Rs[mask]
allVs = Vs[mask]
allFs = Fs[mask]

Ntr = int(0.75 * len(Rs))
Nts = len(Rs) - Ntr

Rs = allRs[:Ntr]
Vs = allVs[:Ntr]
Fs = allFs[:Ntr]

Rst = allRs[Ntr:]
Vst = allVs[Ntr:]
Fst = allFs[Ntr:]

print(f"training data shape(Rs): {Rs.shape}")
print(f"test data shape(Rst): {Rst.shape}")

################################################
################### ML Model ###################
################################################

print("Creating Chain")
_, _, senders, receivers = chain(N)
eorder = edge_order(len(senders))

Ef = dim  # eij dim
Nf = dim
Oh = 1

Eei = 5
Nei = 5
Nei_ = 5  ##Nei for mass

hidden = 5
nhidden = 2


def get_layers(in_, out_):
    return [in_] + [hidden] * nhidden + [out_]


def mlp(in_, out_, key, **kwargs):
    return initialize_mlp(get_layers(in_, out_), key, **kwargs)


fneke_params = initialize_mlp([Oh, Nei], key)
fne_params = initialize_mlp([Oh, Nei], key)  #

# Nei = Nei+dim+dim
fb_params = mlp(Ef, Eei, key)  #
fv_params = mlp(Nei + Eei, Nei, key)  #
fe_params = mlp(Nei, Eei, key)  #

ff1_params = mlp(Eei, dim, key)
ff2_params = mlp(Nei, dim, key)  #
ff3_params = mlp(Nei + dim + dim, dim, key)
ke_params = initialize_mlp([1 + Nei, 10, 10, 1], key, affine=[True])
mass_params = initialize_mlp([Nei_, 5, 1], key, affine=[True])  #

Fparams = dict(fb=fb_params,
               fv=fv_params,
               fe=fe_params,
               ff1=ff1_params,
               ff2=ff2_params,
               ff3=ff3_params,
               fne=fne_params,
               fneke=fneke_params,
               ke=ke_params,
               mass=mass_params)

params = {"Fqqdot": Fparams}


def graph_force_fn(params, graph):
    _GForce = a_cdgnode_cal_force_q_qdot(params, graph, eorder=None,
                                         useT=True)
    return _GForce


R, V = Rs[0][0], Vs[0][0]


def _force_fn(species):
    state_graph = jraph.GraphsTuple(nodes={
        "position": R,
        "velocity": V,
        "type": species
    },
        edges={},
        senders=senders,
        receivers=receivers,
        n_node=jnp.array([R.shape[0]]),
        n_edge=jnp.array([senders.shape[0]]),
        globals={})

    def apply(R, V, params):
        state_graph.nodes.update(position=R)
        state_graph.nodes.update(velocity=V)
        return graph_force_fn(params, state_graph)

    return apply


apply_fn = _force_fn(species)


# v_apply_fn = vmap(apply_fn, in_axes=(None, 0))

def F_q_qdot(x, v, params): return apply_fn(x, v, params["Fqqdot"])

acceleration_fn_model = F_q_qdot
v_acceleration_fn_model = vmap(acceleration_fn_model, in_axes=(0, 0, None))
v_v_acceleration_fn_model = vmap(v_acceleration_fn_model, in_axes=(0, 0, None))

################################################
################## ML Training #################
################################################

@jit
def loss_fn(params, Rs, Vs, Fs):
    pred = v_v_acceleration_fn_model(Rs, Vs, params)
    return MSE(pred, Fs)


def gloss(*args):
    return value_and_grad(loss_fn)(*args)


def update(i, opt_state, params, loss__, *data):
    """ Compute the gradient for a batch and update the parameters """
    value, grads_ = gloss(params, *data)
    opt_state = opt_update(i, grads_, opt_state)
    return opt_state, get_params(opt_state), value


@jit
def step(i, ps, *args):
    return update(i, *ps, *args)


opt_init, opt_update_, get_params = optimizers.adam(lr)


@jit
def opt_update(i, grads_, opt_state):
    grads_ = jax.tree_map(jnp.nan_to_num, grads_)
    grads_ = jax.tree_map(partial(jnp.clip, a_min=-1000.0, a_max=1000.0), grads_)
    return opt_update_(i, grads_, opt_state)


def batching(*args, size=None):
    L = len(args[0])
    if size != None:
        nbatches1 = int((L - 0.5) // size) + 1
        nbatches2 = max(1, nbatches1 - 1)
        size1 = int(L / nbatches1)
        size2 = int(L / nbatches2)
        if size1 * nbatches1 > size2 * nbatches2:
            size = size1
            nbatches = nbatches1
        else:
            size = size2
            nbatches = nbatches2
    else:
        nbatches = 1
        size = L

    newargs = []
    for arg in args:
        newargs += [jnp.array([arg[i * size:(i + 1) * size]
                               for i in range(nbatches)])]
    return newargs


bRs, bVs, bFs = batching(Rs, Vs, Fs,
                         size=min(len(Rs), batch_size))

print(f"training ...")

opt_state = opt_init(params)
epoch = 0
optimizer_step = -1
larray = []
ltarray = []
last_loss = 1000

start = time.time()
train_time_arr = []
for epoch in range(epochs):
    l = 0.0
    count = 0
    for data in zip(bRs, bVs, bFs):
        optimizer_step += 1
        opt_state, params, l_ = step(
            optimizer_step, (opt_state, params, 0), *data)
        l += l_
        count += 1

    # opt_state, params, l_ = step(
    #     optimizer_step, (opt_state, params, 0), Rs, Vs, Fs)
    l = l / count
    larray += [l]
    ltarray += [loss_fn(params, Rst, Vst, Fst)]

    if epoch % 1000 == 0:
        metadata = {
            "savedat": epoch,
        }
        savefile(f"trained_model.dil",
                 params, metadata=metadata)
        savefile(f"loss_array.dil",
                 (larray, ltarray), metadata=metadata)
        if last_loss > larray[-1]:
            last_loss = larray[-1]
            savefile(f"trained_model_low.dil",
                     params, metadata=metadata)

plt.clf()
fig, axs = plt.subplots(1, 1)
plt.semilogy(larray, label="Training")
plt.semilogy(ltarray, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(_filename(f"training_loss.png"))
now = time.time()
train_time_arr.append((now - start))

plt.clf()
fig, axs = plt.subplots(1, 1)
plt.semilogy(larray, label="Training")
plt.semilogy(ltarray, label="Test")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(_filename(f"training_loss.png"))

params = get_params(opt_state)
savefile(f"trained_model.dil",
         params, metadata=metadata)
savefile(f"loss_array.dil",
         (larray, ltarray), metadata=metadata)

if last_loss > larray[-1]:
    last_loss = larray[-1]
    savefile(f"trained_model_low.dil",
             params, metadata=metadata)
if (ifDataEfficiency == 0):
    if not os.path.exists(f"../results/{N}-Spring-gnode/"):
        os.makedirs(f"../results/{N}-Spring-gnode/")
    np.savetxt(f"../results/{N}-Spring-gnode/training_time.txt", train_time_arr, delimiter="\n")
    np.save(f"../results/{N}-Spring-gnode/train_loss.npy", larray)
    np.save(f"../results/{N}-Spring-gnode/test_loss.npy", ltarray)

