import numpy as np
from .hamiltons_eq import hamiltons_eq

def stormer_verlet_step(x_prev, delta_t, grad_H_, gravity,
                        force_node_idx, force_vec, force_snap_idx, idx_snap,
                        fixed_indices, n_obj, dof):
    q_prev, p_prev = np.split(x_prev, 2, axis=-1)
    if gravity is not None:
        gravity_nodewise = np.tile(gravity, *n_obj).reshape(*n_obj, dof)

    if fixed_indices:
        assert np.all(q_prev[fixed_indices] == 0.0), "Wall-nodes moved from initial position."

    # start computing half p (compute dpdt_prev)
    # print(f"Value={x_prev[np.newaxis, ...]}")
    dedx_prev = grad_H_(x_prev[np.newaxis, ...]).squeeze(0)
    # print(f"dedx_prev={dedx_prev}")
    dxdt_prev = hamiltons_eq(dedx_prev)
    _, dpdt_prev = np.split(dxdt_prev, 2, axis=-1)

    # print(f"dpdt_prev={dpdt_prev}")

    # apply gravity and external force
    if gravity is not None:
        dpdt_prev += gravity_nodewise
    if force_node_idx is not None and force_snap_idx == idx_snap:
        dpdt_prev[force_node_idx] += force_vec

    # print(f"dpdt_prev after applying gravity and force: {dpdt_prev}")

    # half p computed
    p_half = p_prev + 0.5 * delta_t * dpdt_prev
    # print(f"p_half= {p_half}")
    # if fixed_indices is not None:
        # p_half[fixed_indices] = 0.0
    # print(f"applied fixed indiceds!p_half= {p_half}")

    # update q
    dedx_half = grad_H_(np.concatenate([q_prev, p_half], axis=-1)[np.newaxis, ...]).squeeze(0)
    # print(f"dedx_half = {dedx_half}")
    dxdt_half = hamiltons_eq(dedx_half)
    dqdt_half, _ = np.split(dxdt_half, 2, axis=-1)
    q_next = q_prev + delta_t * dqdt_half

    # Apply periodic boundary conditions to q_next
    # if box_size is not None:
        # q_next = q_next % box_size

    # second half-kick
    dedx_next = grad_H_(np.concatenate([q_next, p_half], axis=-1)[np.newaxis, ...]).squeeze(0)
    dxdt_next = hamiltons_eq(dedx_next)
    _, dpdt_next = np.split(dxdt_next, 2, axis=-1)
    # print("=> NORM OF DPDT_NEXT", np.linalg.norm(dpdt_next))
    if gravity is not None:
        dpdt_next += gravity_nodewise
    if force_node_idx is not None and force_snap_idx == idx_snap:
        dpdt_next[force_node_idx] += force_vec
    p_next = p_half + 0.5 * delta_t * dpdt_next

    if fixed_indices is not None:
        q_next[fixed_indices] = 0.0
        p_next[fixed_indices] = 0.0

    return np.concatenate([q_next, p_next], axis=-1)

