import numpy as np
from .hamiltons_eq import hamiltons_eq

def runge_kutta_step(x_prev, delta_t, grad_H_, gravity,
                     force_node_idx, force_vec, force_snap_idx, idx_snap,
                     fixed_indices, n_obj, dof):
    def dxdt(x):
        dedx = grad_H_(x[np.newaxis, ...]).squeeze(0)
        dx = hamiltons_eq(dedx)
        if gravity is not None:
            gravity_nodewise = np.tile(gravity, *n_obj).reshape(*n_obj, dof)
            _, dpdt = np.split(dx, 2, axis=-1)
            dpdt += gravity_nodewise
            if force_node_idx is not None and force_snap_idx == idx_snap:
                dpdt[force_node_idx] += force_vec
            dx = np.concatenate([dx[..., :dof], dpdt], axis=-1)
        return dx

    k1 = dxdt(x_prev)
    k2 = dxdt(x_prev + 0.5 * delta_t * k1)
    k3 = dxdt(x_prev + 0.5 * delta_t * k2)
    k4 = dxdt(x_prev + delta_t * k3)

    x_next = x_prev + (delta_t / 6.0) * (k1 + 2*k2 + 2*k3 + k4)

    if fixed_indices is not None:
        x_next[fixed_indices] = 0.0

    return x_next

