import warp as wp

@wp.kernel
def compute_l2_loss_vec3(
        pred:wp.array(dtype=wp.vec3),
        gt: wp.array(dtype=wp.vec3),
        ee_mask: wp.array(dtype=int),
        loss_weight: float,
        l2_loss:wp.array(dtype=float),

):
    i = wp.tid()
    if ee_mask[i] == 0:
        dist = wp.length(pred[i] - gt[i])
        loss = loss_weight * dist * dist
        wp.atomic_add(l2_loss, 0, loss)

@wp.kernel
def compute_l2_loss(
        pred:wp.array(dtype=wp.vec2),
        gt: wp.array(dtype=wp.vec2),
        ee_mask: wp.array(dtype=int),
        loss_weight: float,
        l2_loss:wp.array(dtype=float),

):
    i = wp.tid()
    if ee_mask[i] == 0:
        dist = wp.length(pred[i] - gt[i])
        loss = loss_weight * dist * dist
        wp.atomic_add(l2_loss, 0, loss)

@wp.kernel
def compute_total_loss(
        chamfer_loss:wp.array(dtype=float),
        track_loss:wp.array(dtype=float),
        loss:wp.array(dtype=float)
):

    wp.atomic_add(loss, 0, chamfer_loss[0])
    wp.atomic_add(loss, 0, track_loss[0])

@wp.kernel
def compute_total_loss_without_acc(
        chamfer_loss:wp.array(dtype=float),
        track_loss:wp.array(dtype=float),
        loss:wp.array(dtype=float)
):
    # wp.atomic_add(loss, 0, l2_loss)
    wp.atomic_add(loss, 0, chamfer_loss[0])
    wp.atomic_add(loss, 0, track_loss[0])

@wp.kernel
def compute_seq_loss(
        chamfer_loss:wp.array(dtype=float),
        track_loss:wp.array(dtype=float),
        chamfer_loss_total:wp.array(dtype=float),
        track_loss_total:wp.array(dtype=float),
):
    wp.atomic_add(chamfer_loss_total, 0, chamfer_loss[0])
    wp.atomic_add(track_loss_total, 0, track_loss[0])

@wp.kernel(enable_backward=False)
def compute_distances(
    pred: wp.array(dtype=wp.vec3),
    gt: wp.array(dtype=wp.vec3),
    ee_mask: wp.array(dtype=wp.int32),
    distances: wp.array2d(dtype=float),
):
    i, j = wp.tid()
    if ee_mask[i] == 0:
        dist = wp.length(gt[i] - pred[j])
        distances[i, j] = dist
    else:
        distances[i, j] = 1e6
#
@wp.kernel(enable_backward=False)
def compute_neigh_indices(
    distances: wp.array2d(dtype=float),
    neigh_indices: wp.array(dtype=wp.int32),
):
    i = wp.tid()
    min_dist = float(1e6)
    min_index = int(-1)
    for j in range(distances.shape[1]):
        if distances[i, j] < min_dist:
            min_dist = distances[i, j]
            min_index = j
    neigh_indices[i] = min_index

@wp.kernel(enable_backward=False)
def compute_neigh_indices_inverse(
    distances: wp.array2d(dtype=float),
    neigh_indices: wp.array(dtype=wp.int32),
):
    j = wp.tid()
    min_dist = float(1e6)
    min_index = int(-1)
    for i in range(distances.shape[0]):
        if distances[i, j] < min_dist:
            min_dist = distances[i, j]
            min_index = i
    neigh_indices[j] = min_index
#
@wp.kernel
def compute_chamfer_loss(
    pred: wp.array(dtype=wp.vec3),
    gt: wp.array(dtype=wp.vec3),
    ee_mask: wp.array(dtype=wp.int32),
    num_valid: int,
    neigh_indices: wp.array(dtype=wp.int32),
    loss_weight: float,
    chamfer_loss: wp.array(dtype=float),
):
    i = wp.tid()
    if ee_mask[i] == 0:
        min_pred = pred[neigh_indices[i]]
        min_dist = wp.length(min_pred - gt[i])
        final_min_dist = loss_weight * min_dist * min_dist / float(num_valid)
        wp.atomic_add(chamfer_loss, 0, final_min_dist)

#
@wp.kernel
def compute_chamfer_loss_1d(
    pred: wp.array(dtype=wp.vec3),
    gt: wp.array(dtype=wp.vec3),
    ee_mask: wp.array(dtype=wp.int32),
    num_valid: int,
    neigh_indices: wp.array(dtype=wp.int32),
    loss_weight: float,
    chamfer_loss: wp.array(dtype=float),
):
    i = wp.tid()
    if ee_mask[i] == 0:
        min_pred = pred[neigh_indices[i]]
        min_dist = wp.abs(min_pred[0] - gt[i][0]) + wp.abs(min_pred[1] - gt[i][1]) + wp.abs(min_pred[2] - gt[i][2])
        final_min_dist = loss_weight * min_dist / float(num_valid)
        wp.atomic_add(chamfer_loss, 0, final_min_dist)

@wp.kernel
def compute_track_loss(
    pred: wp.array(dtype=wp.vec3),
    gt: wp.array(dtype=wp.vec3),
    ee_mask: wp.array(dtype=wp.int32),
    num_valid: int,
    loss_weight: float,
    track_loss: wp.array(dtype=float),
):
    i = wp.tid()
    if ee_mask[i] == 0:
        # Calculate the smooth l1 loss modifed from fvcore.nn.smooth_l1_loss
        pred_x = pred[i][0]
        pred_y = pred[i][1]
        pred_z = pred[i][2]
        gt_x = gt[i][0]
        gt_y = gt[i][1]
        gt_z = gt[i][2]

        dist_x = wp.abs(pred_x - gt_x)
        dist_y = wp.abs(pred_y - gt_y)
        dist_z = wp.abs(pred_z - gt_z)

        if dist_x < 1.0:
            temp_track_loss_x = 0.5 * (dist_x**2.0)
        else:
            temp_track_loss_x = dist_x - 0.5

        if dist_y < 1.0:
            temp_track_loss_y = 0.5 * (dist_y**2.0)
        else:
            temp_track_loss_y = dist_y - 0.5

        if dist_z < 1.0:
            temp_track_loss_z = 0.5 * (dist_z**2.0)
        else:
            temp_track_loss_z = dist_z - 0.5

        temp_track_loss = temp_track_loss_x + temp_track_loss_y + temp_track_loss_z

        average_factor = float(num_valid) * 3.0

        final_track_loss = loss_weight * temp_track_loss / average_factor

        wp.atomic_add(track_loss, 0, final_track_loss)

@wp.kernel
def compute_acc_loss(
    v1: wp.array(dtype=wp.vec3),
    v2: wp.array(dtype=wp.vec3),
    prev_acc: wp.array(dtype=wp.vec3),
    num_object_points: int,
    ee_mask: wp.array(dtype=wp.int32),
    acc_weight: float,
    acc_loss: wp.array(dtype=wp.float32),
):
    tid = wp.tid()
    if tid >= num_object_points:
        return

    if ee_mask[tid] == 0:
        # Calculate the smooth l1 loss modifed from fvcore.nn.smooth_l1_loss
        cur_acc = v2[tid] - v1[tid]
        cur_x = cur_acc[0]
        cur_y = cur_acc[1]
        cur_z = cur_acc[2]

        # 使用数组索引访问 vec3f 的分量
        prev_x = prev_acc[tid][0]
        prev_y = prev_acc[tid][1]
        prev_z = prev_acc[tid][2]

        dist_x = wp.abs(cur_x - prev_x)
        dist_y = wp.abs(cur_y - prev_y)
        dist_z = wp.abs(cur_z - prev_z)

        if dist_x < 1.0:
            temp_acc_loss_x = 0.5 * (dist_x**2.0)
        else:
            temp_acc_loss_x = dist_x - 0.5

        if dist_y < 1.0:
            temp_acc_loss_y = 0.5 * (dist_y**2.0)
        else:
            temp_acc_loss_y = dist_y - 0.5

        if dist_z < 1.0:
            temp_acc_loss_z = 0.5 * (dist_z**2.0)
        else:
            temp_acc_loss_z = dist_z - 0.5

        temp_acc_loss = temp_acc_loss_x + temp_acc_loss_y + temp_acc_loss_z

        average_factor = float(num_object_points) * 3.0

        final_acc_loss = acc_weight * temp_acc_loss / average_factor

        wp.atomic_add(acc_loss, 0, final_acc_loss)

#
# @wp.kernel
# def compute_acc_loss(
#     v1: wp.array(dtype=wp.vec3),
#     v2: wp.array(dtype=wp.vec3),
#     prev_acc: wp.array(dtype=wp.vec3),
#     num_object_points: int,
#     ee_mask: wp.array(dtype=wp.int32),
#     acc_weight: float,
#     acc_loss: wp.array(dtype=wp.float32),
# ):
#     if ee_mask == 0:
#         # Calculate the smooth l1 loss modifed from fvcore.nn.smooth_l1_loss
#         tid = wp.tid()
#         cur_acc = v2[tid] - v1[tid]
#         cur_x = cur_acc[0]
#         cur_y = cur_acc[1]
#         cur_z = cur_acc[2]
#
#         prev_x = prev_acc[tid][0]
#         prev_y = prev_acc[tid][1]
#         prev_z = prev_acc[tid][2]
#
#         dist_x = wp.abs(cur_x - prev_x)
#         dist_y = wp.abs(cur_y - prev_y)
#         dist_z = wp.abs(cur_z - prev_z)
#
#         if dist_x < 1.0:
#             temp_acc_loss_x = 0.5 * (dist_x**2.0)
#         else:
#             temp_acc_loss_x = dist_x - 0.5
#
#         if dist_y < 1.0:
#             temp_acc_loss_y = 0.5 * (dist_y**2.0)
#         else:
#             temp_acc_loss_y = dist_y - 0.5
#
#         if dist_z < 1.0:
#             temp_acc_loss_z = 0.5 * (dist_z**2.0)
#         else:
#             temp_acc_loss_z = dist_z - 0.5
#
#         temp_acc_loss = temp_acc_loss_x + temp_acc_loss_y + temp_acc_loss_z
#
#         average_factor = float(num_object_points) * 3.0
#
#         final_acc_loss = acc_weight * temp_acc_loss / average_factor
#
#         wp.atomic_add(acc_loss, 0, final_acc_loss)


