"""
This file includes function implementations in

Physics
- Hamiltonian functions H and grad_H for molecular dynamics
- Also the Poisson matrix computation to evaluate the Hamilton's equations

Numerics
- 2D lattice creation to find the equilibrium points easier in 2D
- Edge index in the fully-connected or k-nearest sense
- Simulation loop to time-step an MD system using the Hamiltonian formulation
- Initial condition generation for an MD system
- Full experiment pipeline including traj-generation, data-generation, learning, evaluation and plotting

Math
- Find an equilibrium state given an MD problem (hopefully that is not far away)
- Distance calculation with a minimum image convention

Plotting
- For plotting numerical solution and prediction in 1D.

Learning
- Dataset generation for MD simulations to train the model by
  sampling with a r_min and reject the data points that are very
  close to each other

"""

import numpy as np
import torch as tc
import scipy.spatial.distance as spdist
import matplotlib.pyplot as plt
import matplotlib.animation as anim
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from scipy.optimize import minimize
from tqdm import tqdm
from time import time
from src.model import GNN
from src.train import sample_and_linear_solve
from src.utils import stormer_verlet_step, eval_dxdt, hamiltons_eq, flatten_TC, unflatten_TC, grad, flatten

################ PHYSICS ################

def H(x, mass, eps, sig, cutoff, as_separate=False, box_size=None):
    """Computes the total Hamiltonian (kinetic part + potential part) given x of shape (n_graphs, n_obj, 2*dof)
    Outputs Hamiltonians of shape (n_graphs, 1)
    """
    if box_size is not None: raise NotImplementedError("Periodic boundary condition is not implemented yet, but box_size was given")
    n_graphs = len(x)
    q, p = np.split(x, 2, axis=-1)  # each of shape (n_graphs, n_obj, dof)
    ke_total = np.empty((n_graphs, 1), dtype=x.dtype)
    pe_total = np.empty((n_graphs, 1), dtype=x.dtype)

    for idx in range(n_graphs):
        ke_total[idx] = kinetic_part(p[idx], mass).sum()
        pe_total[idx] = potential_part(q[idx], eps, sig, cutoff, box_size).sum()

    if as_separate:
        return ke_total, pe_total
    return ke_total + pe_total # of shape (n_graphs, 1)

def grad_H(x, mass, eps, sig, cutoff, box_size=None):
    """Computes the gradient of the Hamiltonian given x of shape (n_graphs, n_obj, 2*dof)
    Outputs gradient of the Hamiltonian of shape (n_graphs, n_obj, dof)
    """
    if box_size is not None: raise NotImplementedError("Periodic boundary condition is not implemented yet, but box_size was given")
    n_graphs = len(x)
    q, p = np.split(x, 2, axis=-1)  # each of shape (n_graphs, n_obj, dof)
    grad_ke = np.empty_like(p, dtype=x.dtype)      # kinetic energy gradient only affects momenta (p)
    grad_pe = np.empty_like(q, dtype=x.dtype)      # potential energy gradient only affects positions (q)

    for idx in range(n_graphs):
        grad_ke[idx] = grad_kinetic_part(p[idx], mass) # shape (n_obj, dof)
        grad_pe[idx] = grad_potential_part(q[idx], eps, sig, cutoff, box_size) # shape (n_obj, dof)

    return np.concatenate([grad_pe, grad_ke], axis=-1) # shape (n_graphs, n_obj, 2*dof) where n_features = 2*dof

def kinetic_part(p, mass):
    """Computes the kinetic energy given p of shape (n_obj, dof)
    Outputs kinetic energies of shape (n_obj, 1) as (0.5 * (|p|^2) / mass) for each object
    """
    return 0.5 * ((p**2).sum(axis=-1, keepdims=True) / mass) # (n_obj, 1)

def grad_kinetic_part(p, mass):
    """Computes gradient of the kinetic energy dT/dp given p of shape (n_obj, dof)
    Outputs gradient of shape (n_obj, dof) as (p / mass)
    """
    return p / mass # (n_obj, dof)

def potential_part(q, eps, sig, cutoff=np.inf, box_size=None):
    """Computes potential part given q of shape (n_obj, dof)
    Outputs potential part of shape (n_obj, 1) representing 'local' potential parts
    """
    if box_size is not None: raise NotImplementedError("Periodic boundary condition is not implemented yet, but box_size was given")
    n_obj = len(q)
    pe = np.zeros((n_obj, 1), dtype=q.dtype)
    for i in range(n_obj):
        for j in range(i + 1, n_obj):
            # compute LJ(|q_j - q_i|)
            dist = r(q[i], q[j], box_size=box_size)
            #dist = np.clip(dist, 0.9, None) # you can also clip on the minima to avoid explosions
            if dist < cutoff:
                total_local_pe = LJ(dist, eps, sig)
                pe[j] += total_local_pe / 2.0
                pe[i] += total_local_pe / 2.0 # spread the potential evenly

    return pe # (n_obj, 1)

def grad_potential_part(q, eps, sig, cutoff=np.inf, box_size=None):
    """Computes the gradient of the potential part given q of shape (n_obj, dof)
    Outputs gradient of shape (n_obj, dof) with grad_LJ(.) * grad_r(.) as entries
    """
    if box_size is not None: raise NotImplementedError("Periodic boundary condition is not implemented yet, but box_size was given")
    n_obj = len(q)
    dq = np.zeros_like(q, dtype=q.dtype) # (n_obj, dof)
    for i in range(n_obj):
        for j in range(i + 1, n_obj):
            # compute |q_j - q_i|
            dist, delta = r(q[i], q[j], box_size=box_size, return_delta=True)
            # compute dLJ(.)/dq_j and dLJ(.)/dq_i
            #dist = np.clip(dist, 0.9, None) # you can also clip on the minima to avoid explosions
            if dist < cutoff:
                grad_lj_r = grad_LJ(dist, eps, sig) # scalar
                grad_r_qj = delta / dist # gradient of |q_j - q_i| w.r.t. q_j, of shape (dof)
                grad_r_qi = -grad_r_qj

                dq[j] += grad_r_qj * grad_lj_r  # for q_j
                dq[i] += grad_r_qi * grad_lj_r  # for q_i

    return dq # (n_obj, dof)

def LJ(r, eps, sig):
    """Computes Lennard-Jones 12-6 potential given scalar r (distance)
    Outputs potential energy (scalar) as (4 * eps * ( (sig/r)^12 - (sig/r)^6 ))
    """
    inv_r = sig / r
    inv_r6 = inv_r ** 6
    inv_r12 = inv_r6 ** 2
    return 4 * eps * (inv_r12 - inv_r6) # scalar

def grad_LJ(r, eps, sig):
    """Computes gradient of the Lennard-Jones 12-6 potential given distance r
    Outputs LJ gradient (scalar because r is scalar) as (4 * eps * (-12 * (sig^12 / r^13) + 6 * (sig^6 / r^7)))
    """
    inv_r13 = (sig**12) / (r**13)
    inv_r7 = (sig**6) / (r**7)
    return 4 * eps * (-12.0 * inv_r13 + 6.0 * inv_r7) # scalar

def get_poisson_matrix(n_obj, dof, dtype=np.float64) -> np.ndarray:
    """Returns the Poisson matrix for the representation used in molecular dynamics code given
    n_obj of type int, dof of type int and dtype either np.float64 or np.float32
    """
    n_features = n_obj * dof * 2 # for q and p and all the objects in the system
    Ls = np.zeros((n_features, n_features), dtype)
    for i in range(n_obj):
        for j in range(dof):
            Ls[dof*i + j, dof*n_obj + dof*i + j] =  1.
            Ls[dof*n_obj + dof*i + j, dof*i + j] = -1.
    return Ls

def get_equil_r(sig):
    """Returns the equilibrium distance where particles do not move given
    sig (scalar) sigma parameter in LJ
    """
    return 2**(1/6) * sig

###################### NUMERIC UTILS #######################

def get_eq_chain_x0(n_obj, dof, sig):
    """Builds chain with equilibrium distances
    given n_obj (int) number of particles
    sig (float) sigma parameter of Lennard Jones potential
    Returns positions q of shape (n_obj, dof)
    """
    q = []
    r_eq = get_equil_r(sig)
    for idx in range(1, n_obj + 1):
        q_at_idx = np.zeros(dof)
        q_at_idx[0] = r_eq * idx # place particles along the first axis
        q.append(q_at_idx)
    return np.vstack(q)

def get_eq_lattice_x0(nx, ny, sig):
    """Returns lattice in 2D with equilibrium distances where particles along x and y axes are
    connected only, given
    nx (int)
    ny (int)
    sig (float) sigma parameter of Lennard Jones
    Returns positions q of shape (n_obj, 2)
    """
    r_eq = get_equil_r(sig)
    x = np.arange(nx) * r_eq
    y = np.arange(ny) * r_eq
    xx, yy = np.meshgrid(x, y)
    q = np.stack([xx.ravel(), yy.ravel()], axis=-1)
    return q

def get_eq_triangle_x0(sig):
    """Returns triangle in 2D with equilibrium distances given
    sig (float) sigma parameter of Lennard Jones
    Returns positions q of shape (3, 2)
    """
    r_eq = get_equil_r(sig)
    q0 = np.array([0.0, 0.0])
    q1 = np.array([r_eq, 0.0])
    q2 = np.array([np.cos(2*np.pi / 3) * r_eq], np.sin(2*np.pi / 3) * r_eq)
    return np.stack([q0, q1, q2], axis=-1)

def get_edge_index(n_obj) -> tc.Tensor:
    """Returns the edge index for a fully-connected graph given
    n_obj of type int
    k in the k-nearest neighbors of type int
    """
    edge_index = []
    for i in range(n_obj):
        for j in range(i + 1, n_obj):
            edge_index.append([i, j])
    edge_index = [ [j, i] for [i, j] in edge_index ] # convention: symmetric connection towards lower indices
    edge_index = tc.tensor(edge_index).T
    return edge_index

def edge_index_shortest_connections(q) -> tc.Tensor:
    """Returns edge_index
    given positions q of shape (n_obj, dof)
    """
    # chain minimum to maximum
    sorted_node_idx = np.argsort(q, axis=0)
    edge_index = []
    for idx in range(len(q) - 1):
        i = sorted_node_idx[idx].item()
        j = sorted_node_idx[idx + 1].item()
        if i > j:
            edge_index.append([i, j])
        else:
            edge_index.append([j, i])
    edge_index = tc.tensor(edge_index).T
    return edge_index

def edge_index_k_nearest(q, degree_limit, cutoff=np.inf, verbose=False) -> tc.Tensor:
    """Returns the edge index for the given positions q of shape (n_obj, dof)
    with
    - degree_limit number of maximum degrees of a node
    - cutoff radius
    And ensures that there is no "better" connection with shorter distance,
    so this function returns an edge index where shorter distances are preferred
    if there is an option.
    """
    n_obj = len(q)
    # Step 1: Compute pairwise distances
    pairwise_distances = {} # key: (i,j) where i != j, value: distance
    for i in range(n_obj):
        for j in range(i + 1, n_obj):
            pairwise_distances[(j, i)] = r(q[i], q[j], box_size=None)
    # Step 2: Pick shortest distances and up to the node degree limit for a node
    edge_index = []
    node_degrees = { i: 0 for i in range(n_obj) } # key is node index, value is the current node degree
    for (i, j), _ in sorted(pairwise_distances.items(), key=lambda item: item[1]):
        if node_degrees[i] < degree_limit and node_degrees[j] < degree_limit:
            edge_index.append([i, j])
            node_degrees[i] += 1
            node_degrees[j] += 1

    edge_index = tc.tensor(edge_index).T
    if verbose:
        print(edge_index)
    return edge_index


    n_obj = len(q)
    # Step 1: Compute pairwise distances
    # distance_matrix = spdist.pdist(q) # dist(i,j) is stored at m * i + j - ((i + 2) * (i + 1)) // 2
    dist_matrix = spdist.cdist(q, q)
    np.fill_diagonal(dist_matrix, np.inf)

    # Step 2: Iterate and keep track of the edge_index and node degrees
    edge_index = []
    degrees = np.zeros(n_obj, dtype=int)
    def find_efficient_connection():
        min_dist = cutoff + 1
        src, dst = -1, -1
        for i in range(n_obj):
            for j in range(i + 1, n_obj):
                r = dist_matrix[i, j]
                if r < cutoff and degrees[i] < degree_limit and degrees[j] < degree_limit:
                    if r < min_dist:
                        min_dist = dist_matrix[i, j]
                        src, dst = i, j
        return src, dst

    while len(edge_index) // 2 < n_obj - 1  and np.any(degrees < degree_limit):
        src, dst = find_efficient_connection()
        if src == -1 or dst == -1:
            break

        edge_index.append([src, dst])
        degrees[src] += 1
        degrees[dst] += 1

    edge_index = [ [j, i] for [i, j] in edge_index ] # convention: symmetric connection towards lower indices
    edge_index = tc.tensor(edge_index).T
    print(edge_index)
    return edge_index

def edge_index_radius(q, cutoff, degree_limit=None):
    """Outputs edge_index given positions (q) of shape (n_obj, dof) and
    cutoff (float). Note that this function does not limit the number of degrees a node can have.
    Returns edge_index of shape (2, n_edges)
    """
    n_obj = len(q)
    if degree_limit is None:
        edge_index = []
        for i in range(n_obj):
            for j in range(i + 1, n_obj):
                dist = r(q[i], q[j])
                if dist < cutoff:
                    edge_index.append([j, i])
    else:
        degrees = { i: 0 for i in range(n_obj) }
        pairs = [] # store indices with their distances
        for i in range(n_obj):
            for j in range(i + 1, n_obj):
                dist = r(q[i], q[j])
                if dist < cutoff:
                    pairs.append([j, i, dist])

        # sort pairs by distance
        pairs.sort(key=lambda x: x[-1])

        edge_index = []
        for j, i, _ in pairs:
            if degrees[j] < degree_limit and degrees[i] < degree_limit:
                edge_index.append([j, i])
                degrees[j] += 1
                degrees[i] += 1

    # return depending on whether list is empty
    if len(edge_index) == 0:
        return []

    # convert to torch tensor
    edge_index = tc.tensor(edge_index).T
    return edge_index

def apply_reflective_bc(x, box_start, box_end):
    """Apply reflective boundary condition given
    x (positions, momenta) of shape (n_obj, 2*dof),
    box_start (scalar for all dof),
    box_end (scalar for all dof).

    Returns updated state x
    """
    _, n_features = x.shape
    dof = n_features // 2
    q, p = np.split(x, 2, axis=-1)

    # Calculate distanced from boundaries
    dist_from_box_start = q - box_start
    dist_from_box_end = box_end - q

    # Determine which boundary each particle is crossing
    reflect_mask_start = dist_from_box_start < 0.0
    reflect_mask_end = dist_from_box_end < 0.0

    if np.any(reflect_mask_start):
        # reflect positions FIXME: may change energy
        x[..., :dof][reflect_mask_start] = box_start - dist_from_box_start[reflect_mask_start]
        # flip momenta
        x[..., dof:][reflect_mask_start] = -p[reflect_mask_start]
    if np.any(reflect_mask_end):
        # reflect positions
        x[..., :dof][reflect_mask_end] = box_end + dist_from_box_end[reflect_mask_end]
        # flip momenta
        x[..., dof:][reflect_mask_end] = -p[reflect_mask_end]

    return x

def simulate(x0, mass, eps, sig, cutoff, n_steps, delta_t, model=None, bc=None, box_start=0.0, box_end=1.0):
    """Given initial state of shape (n_obj, 2*dof) consisting of [q,p] this function simulates the
    system using the Hamiltonian formulation and stormer-verlet integrator, given
    - x0 of shape (n_obj, 2*dof)
    - mass, eps, sig, cutoff as molecular dynamics parameters
    - n_steps is the number of time steps to simulate
    - delta_t is the time step size
    - integration_grad_H is the function that can be queries for the gradient of the Hamiltonian
    - bc is for boundary condition and can be applied with bc="periodic"
    - if bc is given, then boundaries box_start and box_end is used
    """
    if bc == "periodic":
        raise NotImplementedError("Periodic boundary condition is not implemented yet")
        box_size = box_end - box_start
        print(f"-> periodic boundaries with box_size {box_size}")
        assert cutoff < (box_size / 2), "cutoff is too large, particles see themselves"
    else:
        box_size = None # just to be safe, this is not required for reflective BC

    n_obj, n_features = x0.shape
    dof = n_features // 2

    if model is None: # use ground truth to simulate the system
        model_H = lambda x: H(x, mass, eps, sig, cutoff, box_size=box_size)
        integration_grad_H = lambda x: grad_H(x, mass, eps, sig, cutoff, box_size)
    else:
        model_H = lambda x: model.forward(flatten_TC(tc.from_numpy(x))).detach().numpy()
        integration_grad_H = lambda x: unflatten_TC(grad(model, flatten_TC(tc.from_numpy(x))), model.n_obj, dof).detach().numpy()
        n_obj_train = model.n_obj # Training number of objects for this model, saving only for logging purposes
        model.n_obj = [n_obj]     # This update is required if doing zero-shot (training number of nodes is different than testing number of nodes)

    ke, pe = H(x0[np.newaxis, ...], mass, eps, sig, cutoff, as_separate=True, box_size=box_size)

    node_degrees = []
    if model is not None:
        degree_limit = int(np.prod(n_obj_train)) - 1 # The model would not work with unseen node degrees
        model.edge_index = edge_index_shortest_connections(x0[..., :dof]) if dof==1 else \
                           edge_index_radius(x0[..., :dof], cutoff, degree_limit)
        if cutoff < np.inf and model.edge_index != []:
            # print("-> current edge_index =", model.edge_index)
            src, dst = model.edge_index
            all_nodes = tc.cat([src, dst])
            degree_count = tc.bincount(all_nodes)
            unique_degrees = tc.unique(degree_count)
            for unique_degree in unique_degrees:
                item = unique_degree.detach().item()
                if not item in node_degrees:
                    # print("-> New node degree:", item)
                    node_degrees.append(item)
                    # print("corresponding edge_index")
                    # print(model.edge_index)
            if model.edge_index == [] and not 0 in node_degrees:
                pass
                # print("-> New node degree: 0")
                # print("corresponding edge_index")
                # print([])

    h0_pred = model_H(x0[np.newaxis, ...]).squeeze(axis=-1)
    h0_true = H(x0[np.newaxis, ...], mass, eps, sig, cutoff, box_size=box_size).squeeze(axis=-1)
    integration_const = h0_true - h0_pred
    # dHdx = grad_H(x0[np.newaxis, ...], mass, eps, sig, cutoff, box_size=box_size)
    dHdx = integration_grad_H(x0[np.newaxis, ...])

    x_traj, ke_traj, pe_traj, h_pred_traj, grad_h_traj = [x0], [ke.squeeze(0)], [pe.squeeze(0)], [h0_pred.squeeze(0).squeeze(-1) + integration_const], [dHdx.squeeze(0)]
    for _ in tqdm(range(n_steps)):
        x_prev = x_traj[-1]
        if model is not None:
            degree_limit = int(np.prod(n_obj_train)) - 1 # The model would not work with unseen node degrees
            model.edge_index = edge_index_shortest_connections(x_prev[..., :dof]) if dof==1 else \
                               edge_index_radius(x_prev[..., :dof], cutoff, degree_limit)
            if cutoff < np.inf and model.edge_index != []:
                # print("-> current edge_index =", model.edge_index)
                src, dst = model.edge_index
                all_nodes = tc.cat([src, dst])
                degree_count = tc.bincount(all_nodes)
                unique_degrees = tc.unique(degree_count)
                for unique_degree in unique_degrees:
                    item = unique_degree.detach().item()
                    if not item in node_degrees:
                        print("-> New node degree:", item)
                        node_degrees.append(item)
                        print("corresponding edge_index")
                        print(model.edge_index)
            if model.edge_index == [] and not 0 in node_degrees:
                print("-> New node degree: 0")
                print("corresponding edge_index")
                print([])

        x_next = stormer_verlet_step(x_prev, delta_t, integration_grad_H,
                                     gravity=None, force_node_idx=None, force_vec=None,
                                     force_snap_idx=None, idx_snap=None, fixed_indices=None,
                                     n_obj=n_obj, dof=dof)

        if bc == "periodic":
            # TODO: handle periodic logic also correctly using the model
            q = x_next[..., :dof]
            q = ((q - box_start) % (box_end - box_start)) + box_start
            x_next[..., :dof] = q
        elif bc == "reflective":
            x_next = apply_reflective_bc(x_next, box_start, box_end)

        ke_next, pe_next = H(x_next[np.newaxis, ...], mass, eps, sig, cutoff, as_separate=True, box_size=box_size)
        h_next_pred = model_H(x_next[np.newaxis, ...])
        # dHdx_next = grad_H(x_next[np.newaxis, ...], mass, eps, sig, cutoff, box_size=box_size)
        dHdx_next = integration_grad_H(x_next[np.newaxis, ...])

        x_traj.append(x_next)
        ke_traj.append(ke_next.squeeze(0))
        pe_traj.append(pe_next.squeeze(0))
        h_pred_traj.append(h_next_pred.squeeze(axis=-1) + integration_const)
        grad_h_traj.append(dHdx_next.squeeze(0))

    if model is not None:
        model.n_obj = n_obj_train # set back the training number of nodes for logging purposes
    return np.asarray(x_traj), np.asarray(ke_traj), np.asarray(pe_traj), np.asarray(h_pred_traj), np.asarray(grad_h_traj)

def generate_initial_condition(n_obj, dof, mass, eps, sig, q_noise, p_noise, rng, dtype=np.float32):
    """Generates initial condition by first finding the equilibrium points for the problem and then
    displacing with the specified noise scales for q and p.
    """
    x0 = get_equil_x(n_obj, dof, mass, eps, sig, cutoff=np.inf, dtype=dtype) # cutoff is set to inf here anyways to find the equil at a non-infinite solution
    q0, p0 = np.split(x0, 2, axis=-1)
    r_eq = np.linalg.norm(q0[0] - q0[1]).item() # equil distance
    # r_eq = np.abs(q0[0] - q0[1]).item() # equil distance
    q0 += rng.uniform(-q_noise, q_noise, size=q0.shape).astype(dtype)
    p0 += rng.uniform(-p_noise, p_noise, size=p0.shape).astype(dtype)
    return np.column_stack([q0, p0]), r_eq

def get_x0_and_box(nx, ny, sig, r_eq, q_noise, p_noise, dtype, rng):
    """Creates initial state with a proper tight box for reflective boundary condition
    Returns x of shape (n_obj, 2*dof), box_start, box_end
    """
    q = get_eq_lattice_x0(nx, ny, sig).astype(dtype)
    p = np.zeros_like(q, dtype=dtype)
    q += rng.uniform(-q_noise, q_noise, size=q.shape).astype(dtype)
    p += rng.uniform(-p_noise, p_noise, size=p.shape).astype(dtype)
    return np.column_stack([q, p]), np.min(q) - (r_eq / 2.0), np.max(q) + (r_eq / 2.0)

###################### MATH UTILS #######################

def get_equil_x(n_obj, dof, mass, eps, sig, cutoff, initial_positions=None, dtype=np.float32):
    """Returns equilibrated initial condition given
    n_obj of type int which is the total number of particles in the system and
    dof of type int representing the number of spatial dimensions
    """
    def kinetic_part_to_optim(flat_p):
        p = flat_p.reshape(n_obj, dof) # reshape to (n_obj=1, dof=1) for this example
        return kinetic_part(p, mass).sum()

    def grad_kinetic_part_to_optim(flat_p):
        p = flat_p.reshape(n_obj, dof) # reshape to (n_obj=1, dof=1) for this example
        return grad_kinetic_part(p, mass).reshape(-1) # flatten back

    def potential_part_to_optim(flat_q):
        q = flat_q.reshape(n_obj, dof)  # reshape to (n_obj=2, dof=1) for this example
        return potential_part(q, eps, sig, cutoff).sum()

    def grad_potential_part_to_optim(flat_q):
        q = flat_q.reshape(n_obj, dof)
        return grad_potential_part(q, eps, sig, cutoff).reshape(-1)  # flatten back to shape (2,)

    # Kinetic part minimum (0), limits are computed such that the equilibrium will not be placed at infinity (there is only 1 finite equilibrium point)
    displacement = n_obj + 1 # n_obj + 1.0
    flat_p0 = np.linspace(-displacement, displacement, int(n_obj*dof)).astype(dtype).reshape(-1)
    p_eq = minimize(kinetic_part_to_optim, flat_p0, jac=grad_kinetic_part_to_optim, method="BFGS").x.reshape(n_obj, dof).astype(dtype)

    # Potential part minimum
    if initial_positions is not None:
        q_eq = initial_positions
    else:
        flat_q0 = np.linspace(-(n_obj + 0.5), n_obj + 0.5, int(n_obj*dof)).astype(dtype).reshape(-1)
        q_eq = minimize(potential_part_to_optim, flat_q0, jac=grad_potential_part_to_optim, method="BFGS").x.reshape(n_obj, dof).astype(dtype)

    return np.column_stack([q_eq, p_eq])

def r(q_i, q_j, box_size=None, return_delta=False):
    """Computes distance (L2 norm) given q_i and q_j of shape (dof) each
    Outputs |q_j - q_i| (scalar)
    """
    delta = q_j - q_i
    assert box_size is None, "box_size implementation for periodic BC is not fully implemented yet."
    if box_size is not None:
        # when delta < box_size / 2 then particles interact within the original box
        # else then particles interact with their clone in the periodic box in the opposite direction
        delta -= box_size * np.round(delta / box_size)
    dist = np.linalg.norm(delta)
    if np.isclose(dist, 0.0):
        raise ValueError("Particles are overlapping or self-interacting (|q_j - q_i|)")
    if return_delta:
        return dist, delta

    return dist

###################### PLOTTING ######################

def plot_simulation(x_traj, ke_traj, pe_traj, box_start=None, box_end=None):
    """Plots simulation with energies and position trajectories
    x_traj of shape (n_steps, n_obj, 2*dof) where dof=1
    ke_traj of shape (n_steps,)
    pe_traj of shape (n_steps,)
    If box_start and box_end is given then the boundaries are displayed as well.
    """
    dof = x_traj.shape[-1] // 2
    num_axes = 2 + dof # 2 for kinetic and potential energies and others for position dims
    fig, axes = plt.subplots(1, num_axes, figsize=(4*num_axes, 3), dpi=100)

    # energies
    axes[0].plot(ke_traj, linewidth=3.0) # kinetic part
    axes[0].set_title("Total kinetic part: T(p)")

    axes[1].plot(pe_traj, linewidth=3.0) # potential part
    axes[1].set_title("Total potential part: V(q)")

    for idx_obj in range(x_traj.shape[1]):
        # axes[2].plot(x_traj[:, idx_obj, 0], label=rf"$q_{idx_obj}$")
        for (idx_dim, ax) in enumerate(axes[2:]):
            ax.scatter(range(len(x_traj)), x_traj[:, idx_obj, idx_dim], label=rf"$q_{idx_obj}$", linewidths=1.0, s=1.0)
            ax.set_title(f"Positions q{idx_dim}")

    if box_start is not None and box_end is not None:
        [ ax.axhline(box_start, color="k", linestyle="dashed") for ax in axes[2:] ]
        [ ax.axhline(box_end, color="k", linestyle="dashed") for ax in axes[2:] ]

    [ ax.legend() for ax in axes[2:] ]
    [ ax.set_xlabel("Time step") for ax in axes ]
    fig.tight_layout()
    return fig

def plot_simulation_gradients(dHdx_traj):
    """Plots gradient of Hamiltonian with respect to q and p for all objects given
    dHdx_traj of shape (n_steps, n_obj, 2*dof)
    """
    num_axes = dHdx_traj.shape[-1]
    fig, axes = plt.subplots(1, num_axes, figsize=(4*num_axes, 3), dpi=100)
    for idx_obj in range(dHdx_traj.shape[1]):
        for (idx_dim, ax) in enumerate(axes):
            ax.plot(dHdx_traj[:, idx_obj, idx_dim], label=rf"$dH/dx_{idx_obj}$", linewidth=2.5) # dH/dp
    [ ax.legend() for ax in axes ]
    [ ax.set_xlabel("Time step") for ax in axes ]
    fig.tight_layout()
    return fig

def plot_simulation_pred(axes, x_traj_pred, ke_traj_pred, pe_traj_pred):
    """Incorporates predictions of the model for the given simulation with
    axes outputted from `plot_simulation`
    x_traj_pred of shape (traj_len, n_obj, dof)
    ke_traj_pred of shape (traj_len, 1)
    pe_traj_pred of shape (traj_len, 1)
    """
    # energies
    axes[0].plot(ke_traj_pred, c="red", linestyle="dashed", linewidth=1)
    axes[1].plot(pe_traj_pred, c="red", linestyle="dashed", linewidth=1)
    for idx_obj in range(x_traj_pred.shape[1]):
        for (idx_dim, ax) in enumerate(axes[2:]):
            # axes[2].plot(x_traj_pred[:, idx_obj, 0], linestyle="dashed", c="red")
            ax.scatter(range(len(x_traj_pred)), x_traj_pred[:, idx_obj, idx_dim],
                       c="red", linewidths=0.5, s=1.0, alpha=0.5, marker=",")

def plot_simulation_gradients_pred(axes, dHdx_traj_pred):
    """Incorporates predictions of the model for the given simulation with
    axes outputted from `plot_simulation_gradients_1DOF`
    dHdx_traj_pred of shape (traj_len, n_obj, dof*2)
    """
    for idx_obj in range(dHdx_traj_pred.shape[1]):
        for (idx_dim, ax) in enumerate(axes):
            ax.plot(dHdx_traj_pred[:, idx_obj, idx_dim], c="red", linestyle="dashed", linewidth=1)

def plot_traj_MSE(mse_traj, mse_dHdp_traj, mse_dHdq_traj):
    """Plots MSE of the predicted trajectory (q) and gradients given MSE arrays of the trajectory"""
    fig, axes = plt.subplots(1, 3, figsize=(12, 3), dpi=100)
    axes[0].plot(mse_traj, label="Traj. MSE")
    axes[0].set_ylabel("MSE")
    axes[0].legend()
    axes[1].plot(mse_dHdp_traj, label="dHdp MSE")
    axes[2].plot(mse_dHdq_traj, label="dHdq MSE")
    [ ax.set_yscale("symlog") for ax in axes ]
    [ ax.grid() for ax in axes ]
    [ ax.legend() for ax in axes ]
    [ ax.set_xlabel("Time step") for ax in axes ]
    fig.tight_layout()
    return fig

def get_canvas():
    fig = plt.figure(figsize=(15, 20))
    gs = GridSpec(nrows=3, ncols=1, height_ratios=[3, 1, 1], hspace=0.2)
    ax_anim = fig.add_subplot(gs[0])
    ax_anim.set_aspect("equal")
    gs_props = GridSpecFromSubplotSpec(nrows=1, ncols=4, subplot_spec=gs[1], wspace=0.3)

    ax_ke = fig.add_subplot(gs_props[0])
    ax_pe = fig.add_subplot(gs_props[1])
    ax_rel2 = fig.add_subplot(gs_props[2])
    ax_h = fig.add_subplot(gs_props[3])

    return fig, ax_anim, (ax_ke, ax_pe, ax_rel2, ax_h)

def animate_2D(q_true, ke_true, pe_true, h_true, box_start=None, box_end=None,
               q_pred=None, ke_pred=None, pe_pred=None, h_pred=None,
               rel2=None, framing_length=250, filename="lj_simulation.mp4"):
    if q_pred is None:
        true_args = { "edgecolors": "tab:blue", "c": "tab:blue" }
    else:
        true_args = { "facecolors": "none", "edgecolors": "red" }
    fig, ax_anim, (ax_ke, ax_pe, ax_rel2, ax_h) = get_canvas()
    # fig, ((ax_anim,), (ax_ke, ax_pe, ax_rel2)) = plt.subplots(2, 4, figsize=(20, 8))
    margin = 0.1 * np.max(np.abs(q_true))
    anim_xlim_min = np.min(q_true[..., 0] - margin).item() if box_start is None else box_start
    anim_xlim_max = np.max(q_true[..., 0] + margin).item() if box_end is None else box_end
    anim_ylim_min = np.min(q_true[..., 1] - margin).item() if box_start is None else box_start
    anim_ylim_max = np.max(q_true[..., 1] + margin).item() if box_end is None else box_end
    ax_anim.set_xlim(anim_xlim_min, anim_xlim_max)
    ax_anim.set_ylim(anim_ylim_min, anim_ylim_max)
    ax_anim.set_title(r"Positions $q$")
    lines_true = []
    for x, x_pred, ax, title in zip([ke_true, pe_true, h_true], [ke_pred, pe_pred, h_pred], [ax_ke, ax_pe, ax_h], ["Kinetic", "Potential", "Hamiltonian"]):
        ax.set_xlim(0, len(x))
        if x_pred is None:
            ax.set_ylim(np.min(x), np.max(x))
        else:
            min_lim = np.min(x_pred)
            min_lim = np.min(x) if np.isinf(min_lim) else min(min_lim, np.min(x))
            max_lim = np.max(x_pred)
            max_lim = np.max(x) if np.isinf(max_lim) else max(max_lim, np.max(x))
            ax.set_ylim(min_lim, max_lim)
            # ax.set_ylim(min(np.min(x), np.min(x_pred)), max(np.max(x), np.max(x_pred)))
        ax.set_title(title)
        ax.set_xlabel("Time step")
        lines, = ax.plot([], [], linestyle="dashed", color="red", label=r"using true $H$", zorder=5, linewidth=2.0)
        lines_true.append(lines)
        ax.legend()

    [ke_line_true, pe_line_true, h_line_true] = lines_true
    sc_true = ax_anim.scatter([], [], s=80, **true_args, label=r"using true $H$", zorder=5)

    if q_pred is not None:
        sc_pred = ax_anim.scatter([], [], s=80, edgecolors="tab:blue", label=r"prediction", zorder=1)
        ax_rel2.set_xlim(0, len(rel2))
        ax_rel2.set_ylim(np.min(rel2), np.max(rel2))
        ax_rel2.set_title(r"Relative $L^2$ positions")
        ax_rel2.set_xlabel("Time step")
        # ax_rel2.set_yscale("symlog")
        line_rel2, = ax_rel2.plot([], [], linewidth=2)
        lines_pred = []
        for ax in [ax_ke, ax_pe, ax_h]:
            lines, = ax.plot([], [], label="prediction", zorder=1, linewidth=2.0)
            lines_pred.append(lines)
            ax.legend()
        [ke_line_pred, pe_line_pred, h_line_pred] = lines_pred

    fig.tight_layout()
    frame_indices= np.arange(len(q_true))

    def init():
        sc_true.set_offsets(np.empty((0, 2)))

        if q_pred is not None:
            sc_pred.set_offsets(np.empty((0, 2)))
            return ke_line_true, ke_line_pred, \
                   pe_line_true, pe_line_pred, \
                   h_line_true, h_line_pred, \
                   line_rel2, sc_true, sc_pred,

        return ke_line_true, pe_line_true, h_line_true, sc_true,

    def update(frame):
        idx = frame_indices[frame]
        sc_true.set_offsets(q_true[frame])
        for line, x in zip([ke_line_true, pe_line_true, h_line_true], [ke_true, pe_true, h_true]):
            line.set_data(np.arange(1, (idx * framing_length) + 2), x[:(idx * framing_length) + 1].flatten())

        if q_pred is not None:
            sc_pred.set_offsets(q_pred[frame])
            for line, x in zip([ke_line_pred, pe_line_pred, h_line_pred, line_rel2], [ke_pred, pe_pred, h_pred, rel2]):
                line.set_data(np.arange(1, (idx * framing_length) + 2), x[:(idx * framing_length) + 1].flatten())
            return ke_line_true, ke_line_pred, \
                   pe_line_true, pe_line_pred, \
                   h_line_true, h_line_pred, \
                   line_rel2, sc_true, sc_pred,

        return ke_line_true, pe_line_true, h_line_true, sc_true,

    pbar = tqdm(total=len(q_true), desc="Rendering") # total=num_frames
    def progress_callback(i, _):
        pbar.update(i - pbar.n)

    # interval : Delays between frames in ms
    # reasonable for
    ani = anim.FuncAnimation(
        fig, update, frames=len(q_true),
        init_func=init, blit=True, interval=20,
    )
    print("Animating..")
    ani.save(filename, writer="ffmpeg", fps=30, progress_callback=progress_callback)
    pbar.close()
    print(f"-> Saved animation as '{filename}'")

#################### LEARNING ####################

def generate_analytical_dataset(n_graphs, n_obj, dof,
                                mass, eps, sig, cutoff,
                                r_clip, q_noise_scale, p_noise_scale, seed, box_size=None, dtype=np.float32):
    """Generates analytical dataset for Hamiltonian learning for dynamical systems focusing on
    molecular dynamics (MD). In MD, r_clip is essential and provides a rejection sampling approach
    in order to not have very large potentials in the dataset (which is unrealistic for most
                                                               of the scenarios we are considering).
    n_graphs (int) specifies number of data samples to generate (number of systems)
    n_obj (int) specifies the number of particles in the systems
    dof (int) specifies the spatial dimension in the systems
    mass (float) specifies the mass of the objects
    eps (float) specifies the epsilon parameter in LJ
    sig (float) specifies the sigma parameter in LJ
    cutoff (float) specifies the cutoff parameter in LJ
    r_clip (float) specifies the minimum distance to have in the dataset to clip large potentials
    noise_scales (float) stddev for the noise around the equilibrium to use when sampling data
    seed (int) random state
    box_size (float) box size for PBC if given, affects the potential
    """
    x_eq = get_equil_x(n_obj, dof, mass, eps, sig, cutoff=np.inf, dtype=dtype)
    q_eq, p_eq = np.split(x_eq, 2, axis=-1) # each of shape (n_obj, dof)
    x = np.empty((n_graphs, n_obj, 2*dof), dtype=dtype)
    rng = np.random.default_rng(seed)
    for idx_x in range(n_graphs):
        q_sampled = False
        while not q_sampled:
            q_noise = rng.uniform(-q_noise_scale, q_noise_scale, size=q_eq.shape).astype(dtype)
            q = q_eq + q_noise
            q_sampled = True
            for i in range(n_obj):
                for j in range(i + 1, n_obj):
                    dist = r(q[i], q[j], box_size)
                    if dist < r_clip:       # TODO: Also soft reject to enforce smoother transition in a small region to make learning easier
                        q_sampled = False   # reject this q and resample

        p_noise = rng.uniform(-p_noise_scale, p_noise_scale, size=p_eq.shape).astype(dtype)
        p = p_eq + p_noise

        x[idx_x] = np.column_stack([ q, p ])
        if idx_x % 1000 == 0:
            print(f"..generating dataset, currently at n_graphs={idx_x}")

    # energies of the dataset to inspect
    ke, pe = H(x, mass, eps, sig, cutoff, as_separate=True, box_size=box_size)

    # gradients of the dataset
    dHdx = grad_H(x, mass, eps, sig, cutoff, box_size)
    dxdt = hamiltons_eq(dHdx)

    return x, ke, pe, dHdx, dxdt

#################### OTHER ##################

def eval_and_print_err(model, label, x, L, dxdt, box_size=None, cutoff=tc.inf):
    mse, rel2 = eval_dxdt(model, x, L, dxdt, verbose=False)

    print(f"\n{label}")
    print(f"-> mse : {mse:.2e}")
    print(f"-> rel2: {rel2:.2e}")

def eval_and_print_err_as_separate(model, label, x, L, dxdt):
    mse_dqdt, rel2_dqdt, mse_dpdt, rel2_dpdt = eval_dxdt(model, x, L, dxdt, as_separate=True, verbose=False)

    print(f"\n{label}")
    print(f"-> mse  dq/dt : {mse_dqdt:.2e}    mse  dp/dt : {mse_dpdt:.2e}")
    print(f"-> rel2 dq/dt : {rel2_dqdt:.2e}    rel2 dp/dt : {rel2_dpdt:.2e}")
