import sys
sys.dont_write_bytecode = True

import numpy as np 
import tensorflow as tf 
import config


def generate_data(num_agents):
    """ Generate the configuration of the reference trajectories.
    Args:
        num_agents (int): The number of cars.
    Returns:
        scale (N, 2): The scale of the sine trajectories.
        omega (N, 2): The angular velocities.
        phase (N, 2): The initial phases.
        trans (N, 2): The translation of the trajectories.
        init (N, 5): The initial states of the cars.
        obstacles (M, 2): The outline of the obstacles as points.
    """
    scale = np.tile(
        np.array([[6, 2], [6, 7]], dtype=np.float32),
        (num_agents // 2, 1))
    
    omega = np.tile(
        np.array([[0.1, 0.2], [0.2, 0.1], [-0.1, 0.2], [-0.2, 0.1]], 
        dtype=np.float32), (num_agents // 4, 1))
    trans = np.array(
        [[(i // 2) * 12.5, 0] for i in range(num_agents)], 
        dtype=np.float32)

    phase = np.tile(
            np.array([[0, 0], [0, 0]], dtype=np.float32),
            (num_agents // 2, 1))
    z, dzdt, _ = reference_trajectory_np(config.TIME_OFFSET, scale, omega, phase, trans)
    init = np.concatenate([z, dzdt, np.zeros([num_agents, 1])], axis=1)

    obstacles = []
    for i in range(num_agents // 2):
        for j in [7.6, 7.9]:
            line_x = np.expand_dims(np.linspace(-1, 1, 32) * 6 + i * 12.5, axis=1)
            line_y = -np.ones_like(line_x) * j
            obstacles.append(np.concatenate([line_x, line_y], axis=1))
        
    obstacles = np.concatenate(obstacles, axis=0)

    return scale, omega, phase, trans, init, obstacles
        

def car_dynamics_np(s, u):
    """ The dynamics of the car model.
    Args:
        s (N, 5): The position, velocity and rotation.
        u (N, 2): The control input.
    Returns:
        dsdt (N, 5): The time derivatives of the state.
    """
    z1, z2, dzdt1, dzdt2, phi = np.split(s, 5, axis=1)
    u1, u2 = np.split(u, 2, axis=1)
    ddzdt1 = -config.K * dzdt1 + u1
    ddzdt2 = -config.K * dzdt2 + u2
    dphidt = (-dzdt1 * np.sin(phi) + dzdt2 * np.cos(phi)) / config.L
    dsdt = np.concatenate([dzdt1, dzdt2, ddzdt1, ddzdt2, dphidt], axis=1)
    return dsdt


def reference_trajectory_np(t, scale, omega, phase, trans):
    """ Generate the reference trajectories.
    Args:
        t (float): The current time.
        scale (N, 2): The scaling factors of the trajectories.
        omega (N, 2): The angular velocities.
        phase (N, 2): The initial phases.
        trans (N, 2): The translations of the trajectories.
    Returns:
        z (N, 2): The reference location.
        dzdt (N, 2): The first time derivatives of trajectories.
        ddzdt (N, 2): The second time-derivatives of trajectories.
    """
    z = trans + scale * np.sin(omega * t + phase)
    dzdt = scale * omega * np.cos(omega * t + phase)
    ddzdt = -scale * omega ** 2 * np.sin(omega * t + phase)
    return z, dzdt, ddzdt


def car_controller_np(s, z, dzdt, ddzdt):
    """ Control the car to follow a given trajectory.
    Args:
        s (N, 5): The location, velocity and rotation of the car.
        z (N, 2): The reference location.
        dzdt (N, 2): The time derivatives of the reference trajectory.
        ddzdt (N, 2): The second time-derivatives of the reference.
    Returns:
        u (N, 2): The control input.
    """
    g = s[:, :2] - z
    g_norm = np.linalg.norm(g, axis=1, keepdims=True) + 1e-6
    mask = g_norm < config.G / config.KAPPA
    g = config.KAPPA * g * mask + config.G * g * (1 - mask) / g_norm
    u = ddzdt + config.K * dzdt - g
    return u


def car_dynamics_tf(s, u):
    """ The dynamics of the car model.
    Args:
        s (N, 5): The position, velocity and rotation.
        u (N, 2): The control input.
    Returns:
        dsdt (N, 5): The time derivatives of the state.
    """
    z1, z2, dzdt1, dzdt2, phi = tf.split(s, 5, axis=1)
    u1, u2 = tf.split(u, 2, axis=1)
    ddzdt1 = -config.K * dzdt1 + u1
    ddzdt2 = -config.K * dzdt2 + u2
    dphidt = (-dzdt1 * tf.math.sin(phi) + dzdt2 * tf.math.cos(phi)) / config.L
    dsdt = tf.concat([dzdt1, dzdt2, ddzdt1, ddzdt2, dphidt], axis=1)
    return dsdt


def reference_trajectory_tf(t, scale, omega, phase, trans):
    """ Generate the reference trajectories.
    Args:
        t (float): The current time.
        scale (N, 2): The scaling factors of the trajectories.
        omega (N, 2): The angular velocities.
        phase (N, 2): The initial phases.
        trans (N, 2): The translations of the trajectories.
    Returns:
        z (N, 2): The reference location.
        dzdt (N, 2): The first time derivatives of trajectories.
        ddzdt (N, 2): The second time-derivatives of trajectories.
    """
    z = trans + scale * tf.math.sin(omega * t + phase)
    dzdt = scale * omega * tf.math.cos(omega * t + phase)
    ddzdt = -scale * omega ** 2 * tf.math.sin(omega * t + phase)
    return z, dzdt, ddzdt


def car_controller_tf(s, z, dzdt, ddzdt):
    """ Control the car to follow a given trajectory.
    Args:
        s (N, 5): The location, velocity and rotation of the car.
        z (N, 2): The reference location.
        dzdt (N, 2): The time derivatives of the reference trajectory.
        ddzdt (N, 2): The second time-derivatives of the reference.
    Returns:
        u (N, 2): The control input.
    """
    g = s[:, :2] - z
    g_norm = tf.norm(g, axis=1, keepdims=True) + 1e-6
    mask = tf.cast(tf.less(g_norm, config.G / config.KAPPA), tf.float32)
    g = config.KAPPA * g * mask + config.G * g * (1 - mask) / g_norm
    u = ddzdt + config.K * dzdt - g
    return u


def network_cbf(x, d, r, indices=None):
    """ Control barrier function as a neural network.
    Args:
        x (N, N, 5): The state difference of N agents.
        d (N, num, 2): The relative positions to obstacle outline points.
        r (float): The radius of the dangerous zone.
        indices (N, K): The indices of K nearest agents of each agent.
    Returns:
        h (N, K, 1): The CBF of N agents with K neighbouring agents.
        mask (N, K, 1): The mask of agents within the observation radius.
        indices (N, K): The indices of K nearest agents of each agent.
    """
    x, phi = tf.split(x, [4, 1], axis=2)
    x = tf.concat([x, tf.math.cos(phi), tf.math.sin(phi)], axis=2)
    d_norm = tf.sqrt(
        tf.reduce_sum(tf.square(x[:, :, :2]) + 1e-4, axis=2))
    x = tf.concat([x,
        tf.expand_dims(tf.eye(tf.shape(x)[0]), 2),
        tf.expand_dims(d_norm - r, 2)], axis=2)
    x, indices = remove_distant_agents(x=x, k=config.TOP_K, indices=indices)
    dist = tf.sqrt(
        tf.reduce_sum(tf.square(x[:, :, :2]) + 1e-4, axis=2, keepdims=True))
    mask = tf.cast(tf.less_equal(dist, config.OBS_RADIUS), tf.float32)

    d, dphi = tf.split(d, [4, 1], axis=2)
    d = tf.concat([d, tf.math.cos(dphi), tf.math.sin(dphi)], axis=2)
    d_norm = tf.sqrt(
        tf.reduce_sum(tf.square(d[:, :, :2]) + 1e-4, axis=2))
    d = tf.concat([d,
        tf.zeros_like(dphi),
        tf.expand_dims(d_norm - r, 2)], axis=2)
    dist = tf.sqrt(
        tf.reduce_sum(tf.square(d[:, :, :2]) + 1e-4, axis=2, keepdims=True))
    mask_obs = tf.cast(tf.less_equal(dist, config.OBS_RADIUS), tf.float32)

    x_combine = tf.concat([x, d], axis=1)
    mask_combine = tf.concat([mask, mask_obs], axis=1)

    x = tf.contrib.layers.conv1d(inputs=x_combine, 
                                 num_outputs=64,
                                 kernel_size=1, 
                                 reuse=tf.AUTO_REUSE,
                                 scope='cbf/conv_1', 
                                 activation_fn=tf.nn.relu)
    x = tf.contrib.layers.conv1d(inputs=x, 
                                 num_outputs=128,
                                 kernel_size=1, 
                                 reuse=tf.AUTO_REUSE,
                                 scope='cbf/conv_2', 
                                 activation_fn=tf.nn.relu)
    x = tf.contrib.layers.conv1d(inputs=x, 
                                 num_outputs=128,
                                 kernel_size=1, 
                                 reuse=tf.AUTO_REUSE,
                                 scope='cbf/conv_3', 
                                 activation_fn=tf.nn.relu)
    x = tf.contrib.layers.conv1d(inputs=x, 
                                 num_outputs=1,
                                 kernel_size=1, 
                                 reuse=tf.AUTO_REUSE,
                                 scope='cbf/conv_4', 
                                 activation_fn=None)
    h = x * mask_combine
    return h, mask_combine, indices


def network_action(s, z_ref, d, obs_radius=2.0, indices=None):
    """Controller as a neural network.
    Args:
        s (N, 5): The current state of N agents.
        z_ref (N, 6): The reference location, velocity and acceleration.
        obs_radius (float): The observation radius.
        indices (N, K): The indices of K nearest agents of each agent.
    Returns:
        u (N, 2): The control action.
    """
    x = tf.expand_dims(s, 1) - tf.expand_dims(s, 0)
    x, phi_diff = tf.split(x, [4, 1], axis=2)
    x = tf.concat([x, tf.math.cos(phi_diff), tf.math.sin(phi_diff)], axis=2)
    x = tf.concat([x,
        tf.expand_dims(tf.eye(tf.shape(x)[0]), 2)], axis=2)
    x, _ = remove_distant_agents(x=x, k=config.TOP_K, indices=indices)
    dist = tf.norm(x[:, :, :2], axis=2, keepdims=True)
    mask = tf.cast(tf.less(dist, obs_radius), tf.float32)
    
    d, dphi = tf.split(d, [4, 1], axis=2)
    d = tf.concat([d, tf.math.cos(dphi), tf.math.sin(dphi)], axis=2)
    d_norm = tf.sqrt(
        tf.reduce_sum(tf.square(d[:, :, :2]) + 1e-4, axis=2))
    d = tf.concat([d, tf.zeros_like(dphi)], axis=2)
    dist = tf.sqrt(
        tf.reduce_sum(tf.square(d[:, :, :2]) + 1e-4, axis=2, keepdims=True))
    mask_obs = tf.cast(tf.less_equal(dist, config.OBS_RADIUS), tf.float32)

    x_combine = tf.concat([x, d], axis=1)
    mask_combine = tf.concat([mask, mask_obs], axis=1)

    x = tf.contrib.layers.conv1d(inputs=x_combine, 
                                 num_outputs=64,
                                 kernel_size=1, 
                                 reuse=tf.AUTO_REUSE,
                                 scope='action/conv_1', 
                                 activation_fn=tf.nn.relu)
    x = tf.contrib.layers.conv1d(inputs=x, 
                                 num_outputs=128,
                                 kernel_size=1, 
                                 reuse=tf.AUTO_REUSE,
                                 scope='action/conv_2', 
                                 activation_fn=tf.nn.relu)
    x = tf.reduce_max(x * mask_combine, axis=1)
    x = tf.concat([x, s[:, :2] - z_ref[:, :2], s[:, 2:4], 
                   tf.math.cos(s[:, 4:]), tf.math.sin(s[:, 4:]), 
                   z_ref[:, 2:]], axis=1)
    x = tf.contrib.layers.fully_connected(inputs=x,
                                          num_outputs=64,
                                          reuse=tf.AUTO_REUSE,
                                          scope='action/fc_1',
                                          activation_fn=tf.nn.relu)
    x = tf.contrib.layers.fully_connected(inputs=x,
                                          num_outputs=128,
                                          reuse=tf.AUTO_REUSE,
                                          scope='action/fc_2',
                                          activation_fn=tf.nn.relu)
    x = tf.contrib.layers.fully_connected(inputs=x,
                                          num_outputs=64,
                                          reuse=tf.AUTO_REUSE,
                                          scope='action/fc_3',
                                          activation_fn=tf.nn.relu)
    x = tf.contrib.layers.fully_connected(inputs=x,
                                          num_outputs=2,
                                          reuse=tf.AUTO_REUSE,
                                          scope='action/fc_4',
                                          activation_fn=None)
    z, dzdt, ddzdt = tf.split(z_ref, [2, 2, 2], axis=1)
    u_ref = car_controller_tf(s, z, dzdt, ddzdt)
    u = x + u_ref
    return u


def loss_barrier(h, s, d, indices=None, eps=[5e-2, 1e-3]):
    """ Build the loss function for the control barrier functions.
    Args:
        h (N, N, 1): The control barrier function.
        s (N, 5): The current state of N agents.
        indices (N, K): The indices of K nearest agents of each agent.
        eps (2, ): The margin factors.
    Returns:
        loss_dang (float): The barrier loss for dangerous states.
        loss_safe (float): The barrier loss for safe sates.
        acc_dang (float): The accuracy of h(dangerous states) <= 0.
        acc_safe (float): The accuracy of h(safe states) >= 0.
    """
    h_reshape = tf.reshape(h, [-1])
    dang_mask = compute_dangerous_mask(
        s, d=d, r=config.DIST_MIN_THRES, indices=indices)
    dang_mask_reshape = tf.reshape(dang_mask, [-1])
    safe_mask = compute_safe_mask(s, d=d, r=config.DIST_SAFE, indices=indices)
    safe_mask_reshape = tf.reshape(safe_mask, [-1])

    dang_h = tf.boolean_mask(h_reshape, dang_mask_reshape)
    safe_h = tf.boolean_mask(h_reshape, safe_mask_reshape)

    num_dang = tf.cast(tf.shape(dang_h)[0], tf.float32)
    num_safe = tf.cast(tf.shape(safe_h)[0], tf.float32)

    loss_dang = tf.reduce_sum(
        tf.math.maximum(dang_h + eps[0], 0)) / (1e-5 + num_dang)
    loss_safe = tf.reduce_sum(
        tf.math.maximum(-safe_h + eps[1], 0)) / (1e-5 + num_safe)

    acc_dang = tf.reduce_sum(tf.cast(
        tf.less_equal(dang_h, 0), tf.float32)) / (1e-12 + num_dang)
    acc_safe = tf.reduce_sum(tf.cast(
        tf.greater_equal(safe_h, 0), tf.float32)) / (1e-12 + num_safe)

    acc_dang = tf.cond(
        tf.greater(num_dang, 0), lambda: acc_dang, lambda: -tf.constant(1.0))
    acc_safe = tf.cond(
        tf.greater(num_safe, 0), lambda: acc_safe, lambda: -tf.constant(1.0))

    return loss_dang, loss_safe, acc_dang, acc_safe


def loss_derivatives(s, u, h, x, d, indices=None, eps=[8e-2, 0, 3e-2]):
    """ Build the loss function for the derivatives of the CBF.
    Args:
        s (N, 5): The current state of N agents.
        u (N, 2): The control action.
        h (N, N, 1): The control barrier function.
        x (N, N, 5): The state difference of N agents.
        indices (N, K): The indices of K nearest agents of each agent.
        eps (3, ): The margin factors.
    Returns:
        loss_dang_deriv (float): The derivative loss of dangerous states.
        loss_safe_deriv (float): The derivative loss of safe states.
        loss_medium_deriv (float): The derivative loss of medium states.
        acc_dang_deriv (float): The derivative accuracy of dangerous states.
        acc_safe_deriv (float): The derivative accuracy of safe states.
        acc_medium_deriv (float): The derivative accuracy of medium states.
    """
    dsdt = car_dynamics_tf(s, u)
    s_next = s + dsdt * config.TIME_STEP

    x_next = tf.expand_dims(s_next, 1) - tf.expand_dims(s_next, 0)
    h_next, mask_next, _ = network_cbf(
        x=x_next, r=config.DIST_MIN_THRES, d=d, indices=indices)

    deriv = h_next - h + config.TIME_STEP * config.ALPHA_CBF * h

    deriv_reshape = tf.reshape(deriv, [-1])
    dang_mask = compute_dangerous_mask(s, d=d, r=config.DIST_MIN_THRES, indices=indices)
    dang_mask_reshape = tf.reshape(dang_mask, [-1])
    safe_mask = compute_safe_mask(s, d=d, r=config.DIST_SAFE, indices=indices)
    safe_mask_reshape = tf.reshape(safe_mask, [-1])
    medium_mask_reshape = tf.logical_not(
        tf.logical_or(dang_mask_reshape, safe_mask_reshape))

    dang_deriv = tf.boolean_mask(deriv_reshape, dang_mask_reshape)
    safe_deriv = tf.boolean_mask(deriv_reshape, safe_mask_reshape)
    medium_deriv = tf.boolean_mask(deriv_reshape, medium_mask_reshape)

    num_dang = tf.cast(tf.shape(dang_deriv)[0], tf.float32)
    num_safe = tf.cast(tf.shape(safe_deriv)[0], tf.float32)
    num_medium = tf.cast(tf.shape(medium_deriv)[0], tf.float32)

    loss_dang_deriv = tf.reduce_sum(
        tf.math.maximum(-dang_deriv + eps[0], 0)) / (1e-5 + num_dang)
    loss_safe_deriv = tf.reduce_sum(
        tf.math.maximum(-safe_deriv + eps[1], 0)) / (1e-5 + num_safe)
    loss_medium_deriv = tf.reduce_sum(
        tf.math.maximum(-medium_deriv + eps[2], 0)) / (1e-5 + num_medium)

    acc_dang_deriv = tf.reduce_sum(tf.cast(
        tf.greater_equal(dang_deriv, 0), tf.float32)) / (1e-12 + num_dang)
    acc_safe_deriv = tf.reduce_sum(tf.cast(
        tf.greater_equal(safe_deriv, 0), tf.float32)) / (1e-12 + num_safe)
    acc_medium_deriv = tf.reduce_sum(tf.cast(
        tf.greater_equal(medium_deriv, 0), tf.float32)) / (1e-12 + num_medium)

    acc_dang_deriv = tf.cond(
        tf.greater(num_dang, 0), lambda: acc_dang_deriv, lambda: -tf.constant(1.0))
    acc_safe_deriv = tf.cond(
        tf.greater(num_safe, 0), lambda: acc_safe_deriv, lambda: -tf.constant(1.0))
    acc_medium_deriv = tf.cond(
        tf.greater(num_medium, 0), lambda: acc_medium_deriv, lambda: -tf.constant(1.0))

    return (loss_dang_deriv, loss_safe_deriv, loss_medium_deriv, 
            acc_dang_deriv, acc_safe_deriv, acc_medium_deriv)


def loss_actions(s, u, z_ref, d, indices):
    """ Build the loss function for control actions.
    Args:
        s (N, 5): The current state of N agents.
        u (N, 2): The control action.
        z_ref (N, 6): The reference trajectory.
        indices (N, K): The indices of K nearest agents of each agent.
    Returns:
        loss (float): The loss function for control actions.
    """
    z, dzdt, ddzdt = tf.split(z_ref, [2, 2, 2], axis=1)
    u_ref = car_controller_tf(s, z, dzdt, ddzdt)
    loss = tf.minimum(tf.abs(u - u_ref), (u - u_ref)**2)
    safe_mask = compute_safe_mask(s, d, config.DIST_SAFE, indices)
    safe_mask = tf.reduce_mean(tf.cast(safe_mask, tf.float32), axis=1)
    safe_mask = tf.cast(tf.equal(safe_mask, 1), tf.float32)
    loss = tf.reduce_sum(loss * safe_mask) / (1e-4 + tf.reduce_sum(safe_mask))
    return loss


def compute_dangerous_mask(s, d, r, indices=None):
    """ Identify the agents within the dangerous radius.
    Args:
        s (N, 5): The current state of N agents.
        d (N, num, 2): The relative position to nearest obstacle points.
        r (float): The dangerous radius.
        indices (N, K): The indices of K nearest agents of each agent.
    Returns:
        mask (N, K): 1 for agents inside the dangerous radius and 0 otherwise.
    """
    s_diff = tf.expand_dims(s, 1) - tf.expand_dims(s, 0)
    s_diff = tf.concat(
        [s_diff, tf.expand_dims(tf.eye(tf.shape(s)[0]), 2)], axis=2)
    s_diff, _ = remove_distant_agents(s_diff, config.TOP_K, indices)
    z_diff, eye = s_diff[:, :, :2], s_diff[:, :, -1:]
    z_diff = tf.norm(z_diff, axis=2, keepdims=True)
    mask = tf.logical_and(tf.less(z_diff, r), tf.equal(eye, 0))
    mask_obs = tf.less(tf.norm(d[:, :, :2], axis=2, keepdims=True), r)
    mask = tf.concat([mask, mask_obs], axis=1)
    return mask


def compute_safe_mask(s, d, r, indices=None):
    """ Identify the agents outside the safe radius.
    Args:
        s (N, 5): The current state of N agents.
        d (N, num, 2): The relative position to nearest obstacle points.
        r (float): The safe radius.
        indices (N, K): The indices of K nearest agents of each agent.
    Returns:
        mask (N, K): 1 for agents outside the safe radius and 0 otherwise.
    """
    s_diff = tf.expand_dims(s, 1) - tf.expand_dims(s, 0)
    s_diff = tf.concat(
        [s_diff, tf.expand_dims(tf.eye(tf.shape(s)[0]), 2)], axis=2)
    s_diff, _ = remove_distant_agents(s_diff, config.TOP_K, indices)
    z_diff, eye = s_diff[:, :, :2], s_diff[:, :, -1:]
    z_diff = tf.norm(z_diff, axis=2, keepdims=True)
    mask = tf.logical_or(tf.greater(z_diff, r), tf.equal(eye, 1))
    mask_obs = tf.greater(tf.norm(d[:, :, :2], axis=2, keepdims=True), r)
    mask = tf.concat([mask, mask_obs], axis=1)
    return mask


def dangerous_mask_np(s, o, r):
    """ Identify the agents within the dangerous radius.
    Args:
        s (N, 5): The current state of N agents.
        r (float): The dangerous radius.
    Returns:
        mask (N, N): 1 for agents inside the dangerous radius and 0 otherwise.
    """
    s_diff = np.expand_dims(s, 1) - np.expand_dims(s, 0)
    s_diff = np.linalg.norm(s_diff[:, :, :2], axis=2, keepdims=False)
    eye = np.eye(s_diff.shape[0])
    mask = np.logical_and(s_diff < r, eye == 0).astype(np.float32)

    s_coor = np.reshape(s[:, :2], [-1, 1, 2])
    o = np.reshape(o, [1, -1, 2])
    dist = np.amin(np.linalg.norm(s_coor - o, axis=2), axis=1, keepdims=True)
    mask_obs = (dist < r).astype(np.float32)
    mask = np.clip(mask + mask_obs, 0, 1)

    return mask


def remove_distant_agents(x, k, indices=None):
    """ Remove the distant agents.
    Args:
        x (N, N, C): The state difference of N agents.
        k (int): The K nearest agents to keep.
    Returns:
        x (N, K, C): The K nearest agents.
        indices (N, K): The indices of K nearest agents of each agent.
    """
    n, _, c = x.get_shape().as_list()
    if n <= k:
        return x, False
    d_norm = tf.sqrt(tf.reduce_sum(tf.square(x[:, :, :2]) + 1e-6, axis=2))
    if indices is not None:
        x = tf.reshape(tf.gather_nd(x, indices), [n, k, c])
        return x, indices
    _, indices = tf.nn.top_k(-d_norm, k=k)
    row_indices = tf.expand_dims(
        tf.range(tf.shape(indices)[0]), 1) * tf.ones_like(indices)
    row_indices = tf.reshape(row_indices, [-1, 1])
    column_indices = tf.reshape(indices, [-1, 1])
    indices = tf.concat([row_indices, column_indices], axis=1)
    x = tf.reshape(tf.gather_nd(x, indices), [n, k, c])
    return x, indices


def detect_nearest_obstacles(s, o, num=8):
    """ Detect nearest obstacles.
    Args:
        s (N, 5): The current state of N agents.
        o (M, 2): The outlines of the obstacles as M points.
        num (int): The maximum number of obstacle points to keep.
    Returns:
        d (N, num, 2): The vectors from the agent to the nearest points.
    """
    n, _, = s.get_shape().as_list()
    s_reshape = tf.reshape(s, [-1, 1, 5])
    o = tf.reshape(o, [1, -1, 2])
    ox, oy = tf.split(o, 2, axis=2)
    o_extend = tf.concat([o, tf.zeros_like(o), tf.zeros_like(ox)], axis=2)
    diff = s_reshape - o_extend
    dist = tf.norm(diff[:, :, :2], axis=2, keepdims=False)
    _, indices = tf.nn.top_k(-dist, k=num)
    row_indices = tf.expand_dims(
        tf.range(tf.shape(indices)[0]), 1) * tf.ones_like(indices)
    row_indices = tf.reshape(row_indices, [n * num, 1])
    column_indices = tf.reshape(indices, [n * num, 1])
    indices = tf.concat([row_indices, column_indices], axis=1)
    d = tf.reshape(tf.gather_nd(diff, indices), [n, num, 5])
    return d
