import time
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

import faiss
import faiss.contrib.torch_utils
from matplotlib import pyplot as plt

from helperfunctions.hfunctions import assert_torch_invalid


class SobelFilter(nn.Module):
    '''
    引用自: https://github.com/chaddy1004/sobel-operator-pytorch/blob/master/model.py
    该类实现了 Sobel 滤波器，用于图像边缘检测。
    '''
    def __init__(self, device):
        super().__init__()
        # 定义一个二维卷积层，输入通道数为1，输出通道数为2，卷积核大小为3x3，步长为1，填充为1，填充模式为反射，无偏置
        self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3,
                                stride=1, padding=1, padding_mode='reflect', bias=False)
        # 定义 Sobel 算子核 Gx 和 Gy，用于水平和垂直方向的边缘检测
        Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]]).to(device)
        Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]]).to(device)

        # 将 Gx 和 Gy 合并为一个权重张量，并调整维度以适应卷积层的权重形状
        G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
        G = G.unsqueeze(1)

        # 将权重张量设置为卷积层的权重，并且不允许其在训练过程中更新
        self.filter.weight = nn.Parameter(G, requires_grad=False)

    def forward(self, img):
        # 将输入图像 img 进行卷积操作
        x = self.filter(img)
        # 对卷积结果进行平方操作
        x = torch.mul(x, x)
        # 对平方结果沿着通道维度求和，保留维度
        x = torch.sum(x, dim=1, keepdim=True)
        # 对求和结果进行平方根操作
        x = torch.sqrt(x)
        # 返回处理后的图像
        return x


def find_nearest_targets(P_source, P_target, gpu_resource=None):
    d = P_target.shape[1]  # 获取目标点的维度
    if gpu_resource is not None:  # 如果提供了GPU资源
        l2_idx = faiss.GpuIndexFlatL2(gpu_resource, d)  # 使用GPU创建L2距离索引
    else:
        l2_idx = faiss.IndexFlatL2(d)  # 使用CPU创建L2距离索引
    l2_idx.add(P_target)  # 将目标点添加到索引中
    _, nearest_tar_idx = l2_idx.search(P_source, k=1)  # 搜索每个源点最近的目标点
    return nearest_tar_idx.squeeze(1)  # 返回最近目标点的索引并去掉多余的维度

def nearest_target_distance(P_source, P_target, gpu_resources):
    MSE = torch.nn.MSELoss(reduction='mean')  # 定义均方误差损失函数
    RMSE = lambda x, y: torch.sqrt(MSE(x, y))  # 定义均方根误差损失函数
    nearest_tar_idx = find_nearest_targets(P_source=P_source,
                                           P_target=P_target,
                                           gpu_resource=gpu_resources)  # 找到最近的目标点索引
    P_target_nearest = torch.index_select(P_target, 0, nearest_tar_idx)  # 根据索引选择最近的目标点
    rmse_loss = RMSE(P_target_nearest, P_source)  # 计算源点和最近目标点之间的RMSE损失

    return rmse_loss, P_target_nearest  # 返回RMSE损失和最近的目标点

 # 用于处理图像掩码，提取指定类别的像素位置和边缘信息
def mask_img_2_UV_flat(mask, class_idx):
    UV_flat = torch.where(mask==class_idx)  # 找到掩码中指定类别的像素位置
    B_idx = UV_flat[0]  # 获取批次索引
    max_n = torch.max(torch.bincount(B_idx)).item()  # 计算每个批次中最多的像素数
    UV_flat = torch.stack([UV_flat[2], UV_flat[1]], dim=-1)  # 重新组织像素位置为(x, y)格式
    return UV_flat, B_idx, max_n  # 返回像素位置，批次索引和最大像素数

def mask_img_2_edge_UV_flat(mask, sobel_filter):
    mask_edge = sobel_filter(mask.unsqueeze(1)).squeeze(1)  # 使用Sobel滤波器检测边缘
    mask_edge = torch.where(mask_edge > 0, 1, 0)  # 将边缘像素值设为1，其他设为0
    UV_flat, B_idx, max_n = mask_img_2_UV_flat(mask_edge, 1)  # 获取边缘像素的位置、批次索引和最大像素数
    return UV_flat, B_idx, max_n  # 返回结果

def pad_UV_i(UV_i, max_n_i, invalid_uv=0):
    cnt_i = UV_i.shape[0]  # 获取当前像素数
    n_diff_i = max_n_i - cnt_i  # 计算需要填充的像素数
    pad_i = invalid_uv * torch.ones(n_diff_i, 2).to(UV_i.device)  # 创建填充值
    UV_i = torch.cat((UV_i, pad_i))  # 将填充值添加到原有像素位置
    return UV_i, cnt_i  # 返回填充后的像素位置和原始像素数

def dist_loss(x, y):
    return (x - y).square().sum(-1).sqrt().mean()  # 计算欧几里得距离损失并返回


def rendered_semantics_loss_vectorized2(gt_mask, rend_dict, sobel_filter, faiss_gpu_res, rend_dict_loss, args):
    # 断言，确保 `loss_w_rend_diameter` 为 0
    assert args['loss_w_rend_diameter'] == 0

    # 获取设备
    dev = gt_mask.device

    # 类别索引
    iris_idx = 2

    # 提取预测模板
    UV_iris_pred_1 = rend_dict['iris_UV']



    # 如果启用了预测到GT边缘的损失
    if args['loss_w_rend_pred_2_gt_edge']:
        which_edge1 = 'sides'  # 指定边缘
        side_idx_iris1 = rend_dict['edge_idx_iris']
        UV_iris_edge_pred_1 = UV_iris_pred_1[:, side_idx_iris1[which_edge1], :]
    if args['loss_w_rend_pred_2_gt_edge_2D']:
        UV_iris_pred_2 = rend_dict_loss['iris_UV']
        which_edge2 = 'sides'  # 指定边缘
        side_idx_iris2 = rend_dict_loss['edge_idx_iris']
        UV_iris_edge_pred_2 = UV_iris_pred_2[:, side_idx_iris2[which_edge2], :]

    # 基本断言，确保输入和预测的维度匹配
    assert gt_mask.shape[0] == UV_iris_pred_1.shape[0]
    assert len(gt_mask.shape) == 3
    assert len(UV_iris_pred_1.shape) == 3
    assert UV_iris_pred_1.shape[2] == 2, print('虹膜特征数量无效')

    # 对ground truth的预处理，不需要梯度计算
    with torch.no_grad():
        # 准备辅助变量
        loss_dict = {}
        n_i = args['batch_size'] * args['frames']
        n_t = UV_iris_pred_1.shape[1]
        invalid_uv = 0
        UV_iris_gt = []

        if args['loss_w_rend_gt_2_pred']:
            # 提取GT掩码的UV位置
            UV_iris_gt_flat, B_idx_iris_gt, max_n_iris_gt = mask_img_2_UV_flat(gt_mask, iris_idx)
            mask_iris_gt = torch.ones(n_i, max_n_iris_gt, dtype=torch.bool).to(dev)

        # 提取GT的虹膜和瞳孔边缘掩码（使用Sobel滤波器），然后提取GT边缘的UV位置
        # 如果 loss_w_rend_pred_2_gt_edge 参数为真，则执行以下代码块
        if args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_pred_2_gt_edge_2D']:
            # 创建一个虹膜掩码，其中大于等于虹膜索引的部分设为1，否则设为0
            iris_mask_gt = torch.where(gt_mask >= iris_idx, 1., 0.)
            # plt.imshow(iris_mask_gt.detach().cpu().numpy()[0])
            # plt.axis('off')
            # plt.show()
            # 将虹膜掩码转换为UV边缘，获取边缘UV值，边缘索引和最大边缘数量
            UV_iris_edge_gt_flat, B_idx_iris_edge_gt, max_n_iris_edge_gt = mask_img_2_edge_UV_flat(iris_mask_gt,
                                                                                                   sobel_filter)
            # 创建一个瞳孔掩码，其中等于瞳孔索引的部分设为1，否则设为0
            # 将瞳孔掩码转换为UV边缘，获取边缘UV值，边缘索引和最大边缘数量
            UV_iris_edge_gt = []

            UV_iris_edge_gt_norm = []
            mask_iris_edge_gt = torch.ones(n_i, max_n_iris_edge_gt, dtype=torch.bool).to(dev)

        # 将展平的GT UV位置重新打包成 (B, N_max, 2) 张量，N_max是最大行数，其他行填充到N_max长度
        # 保存一个掩码 (B, N_Max)，用于区分原始和填充的UV值
        for i in range(n_i):
            if args['loss_w_rend_gt_2_pred']:
                UV_iris_gt_i = UV_iris_gt_flat[B_idx_iris_gt == i]
                UV_iris_gt_i, cnt_iris_gt_i = pad_UV_i(UV_iris_gt_i, max_n_iris_gt, invalid_uv)
                UV_iris_gt.append(UV_iris_gt_i)
                mask_iris_gt[i, cnt_iris_gt_i:] = 0

            if args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_pred_2_gt_edge_2D']:
                UV_iris_edge_gt_i = UV_iris_edge_gt_flat[B_idx_iris_edge_gt == i]
                UV_iris_edge_gt_i, cnt_iris_edge_gt_i = pad_UV_i(UV_iris_edge_gt_i, max_n_iris_edge_gt, invalid_uv)
                UV_iris_edge_gt.append(UV_iris_edge_gt_i)
                mask_iris_edge_gt[i, cnt_iris_edge_gt_i:] = 0

        # 仅在此计算距离，无需梯度。这在所有对上运行计算成本更低。
        # 稍后，这将用于子采样仅最近的对，并在此计算适当的欧几里得距离（带梯度）。
        invalid_dist = 1e9
        if args['loss_w_rend_gt_2_pred']:
            # 最后将其打包成 (B, N_max, 2) 的张量
            UV_iris_gt = torch.stack(UV_iris_gt)
            dist_iris = (UV_iris_pred_1.detach().unsqueeze(1) - UV_iris_gt.unsqueeze(2)).square().sum(-1)
            dist_iris = dist_iris + invalid_dist * (mask_iris_gt == 0).unsqueeze(-1)

        if args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_pred_2_gt_edge_2D']:
            UV_iris_edge_gt = torch.stack(UV_iris_edge_gt)

            # UV_iris_edge_gt_flatten = UV_iris_edge_gt.reshape(UV_iris_edge_gt.shape[0], -1)

            # UV_pupil_edge_gt_flatten = UV_pupil_edge_gt.reshape(UV_pupil_edge_gt.shape[0], -1)

            # # 第1步：计算每一行的均值（dim=1 表示按第二个维度计算）
            # iris_mean = UV_iris_edge_gt_flatten.mean(dim=1, keepdim=True)
            # pupil_mean = UV_pupil_edge_gt_flatten.mean(dim=1, keepdim=True)

            # # 第2步：计算每一行的标准差（加上eps避免除以0）
            # iris_std = UV_iris_edge_gt_flatten.std(dim=1, keepdim=True, unbiased=False)  # 使用无偏估计
            # pupil_std = UV_pupil_edge_gt_flatten.std(dim=1, keepdim=True, unbiased=False)  # 使用无偏估计
            # eps = 1e-5  # 小常数，防止除以 0

            # # 第3步：根据 LayerNorm 的公式进行归一化
            # iris_normalized = (UV_iris_edge_gt_flatten - iris_mean) / (iris_std + eps)
            # pupil_normalized = (UV_pupil_edge_gt_flatten - pupil_mean) / (pupil_std + eps)

            # iris_x, iris_y = torch.chunk(iris_normalized, 2, dim=-1)
            # pupil_x, pupil_y = torch.chunk(pupil_normalized, 2, dim=-1)
            # UV_iris_edge_gt_norm = torch.stack((iris_x, iris_y), dim=-1)
            # UV_pupil_edge_gt_norm = torch.stack((pupil_x, pupil_y), dim=-1)

            if args['loss_w_rend_pred_2_gt_edge']:
                # 3D边缘损失
                # print("UV_iris_edge_pred:",UV_iris_edge_pred.shape)
                # print(UV_iris_edge_pred)
                # print("UV_iris_edge_gt:",UV_iris_edge_gt.shape)
                # print(UV_iris_edge_gt)
                dist_iris_edge = (
                            UV_iris_edge_pred_1.detach().unsqueeze(1) - UV_iris_edge_gt.unsqueeze(2)).square().sum(-1)
                dist_iris_edge = dist_iris_edge + invalid_dist * (mask_iris_edge_gt == 0).unsqueeze(-1)

            if args['loss_w_rend_pred_2_gt_edge_2D']:
                # 2D边缘损失
                # iris_edge_pred_2D, pupil_edge_pred_2D = torch.split(edge_feature_concat, [26, 128], dim=1)
                # print("22222222:",iris_edge_pred_2D.shape)
                # print(iris_edge_pred_2D)

                dist_iris_edge2 = (
                            UV_iris_edge_pred_2.detach().unsqueeze(1) - UV_iris_edge_gt.unsqueeze(2)).square().sum(-1)
                dist_iris_edge2 = dist_iris_edge2 + invalid_dist * (mask_iris_edge_gt == 0).unsqueeze(-1)

    # 如果指定了 args 中 'loss_w_rend_gt_2_pred' 的值
    if args['loss_w_rend_gt_2_pred']:
        # 计算出最接近的虹膜索引
        gt_2_pred_iris_idx = torch.argmin(dist_iris, dim=2).unsqueeze(-1).repeat(1, 1, 2)
        # 从 UV_iris_pred 中根据索引获取最接近的 UV 值
        UV_iris_pred_closest = torch.gather(UV_iris_pred_1, dim=1, index=gt_2_pred_iris_idx)
        # 计算虹膜预测值与真实值的距离损失
        loss_iris_gt_2_pred = dist_loss(UV_iris_pred_closest[mask_iris_gt], UV_iris_gt[mask_iris_gt])


        # 将损失乘以 args 中 'loss_w_rend_gt_2_pred' 的权重，并保存到损失字典中
        loss_dict['iris_gt_2_pred'] = loss_iris_gt_2_pred * args['loss_w_rend_gt_2_pred']

    if args['loss_w_rend_pred_2_gt']:
        pred_2_gt_iris_idx = torch.argmin(dist_iris, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_gt_closest = torch.gather(UV_iris_gt, dim=1, index=pred_2_gt_iris_idx)
        loss_iris_pred_2_gt = dist_loss(UV_iris_gt_closest, UV_iris_pred_1)

        loss_dict['iris_pred_2_gt'] = loss_iris_pred_2_gt * args['loss_w_rend_pred_2_gt']

    if args['loss_w_rend_pred_2_gt_edge']:
        # 3D损失
        pred_2_gt_iris_edge_idx = torch.argmin(dist_iris_edge, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_edge_gt_closest = torch.gather(UV_iris_edge_gt, dim=1, index=pred_2_gt_iris_edge_idx)
        loss_pred_2_gt_iris_edge = dist_loss(UV_iris_edge_gt_closest, UV_iris_edge_pred_1)

        loss_dict['iris_pred_2_gt_edge'] = loss_pred_2_gt_iris_edge * args['loss_w_rend_pred_2_gt_edge']

    if args['loss_w_rend_pred_2_gt_edge_2D']:
        # 2D损失
        pred_2_gt_iris_edge_idx2 = torch.argmin(dist_iris_edge2, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_edge_gt_closest2 = torch.gather(UV_iris_edge_gt, dim=1, index=pred_2_gt_iris_edge_idx2)
        loss_pred_2_gt_iris_edge2 = dist_loss(UV_iris_edge_gt_closest2, UV_iris_edge_pred_2)

        loss_dict['iris_pred_2_gt_edge_2D'] = loss_pred_2_gt_iris_edge2 * args['loss_w_rend_pred_2_gt_edge_2D']

    # 计算总损失
    total_loss = 0.0
    for k in loss_dict:
        total_loss += loss_dict[k]

    return total_loss, loss_dict  # 返回总损失和损失字典

    # #--------------------------------------------------------------------------
    # #                              DEBUGGING (should go above loss components)
    # #--------------------------------------------------------------------------
    # # 确保 UV_iris_pred 和 UV_pupil_pred 是连续的张量
    # UV_iris_pred = UV_iris_pred.contiguous()
    # UV_pupil_pred = UV_pupil_pred.contiguous()
    # i = 2
    # # Extract GT iris & pupil locations # 提取 GT 虹膜和瞳孔的位置
    # y, x = torch.where(gt_mask[i]==iris_idx)
    # UV_iris_gt_i = torch.stack((x,y), axis=1).float()
    # y, x = torch.where(gt_mask[i]==pupil_idx)
    # UV_pupil_gt_i = torch.stack((x,y), axis=1).float()

    # if args['loss_w_rend_gt_2_pred']:
    # #计算虹膜预测值与最接近的真实值的距离及最接近的真实值
    #     loss_iris_gt_2_pred, \
    #         UV_iris_pred_nearest_i = nearest_target_distance(P_source=UV_iris_gt_i,
    #                                                         P_target=UV_iris_pred[i],
    #                                                         gpu_resources=faiss_gpu_res)
    # # 计算虹膜预测值与最接近的真实值之间的平均距离
    #     loss_tmp_1 = (UV_iris_gt_i - UV_iris_pred_nearest_i).square().sum(-1).sqrt().mean()
    # # 计算瞳孔预测值与最接近的真实值的距离及最接近的真实值
    #     loss_pupil_gt_2_pred, \
    #         UV_pupil_pred_nearest_i = nearest_target_distance(P_source=UV_pupil_gt_i,
    #                                                         P_target=UV_pupil_pred[i],
    #                                                         gpu_resources=faiss_gpu_res)
    # # 计算瞳孔预测值与最接近的真实值之间的平均距离
    #     loss_tmp_2 = (UV_pupil_gt_i - UV_pupil_pred_nearest_i).square().sum(-1).sqrt().mean()

    # # 计算临时变量，用于调试
    #     tmp = (UV_iris_pred[i].detach().unsqueeze(0) - UV_iris_gt_i.unsqueeze(1)).square().sum(-1).sqrt()
    #     tmp = UV_iris_pred[i][torch.argmin(tmp, dim=1)]
    #     print(torch.abs(UV_iris_pred_nearest_i - tmp).max())
    #     print((torch.abs(UV_iris_pred_nearest_i - tmp) > 0).sum())
    # # 计算虹膜预测值与最接近的真实值的距离损失
    #     loss_tmp_1_1 = dist_loss(UV_iris_pred_closest[i][mask_iris_gt[i]], UV_iris_gt[i][mask_iris_gt[i]])
    # # 计算虹膜预测值最接近的真实值与最接近的真实值之间的距离损失
    #     loss_tmp_1_2 = dist_loss(UV_iris_pred_nearest_i, UV_iris_gt_i)
    # #--------------------------------------------------------------------------
    # #--------------------------------------------------------------------------

def rendered_semantics_loss_vectorized3(gt_mask, rend_dict, sobel_filter, faiss_gpu_res, rend_dict_loss, args):
    # 断言，确保 `loss_w_rend_diameter` 为 0
    assert args['loss_w_rend_diameter'] == 0

    # 获取设备
    dev = gt_mask.device

    # 类别索引
    iris_idx = 1
    pupil_idx = 2

    # 提取预测模板
    UV_iris_pred_1 = rend_dict['iris_UV']
    UV_pupil_pred_1 = rend_dict['pupil_UV']



    # 如果启用了预测到GT边缘的损失
    if args['loss_w_rend_pred_2_gt_edge']:
        which_edge1 = 'sides'  # 指定边缘
        side_idx_iris1 = rend_dict['edge_idx_iris']
        side_idx_pupil1 = rend_dict['edge_idx_pupil']
        UV_iris_edge_pred_1 = UV_iris_pred_1[:, side_idx_iris1[which_edge1], :]
        UV_pupil_edge_pred_1 = UV_pupil_pred_1[:, side_idx_pupil1[which_edge1], :]
    if args['loss_w_rend_pred_2_gt_edge_2D']:
        UV_iris_pred_2 = rend_dict_loss['iris_UV']
        UV_pupil_pred_2 = rend_dict_loss['pupil_UV']
        which_edge2 = 'sides'  # 指定边缘
        side_idx_iris2 = rend_dict_loss['edge_idx_iris']
        side_idx_pupil2 = rend_dict_loss['edge_idx_pupil']
        UV_iris_edge_pred_2 = UV_iris_pred_2[:, side_idx_iris2[which_edge2], :]
        UV_pupil_edge_pred_2 = UV_pupil_pred_2[:, side_idx_pupil2[which_edge2], :]

    # 基本断言，确保输入和预测的维度匹配
    assert gt_mask.shape[0] == UV_pupil_pred_1.shape[0] == UV_iris_pred_1.shape[0]
    assert len(gt_mask.shape) == 3
    assert len(UV_pupil_pred_1.shape) == len(UV_iris_pred_1.shape) == 3
    assert UV_iris_pred_1.shape[2] == 2, print('虹膜特征数量无效')
    assert UV_pupil_pred_1.shape[2] == 2, print('瞳孔特征数量无效')

    # 对ground truth的预处理，不需要梯度计算
    with torch.no_grad():
        # 准备辅助变量
        loss_dict = {}
        n_i = args['batch_size'] * args['frames']
        n_t = UV_iris_pred_1.shape[1]
        invalid_uv = 0
        UV_iris_gt = []
        UV_pupil_gt = []

        if args['loss_w_rend_gt_2_pred']:
            # 提取GT掩码的UV位置
            UV_iris_gt_flat, B_idx_iris_gt, max_n_iris_gt = mask_img_2_UV_flat(gt_mask, iris_idx)
            UV_pupil_gt_flat, B_idx_pupil_gt, max_n_pupil_gt = mask_img_2_UV_flat(gt_mask, pupil_idx)
            mask_iris_gt = torch.ones(n_i, max_n_iris_gt, dtype=torch.bool).to(dev)
            mask_pupil_gt = torch.ones(n_i, max_n_pupil_gt, dtype=torch.bool).to(dev)

        # 提取GT的虹膜和瞳孔边缘掩码（使用Sobel滤波器），然后提取GT边缘的UV位置
        # 如果 loss_w_rend_pred_2_gt_edge 参数为真，则执行以下代码块
        if args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_pred_2_gt_edge_2D']:
            # 创建一个虹膜掩码，其中大于等于虹膜索引的部分设为1，否则设为0
            iris_mask_gt = torch.where(gt_mask >= iris_idx, 1., 0.)
            # 将虹膜掩码转换为UV边缘，获取边缘UV值，边缘索引和最大边缘数量
            UV_iris_edge_gt_flat, B_idx_iris_edge_gt, max_n_iris_edge_gt = mask_img_2_edge_UV_flat(iris_mask_gt,
                                                                                                   sobel_filter)
            # 创建一个瞳孔掩码，其中等于瞳孔索引的部分设为1，否则设为0
            pupil_mask_gt = torch.where(gt_mask == pupil_idx, 1., 0.)
            # 将瞳孔掩码转换为UV边缘，获取边缘UV值，边缘索引和最大边缘数量
            UV_pupil_edge_gt_flat, B_idx_pupil_edge_gt, max_n_pupil_edge_gt = mask_img_2_edge_UV_flat(pupil_mask_gt,
                                                                                                      sobel_filter)
            UV_iris_edge_gt = []
            UV_pupil_edge_gt = []

            UV_iris_edge_gt_norm = []
            UV_pupil_edge_gt_norm = []
            mask_iris_edge_gt = torch.ones(n_i, max_n_iris_edge_gt, dtype=torch.bool).to(dev)
            mask_pupil_edge_gt = torch.ones(n_i, max_n_pupil_edge_gt, dtype=torch.bool).to(dev)

        # 将展平的GT UV位置重新打包成 (B, N_max, 2) 张量，N_max是最大行数，其他行填充到N_max长度
        # 保存一个掩码 (B, N_Max)，用于区分原始和填充的UV值
        for i in range(n_i):
            if args['loss_w_rend_gt_2_pred']:
                UV_iris_gt_i = UV_iris_gt_flat[B_idx_iris_gt == i]
                UV_iris_gt_i, cnt_iris_gt_i = pad_UV_i(UV_iris_gt_i, max_n_iris_gt, invalid_uv)
                UV_iris_gt.append(UV_iris_gt_i)
                mask_iris_gt[i, cnt_iris_gt_i:] = 0

                UV_pupil_gt_i = UV_pupil_gt_flat[B_idx_pupil_gt == i]
                UV_pupil_gt_i, cnt_pupil_gt_i = pad_UV_i(UV_pupil_gt_i, max_n_pupil_gt, invalid_uv)
                UV_pupil_gt.append(UV_pupil_gt_i)
                mask_pupil_gt[i, cnt_pupil_gt_i:] = 0

            if args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_pred_2_gt_edge_2D']:
                UV_iris_edge_gt_i = UV_iris_edge_gt_flat[B_idx_iris_edge_gt == i]
                UV_iris_edge_gt_i, cnt_iris_edge_gt_i = pad_UV_i(UV_iris_edge_gt_i, max_n_iris_edge_gt, invalid_uv)
                UV_iris_edge_gt.append(UV_iris_edge_gt_i)
                mask_iris_edge_gt[i, cnt_iris_edge_gt_i:] = 0

                UV_pupil_edge_gt_i = UV_pupil_edge_gt_flat[B_idx_pupil_edge_gt == i]
                UV_pupil_edge_gt_i, cnt_pupil_edge_gt_i = pad_UV_i(UV_pupil_edge_gt_i, max_n_pupil_edge_gt, invalid_uv)
                UV_pupil_edge_gt.append(UV_pupil_edge_gt_i)
                mask_pupil_edge_gt[i, cnt_pupil_edge_gt_i:] = 0

        # 仅在此计算距离，无需梯度。这在所有对上运行计算成本更低。
        # 稍后，这将用于子采样仅最近的对，并在此计算适当的欧几里得距离（带梯度）。
        invalid_dist = 1e9
        if args['loss_w_rend_gt_2_pred']:
            # 最后将其打包成 (B, N_max, 2) 的张量
            UV_iris_gt = torch.stack(UV_iris_gt)
            UV_pupil_gt = torch.stack(UV_pupil_gt)
            dist_iris = (UV_iris_pred_1.detach().unsqueeze(1) - UV_iris_gt.unsqueeze(2)).square().sum(-1)
            dist_iris = dist_iris + invalid_dist * (mask_iris_gt == 0).unsqueeze(-1)
            dist_pupil = (UV_pupil_pred_1.detach().unsqueeze(1) - UV_pupil_gt.unsqueeze(2)).square().sum(-1)
            dist_pupil = dist_pupil + invalid_dist * (mask_pupil_gt == 0).unsqueeze(-1)

        if args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_pred_2_gt_edge_2D']:
            UV_iris_edge_gt = torch.stack(UV_iris_edge_gt)
            UV_pupil_edge_gt = torch.stack(UV_pupil_edge_gt)

            # UV_iris_edge_gt_flatten = UV_iris_edge_gt.reshape(UV_iris_edge_gt.shape[0], -1)

            # UV_pupil_edge_gt_flatten = UV_pupil_edge_gt.reshape(UV_pupil_edge_gt.shape[0], -1)

            # # 第1步：计算每一行的均值（dim=1 表示按第二个维度计算）
            # iris_mean = UV_iris_edge_gt_flatten.mean(dim=1, keepdim=True)
            # pupil_mean = UV_pupil_edge_gt_flatten.mean(dim=1, keepdim=True)

            # # 第2步：计算每一行的标准差（加上eps避免除以0）
            # iris_std = UV_iris_edge_gt_flatten.std(dim=1, keepdim=True, unbiased=False)  # 使用无偏估计
            # pupil_std = UV_pupil_edge_gt_flatten.std(dim=1, keepdim=True, unbiased=False)  # 使用无偏估计
            # eps = 1e-5  # 小常数，防止除以 0

            # # 第3步：根据 LayerNorm 的公式进行归一化
            # iris_normalized = (UV_iris_edge_gt_flatten - iris_mean) / (iris_std + eps)
            # pupil_normalized = (UV_pupil_edge_gt_flatten - pupil_mean) / (pupil_std + eps)

            # iris_x, iris_y = torch.chunk(iris_normalized, 2, dim=-1)
            # pupil_x, pupil_y = torch.chunk(pupil_normalized, 2, dim=-1)
            # UV_iris_edge_gt_norm = torch.stack((iris_x, iris_y), dim=-1)
            # UV_pupil_edge_gt_norm = torch.stack((pupil_x, pupil_y), dim=-1)

            if args['loss_w_rend_pred_2_gt_edge']:
                # 3D边缘损失
                # print("UV_iris_edge_pred:",UV_iris_edge_pred.shape)
                # print(UV_iris_edge_pred)
                # print("UV_iris_edge_gt:",UV_iris_edge_gt.shape)
                # print(UV_iris_edge_gt)
                dist_iris_edge = (
                            UV_iris_edge_pred_1.detach().unsqueeze(1) - UV_iris_edge_gt.unsqueeze(2)).square().sum(-1)
                dist_iris_edge = dist_iris_edge + invalid_dist * (mask_iris_edge_gt == 0).unsqueeze(-1)
                dist_pupil_edge = (
                            UV_pupil_edge_pred_1.detach().unsqueeze(1) - UV_pupil_edge_gt.unsqueeze(2)).square().sum(-1)
                dist_pupil_edge = dist_pupil_edge + invalid_dist * (mask_pupil_edge_gt == 0).unsqueeze(-1)

            if args['loss_w_rend_pred_2_gt_edge_2D']:
                # 2D边缘损失
                # iris_edge_pred_2D, pupil_edge_pred_2D = torch.split(edge_feature_concat, [26, 128], dim=1)
                # print("22222222:",iris_edge_pred_2D.shape)
                # print(iris_edge_pred_2D)

                dist_iris_edge2 = (
                            UV_iris_edge_pred_2.detach().unsqueeze(1) - UV_iris_edge_gt.unsqueeze(2)).square().sum(-1)
                dist_iris_edge2 = dist_iris_edge2 + invalid_dist * (mask_iris_edge_gt == 0).unsqueeze(-1)
                dist_pupil_edge2 = (
                            UV_pupil_edge_pred_2.detach().unsqueeze(1) - UV_pupil_edge_gt.unsqueeze(2)).square().sum(-1)
                dist_pupil_edge2 = dist_pupil_edge2 + invalid_dist * (mask_pupil_edge_gt == 0).unsqueeze(-1)

    # 如果指定了 args 中 'loss_w_rend_gt_2_pred' 的值
    if args['loss_w_rend_gt_2_pred']:
        # 计算出最接近的虹膜索引
        gt_2_pred_iris_idx = torch.argmin(dist_iris, dim=2).unsqueeze(-1).repeat(1, 1, 2)
        # 从 UV_iris_pred 中根据索引获取最接近的 UV 值
        UV_iris_pred_closest = torch.gather(UV_iris_pred_1, dim=1, index=gt_2_pred_iris_idx)
        # 计算虹膜预测值与真实值的距离损失
        loss_iris_gt_2_pred = dist_loss(UV_iris_pred_closest[mask_iris_gt], UV_iris_gt[mask_iris_gt])

        # 计算出最接近的瞳孔索引
        gt_2_pred_pupil_idx = torch.argmin(dist_pupil, dim=2).unsqueeze(-1).repeat(1, 1, 2)
        # 从 UV_pupil_pred 中根据索引获取最接近的 UV 值
        UV_pupil_pred_closest = torch.gather(UV_pupil_pred_1, dim=1, index=gt_2_pred_pupil_idx)
        # 计算瞳孔预测值与真实值的距离损失
        loss_pupil_gt_2_pred = dist_loss(UV_pupil_pred_closest[mask_pupil_gt], UV_pupil_gt[mask_pupil_gt])

        # 将损失乘以 args 中 'loss_w_rend_gt_2_pred' 的权重，并保存到损失字典中
        loss_dict['iris_gt_2_pred'] = loss_iris_gt_2_pred * args['loss_w_rend_gt_2_pred']
        loss_dict['pupil_gt_2_pred'] = loss_pupil_gt_2_pred * args['loss_w_rend_gt_2_pred']

    if args['loss_w_rend_pred_2_gt']:
        pred_2_gt_iris_idx = torch.argmin(dist_iris, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_gt_closest = torch.gather(UV_iris_gt, dim=1, index=pred_2_gt_iris_idx)
        loss_iris_pred_2_gt = dist_loss(UV_iris_gt_closest, UV_iris_pred_1)

        pred_2_gt_pupil_idx = torch.argmin(dist_pupil, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_pupil_gt_closest = torch.gather(UV_pupil_gt, dim=1, index=pred_2_gt_pupil_idx)
        loss_pupil_pred_2_gt = dist_loss(UV_pupil_gt_closest, UV_pupil_pred_1)

        loss_dict['iris_pred_2_gt'] = loss_iris_pred_2_gt * args['loss_w_rend_pred_2_gt']
        loss_dict['pupil_pred_2_gt'] = loss_pupil_pred_2_gt * args['loss_w_rend_pred_2_gt']

    if args['loss_w_rend_pred_2_gt_edge']:
        # 3D损失
        pred_2_gt_iris_edge_idx = torch.argmin(dist_iris_edge, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_edge_gt_closest = torch.gather(UV_iris_edge_gt, dim=1, index=pred_2_gt_iris_edge_idx)
        loss_pred_2_gt_iris_edge = dist_loss(UV_iris_edge_gt_closest, UV_iris_edge_pred_1)

        pred_2_gt_pupil_edge_idx = torch.argmin(dist_pupil_edge, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_pupil_edge_gt_closest = torch.gather(UV_pupil_edge_gt, dim=1, index=pred_2_gt_pupil_edge_idx)
        loss_pred_2_gt_pupil_edge = dist_loss(UV_pupil_edge_gt_closest, UV_pupil_edge_pred_1)

        loss_dict['iris_pred_2_gt_edge'] = loss_pred_2_gt_iris_edge * args['loss_w_rend_pred_2_gt_edge']
        loss_dict['pupil_pred_2_gt_edge'] = loss_pred_2_gt_pupil_edge * args['loss_w_rend_pred_2_gt_edge']

    if args['loss_w_rend_pred_2_gt_edge_2D']:
        # 2D损失
        pred_2_gt_iris_edge_idx2 = torch.argmin(dist_iris_edge2, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_edge_gt_closest2 = torch.gather(UV_iris_edge_gt, dim=1, index=pred_2_gt_iris_edge_idx2)
        loss_pred_2_gt_iris_edge2 = dist_loss(UV_iris_edge_gt_closest2, UV_iris_edge_pred_2)

        pred_2_gt_pupil_edge_idx2 = torch.argmin(dist_pupil_edge2, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_pupil_edge_gt_closest2 = torch.gather(UV_pupil_edge_gt, dim=1, index=pred_2_gt_pupil_edge_idx2)
        loss_pred_2_gt_pupil_edge2 = dist_loss(UV_pupil_edge_gt_closest2, UV_pupil_edge_pred_2)

        loss_dict['iris_pred_2_gt_edge_2D'] = loss_pred_2_gt_iris_edge2 * args['loss_w_rend_pred_2_gt_edge_2D']
        loss_dict['pupil_pred_2_gt_edge_2D'] = loss_pred_2_gt_pupil_edge2 * args['loss_w_rend_pred_2_gt_edge_2D']

    # 计算总损失
    total_loss = 0.0
    for k in loss_dict:
        total_loss += loss_dict[k]

    return total_loss, loss_dict  # 返回总损失和损失字典

    # #--------------------------------------------------------------------------
    # #                              DEBUGGING (should go above loss components)
    # #--------------------------------------------------------------------------
    # # 确保 UV_iris_pred 和 UV_pupil_pred 是连续的张量
    # UV_iris_pred = UV_iris_pred.contiguous()
    # UV_pupil_pred = UV_pupil_pred.contiguous()
    # i = 2
    # # Extract GT iris & pupil locations # 提取 GT 虹膜和瞳孔的位置
    # y, x = torch.where(gt_mask[i]==iris_idx)
    # UV_iris_gt_i = torch.stack((x,y), axis=1).float()
    # y, x = torch.where(gt_mask[i]==pupil_idx)
    # UV_pupil_gt_i = torch.stack((x,y), axis=1).float()

    # if args['loss_w_rend_gt_2_pred']:
    # #计算虹膜预测值与最接近的真实值的距离及最接近的真实值
    #     loss_iris_gt_2_pred, \
    #         UV_iris_pred_nearest_i = nearest_target_distance(P_source=UV_iris_gt_i,
    #                                                         P_target=UV_iris_pred[i],
    #                                                         gpu_resources=faiss_gpu_res)
    # # 计算虹膜预测值与最接近的真实值之间的平均距离
    #     loss_tmp_1 = (UV_iris_gt_i - UV_iris_pred_nearest_i).square().sum(-1).sqrt().mean()
    # # 计算瞳孔预测值与最接近的真实值的距离及最接近的真实值
    #     loss_pupil_gt_2_pred, \
    #         UV_pupil_pred_nearest_i = nearest_target_distance(P_source=UV_pupil_gt_i,
    #                                                         P_target=UV_pupil_pred[i],
    #                                                         gpu_resources=faiss_gpu_res)
    # # 计算瞳孔预测值与最接近的真实值之间的平均距离
    #     loss_tmp_2 = (UV_pupil_gt_i - UV_pupil_pred_nearest_i).square().sum(-1).sqrt().mean()

    # # 计算临时变量，用于调试
    #     tmp = (UV_iris_pred[i].detach().unsqueeze(0) - UV_iris_gt_i.unsqueeze(1)).square().sum(-1).sqrt()
    #     tmp = UV_iris_pred[i][torch.argmin(tmp, dim=1)]
    #     print(torch.abs(UV_iris_pred_nearest_i - tmp).max())
    #     print((torch.abs(UV_iris_pred_nearest_i - tmp) > 0).sum())
    # # 计算虹膜预测值与最接近的真实值的距离损失
    #     loss_tmp_1_1 = dist_loss(UV_iris_pred_closest[i][mask_iris_gt[i]], UV_iris_gt[i][mask_iris_gt[i]])
    # # 计算虹膜预测值最接近的真实值与最接近的真实值之间的距离损失
    #     loss_tmp_1_2 = dist_loss(UV_iris_pred_nearest_i, UV_iris_gt_i)
    # #--------------------------------------------------------------------------
    # #--------------------------------------------------------------------------

def rendered_semantics_loss_vectorized(gt_mask, rend_dict, sobel_filter, faiss_gpu_res, args):
    # 断言，确保 `loss_w_rend_diameter` 为 0
    assert args['loss_w_rend_diameter'] == 0

    # 获取设备
    dev = gt_mask.device

    # 类别索引
    iris_idx = 1;
    pupil_idx = 2

    # 提取预测模板
    UV_iris_pred = rend_dict['iris_UV']
    UV_pupil_pred = rend_dict['pupil_UV']

    # 如果启用了预测到GT边缘的损失
    if args['loss_w_rend_pred_2_gt_edge']:
        which_edge = 'sides'  # 指定边缘
        side_idx_iris = rend_dict['edge_idx_iris']
        side_idx_pupil = rend_dict['edge_idx_pupil']
        UV_iris_edge_pred = UV_iris_pred[:, side_idx_iris[which_edge], :]
        UV_pupil_edge_pred = UV_pupil_pred[:, side_idx_pupil[which_edge], :]

    # 基本断言，确保输入和预测的维度匹配
    assert gt_mask.shape[0] == UV_pupil_pred.shape[0] == UV_iris_pred.shape[0]
    assert len(gt_mask.shape) == 3
    assert len(UV_pupil_pred.shape) == len(UV_iris_pred.shape) == 3
    assert UV_iris_pred.shape[2] == 2, print('虹膜特征数量无效')
    assert UV_pupil_pred.shape[2] == 2, print('瞳孔特征数量无效')

    # 对ground truth的预处理，不需要梯度计算
    with torch.no_grad():

        # 提取GT掩码的UV位置
        UV_iris_gt_flat, B_idx_iris_gt, max_n_iris_gt = mask_img_2_UV_flat(gt_mask, iris_idx)
        UV_pupil_gt_flat, B_idx_pupil_gt, max_n_pupil_gt = mask_img_2_UV_flat(gt_mask, pupil_idx)

        # 提取GT的虹膜和瞳孔边缘掩码（使用Sobel滤波器），然后提取GT边缘的UV位置
        # 如果 loss_w_rend_pred_2_gt_edge 参数为真，则执行以下代码块
        if args['loss_w_rend_pred_2_gt_edge']:
            # 创建一个虹膜掩码，其中大于等于虹膜索引的部分设为1，否则设为0
            iris_mask_gt = torch.where(gt_mask >= iris_idx, 1., 0.)
            # 将虹膜掩码转换为UV边缘，获取边缘UV值，边缘索引和最大边缘数量
            UV_iris_edge_gt_flat, B_idx_iris_edge_gt, max_n_iris_edge_gt = mask_img_2_edge_UV_flat(iris_mask_gt,
                                                                                                  sobel_filter)
            # 创建一个瞳孔掩码，其中等于瞳孔索引的部分设为1，否则设为0
            pupil_mask_gt = torch.where(gt_mask == pupil_idx, 1., 0.)
            # 将瞳孔掩码转换为UV边缘，获取边缘UV值，边缘索引和最大边缘数量
            UV_pupil_edge_gt_flat, B_idx_pupil_edge_gt, max_n_pupil_edge_gt = mask_img_2_edge_UV_flat(pupil_mask_gt,
                                                                                                      sobel_filter)

        # 准备辅助变量
        loss_dict = {}
        n_i = args['batch_size'] * args['frames']
        n_t = UV_iris_pred.shape[1]
        invalid_uv = 0
        UV_iris_gt = []
        UV_pupil_gt = []
        mask_iris_gt = torch.ones(n_i, max_n_iris_gt, dtype=torch.bool).to(dev)
        mask_pupil_gt = torch.ones(n_i, max_n_pupil_gt, dtype=torch.bool).to(dev)

        # 如果启用了预测到GT边缘的损失
        if args['loss_w_rend_pred_2_gt_edge']:
            UV_iris_edge_gt = []
            UV_pupil_edge_gt = []
            mask_iris_edge_gt = torch.ones(n_i, max_n_iris_edge_gt, dtype=torch.bool).to(dev)
            mask_pupil_edge_gt = torch.ones(n_i, max_n_pupil_edge_gt, dtype=torch.bool).to(dev)

        # 将展平的GT UV位置重新打包成 (B, N_max, 2) 张量，N_max是最大行数，其他行填充到N_max长度
        # 保存一个掩码 (B, N_Max)，用于区分原始和填充的UV值
        for i in range(n_i):
            UV_iris_gt_i = UV_iris_gt_flat[B_idx_iris_gt == i]
            UV_iris_gt_i, cnt_iris_gt_i = pad_UV_i(UV_iris_gt_i, max_n_iris_gt, invalid_uv)
            UV_iris_gt.append(UV_iris_gt_i)
            mask_iris_gt[i, cnt_iris_gt_i:] = 0

            UV_pupil_gt_i = UV_pupil_gt_flat[B_idx_pupil_gt == i]
            UV_pupil_gt_i, cnt_pupil_gt_i = pad_UV_i(UV_pupil_gt_i, max_n_pupil_gt, invalid_uv)
            UV_pupil_gt.append(UV_pupil_gt_i)
            mask_pupil_gt[i, cnt_pupil_gt_i:] = 0

            if args['loss_w_rend_pred_2_gt_edge']:
                UV_iris_edge_gt_i = UV_iris_edge_gt_flat[B_idx_iris_edge_gt == i]
                UV_iris_edge_gt_i, cnt_iris_edge_gt_i = pad_UV_i(UV_iris_edge_gt_i, max_n_iris_edge_gt, invalid_uv)
                UV_iris_edge_gt.append(UV_iris_edge_gt_i)
                mask_iris_edge_gt[i, cnt_iris_edge_gt_i:] = 0

                UV_pupil_edge_gt_i = UV_pupil_edge_gt_flat[B_idx_pupil_edge_gt == i]
                UV_pupil_edge_gt_i, cnt_pupil_edge_gt_i = pad_UV_i(UV_pupil_edge_gt_i, max_n_pupil_edge_gt, invalid_uv)
                UV_pupil_edge_gt.append(UV_pupil_edge_gt_i)
                mask_pupil_edge_gt[i, cnt_pupil_edge_gt_i:] = 0

        # 最后将其打包成 (B, N_max, 2) 的张量
        UV_iris_gt = torch.stack(UV_iris_gt)
        UV_pupil_gt = torch.stack(UV_pupil_gt)
        if args['loss_w_rend_pred_2_gt_edge']:
            UV_iris_edge_gt = torch.stack(UV_iris_edge_gt)
            UV_pupil_edge_gt = torch.stack(UV_pupil_edge_gt)

        # 仅在此计算距离，无需梯度。这在所有对上运行计算成本更低。
        # 稍后，这将用于子采样仅最近的对，并在此计算适当的欧几里得距离（带梯度）。
        invalid_dist = 1e9
        dist_iris = (UV_iris_pred.detach().unsqueeze(1) - UV_iris_gt.unsqueeze(2)).square().sum(-1)
        dist_iris = dist_iris + invalid_dist * (mask_iris_gt == 0).unsqueeze(-1)
        dist_pupil = (UV_pupil_pred.detach().unsqueeze(1) - UV_pupil_gt.unsqueeze(2)).square().sum(-1)
        dist_pupil = dist_pupil + invalid_dist * (mask_pupil_gt == 0).unsqueeze(-1)

        if args['loss_w_rend_pred_2_gt_edge']:
            dist_iris_edge = (UV_iris_edge_pred.detach().unsqueeze(1) - UV_iris_edge_gt.unsqueeze(2)).square().sum(-1)
            dist_iris_edge = dist_iris_edge + invalid_dist * (mask_iris_edge_gt == 0).unsqueeze(-1)
            dist_pupil_edge = (UV_pupil_edge_pred.detach().unsqueeze(1) - UV_pupil_edge_gt.unsqueeze(2)).square().sum(-1)
            dist_pupil_edge = dist_pupil_edge + invalid_dist * (mask_pupil_edge_gt == 0).unsqueeze(-1)

    # 如果指定了 args 中 'loss_w_rend_gt_2_pred' 的值
    if args['loss_w_rend_gt_2_pred']:
        # 计算出最接近的虹膜索引
        gt_2_pred_iris_idx = torch.argmin(dist_iris, dim=2).unsqueeze(-1).repeat(1, 1, 2)
        # 从 UV_iris_pred 中根据索引获取最接近的 UV 值
        UV_iris_pred_closest = torch.gather(UV_iris_pred, dim=1, index=gt_2_pred_iris_idx)
        # 计算虹膜预测值与真实值的距离损失
        loss_iris_gt_2_pred = dist_loss(UV_iris_pred_closest[mask_iris_gt], UV_iris_gt[mask_iris_gt])

        # 计算出最接近的瞳孔索引
        gt_2_pred_pupil_idx = torch.argmin(dist_pupil, dim=2).unsqueeze(-1).repeat(1, 1, 2)
        # 从 UV_pupil_pred 中根据索引获取最接近的 UV 值
        UV_pupil_pred_closest = torch.gather(UV_pupil_pred, dim=1, index=gt_2_pred_pupil_idx)
        # 计算瞳孔预测值与真实值的距离损失
        loss_pupil_gt_2_pred = dist_loss(UV_pupil_pred_closest[mask_pupil_gt], UV_pupil_gt[mask_pupil_gt])

        # 将损失乘以 args 中 'loss_w_rend_gt_2_pred' 的权重，并保存到损失字典中
        loss_dict['iris_gt_2_pred'] = loss_iris_gt_2_pred * args['loss_w_rend_gt_2_pred']
        loss_dict['pupil_gt_2_pred'] = loss_pupil_gt_2_pred * args['loss_w_rend_gt_2_pred']

    if args['loss_w_rend_pred_2_gt']:
        pred_2_gt_iris_idx = torch.argmin(dist_iris, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_gt_closest = torch.gather(UV_iris_gt, dim=1, index=pred_2_gt_iris_idx)
        loss_iris_pred_2_gt = dist_loss(UV_iris_gt_closest, UV_iris_pred)

        pred_2_gt_pupil_idx = torch.argmin(dist_pupil, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_pupil_gt_closest = torch.gather(UV_pupil_gt, dim=1, index=pred_2_gt_pupil_idx)
        loss_pupil_pred_2_gt = dist_loss(UV_pupil_gt_closest, UV_pupil_pred)

        loss_dict['iris_pred_2_gt'] = loss_iris_pred_2_gt * args['loss_w_rend_pred_2_gt']
        loss_dict['pupil_pred_2_gt'] = loss_pupil_pred_2_gt * args['loss_w_rend_pred_2_gt']

    if args['loss_w_rend_pred_2_gt_edge']:
        pred_2_gt_iris_edge_idx = torch.argmin(dist_iris_edge, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_iris_edge_gt_closest = torch.gather(UV_iris_edge_gt, dim=1, index=pred_2_gt_iris_edge_idx)
        loss_pred_2_gt_iris_edge = dist_loss(UV_iris_edge_gt_closest, UV_iris_edge_pred)

        pred_2_gt_pupil_edge_idx = torch.argmin(dist_pupil_edge, dim=1).unsqueeze(-1).repeat(1, 1, 2)
        UV_pupil_edge_gt_closest = torch.gather(UV_pupil_edge_gt, dim=1, index=pred_2_gt_pupil_edge_idx)
        loss_pred_2_gt_pupil_edge = dist_loss(UV_pupil_edge_gt_closest, UV_pupil_edge_pred)

        loss_dict['iris_pred_2_gt_edge'] = loss_pred_2_gt_iris_edge * args['loss_w_rend_pred_2_gt_edge']
        loss_dict['pupil_pred_2_gt_edge'] = loss_pred_2_gt_pupil_edge * args['loss_w_rend_pred_2_gt_edge']

    # 计算总损失
    total_loss = 0.0
    for k in loss_dict:
        total_loss += loss_dict[k]

    return total_loss, loss_dict  # 返回总损失和损失字典

    # #--------------------------------------------------------------------------
    # #                              DEBUGGING (should go above loss components)
    # #--------------------------------------------------------------------------
    # # 确保 UV_iris_pred 和 UV_pupil_pred 是连续的张量
    # UV_iris_pred = UV_iris_pred.contiguous()
    # UV_pupil_pred = UV_pupil_pred.contiguous()  
    # i = 2
    # # Extract GT iris & pupil locations # 提取 GT 虹膜和瞳孔的位置
    # y, x = torch.where(gt_mask[i]==iris_idx)
    # UV_iris_gt_i = torch.stack((x,y), axis=1).float()
    # y, x = torch.where(gt_mask[i]==pupil_idx)
    # UV_pupil_gt_i = torch.stack((x,y), axis=1).float()

    # if args['loss_w_rend_gt_2_pred']:
    # #计算虹膜预测值与最接近的真实值的距离及最接近的真实值
    #     loss_iris_gt_2_pred, \
    #         UV_iris_pred_nearest_i = nearest_target_distance(P_source=UV_iris_gt_i, 
    #                                                         P_target=UV_iris_pred[i], 
    #                                                         gpu_resources=faiss_gpu_res)
    # # 计算虹膜预测值与最接近的真实值之间的平均距离
    #     loss_tmp_1 = (UV_iris_gt_i - UV_iris_pred_nearest_i).square().sum(-1).sqrt().mean()
    # # 计算瞳孔预测值与最接近的真实值的距离及最接近的真实值
    #     loss_pupil_gt_2_pred, \
    #         UV_pupil_pred_nearest_i = nearest_target_distance(P_source=UV_pupil_gt_i, 
    #                                                         P_target=UV_pupil_pred[i], 
    #                                                         gpu_resources=faiss_gpu_res)
    # # 计算瞳孔预测值与最接近的真实值之间的平均距离
    #     loss_tmp_2 = (UV_pupil_gt_i - UV_pupil_pred_nearest_i).square().sum(-1).sqrt().mean()

    # # 计算临时变量，用于调试
    #     tmp = (UV_iris_pred[i].detach().unsqueeze(0) - UV_iris_gt_i.unsqueeze(1)).square().sum(-1).sqrt()
    #     tmp = UV_iris_pred[i][torch.argmin(tmp, dim=1)]
    #     print(torch.abs(UV_iris_pred_nearest_i - tmp).max())
    #     print((torch.abs(UV_iris_pred_nearest_i - tmp) > 0).sum())
    # # 计算虹膜预测值与最接近的真实值的距离损失
    #     loss_tmp_1_1 = dist_loss(UV_iris_pred_closest[i][mask_iris_gt[i]], UV_iris_gt[i][mask_iris_gt[i]])
    # # 计算虹膜预测值最接近的真实值与最接近的真实值之间的距离损失
    #     loss_tmp_1_2 = dist_loss(UV_iris_pred_nearest_i, UV_iris_gt_i)
    # #--------------------------------------------------------------------------
    # #--------------------------------------------------------------------------


def rendered_semantics_loss(gt_mask, rend_dict, sobel_filter, faiss_gpu_res, args):
    # 定义虹膜和瞳孔的类别索引
    iris_idx = 1
    pupil_idx = 2

    # 提取预测的UV坐标和边缘索引
    UV_pupil_pred = rend_dict['pupil_UV']
    UV_iris_pred = rend_dict['iris_UV']
    side_idx_iris = rend_dict['edge_idx_iris']
    side_idx_pupil = rend_dict['edge_idx_pupil']

    # 基本断言，确保输入和预测的维度匹配
    assert gt_mask.shape[0] == UV_pupil_pred.shape[0] == UV_iris_pred.shape[0]
    assert UV_iris_pred.shape[2] == 2, print('虹膜特征数量无效')
    assert UV_pupil_pred.shape[2] == 2, print('瞳孔特征数量无效')

    # 如果启用了预测到GT边缘的损失
    if args['loss_w_rend_pred_2_gt_edge']:
        # 提取GT虹膜和瞳孔边缘掩码（使用Sobel滤波器）
        iris_mask_gt = torch.where(gt_mask >= iris_idx, 1., 0.)  # TODO == iris_idx
        iris_edge_gt = sobel_filter(iris_mask_gt.unsqueeze(1)).squeeze(1)
        iris_edge_gt = torch.where(iris_edge_gt > 0, 1, 0)
        pupil_mask_gt = torch.where(gt_mask == pupil_idx, 1., 0.)
        pupil_edge_gt = sobel_filter(pupil_mask_gt.unsqueeze(1)).squeeze(1)
        pupil_edge_gt = torch.where(pupil_edge_gt > 0, 1, 0)

    # 确保预测的UV坐标是连续的
    UV_iris_pred = UV_iris_pred.contiguous()
    UV_pupil_pred = UV_pupil_pred.contiguous()

    # TODO 在测试后移除此循环
    loss_dict = {}
    n_i = args['batch_size'] * args['frames']
    n_t = UV_iris_pred.shape[1]
    for i in range(n_i):

        # 提取GT虹膜和瞳孔的位置
        y, x = torch.where(gt_mask[i] == iris_idx)
        UV_iris_gt = torch.stack((x, y), axis=1).float()
        y, x = torch.where(gt_mask[i] == pupil_idx)
        UV_pupil_gt = torch.stack((x, y), axis=1).float()

        # GT到最近预测的距离
        if args['loss_w_rend_gt_2_pred']:
            loss_iris_gt_2_pred, UV_iris_pred_nearest_i = nearest_target_distance(
                P_source=UV_iris_gt,
                P_target=UV_iris_pred[i],
                gpu_resources=faiss_gpu_res
            )

            loss_pupil_gt_2_pred, UV_pupil_pred_nearest_i = nearest_target_distance(
                P_source=UV_pupil_gt,
                P_target=UV_pupil_pred[i],
                gpu_resources=faiss_gpu_res
            )
            if i == 0:
                loss_dict['iris_gt_2_pred'] = 0.0
                loss_dict['pupil_gt_2_pred'] = 0.0
            loss_dict['iris_gt_2_pred'] += loss_iris_gt_2_pred * args['loss_w_rend_gt_2_pred']
            loss_dict['pupil_gt_2_pred'] += loss_pupil_gt_2_pred * args['loss_w_rend_gt_2_pred']

        # 预测到最近GT的距离
        if args['loss_w_rend_pred_2_gt']:
            loss_iris_pred_2_gt, UV_iris_gt_nearest_i = nearest_target_distance(
                P_source=UV_iris_pred[i],
                P_target=UV_iris_gt,
                gpu_resources=faiss_gpu_res
            )

            loss_pupil_pred_2_gt, UV_pupil_gt_nearest_i = nearest_target_distance(
                P_source=UV_pupil_pred[i],
                P_target=UV_pupil_gt,
                gpu_resources=faiss_gpu_res
            )
            if i == 0:
                loss_dict['iris_pred_2_gt'] = 0.0
                loss_dict['pupil_pred_2_gt'] = 0.0
            loss_dict['iris_pred_2_gt'] += loss_iris_pred_2_gt * args['loss_w_rend_pred_2_gt']
            loss_dict['pupil_pred_2_gt'] += loss_pupil_pred_2_gt * args['loss_w_rend_pred_2_gt']

        # 预测边缘点到GT边缘点的距离
        if args['loss_w_rend_pred_2_gt_edge']:
            y, x = torch.where(iris_edge_gt[i] > 0)
            UV_iris_gt_edge = torch.stack((x, y), axis=1).float()
            loss_iris_pred_2_gt_edge, UV_iris_gt_nearest_edge_i = nearest_target_distance(
                P_source=UV_iris_pred[i, side_idx_iris['sides']],
                P_target=UV_iris_gt_edge,
                gpu_resources=faiss_gpu_res
            )

            y, x = torch.where(pupil_edge_gt[i] > 0)
            UV_pupil_gt_edge = torch.stack((x, y), axis=1).float()
            loss_pupil_pred_2_gt_edge, UV_pupil_gt_nearest_edge_i = nearest_target_distance(
                P_source=UV_pupil_pred[i, side_idx_pupil['sides']],
                P_target=UV_pupil_gt_edge,
                gpu_resources=faiss_gpu_res
            )
            if i == 0:
                loss_dict['iris_pred_2_gt_edge'] = 0.0
                loss_dict['pupil_pred_2_gt_edge'] = 0.0
            loss_dict['iris_pred_2_gt_edge'] += loss_iris_pred_2_gt_edge * args['loss_w_rend_pred_2_gt_edge']
            loss_dict['pupil_pred_2_gt_edge'] += loss_pupil_pred_2_gt_edge * args['loss_w_rend_pred_2_gt_edge']

        # 预测和GT的直径差异
        if args['loss_w_rend_diameter']:
            MSE = torch.nn.MSELoss(reduction='mean')
            RMSE = lambda x, y: torch.sqrt(MSE(x, y))

            diameter_iris_gt = torch.max(UV_iris_gt[..., 0]) - torch.min(UV_iris_gt[..., 0])
            diameter_iris_pred = UV_iris_pred[i, side_idx_iris['right'][0], 0] - \
                                 UV_iris_pred[i, side_idx_iris['left'][4], 0]
            loss_iris_diameter = RMSE(diameter_iris_pred, diameter_iris_gt)

            diameter_pupil_gt = torch.max(UV_pupil_gt[..., 0]) - torch.min(UV_pupil_gt[..., 0])
            diameter_pupil_pred = UV_pupil_pred[i, side_idx_pupil['right'][0], 0] - \
                                  UV_pupil_pred[i, side_idx_pupil['left'][23], 0]
            loss_pupil_diameter = RMSE(diameter_pupil_pred, diameter_pupil_gt)
            if i == 0:
                loss_dict['iris_diameter'] = 0.0
                loss_dict['pupil_diameter'] = 0.0
            loss_dict['iris_diameter'] += loss_iris_diameter * args['loss_w_rend_diameter']
            loss_dict['pupil_diameter'] += loss_pupil_diameter * args['loss_w_rend_diameter']

    # 计算总损失
    total_loss = 0.0
    for k in loss_dict:
        loss_dict[k] /= (i + 1)
        total_loss += loss_dict[k]
    return total_loss, loss_dict  # 返回总损失和损失字典

# 用于渲染监督的损失函数
def loss_fn_rend_sprvs(gt_dict, pred_dict, args):
    # 定义RMSE和余弦相似度的损失函数
    RMSE = lambda x, y: (x - y).square().sum(-1).sqrt().mean()
    cos_sim = lambda x, y: 1 - F.cosine_similarity(x, y).mean()

    # 初始化损失字典
    loss_dict = {}

    # 获取预测字典中设备的信息
    device = pred_dict['gaze_vector_3D'].device

    # 如果启用了监督眼球中心的损失
    if args['loss_w_supervise_eyeball_center']:
        # 获取GT和预测的眼球中心UV坐标
        gt_eyeball_c_UV = gt_dict['eyeball'][..., 1:3].to(device)
        pred_eyeball_c_UV = pred_dict['eyeball_c_UV']

        # 计算眼球中心的RMSE损失
        loss_eyeball_c_UV = RMSE(gt_eyeball_c_UV, pred_eyeball_c_UV) * args['loss_w_supervise_eyeball_center']
        loss_dict['eyeball_c_UV'] = loss_eyeball_c_UV

    # 如果启用了监督瞳孔中心的损失
    if args['loss_w_supervise_pupil_center']:
        # 计算GT瞳孔中心UV坐标
        gt_pupil_c_UV = (
                    gt_dict['eyeball'][..., 1:3] + (gt_dict['eyeball'][..., 0:1] * gt_dict['gaze_vector'][..., :2])).to(
            device)
        pred_pupil_c_UV = pred_dict['pupil_c_UV']

        # 计算瞳孔中心的RMSE损失
        loss_pupil_c_UV = RMSE(gt_pupil_c_UV, pred_pupil_c_UV) * args['loss_w_supervise_pupil_center']
        loss_dict['pupil_c_UV'] = loss_pupil_c_UV

    # 如果启用了监督3D凝视向量的L2损失
    if args['loss_w_supervise_gaze_vector_3D_L2']:
        gt_gaze_vector_3D = gt_dict['gaze_vector'].to(device)
        pred_gaze_vector_3D = pred_dict['gaze_vector_3D']

        # print(gt_gaze_vector_3D.shape)
        # print(pred_gaze_vector_3D.shape)
        # 计算3D凝视向量的RMSE损失
        loss_gaze_vector_3D = RMSE(gt_gaze_vector_3D, pred_gaze_vector_3D) * args['loss_w_supervise_gaze_vector_3D_L2']

        loss_dict['gaze_vector_3D_L2'] = loss_gaze_vector_3D

    # 如果启用了监督3D凝视向量的余弦相似度损失
    if args['loss_w_supervise_gaze_vector_3D_cos_sim']:
        gt_gaze_vector_3D = gt_dict['gaze_vector'].to(device)
        pred_gaze_vector_3D = pred_dict['gaze_vector_3D']

        # 计算3D凝视向量的余弦相似度损失
        loss_gaze_vector_3D = cos_sim(gt_gaze_vector_3D, pred_gaze_vector_3D) * args[
            'loss_w_supervise_gaze_vector_3D_cos_sim']
        loss_dict['gaze_vector_3D_cos_sim'] = loss_gaze_vector_3D

    # 如果启用了监督UV凝视向量的损失
    if args['loss_w_supervise_gaze_vector_UV']:
        # 计算GT瞳孔中心和眼球中心的UV坐标
        gt_eyeball_c_UV = gt_dict['eyeball'][..., 1:3]
        gt_pupil_c_UV = gt_dict['eyeball'][..., 1:3] + (gt_dict['eyeball'][..., 0:1] * gt_dict['gaze_vector'][..., :2])

        # 计算GT的UV凝视向量并进行归一化
        gt_gaze_vector_UV = (gt_pupil_c_UV - gt_eyeball_c_UV).to(device)
        gt_gaze_vector_UV /= torch.norm(gt_gaze_vector_UV, dim=-1, keepdim=True) + 1e-9

        # 计算预测的UV凝视向量并进行归一化
        temp_gaze_vector_UV = (pred_dict['pupil_c_UV'] - pred_dict['eyeball_c_UV'])
        gaze_vector_UV = temp_gaze_vector_UV / (torch.norm(temp_gaze_vector_UV, dim=-1, keepdim=True) + 1e-5)
        pred_gaze_vector_UV = gaze_vector_UV

        # 计算UV凝视向量的RMSE损失
        loss_gaze_vector_UV = RMSE(gt_gaze_vector_UV, pred_gaze_vector_UV) * args['loss_w_supervise_gaze_vector_UV']
        loss_dict['gaze_vector_UV'] = loss_gaze_vector_UV

    # 计算总损失
    total_loss = 0.0
    for key in loss_dict:
        total_loss += loss_dict[key]

    return total_loss, loss_dict  # 返回总损失和损失字典


def rendered_semantics_loss_OLD_01_05(gt_mask, rend_dict, sobel_filter, faiss_gpu_res, args):
    # 从渲染字典中获取预测的瞳孔UV坐标和虹膜UV坐标
    UV_pupil_pred = rend_dict['pupil_UV']
    UV_iris_pred = rend_dict['iris_UV']
    side_idx_iris = rend_dict['edge_idx_iris']  # 虹膜边缘索引
    side_idx_pupil = rend_dict['edge_idx_pupil'] # 瞳孔边缘索引
    # 断言保证输入的形状正确
    assert gt_mask.shape[0] == UV_pupil_pred.shape[0] == UV_iris_pred.shape[0]
    assert UV_iris_pred.shape[1] == 5000, print('iris not valid number of points')
    assert UV_pupil_pred.shape[1] == 5000, print('pupil not valid number of points')
    assert UV_iris_pred.shape[2] == 2, print('iris not valid number of features')
    assert UV_pupil_pred.shape[2] == 2, print('pupil not valid number of features')

    iris_idx = 1
    pupil_idx = 2
    
    # Extract GT iris & pupil edge mask (using sobel filtering)
    # 提取GT虹膜和瞳孔边缘掩码（使用Sobel滤波）
    iris_mask_gt = torch.where(gt_mask>=iris_idx, 1., 0.) # TODO == iris_idx
    iris_edge_gt = sobel_filter(iris_mask_gt.unsqueeze(1)).squeeze(1)
    iris_edge_gt = torch.where(iris_edge_gt > 0, 1, 0)
    pupil_mask_gt = torch.where(gt_mask==pupil_idx, 1., 0.)
    pupil_edge_gt = sobel_filter(pupil_mask_gt.unsqueeze(1)).squeeze(1)
    pupil_edge_gt = torch.where(pupil_edge_gt > 0, 1, 0)

    UV_iris_pred = UV_iris_pred.contiguous()
    UV_pupil_pred = UV_pupil_pred.contiguous()

    #TODO REMOVE THIS FOR LOOP AFTER TESTING
    loss_dict = {}# 用于存储损失的字典
    n_i = args['batch_size']*args['frames']  # 批次大小和帧数的乘积
    n_t = UV_iris_pred.shape[1]  # UV_iris_pred的第二个维度大小
    for i in range(args['batch_size']*args['frames']):
        # 提取GT虹膜和瞳孔边缘位置
        # Extract GT iris & pupil edge locations
        y, x = torch.where(iris_edge_gt[i]>0) 
        UV_iris_gt_edge = torch.stack((x,y), axis=1).float()

        # %%Split the iris edge to left and right for better loss computation
        #Define left point of edge
        # 提取GT瞳孔边缘位置
        left_location = UV_iris_gt_edge[...,0].min()
        right_location = UV_iris_gt_edge[...,0].max()
        middle_location = torch.round((left_location + right_location) / 2)

        #sort the tensor based on x axis
        _, idx = torch.sort(UV_iris_gt_edge, axis=0)
        UV_iris_gt_edge_sorted_x = UV_iris_gt_edge[idx[...,0]]

        #find the index of the middle location
        middle_points = torch.where(UV_iris_gt_edge_sorted_x[...,0] == middle_location)

        #split the iris edge to left and right
        UV_left_iris_gt_edge = UV_iris_gt_edge_sorted_x[0:middle_points[0][0]]
        UV_right_iris_gt_edge = UV_iris_gt_edge_sorted_x[middle_points[0][0]:]

        y, x = torch.where(pupil_edge_gt[i]>0) 
        UV_pupil_gt_edge = torch.stack((x,y), axis=1).float()

        # %%Split the pupil edge to left and right for better loss computation
        #Define left point of edge
        left_location = UV_pupil_gt_edge[...,0].min()
        right_location = UV_pupil_gt_edge[...,0].max()
        middle_location = torch.round((left_location + right_location) / 2)

        #sort the tensor based on x axis
        _, idx = torch.sort(UV_pupil_gt_edge, axis=0)
        UV_pupil_gt_edge_sorted_x = UV_pupil_gt_edge[idx[...,0]]

        #find the index of the middle location
        middle_points = torch.where(UV_pupil_gt_edge_sorted_x[...,0] == middle_location)

        #split the iris edge to left and right
        UV_left_pupil_gt_edge = UV_pupil_gt_edge_sorted_x[0:middle_points[0][0]]
        UV_right_pupil_gt_edge = UV_pupil_gt_edge_sorted_x[middle_points[0][0]:]

        # Extract GT iris & pupil locations
        y, x = torch.where(gt_mask[i]==iris_idx)
        UV_iris_gt = torch.stack((x,y), axis=1).float()
        y, x = torch.where(gt_mask[i]==pupil_idx)
        UV_pupil_gt = torch.stack((x,y), axis=1).float()

        # Distance of GT to closest predictions
        if args['loss_w_rend_gt_2_pred']:
            loss_iris_gt_2_pred, \
                UV_iris_pred_nearest_i = nearest_target_distance(P_source=UV_iris_gt, 
                                                                P_target=UV_iris_pred[i], 
                                                                gpu_resources=faiss_gpu_res)

            loss_pupil_gt_2_pred, \
                UV_pupil_pred_nearest_i = nearest_target_distance(P_source=UV_pupil_gt, 
                                                                P_target=UV_pupil_pred[i], 
                                                                gpu_resources=faiss_gpu_res)
            if i == 0: loss_dict['iris_gt_2_pred'] = 0.0; loss_dict['pupil_gt_2_pred'] = 0.0
            loss_dict['iris_gt_2_pred'] += loss_iris_gt_2_pred * args['loss_w_rend_gt_2_pred']
            loss_dict['pupil_gt_2_pred'] += loss_pupil_gt_2_pred * args['loss_w_rend_gt_2_pred']

        # Distance of predictions to closest GT
        if args['loss_w_rend_pred_2_gt']:
            # New way: compare the whole predicted template 
            loss_iris_pred_2_gt,\
                UV_iris_gt_nearest_i = nearest_target_distance(P_source=UV_iris_pred[i], 
                                                                P_target=UV_iris_gt, 
                                                                gpu_resources=faiss_gpu_res)

            loss_pupil_pred_2_gt,\
                UV_pupil_gt_nearest_i = nearest_target_distance(P_source=UV_pupil_pred[i], 
                                                                P_target=UV_pupil_gt, 
                                                                gpu_resources=faiss_gpu_res)

            if i == 0: loss_dict['iris_pred_2_gt'] = 0.0; loss_dict['pupil_pred_2_gt'] = 0.0
            loss_dict['iris_pred_2_gt'] += loss_iris_pred_2_gt * args['loss_w_rend_pred_2_gt']
            loss_dict['pupil_pred_2_gt'] += loss_pupil_pred_2_gt * args['loss_w_rend_pred_2_gt']
        
        # Distance of prediction edge points to GT edge points
        if args['loss_w_rend_pred_2_gt_edge']:
            loss_iris_pred_2_gt_edge,\
                UV_iris_gt_nearest_edge_i = nearest_target_distance(P_source=UV_iris_pred[i, side_idx_iris['whole']], 
                                                                P_target=UV_left_iris_gt_edge, 
                                                                gpu_resources=faiss_gpu_res)
            
            loss_pupil_pred_2_gt_edge,\
                UV_pupil_gt_nearest_edge_i = nearest_target_distance(P_source=UV_pupil_pred[i, side_idx_pupil['whole']], 
                                                                P_target=UV_left_pupil_gt_edge, 
                                                                gpu_resources=faiss_gpu_res)
            
            if i == 0: loss_dict['iris_pred_2_gt_edge'] = 0.0; loss_dict['pupil_pred_2_gt_edge'] = 0.0
            loss_dict['iris_pred_2_gt_edge'] += loss_iris_pred_2_gt_edge * args['loss_w_rend_pred_2_gt_edge']
            loss_dict['pupil_pred_2_gt_edge'] += loss_pupil_pred_2_gt_edge * args['loss_w_rend_pred_2_gt_edge']

        # Diameter difference of predicted and gt
        if args['loss_w_rend_diameter']:
            MSE = torch.nn.MSELoss(reduction='mean')
            RMSE = lambda x, y: torch.sqrt(MSE(x, y))
            diameter_iris_gt = torch.max(UV_iris_gt[...,0]) - torch.min(UV_iris_gt[...,0])
            diameter_iris_pred = UV_iris_pred[i, side_idx_iris['right'][0], 0] - \
                                    UV_iris_pred[i, side_idx_iris['left'][4], 0]
            loss_iris_diameter = RMSE(diameter_iris_pred, diameter_iris_gt)
            diameter_pupil_gt = torch.max(UV_pupil_gt[...,0]) - torch.min(UV_pupil_gt[...,0])
            diameter_pupil_pred = UV_pupil_pred[i, side_idx_pupil['right'][0], 0] - \
                                    UV_pupil_pred[i, side_idx_pupil['left'][23], 0]
            loss_pupil_diameter = RMSE(diameter_pupil_pred, diameter_pupil_gt)   
            if i == 0: loss_dict['iris_diameter'] = 0.0; loss_dict['pupil_diameter'] = 0.0
            loss_dict['iris_diameter'] += loss_iris_diameter * args['loss_w_rend_diameter']
            loss_dict['pupil_diameter'] += loss_pupil_diameter * args['loss_w_rend_diameter']

        # import matplotlib.pyplot as plt
        # fig, axs = plt.subplots(3, 2)
        # axs[0,0].scatter(UV_iris_gt[:, 0].detach(), UV_iris_gt[:, 1].detach(), color='b', marker='*', label='source')
        # axs[0,0].scatter(UV_iris_pred[i][:, 0].detach(), UV_iris_pred[i][:, 1].detach(), color='g', marker='^', label='target')
        # axs[0,0].scatter(UV_iris_gt[0, 0].detach(), UV_iris_gt[0, 1].detach(), color='r', marker='*', label='source')
        # axs[0,0].scatter(UV_iris_pred_nearest_i[0, 0].detach(), UV_iris_pred_nearest_i[0, 1].detach(), color='r', marker='^', label='target')     
        # axs[0,1].scatter(UV_pupil_gt[:, 0].detach(), UV_pupil_gt[:, 1].detach(), color='b', marker='*', label='source')
        # axs[0,1].scatter(UV_pupil_pred[i][:, 0].detach(), UV_pupil_pred[i][:, 1].detach(), color='g', marker='^', label='target')
        # axs[0,1].scatter(UV_pupil_gt[0, 0].detach(), UV_pupil_gt[0, 1].detach(), color='r', marker='*', label='source')
        # axs[0,1].scatter(UV_pupil_pred_nearest_i[0, 0].detach(), UV_pupil_pred_nearest_i[0, 1].detach(), color='r', marker='^', label='target')
        # axs[0,0].title.set_text('Distance of GT to closest predictions')
        # axs[0,1].title.set_text('Distance of GT to closest predictions')
        # axs[1,0].scatter(UV_iris_gt[:, 0].detach(), UV_iris_gt[:, 1].detach(), color='g', marker='^', label='target')
        # axs[1,0].scatter(UV_iris_pred[i][:, 0].detach(), UV_iris_pred[i][:, 1].detach(), color='b', marker='*', label='source')
        # axs[1,0].scatter(UV_iris_pred[i][0, 0].detach(), UV_iris_pred[i][0, 1].detach(), color='r', marker='*', label='source')
        # axs[1,0].scatter(UV_iris_gt_nearest_i[0, 0].detach(), UV_iris_gt_nearest_i[0, 1].detach(), color='r', marker='^', label='target')
        # axs[1,1].scatter(UV_pupil_gt[:, 0].detach(), UV_pupil_gt[:, 1].detach(), color='g', marker='^', label='target')
        # axs[1,1].scatter(UV_pupil_pred[i][:, 0].detach(), UV_pupil_pred[i][:, 1].detach(), color='b', marker='*', label='source')
        # axs[1,1].scatter(UV_pupil_pred[i][0, 0].detach(), UV_pupil_pred[i][0, 1].detach(), color='r', marker='*', label='source')
        # axs[1,1].scatter(UV_pupil_gt_nearest_i[0, 0].detach(), UV_pupil_gt_nearest_i[0, 1].detach(), color='r', marker='^', label='target')
        # axs[1,0].title.set_text('Distance of predictions to closest GT')
        # axs[1,1].title.set_text('Distance of predictions to closest GT')
        # axs[2,0].scatter(UV_iris_pred[i, side_idx_iris][:, 0].detach(), UV_iris_pred[i, side_idx_iris][:, 1].detach(), color='b', marker='*', label='source')
        # axs[2,0].scatter(UV_iris_gt_edge[:, 0].detach(), UV_iris_gt_edge[:, 1].detach(), color='g', marker='^', label='target')
        # axs[2,0].scatter(UV_iris_pred[i, side_idx_iris][0, 0].detach(), UV_iris_pred[i, side_idx_iris][0, 1].detach(), color='r', marker='*', label='source')
        # axs[2,0].scatter(UV_iris_gt_nearest_edge_i[0, 0].detach(), UV_iris_gt_nearest_edge_i[0, 1].detach(), color='r', marker='^', label='target')
        # axs[2,1].scatter(UV_pupil_pred[i, side_idx_pupil][:, 0].detach(), UV_pupil_pred[i, side_idx_pupil][:, 1].detach(), color='b', marker='*', label='source')
        # axs[2,1].scatter(UV_pupil_gt_edge[:, 0].detach(), UV_pupil_gt_edge[:, 1].detach(), color='g', marker='^', label='target')
        # axs[2,1].scatter(UV_pupil_pred[i, side_idx_pupil][0, 0].detach(), UV_pupil_pred[i, side_idx_pupil][0, 1].detach(), color='r', marker='*', label='source')
        # axs[2,1].scatter(UV_pupil_gt_nearest_edge_i[0, 0].detach(), UV_pupil_gt_nearest_edge_i[0, 1].detach(), color='r', marker='^', label='target')
        # axs[2,0].title.set_text('Distance of prediction edge points to GT edge points')
        # axs[2,1].title.set_text('Distance of prediction edge points to GT edge points')
        # plt.legend()
        # plt.show() 


    total_loss = 0.0
    for k in loss_dict:
        loss_dict[k] /= (i+1)
        total_loss += loss_dict[k]

    return total_loss, loss_dict