import numpy as np
from .hamiltons_eq import hamiltons_eq

def symp_euler_step(x_prev, delta_t, grad_H_, gravity,
                    force_node_idx, force_vec, force_snap_idx, idx_snap,
                    fixed_indices, n_obj, dof):
    q_prev, p_prev = np.split(x_prev, 2, axis=-1)
    if gravity is not None:
        gravity_nodewise = np.tile(gravity, *n_obj).reshape(*n_obj, dof)

    if fixed_indices:
        assert np.all(x_prev[fixed_indices] == 0.0), "Wall-nodes moved from initial position."
    # compute time derivatives
    dedx_prev = grad_H_(x_prev[np.newaxis, ...]).squeeze(0)
    dxdt_prev = hamiltons_eq(dedx_prev)
    _, dpdt_prev = np.split(dxdt_prev, 2, axis=-1)

    # apply gravity and external force
    if gravity is not None:
        dpdt_prev += gravity_nodewise
    if force_node_idx is not None and force_snap_idx == idx_snap:
        dpdt_prev[force_node_idx] += force_vec

    # update p
    p_next = p_prev + delta_t * dpdt_prev

    # enforce wall-constraint
    if fixed_indices is not None:
        p_next[fixed_indices] = 0.0

    # compute time derivatives at the updated p
    dedx_next = grad_H_(np.concatenate([q_prev, p_next], axis=-1)[np.newaxis, ...]).squeeze(0)
    dxdt_next = hamiltons_eq(dedx_next)
    dqdt_next, _ = np.split(dxdt_next, 2, axis=-1)

    # update q
    q_next = q_prev + delta_t * dqdt_next
    # enforce wall-constraint
    if fixed_indices is not None:
        q_next[fixed_indices] = 0.0
        p_next[fixed_indices] = 0.0

    return np.concatenate([q_next, p_next], axis=-1)
