#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""

#try git

import gc
import sys
import time
import traceback

import h5py
import math
from collections import OrderedDict

import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import numpy as np
import faiss
import wandb

import cv2

from einops import rearrange
from torch import nn

from helperfunctions.loss import get_seg_loss, get_uncertain_l1_loss
from helperfunctions.hfunctions import assert_torch_invalid
# from helperfunctions.loss import get_l2c_loss

from helperfunctions.hfunctions import plot_images_with_annotations
from helperfunctions.hfunctions import convert_to_list_entries
from helperfunctions.hfunctions import merge_two_dicts, fix_batch
from helperfunctions.hfunctions import generate_rend_masks

from helperfunctions.utils import get_seg_metrics, get_distance, compute_norm
from helperfunctions.utils import getAng_metric, generate_pseudo_labels
from helperfunctions.utils import remove_underconfident_psuedo_labels

from Visualitation_TEyeD.gaze_estimation import generate_gaze_gt

from rendering.rendering import render_semantics, euler_to_rotation, eyeball_center
from rendering.rendered_semantics_loss import rendered_semantics_loss, rendered_semantics_loss_vectorized, SobelFilter, \
    rendered_semantics_loss_vectorized2

from rendering.rendered_semantics_loss import loss_fn_rend_sprvs
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# 将字典中的torch张量从计算图中分离，移动到CPU，并转换为numpy数组
def detach_cpu_numpy(data_dict):
    out_dict = {}  # 初始化输出字典
    for key, value in data_dict.items():  # 遍历输入字典中的每个键值对
        if 'torch' in str(type(value)):  # 如果值的类型是torch张量
            out_dict[key] = value.detach().cpu().numpy()  # 将张量从计算图中分离，移动到CPU，并转换为numpy数组
        else:
            out_dict[key] = value  # 否则保持原样
    return out_dict  # 返回处理后的输出字典

# 将字典中的torch张量移动到指定的GPU设备
def move_gpu(data_dict, device):
    out_dict = {}  # 初始化输出字典
    for key, value in data_dict.items():  # 遍历输入字典中的每个键值对
        if 'torch' in str(type(value)):  # 如果值的类型是torch张量
            out_dict[key] = data_dict[key].to(device)  # 将张量移动到指定的GPU设备
        else:
            out_dict[key] = data_dict[key]  # 否则保持原样
    return out_dict  # 返回处理后的输出字典

# 将字典中的torch张量移动到指定的设备，可以是CPU或GPU
def send_to_device(data_dict, device):
    out_dict = {}  # 初始化输出字典
    for key, value in data_dict.items():  # 遍历输入字典中的每个键值对
        if 'torch' in str(type(value)):  # 如果值的类型是torch张量
            out_dict[key] = value.to(device)  # 将张量移动到指定的设备
        else:
            out_dict[key] = value  # 否则保持原样
    return out_dict  # 返回处理后的输出字典


def gaze_loss(pred, target):
    pred = torch.tensor(pred, dtype=torch.float32,device='cuda',requires_grad=True)
    target = torch.tensor(target, dtype=torch.float32, device='cuda',requires_grad=True)
    # 计算 MSE Loss
    mse_loss = nn.MSELoss()(pred, target)
    cos_loss = torch.acos(torch.clamp(torch.sum(pred * target, dim=-1), -1.0, 1.0)).mean()  # 角度误差
    return 2.5 * mse_loss + 2.5 * cos_loss  # 组合损失


def evaluate_gaze_metrics(pred_gaze, true_gaze, metic_dict):
    """
    计算 gaze 估计的多个评价指标。

    :param pred_gaze: (N, 3) 预测 gaze 方向
    :param true_gaze: (N, 3) 真实 gaze 方向
    :return: 包含多个评价指标的字典
    """

    # 归一化 gaze 向量
    pred_gaze = pred_gaze / (torch.norm(pred_gaze, dim=-1, keepdim=True) + 1e-9)
    true_gaze = true_gaze / (torch.norm(true_gaze, dim=-1, keepdim=True) + 1e-9)

    # 计算余弦相似度
    cosine_sim = torch.sum(pred_gaze * true_gaze, dim=-1)  # [-1, 1]
    cosine_sim = torch.clamp(cosine_sim, -1.0, 1.0)  # 避免数值错误

    # 计算角度误差（单位：°）
    angular_error = torch.acos(cosine_sim) * (180 / torch.pi)  # 转换为角度

    # 计算各个指标
    mae = angular_error.mean().item()  # 平均角度误差
    std = angular_error.std().item()  # 角度误差标准差
    cosine_sim_mean = cosine_sim.mean().item()  # 余弦相似度均值
    da_5 = (angular_error < 5).float().mean().item()  # DA@5°
    da_10 = (angular_error < 10).float().mean().item()  # DA@10°
    da_15 = (angular_error < 15).float().mean().item()  # DA@15°

    metic_dict['Mean_Angular_Error'] = mae
    metic_dict['Angular_Error_Std'] = std
    metic_dict['Cosine_Similarity'] = cosine_sim_mean
    metic_dict['DA@5'] = da_5
    metic_dict['DA@10'] = da_10
    metic_dict['DA@15'] = da_15

    return metic_dict
    
def load_isomap_gt(file_path, device="cuda"):
    """加载保存的Isomap真值"""
    data = np.load(file_path)
    gt = torch.from_numpy(data["gt"]).float().to(device)
    print(f"Loaded Isomap GT with shape: {gt.shape}")
    return gt

def forward(net,
            spiker,
            logger,
            loader,
            optimizer,
            args,
            path_dict,
            epoch=0,
            mode='test',
            writer=[],
            rank_cond=False,
            optimizer_disc=False,
            batches_per_ep=2000,
            last_epoch_valid = False,
            csv_save_dir=None):
    # 定义前向传播函数，接收网络模型、尖峰检测器、日志记录器、数据加载器、优化器、参数字典、路径字典等作为参数
    net_param_tmp = next(net.parameters())
    device = net_param_tmp.device
    print(device)

    # if net_param_tmp.is_cuda:
    #     print('Using faiss GPU resource.')
    #     faiss_gpu_res = faiss.StandardGpuResources()  # use a single GPU
    # else:
    print('Using faiss CPU resource.')
    faiss_gpu_res = None
    # 初始化 Sobel 滤波器
    sobel_filter = SobelFilter(device) #.get_device()

    logger.write('----{}. Epoch: {}----'.format(mode, epoch))

    #deactivate tensorboard
    #rank_cond = ((args['local_rank'] == 0) or not args['do_distributed'])
    # 设置 rank_cond 标志，原代码用于判断是否使用分布式
    rank_cond = rank_cond
    # 根据模式设置网络模型为训练模式或评估模式
    if mode == 'train':
        net.train()
    else:
        net.eval()
    # 初始化 IO 时间列表和数据加载器迭代器
    io_time = []
    loader_iter = iter(loader)
    # 初始化度量、数据集 ID、嵌入和预测掩码标志
    metrics = []
    dataset_id = []
    embeddings = []
    available_predicted_mask = False

    train_with_mask = args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_gt_2_pred'] \
                        or args['loss_w_rend_pred_2_gt']

    if (mode == 'test') and args['save_test_maps']:
        logger.write('Generating test object')
        test_results_obj = h5py.File(path_dict['results']+'/test_results.h5',
                                     'w', swmr=True)
    isomap_gt_all = load_isomap_gt("./isomap_gt/isomap_gt_epoch0.npz")
    isomap_gt_all_valid = load_isomap_gt("./isomap_gt/isomap_gt_valid_epoch0.npz")

    for bt in range(batches_per_ep):

        # 开始每个批次的循环
        start_time = time.time()
        # 记录批次开始时间

        try:
            data_dict = next(loader_iter)
            print('----epoch:{}----batch:{}/{}'.format(epoch,bt,batches_per_ep))
            # 尝试从迭代器中获取下一个数据批次
        except:
            print('Loader reset')
            # 如果发生异常（迭代器耗尽），重置迭代器
            loader_iter = iter(loader)
            data_dict = next(loader_iter)
            # 再次获取数据批次
            args['time_to_update'] = True
            # 标记需要更新

        if torch.any(data_dict['is_bad']):
            logger.write('Bad batch found!', do_warn=True)
            # 如果批次数据中有问题样本，记录警告

            # DDP（分布式数据并行）在跳过批次时会崩溃，因为没有匹配的梯度
            # 为了避免这种情况，通过从剩余的批次中随机抽取一个好的样本来替换有问题的样本
            data_dict = fix_batch(data_dict)

        end_time = time.time()
        # 记录批次结束时间
        io_time.append(end_time - start_time)
        # 计算并记录IO时间

        if args['do_distributed']:
            # 如果是分布式环境，设置同步屏障
            torch.distributed.barrier()


        if args['cur_obj'] != 'Ours':
            # 验证瞳孔和虹膜椭圆的方向角在合理范围内
            assert torch.all(data_dict['pupil_ellipse'][:, :, -1] >= 0), 'pupil ellipse orientation >= 0'
            assert torch.all(data_dict['pupil_ellipse'][:, :, -1] <= 2 * (np.pi)), 'pupil ellipse orientation <= 2*pi'
            assert torch.all(data_dict['iris_ellipse'][:, :, -1] >= 0), 'iris ellipse orientation >= 0'
            assert torch.all(data_dict['iris_ellipse'][:, :, -1] <= 2 * (np.pi)), 'iris ellipse orientation <= 2*pi'


        with torch.autograd.set_detect_anomaly(bool(args['detect_anomaly'])):
            with torch.cuda.amp.autocast(enabled=bool(args['mixed_precision'])):
                # 使用自动混合精度和异常检测
                # 初始化批次结果字典
                batch_results_rend = {}
                batch_results_gaze = {}
                batch_size = args['batch_size']
                frames = args['frames'] 
                # 计算当前batch在全局中的索引
                start_idx = bt * batch_size * frames
                end_idx = start_idx + (batch_size * frames)
                

                if mode == 'train':
                    print("train net begin")
                    projected_features_GT = isomap_gt_all[start_idx:end_idx]  # [B, d]
                    out_dict_gaze, out_dict_eye = net(data_dict, args, projected_features_GT)
                    print("train net finish",net.device)
                    # 如果是训练模式，执行前向传播
                else:
                    with torch.no_grad():
                        projected_features_GT = isomap_gt_all_valid[start_idx:end_idx]  # [B, d]
                        out_dict_gaze, out_dict_eye = net(data_dict, args, projected_features_GT)
                        print("valid net finish")
                        # 如果是测试模式，执行无梯度的前向传播

                if torch.all(data_dict['image'] == 0):
                    optimizer.zero_grad()
                    net.zero_grad()
                    print('invalid input image')
                    continue
                    # 如果输入图像全为零，跳过该批次

                data_dict = reshape_gt(data_dict, args)
                # 重塑张量，将批次和帧合并为一个维度

                
                # 获取批次大小和帧数

                H = data_dict['image'].shape[1]
                W = data_dict['image'].shape[2]
                # 获取图像的高度和宽度

                image_resolution_diagonal = math.sqrt(H ** 2 + W ** 2)
                # 计算图像分辨率对角线

                if args['net_rend_head']:
                # 如果使用渲染头，执行以下操作
                    # 检查预测值是否在合理范围内，若不在，则跳过该批次
                    if torch.any(torch.any(out_dict_eye['T'] < -1) or torch.any(out_dict_eye['T'] > 1)) or \
                            torch.any(torch.any(out_dict_eye['R'] < -1) or torch.any(out_dict_eye['R'] > 1)) or \
                            torch.any(torch.any(out_dict_eye['L'] < -1) or torch.any(out_dict_eye['L'] > 1)) or \
                            torch.any(torch.any(out_dict_eye['focal'] < -1) or torch.any(out_dict_eye['focal'] > 1)) or \
                            torch.any(torch.any(out_dict_eye['r_pupil'] < -1) or torch.any(out_dict_eye['r_pupil'] > 1)) or \
                            torch.any(torch.any(out_dict_eye['r_iris'] < -1) or torch.any(out_dict_eye['r_iris'] > 1)):
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('invalid predicted values from rend head')
                        continue

                    # 检查预测值是否为 NaN 或 Inf，若是，则跳过该批次
                    if torch.isnan(out_dict_eye['gaze_vector_3D']).any():
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('NaN gaze_vector_3D BEFORE FUNCTION')
                        continue
                    if torch.isinf(out_dict_eye['T']).any():
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('inf problem T inf before function')
                        continue

                    if torch.isnan(out_dict_eye['R']).any():
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('NaN problem R BEFORE FUNCTION')
                        continue
                    if torch.isinf(out_dict_eye['R']).any():
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('inf problem R inf before function')
                        continue

                    # 渲染语义
                    out_dict_eye, rend_dict = render_semantics(out_dict_eye, H=H, W=W, args=args, data_dict=data_dict)


                    # 检查渲染后的值是否为 NaN 或 Inf，若是，则跳过该批次
                    if (torch.isnan(rend_dict['pupil_UV']).any() or torch.isinf(rend_dict['pupil_UV']).any()):
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('invalid pupil from rendering points')
                        continue

                    if (torch.isnan(rend_dict['iris_UV']).any() or torch.isinf(rend_dict['iris_UV']).any()):
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('invalid iris from rendering points')
                        continue

                    if (torch.isnan(rend_dict['pupil_c_UV']).any() or torch.isinf(rend_dict['pupil_c_UV']).any()):
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('invalid pupil center')
                        continue

                    if (torch.isnan(rend_dict['eyeball_c_UV']).any() or torch.isinf(rend_dict['eyeball_c_UV']).any()):
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('invalid eyeball center')
                        continue

                    # 将数据字典发送到设备（如 GPU）
                    data_dict['mask'][data_dict['mask'] == 3] = 2  # 将虹膜移到2
                    # # 定义类别颜色映射（示例：3个类别）
                    # colormap = np.array([
                    #     [0, 0, 0],  # 类别0：黑色（背景）
                    #     [255, 0, 0],  # 类别1：红色
                    #     [0, 255, 0]  # 类别2：绿色
                    # ])
                    #
                    # # 将类别索引转换为 RGB 图像
                    # mask_rgb = colormap[data_dict['mask'][0]]  # 形状变为 [H, W, 3]
                    #
                    # plt.imshow(mask_rgb)
                    # plt.axis('off')
                    # plt.show()

                    data_dict = send_to_device(data_dict, device)

                    if train_with_mask:
                        # 根据参数选择损失函数
                        if args['loss_rend_vectorized']:
                            loss_fn_rend = rendered_semantics_loss_vectorized2
                        else:
                            loss_fn_rend = rendered_semantics_loss

                        # 计算渲染的瞳孔和虹膜损失
                        total_loss_rend, loss_dict_rend = loss_fn_rend(data_dict['mask'],
                                                                       rend_dict,
                                                                       sobel_filter,
                                                                       faiss_gpu_res,
                                                                       None,
                                                                       args)


                        iterations = args['batch_size'] * args['frames']
                        # 如果满足条件，生成渲染掩码
                        if (bt % args['produce_rend_mask_per_iter'] == 0 or last_epoch_valid \
                                or (mode == 'test')):
                            available_predicted_mask = True

                            # 生成渲染掩码
                            rend_dict['eyeball_circle'] = eyeball_center(out_dict_eye,
                                                                         H=H,
                                                                         W=W,
                                                                         args=args)

                            rend_dict = generate_rend_masks(rend_dict, H, W, iterations)


                            rend_dict['mask'] = torch.argmax(rend_dict['predict'], dim=1)
                            rend_dict['gaze_img'] = rend_dict['mask_gaze']

                            rend_dict['mask'] = torch.clamp(rend_dict['mask'], min=0, max=255)
                            rend_dict['gaze_img'] = torch.clamp(rend_dict['gaze_img'], min=0, max=255)

                            rend_dict['mask'] = rend_dict['mask'].detach().cpu().numpy()
                            rend_dict['gaze_img'] = rend_dict['gaze_img'].detach().cpu().numpy()
                            # 将渲染掩码转换为 NumPy 数组并限制其范围

                            # mask = rend_dict['gaze_img']
                            # # 如果你选择的是第一个方法，显示单个通道
                            # mask_channel = mask[0, :, :]  # 选择第一个通道

                            # # 绘制掩码图像
                            # plt.figure(figsize=(6, 6))
                            # plt.imshow(mask_channel, cmap='gray')  # 使用灰度色彩映射显示掩码
                            # plt.colorbar()  # 添加颜色条，用于显示掩码值的范围
                            # plt.title('Mask Visualization')  # 设置标题
                            # plt.axis('off')  # 关闭坐标轴显示
                            # plt.show()  # 显示掩码图像

                        # 如果不满足条件，渲染掩码不可用
                        else:
                            available_predicted_mask = False


                        if torch.is_tensor(total_loss_rend):
                            total_loss_rend_value = total_loss_rend.item()
                            is_spike = spiker.update(total_loss_rend_value) if spiker else False
                        else:
                            # 记录渲染损失值
                            total_loss_rend_value = total_loss_rend

                        # 获取渲染结果的指标
                        batch_results_rend = get_metrics_rend(detach_cpu_numpy(rend_dict),
                                                            detach_cpu_numpy(data_dict),
                                                            batch_results_rend,
                                                            image_resolution_diagonal,
                                                            args,
                                                            available_predicted_mask)
                        # 记录渲染总损失
                        batch_results_rend['loss/rend_total'] = total_loss_rend_value
                        for k in loss_dict_rend:
                            # 记录渲染损失字典中的各项损失
                            batch_results_rend[f'loss/rend_{k}'] = loss_dict_rend[k].item()
                    else:
                        total_loss_rend = 0.0
                else:
                    # 初始化渲染损失、损失值、损失字典和渲染字典
                    total_loss_rend = 0.0
                    total_loss_rend_value = 0.0
                    loss_dict_rend = {}
                    rend_dict = {}

            #     #add loss in case we want to supervise the 3D Eye model or directly the UV point
                if args['loss_w_supervise']:
                    if args['net_rend_head']:
                        # 计算注视向量的损失
                        total_supervised_loss_eye, loss_dict_supervised_eye = loss_fn_rend_sprvs(data_dict,
                                                                                        rend_dict,
                                                                                        args)
                        # 计算监督损失
                        batch_results_rend[f'loss/eyeBranch_total'] = total_supervised_loss_eye.item()
                        # print("loss/eye_total", total_supervised_loss.item())
                        for k in loss_dict_supervised_eye:
                            # 记录监督损失
                            batch_results_rend[f'loss/eyeBranch_{k}'] = loss_dict_supervised_eye[k].item()

                    if args['net_simply_head']:
                        total_supervised_loss_gaze, loss_dict_supervised_gaze = loss_fn_rend_sprvs(data_dict,
                                                                                        out_dict_gaze,
                                                                                        args)
                        batch_results_gaze[f'loss/gazeBranch_total'] = total_supervised_loss_gaze.item()
                        for k in loss_dict_supervised_gaze:
                            # 记录监督损失
                            batch_results_gaze[f'loss/gazeBranch_{k}'] = loss_dict_supervised_gaze[k].item()

                # 初始化椭圆分割损失、损失值、对抗损失和损失字典
                total_loss_ellseg = 0.0
                total_loss_ellseg_value = 0.0
                disc_loss = 0.0
                loss_dict_ellseg = {}
                is_spike = False

            #take metrics of the simply head
            if args['loss_w_supervise']:
                if args['net_simply_head'] and (args['net_rend_head'] == False):
                    # 获取简单头的指标
                    batch_results_gaze = get_metrics_simple(out_dict_gaze,
                                                              move_gpu(data_dict, out_dict_gaze['gaze_vector_3D'].device),
                                                              batch_results_gaze,
                                                              image_resolution_diagonal,
                                                              args)

            # define losses
            if args['net_rend_head'] and (args['net_simply_head'] == False):
                loss_eye = total_loss_rend
                if args['loss_w_supervise']:
                    # 如果只有渲染头分支，计算总损失
                    loss_eye += args['loss_w_supervise_eye'] * total_supervised_loss_eye
                    # 更新损失尖峰检测器
                    is_spike = spiker.update(loss_eye.item()) if spiker else False
            elif args['net_rend_head'] == False and args['net_simply_head']:
                if args['loss_w_supervise']:
                    # 如果只有gaze分支，计算总损失
                    loss_gaze = args['loss_w_supervise_gaze'] * total_supervised_loss_gaze
                    is_spike = spiker.update(loss_gaze.item()) if spiker else False
            elif args['net_rend_head'] and args['net_simply_head']:
                loss_eye = total_loss_rend
                if args['loss_w_supervise']:
                    # 如果两者都有，计算总损失
                    loss_eye += args['loss_w_supervise_eye'] * total_supervised_loss_eye
                    loss_gaze = args['loss_w_supervise_gaze'] * total_supervised_loss_gaze
                    loss_total = loss_gaze + loss_eye
                    is_spike = spiker.update(loss_total.item()) if spiker else False



            if mode == 'train':
                # if args['adv_DG']:
                #     # Teach the disc to classify the domains based on predicted
                #     # segmentation mask
                #     # 如果使用对抗训练，反向传播对抗损失
                #     disc_loss.backward(retain_graph=True)
                #
                #     # Remove gradients accumulated in the encoder and decoder
                #     # 清零编码器、解码器和渲染回归器的梯度
                #     net.enc.zero_grad()
                #     net.dec.zero_grad()
                #     net.elReg.zero_grad()
                #     net.renderingReg.zero_grad()
                if args['net_simply_head']:                    
                    # --------------------- 训练3D眼球重建部分 ---------------------
                    # 暂时冻结特征提取模块参数，不参与梯度更新
                    # for param in net.featureExtractor.parameters():
                    #     param.requires_grad = False
                    # 反向传播总损失
                    loss_gaze.backward()
                    # 解除冻结，使特征提取模块参数可以更新
                    # for param in net.featureExtractor.parameters():
                    #     param.requires_grad = True

                    #print('gradient {}'.format(compute_norm(net)))

                    if not is_spike:

                        # Gradient clipping, if needed, goes here
                        if args['grad_clip_norm'] > 0:
                            # 如果设置了梯度裁剪，执行梯度裁剪
                            grad_norm = torch.nn.utils.clip_grad_norm_(net.parameters(),
                                                                       max_norm=args['grad_clip_norm'],
                                                                       norm_type=2)

                        # Step the optimizer and update weights
                        # if args['adv_DG']:
                        #     optimizer_disc.step()
                        print("Gaze loss:", loss_gaze.item())
                        optimizer.step()

                    else:
                        # 如果检测到损失尖峰，打印警告信息
                        total_norm = np.inf
                        print('-------------')
                        print('Spike detected! Loss: {}'.format(loss_gaze.item()))
                        print('-------------')

            # Zero out gradients no matter what
            # 清零优化器和网络的梯度
            optimizer.zero_grad()
            net.zero_grad()

            if optimizer_disc:
                # 清零对抗优化器的梯度
                optimizer_disc.zero_grad()

        if args['do_distributed']:
            # 如果是分布式环境，同步CUDA操作
            torch.cuda.synchronize()

        # if bt < args['num_samples_for_embedding']:
        #     # 收集嵌入和数据集ID
        #     embeddings.append(out_dict['latent'])
        #     dataset_id.append(data_dict['ds_num'])
        
        # Merge metrics
        # 合并批次结果
        batch_results = merge_two_dicts(detach_cpu_numpy(batch_results_rend),
                                        detach_cpu_numpy(batch_results_gaze))

        # Record network outputs
        # 如果是测试模式并且保存测试映射，将结果保存到HDF5文件
        # if (mode == 'test') and args['save_test_maps']:
        #
        #     test_dict = merge_two_dicts(detach_cpu_numpy(out_dict),
        #                                 detach_cpu_numpy(batch_results))
        #     test_dict_list = convert_to_list_entries(test_dict)
        #
        #     for idx, entry in enumerate(test_dict_list):
        #
        #         sample_id = data_dict['archName'][idx] + '/' + \
        #                             str(data_dict['im_num'][idx].item())
        #
        #         for key, value in entry.items():
        #
        #             sample_id_key = sample_id + '/' + key
        #
        #             try:
        #                 if 'predict' not in key:
        #                     if 'mask' in key:
        #                         # Save mask as integer objects to avoid
        #                         # overloading harddrive
        #                         test_results_obj.create_dataset(sample_id_key,
        #                                                         data=value,
        #                                                         dtype='uint8',
        #                                                         compression='lzf')
        #                     else:
        #                         # Save out remaining data points with float16 to
        #                         # avoid overloading harddrive
        #                         test_results_obj.create_dataset(sample_id_key,
        #                                                         data=np.array(value))
        #             except Exception:
        #                 print('Repeated sample because of corrupt entry in H5')
        #                 print('Skipping sample {} ... '.format(sample_id))
        # 记录损失和批次指标
        batch_results['loss'] = loss_gaze.item()
        batch_metrics = aggregate_metrics([batch_results])
        metrics.append(batch_results)

        if args['exp_name'] != 'DEBUG':
            # 记录到wandb
            log_wandb(batch_metrics, rend_dict, data_dict, out_dict_eye, loss_gaze,
                        available_predicted_mask, mode, epoch, bt, H, W, args)

        # Use this if you want to save out spiky conditions
        # save all images for train validation and testing
        # if ((bt % 100 == 0) or is_spike) and rank_cond and (mode == 'train'):
        # 如果满足条件，保存输出
        if (available_predicted_mask and args['net_rend_head'] and \
                     bt % args['produce_rend_mask_per_iter'] == 0):
            gt_dict = {}
            gt_dict['image'] = data_dict['image']
            gt_dict['mask'] = data_dict['mask']
            gt_dict['iris_ellipse'] = data_dict['iris_ellipse']
            # Saving spiky conditions unnecessarily bloats the drive
            save_out(gt_dict, out_dict_eye, rend_dict, data_dict['image'], path_dict, mode,
                                        is_spike, args, epoch, bt)
            # 如果满足条件且训练时使用掩码，保存输出
        elif (bt % args['produce_rend_mask_per_iter'] ==0 and train_with_mask):
            gt_dict = {}
            gt_dict['image'] = data_dict['image']
            gt_dict['mask'] = data_dict['mask']
            gt_dict['iris_ellipse'] = data_dict['iris_ellipse']
            save_out(gt_dict, out_dict_eye, None, data_dict['image'], path_dict, mode,
                                        is_spike, args, epoch, bt)

        del out_dict_eye  # Explicitly free up memory 释放内存
        del out_dict_gaze  # Explicitly free up memory 释放内存
        del rend_dict # Explicitly free up memory 释放内存
        del batch_results_gaze
        del batch_results_rend
        torch.cuda.empty_cache()
    # 如果提供了csv保存目录且目录存在，则生成保存路径，否则将保存路径设置为None
    if csv_save_dir is not None:
        if os.path.isdir(csv_save_dir):
            csv_save_path = os.path.join(csv_save_dir, f'{mode}_raw_results.csv')
    else:
        csv_save_path = None
    # 调用aggregate_metrics函数聚合所有批次的指标并保存到csv文件中（如果csv_save_path不为None）
    results_dict = aggregate_metrics(metrics, csv_save_path)

    # 如果是在测试模式并且需要保存测试映射，则关闭HDF5文件对象
    if (mode == 'test') and args['save_test_maps']:
        test_results_obj.close()

    # 清除RAM中积累的数据
    del loader_iter

    # 清除CUDA缓存和RAM缓存
    torch.cuda.empty_cache()
    gc.collect()
    # 返回聚合后的结果字典
    return results_dict


# 记录训练过程中每个批次的各种指标和图像数据
def log_wandb(batch_metrics, rend_dict, data_dict, out_dict, loss,
              available_predicted_mask, mode, epoch, bt, H, W, args):
    # 在wandb中记录训练结果
    class_labels = {0: 'background', 1: 'iris', 2: 'pupil'}
    class_labels_gaze = {0: 'background', 1: 'gaze'}
    # 定义类别标签和凝视类别标签

    # 每隔100个批次将所有指标记录到wandb中
    if (bt % 100 == 0):
        for key, item in batch_metrics.items():
            if ('mean' in key) and (not np.isnan(item).any()):
                # 记录包含'mean'的键且值不是NaN的指标
                if ('pupil_c_px_dst' in key) or ('eyeball_c_px_dist' in key) or \
                    ('gaze_ang_deg' in key) or ('loss' in key) or ('score' in key) or \
                    ('rendering_iou' in key) or ('masked_rendering_iou' in key) or \
                    ('norm' in key) or ('gaze' in key):
                    # 记录特定关键字的指标
                    wandb.log({'{}/{}'.format(mode, key): item, 'epoch': epoch, 'batch': bt}, commit=False)

        # 如果启用了ellseg_head和loss_w_ellseg
        if args['net_ellseg_head'] and args['loss_w_ellseg']:
            if bt % args['produce_rend_mask_per_iter'] == 0:
                # 每隔指定的批次数记录ellseg_mask
                mask_img = wandb.Image(data_dict['image'][0].detach().cpu().numpy(),
                                        masks={
                                                "predictions": {
                                                    "mask_data": out_dict['mask'][0],
                                                    "class_labels": class_labels
                                                },
                                                "ground_truth": {
                                                    "mask_data": data_dict['mask'][0].detach().cpu().numpy(),
                                                    "class_labels": class_labels
                                                }
                                        })
                wandb.log({'{}/ellseg_mask'.format(mode): mask_img, 'epoch': epoch, 'batch': bt}, commit=False)

        # 如果启用了rend_head
        if args['net_rend_head']:
            if (available_predicted_mask):
                # 如果预测掩码可用，记录rend_mask
                mask_img = wandb.Image(data_dict['image'][0].detach().cpu().numpy(),
                                        masks={
                                            "predictions": {
                                                "mask_data": np.clip(rend_dict['mask'][0], 0, 255),
                                                "class_labels": class_labels
                                            },
                                            "ground_truth": {
                                                "mask_data": np.clip(data_dict['mask'][0].detach().cpu().numpy(), 0, 255),
                                                "class_labels": class_labels
                                            }
                                        })
                wandb.log({'{}/rend_mask'.format(mode): mask_img, 'epoch': epoch, 'batch': bt}, commit=False)
                if 'TEyeD' in args['cur_obj']:
                    # 如果当前对象是TEyeD，生成并记录gaze_mask的ground_truth
                    gaze_mask_gt = generate_gaze_gt(data_dict['eyeball'][0].detach().cpu().numpy(),
                                            data_dict['gaze_vector'][0].detach().cpu().numpy(),
                                            H, W)
                    gaze_img = wandb.Image(data_dict['image'][0].detach().cpu().numpy(),
                                            masks={
                                                "predictions": {
                                                    "mask_data": rend_dict['gaze_img'][0],
                                                    "class_labels": class_labels_gaze
                                                },
                                                "ground_truth": {
                                                    "mask_data": gaze_mask_gt,
                                                    "class_labels": class_labels_gaze
                                                }
                                            })
                else:
                    # 否则，仅记录gaze_mask的预测
                    gaze_img = wandb.Image(data_dict['image'][0].detach().cpu().numpy(),
                                            masks={
                                                "predictions": {
                                                    "mask_data": rend_dict['gaze_img'][0],
                                                    "class_labels": class_labels_gaze
                                                }
                                            })
                wandb.log({'{}/gaze_mask'.format(mode): gaze_img, 'epoch': epoch, 'batch': bt}, commit=False)

        # 最后记录损失值
        wandb.log({'{}/loss'.format(mode): loss.item(), 'epoch': epoch, 'batch': bt})

# 保存模型输出的结果。保存 ellseg 结果和 rend 结果（带掩码和不带掩码）
def save_out(gt_dict, ellseg_dict, rend_dict, image, path_dict, mode,
             is_spike, args, epoch, bt):
    # is_spike表示是否发生了异常（如梯度爆炸）
    # 检查是否启用了ellseg_head和loss_w_ellseg
    if args['net_ellseg_head'] and args['loss_w_ellseg']:
        ellseg_dict['image'] = image  # 将图像添加到ellseg_dict
        save_plot(ellseg_dict, path_dict, mode, 'ellseg', True,
                  is_spike, args, epoch, bt)  # 调用save_plot函数保存ellseg结果

    # 检查是否启用了rend_head
    if args['net_rend_head']:
        rend_dict['image'] = image  # 将图像添加到rend_dict
        save_plot(rend_dict, path_dict, mode, 'rend', True,
                  is_spike, args, epoch, bt, True, False)  # 调用save_plot函数保存rend结果（带掩码）
        save_plot(gt_dict, path_dict, mode, 'gt', True,
                  is_spike, args, epoch, bt, False, False)  # 调用save_plot函数保存gt结果（带掩码）
        if 'TEyeD' or 'UEGaze' in args['cur_obj']:
            save_plot(rend_dict, path_dict, mode, 'rend', False,
                      is_spike, args, epoch, bt, True, False)  # 调用save_plot函数保存rend结果（不带掩码）

# 将模型输出的结果绘制成图像，并保存到文件中
def save_plot(data_dict, path_dict, mode, head, mask, is_spike, args, epoch, bt, rendering, is_list_of_entries):
    # 根据是否是spike设置图片名称
    if is_spike:
        im_name = '{}_spike/{}_ep_{}_bt_{}.jpg'.format(mode, head, epoch, bt)
    else:
        if mask:
            im_name = '{}_{}/{}_mask_bt_{}.jpg'.format(mode, epoch, head, bt)
        else:
            im_name = '{}_{}/{}_gaze_bt_{}.jpg'.format(mode, epoch, head, bt)

    # 生成图片保存路径
    path_im_image = os.path.join(path_dict['figures'], im_name)

    # 调用plot_images_with_annotations函数绘制并保存图像
    plot_images_with_annotations(detach_cpu_numpy(data_dict),  # 将数据从CPU上分离下来
                                 args,
                                 write=path_im_image,
                                 rendering=rendering,
                                 mask=mask,
                                 remove_saturated=False,
                                 is_list_of_entries=is_list_of_entries,
                                 is_predict=True,
                                 show=False,
                                 mode=mode,
                                 epoch=epoch,
                                 batch=bt)


def reshape_gt(gt_dict, args):
    # 判断是否需要使用掩码进行训练
    train_with_mask = args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_gt_2_pred'] \
                      or args['loss_w_rend_pred_2_gt']

    ##################### 准备张量 ############################
    # 重新排列图像张量的形状，从 (b f h w) 变为 (b*f h w)
    gt_dict['image'] = rearrange(gt_dict['image'], 'b f h w-> (b f) h w')

    if train_with_mask:
        # 如果使用掩码进行训练，重新排列掩码和掩码可用性张量的形状
        gt_dict['mask'] = rearrange(gt_dict['mask'], 'b f h w-> (b f) h w')
        gt_dict['mask_available'] = rearrange(gt_dict['mask_available'], 'b f -> (b f)')

    if args['net_ellseg_head']:
        # 如果启用了 ellseg_head，重新排列空间权重和距离图的张量形状
        gt_dict['spatial_weights'] = rearrange(gt_dict['spatial_weights'], 'b f h w-> (b f) h w')
        gt_dict['distance_map'] = rearrange(gt_dict['distance_map'], 'b f c h w-> (b f) c h w')

    # 重新排列瞳孔中心及其相关张量的形状
    gt_dict['pupil_center'] = rearrange(gt_dict['pupil_center'], 'b f e -> (b f) e')
    gt_dict['pupil_ellipse'] = rearrange(gt_dict['pupil_ellipse'], 'b f e -> (b f) e')
    gt_dict['pupil_center_norm'] = rearrange(gt_dict['pupil_center_norm'], 'b f e -> (b f) e')
    gt_dict['pupil_center_available'] = rearrange(gt_dict['pupil_center_available'], 'b f -> (b f)')
    gt_dict['pupil_ellipse_norm'] = rearrange(gt_dict['pupil_ellipse_norm'], 'b f e -> (b f) e')
    gt_dict['pupil_ellipse_available'] = rearrange(gt_dict['pupil_ellipse_available'], 'b f -> (b f)')

    # 重新排列虹膜椭圆及其相关张量的形状
    gt_dict['iris_ellipse'] = rearrange(gt_dict['iris_ellipse'], 'b f e -> (b f) e')
    gt_dict['iris_ellipse_norm'] = rearrange(gt_dict['iris_ellipse_norm'], 'b f e -> (b f) e')
    gt_dict['iris_ellipse_available'] = rearrange(gt_dict['iris_ellipse_available'], 'b f -> (b f)')

    # 重新排列数据集编号、图像编号和不良样本标志张量的形状
    gt_dict['ds_num'] = rearrange(gt_dict['ds_num'], 'b f -> (b f)')
    gt_dict['im_num'] = rearrange(gt_dict['im_num'], 'b f -> (b f)')
    gt_dict['is_bad'] = rearrange(gt_dict['is_bad'], 'b f -> (b f)')

    # 重新排列眼球和注视向量及其相关张量的形状
    gt_dict['eyeball'] = rearrange(gt_dict['eyeball'], 'b f e -> (b f) e')
    gt_dict['gaze_vector'] = rearrange(gt_dict['gaze_vector'], 'b f e -> (b f) e')
    gt_dict['pupil_lm_2D'] = rearrange(gt_dict['pupil_lm_2D'], 'b f e -> (b f) e')
    gt_dict['pupil_lm_3D'] = rearrange(gt_dict['pupil_lm_3D'], 'b f e -> (b f) e')
    gt_dict['iris_lm_2D'] = rearrange(gt_dict['iris_lm_2D'], 'b f e -> (b f) e')
    gt_dict['iris_lm_3D'] = rearrange(gt_dict['iris_lm_3D'], 'b f e -> (b f) e')

    return gt_dict


def reshape_ellseg_out(out_dict, args):
    ##################### 准备张量 ############################
    if args['net_ellseg_head']:
        # 重新排列预测张量的形状，从 (b, f, c, h, w) 变为 (b*f, c, h, w)
        out_dict['predict'] = rearrange(out_dict['predict'], 'b f c h w-> (b f) c h w')

        # 重新排列虹膜椭圆的形状，从 (b, f, e) 变为 (b*f, e)
        out_dict['iris_ellipse'] = rearrange(out_dict['iris_ellipse'], 'b f e -> (b f) e')

        # 重新排列瞳孔椭圆的形状，从 (b, f, e) 变为 (b*f, e)
        out_dict['pupil_ellipse'] = rearrange(out_dict['pupil_ellipse'], 'b f e -> (b f) e')

        # 重新排列掩码张量的形状，从 (b, f, h, w) 变为 (b*f, h, w)
        out_dict['mask'] = rearrange(out_dict['mask'], 'b f h w-> (b f) h w')

        # 重新排列瞳孔置信度、瞳孔中心及其相关张量的形状
        out_dict['pupil_conf'] = rearrange(out_dict['pupil_conf'], 'b f e -> (b f) e')
        out_dict['pupil_center'] = rearrange(out_dict['pupil_center'], 'b f e -> (b f) e')
        out_dict['pupil_ellipse_norm'] = rearrange(out_dict['pupil_ellipse_norm'], 'b f e -> (b f) e')
        out_dict['pupil_ellipse_norm_regressed'] = rearrange(out_dict['pupil_ellipse_norm_regressed'],
                                                             'b f e -> (b f) e')

        # 重新排列虹膜中心、虹膜置信度及其相关张量的形状
        out_dict['iris_center'] = rearrange(out_dict['iris_center'], 'b f e -> (b f) e')
        out_dict['iris_conf'] = rearrange(out_dict['iris_conf'], 'b f e -> (b f) e')
        out_dict['iris_ellipse_norm'] = rearrange(out_dict['iris_ellipse_norm'], 'b f e -> (b f) e')
        out_dict['iris_ellipse_norm_regressed'] = rearrange(out_dict['iris_ellipse_norm_regressed'], 'b f e -> (b f) e')

    return out_dict


# %% Loss function  计算 EllSeg 模型的损失函数
def get_loss_ellseg(gt_dict, pd_dict, alpha, beta,
             make_aleatoric=False, regress_loss=True, bias_removal=False,
             pseudo_labels=False, adv_loss=False, label_tracker=False):
    '''
    da_loss：域适应损失。
    seg_loss：分割损失。
    pseudo_loss：伪标签损失。
    da_loss_dec：域适应损失（对抗损失）。
    iris_c_loss：虹膜中心损失。
    pupil_c_loss：瞳孔中心损失。
    iris_el_loss：虹膜椭圆参数损失。
    pupil_c_reg_loss：瞳孔中心回归损失。
    pupil_params_loss：瞳孔椭圆参数损失。
    '''
    # 分割损失
    loss_seg = get_seg_loss(gt_dict, pd_dict, 0.5)
    # 无不确定性的 L1 损失（用于分割中心）
    loss_pupil_c = get_uncertain_l1_loss(gt_dict['pupil_center_norm'],
                                         pd_dict['pupil_ellipse_norm'][:, :2],
                                         None,
                                         uncertain=False,
                                         cond=gt_dict['pupil_center_available'],
                                         do_aleatoric=False)

    loss_iris_c  = get_uncertain_l1_loss(gt_dict['iris_ellipse_norm'][:, :2],
                                         pd_dict['iris_ellipse_norm'][:, :2],
                                         None,
                                         uncertain=False,
                                         cond=gt_dict['iris_ellipse_available'],
                                         do_aleatoric=False)

    # 含不确定性的 L1 损失（用于瞳孔中心回归）
    loss_pupil_c_reg = get_uncertain_l1_loss(gt_dict['pupil_center_norm'],
                                             pd_dict['pupil_ellipse_norm_regressed'][:, :2],
                                             None,
                                             uncertain=pd_dict['pupil_conf'][:, :2],
                                             cond=gt_dict['pupil_center_available'],
                                             do_aleatoric=make_aleatoric)

    # 含不确定性的 L1 损失（用于椭圆参数回归）
    loss_pupil_el = get_uncertain_l1_loss(gt_dict['pupil_ellipse_norm'][:, 2:],
                                          pd_dict['pupil_ellipse_norm_regressed'][:, 2:],
                                          [4, 4, 3],
                                          uncertain=pd_dict['pupil_conf'][:, 2:],
                                          cond=gt_dict['pupil_ellipse_available'],
                                          do_aleatoric=make_aleatoric)

    loss_iris_el = get_uncertain_l1_loss(gt_dict['iris_ellipse_norm'],
                                         pd_dict['iris_ellipse_norm_regressed'],
                                         [1, 1, 4, 4, 3],
                                         uncertain=pd_dict['iris_conf'],
                                         cond=gt_dict['iris_ellipse_available'],
                                         do_aleatoric=make_aleatoric)

    # 梯度反转
    if bias_removal:
        num_samples = gt_dict['ds_num'].shape[0]
        gt_ds_num = gt_dict['ds_num'].reshape(num_samples, 1, 1)
        gt_ds_num = gt_ds_num.repeat((1, ) + pd_dict['ds_onehot'].shape[-2:])
        loss_da = torch.nn.functional.cross_entropy(pd_dict['ds_onehot'],
                                                    gt_ds_num)
    else:
        loss_da = torch.tensor([0.0]).to(loss_seg.device)

    if not regress_loss:
        loss = 0.5*(20*loss_seg + loss_da) + (1-0.5)*(loss_pupil_c + loss_iris_c)
    else:
        loss = 0.5*(20*loss_seg + 2*alpha*loss_da) + \
               (1-0.5)*(loss_pupil_c + loss_iris_c +
                         loss_pupil_el + loss_iris_el + loss_pupil_c_reg)

    if adv_loss:
        da_loss_dec = torch.nn.functional.cross_entropy(pd_dict['disc_onehot'],
                                                        gt_dict['ds_num'].to(torch.long))

        # 逆域分类，我们希望随着训练的进行增加域混淆。确保在任何给定的 epoch 中，权重不会超过主损失的权重，
        # 否则可能会导致为混淆鉴别器而产生意外的解决方案
        loss = loss - 0.4*beta*da_loss_dec
    else:
        da_loss_dec = torch.tensor([0.0]).to(loss_seg.device)

    if pseudo_labels:
        # 基于熵生成伪标签和置信度
        pseudo_labels, conf = generate_pseudo_labels(pd_dict['predict'])

        # 基于有地面真值信息的样本，从基于熵的置信度中分类每个预测为“好”或“坏”
        loc = remove_underconfident_psuedo_labels(conf.detach(),
                                                  label_tracker,
                                                  gt_dict=False)  # gt_dict

        # 对于有地面真值信息的样本，移除伪标签
        loc = gt_dict['mask'] != -1  # 有地面真值的样本
        pseudo_labels[loc] = -1  # 禁用伪样本
        conf[loc] = 0.0

        # 有非零置信度的像素和样本数
        num_valid_pxs = torch.sum(conf.flatten(start_dim=-2) > 0, dim=-1) + 1
        num_valid_samples = torch.sum(num_valid_pxs > 0)

        pseudo_loss = torch.nn.functional.cross_entropy(pd_dict['predict'],
                                                        pseudo_labels,
                                                        ignore_index=-1,
                                                        reduction='none')
        pseudo_loss = (conf * pseudo_loss).flatten(start_dim=-2)
        pseudo_loss = torch.sum(pseudo_loss, dim=-1) / num_valid_pxs  # 在像素上平均
        pseudo_loss = torch.sum(pseudo_loss, dim=0) / num_valid_samples  # 在样本上平均

        loss = loss + beta * pseudo_loss
    else:
        pseudo_loss = torch.tensor([0.0])

    loss_dict = {'da_loss': loss_da.item(),
                 'seg_loss': loss_seg.item(),
                 'pseudo_loss': pseudo_loss.item(),
                 'da_loss_dec': da_loss_dec.item(),
                 'iris_c_loss': loss_iris_c.item(),
                 'pupil_c_loss': loss_pupil_c.item(),
                 'iris_el_loss': loss_iris_el.item(),
                 'pupil_c_reg_loss': loss_pupil_c_reg.item(),
                 'pupil_params_loss': loss_pupil_el.item(),
                 }

    return loss, da_loss_dec, loss_dict

# 计算简单的评估指标，包括瞳孔和虹膜的位置准确性以及凝视方向的准确性。
def get_metrics_simple(out_dict, gt_dict, metric_dict, diagonal, args, mode="rend"):
    """
    计算简单的评估指标，包括瞳孔和虹膜的位置准确性以及凝视方向的准确性。

    参数：
    - out_dict：模型输出的字典，包含预测的瞳孔位置、虹膜位置和凝视向量等信息。
    - gt_dict：真实数据的字典，包含真实的瞳孔位置、虹膜位置和凝视向量等信息。
    - metric_dict：用于存储评估指标的字典。
    - diagonal：图像对角线的长度，用于将像素距离归一化到图像尺寸。
    - args：参数字典，包含各种设置和配置信息。

    返回：
    - metric_dict：更新后的评估指标字典。
    """

    # # 计算瞳孔中心的像素坐标（UV）：通过瞳孔中心在图像上的位置和凝视向量计算。
    # gt_pupil_c_UV = gt_dict['eyeball'][..., 1:3] + \
    #                 (gt_dict['eyeball'][..., 0:1] * gt_dict['gaze_vector'][..., :2])
    # gt_pupil_c_UV = gt_pupil_c_UV.float()  # 转换为浮点数张量
    # gt_eyeball_c_UV = gt_dict['eyeball'][..., 1:3]  # 获取真实眼球中心的像素坐标
    #
    # pupil_c_UV = out_dict['pupil_c_UV']  # 获取模型预测的瞳孔中心像素坐标
    # eyeball_c_UV = out_dict['eyeball_c_UV']  # 获取模型预测的眼球中心像素坐标
    #
    # # 计算瞳孔中心的欧氏距离
    # pupil_c_px_dist = torch.sqrt((pupil_c_UV[..., 0] - gt_pupil_c_UV[..., 0]) ** 2
    #                              + (pupil_c_UV[..., 1] - gt_pupil_c_UV[..., 1]) ** 2)
    #
    # # 计算眼球中心的欧氏距离
    # eyeball_c_px_dist = torch.sqrt((eyeball_c_UV[..., 0] - gt_eyeball_c_UV[..., 0]) ** 2
    #                                + (eyeball_c_UV[..., 1] - gt_eyeball_c_UV[..., 1]) ** 2)
    #
    # # 将欧氏距离和图像对角线长度归一化为百分比
    # metric_dict['pupil_c_px_dst'] = pupil_c_px_dist  # 存储瞳孔中心像素距离
    # metric_dict['eyeball_c_px_dist'] = eyeball_c_px_dist  # 存储眼球中心像素距离
    # metric_dict['norm_pupil_c_px_dst'] = (pupil_c_px_dist * 100) / diagonal  # 归一化瞳孔中心像素距离
    # metric_dict['norm_eyeball_c_px_dist'] = (eyeball_c_px_dist * 100) / diagonal  # 归一化眼球中心像素距离

    # 计算真实和预测的3D凝视向量之间的角度误差
    gt_gaze_3d = gt_dict['gaze_vector'].detach().cpu().numpy()  # 真实的3D凝视向量
    pred_gaze_3d = out_dict['gaze_vector_3D'].detach().cpu().numpy()  # 预测的3D凝视向量
    metric_dict['gaze_3D_ang_deg'] = angular_error(pred_gaze_3d, gt_gaze_3d)  # 存储角度误差
    metric_dict['gaze_3D_xy_ang_deg'] = angular_error(
        pred_gaze_3d[..., :2] / np.linalg.norm(pred_gaze_3d[..., :2], axis=-1, keepdims=True),
        gt_gaze_3d[..., :2] / np.linalg.norm(gt_gaze_3d[..., :2], axis=-1, keepdims=True))  # 存储水平角度误差

    # 根据设置选择用于评分的指标
    if args['loss_w_supervise_gaze_vector_3D_L2'] or args['loss_w_supervise_gaze_vector_3D_cos_sim']:
        metric_dict['score'] = metric_dict['gaze_3D_ang_deg']  # 使用3D凝视向量的角度误差评分
    else:
        metric_dict['score'] = metric_dict['gaze_3D_ang_deg']  # 默认使用3D凝视向量的角度误差评分

    return metric_dict

# 计算两个向量之间的角度误差。
def angular_error(pred, gt):
    """
    计算两个向量之间的角度误差。
    参数：
    - pred：预测的向量。
    - gt：真实的向量。

    返回：
    - err：角度误差。
    """
    err = (pred * gt).sum(-1)  # 计算点积
    err = np.arccos(err)  # 计算反余弦，得到弧度值
    err = np.degrees(err)  # 转换为角度
    return err


# %% Get performance metrics for rendering 计算模型在渲染模式下的评估指标
def get_metrics_rend(rendering_dict, gt_dict, metric_dict, diagonal, args, available_predicted_mask):
    '''
    计算模型在渲染模式下的评估指标，包括分割 IoU（交并比）、瞳孔中心和眼球中心的像素距离以及凝视方向的角度误差
    '''
    # 获取批次大小和帧数
    B = args['batch_size']
    F = args['frames']

    # 检查是否需要使用mask进行训练
    train_with_mask = args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_gt_2_pred'] \
                      or args['loss_w_rend_pred_2_gt']

    # 如果有可用的预测mask且需要用mask训练
    if available_predicted_mask and train_with_mask:
        # 保留在虹膜中的mask
        masked_predicted = np.where(gt_dict['mask'] > 0, 1, 0) * rendering_dict['mask']

        # 计算分割IoU（交并比）
        metric_dict['rendering_iou'] = get_seg_metrics(gt_dict['mask'],
                                                       rendering_dict['mask'],
                                                       gt_dict['mask_available'],
                                                       B, F)

        # 如果使用rendering头且不使用ellseg头，或者ellseg损失为0
        if args['net_rend_head'] and (args['net_ellseg_head'] == False or args['loss_w_ellseg'] == 0):
            # 计算得分为rendering_iou的平均值
            metric_dict['score'] = metric_dict['rendering_iou'].mean(axis=1)

        # 计算带掩码的分割IoU
        metric_dict['masked_rendering_iou'] = get_seg_metrics(gt_dict['mask'],
                                                              masked_predicted,
                                                              gt_dict['mask_available'],
                                                              B, F)

    # 如果不使用mask训练
    elif not train_with_mask:
        pass
    else:
        # 将指标设为NaN以便于计算正确的平均值
        if args['net_rend_head'] and (args['net_ellseg_head'] == False or args['loss_w_ellseg'] == 0):
            metric_dict['score'] = np.zeros((B * F))
            metric_dict['score'][metric_dict['score'] == 0] = np.nan
        metric_dict['rendering_iou'] = np.zeros((B * F, 3))
        metric_dict['rendering_iou'][metric_dict['rendering_iou'] == 0] = np.nan
        metric_dict['masked_rendering_iou'] = np.zeros((B * F, 3))
        metric_dict['masked_rendering_iou'][metric_dict['masked_rendering_iou'] == 0] = np.nan

    # 如果当前对象是 'TEyeD'
    if 'TEyeD' in args['cur_obj']:
        print(args['cur_obj'])
        # 计算基于GT眼球和视线向量的UV坐标中的瞳孔中心
        gt_pupil_c_UV = gt_dict['eyeball'][..., 1:3] + (gt_dict['eyeball'][..., 0:1] * gt_dict['gaze_vector'][..., :2])
        gt_pupil_c_UV = gt_pupil_c_UV

        # 计算瞳孔中心的欧氏距离
        pupil_c_px_dist = np.sqrt((rendering_dict['pupil_c_UV'][..., 0] - gt_pupil_c_UV[..., 0]) ** 2
                                  + (rendering_dict['pupil_c_UV'][..., 1] - gt_pupil_c_UV[..., 1]) ** 2)

        metric_dict['pupil_c_px_dst'] = pupil_c_px_dist
        metric_dict['norm_pupil_c_px_dst'] = (pupil_c_px_dist * 100) / diagonal

        # 计算眼球中心的欧氏距离
        gt_eyeball_c_UV = gt_dict['eyeball'][..., 1:3]
        eyeball_c_px_dist = np.sqrt((rendering_dict['eyeball_c_UV'][..., 0] - gt_eyeball_c_UV[..., 0]) ** 2
                                    + (rendering_dict['eyeball_c_UV'][..., 1] - gt_eyeball_c_UV[..., 1]) ** 2)

        metric_dict['eyeball_c_px_dist'] = eyeball_c_px_dist
        metric_dict['norm_eyeball_c_px_dist'] = (eyeball_c_px_dist * 100) / diagonal

        # 计算3D视线向量的角度误差
        gt_gaze_3d = gt_dict['gaze_vector']
        pred_gaze_3d = rendering_dict['gaze_vector_3D']
        metric_dict['gaze_3D_ang_deg'] = angular_error(pred_gaze_3d, gt_gaze_3d)
        metric_dict['gaze_3D_xy_ang_deg'] = angular_error(
            pred_gaze_3d[..., :2] / np.linalg.norm(pred_gaze_3d[..., :2], axis=-1, keepdims=True),
            gt_gaze_3d[..., :2] / np.linalg.norm(gt_gaze_3d[..., :2], axis=-1, keepdims=True))

    # 如果当前对象是 'nvgaze'
    elif 'nvgaze' in args['cur_obj']:
        # 计算视线误差
        metric_dict['gaze_ang_deg'] = np.zeros((B * F))
        metric_dict['eyeball_c_px_dist'] = np.zeros((B * F))
        metric_dict['pupil_c_px_dst'] = np.zeros((B * F))

    # 其他情况
    else:
        # 计算3D视线向量的角度误差
        gt_gaze_3d = gt_dict['gaze_vector']
        pred_gaze_3d = rendering_dict['gaze_vector_3D']
        metric_dict['gaze_3D_ang_deg'] = angular_error(pred_gaze_3d, gt_gaze_3d)
        metric_dict['gaze_3D_xy_ang_deg'] = angular_error(
            pred_gaze_3d[..., :2] / np.linalg.norm(pred_gaze_3d[..., :2], axis=-1, keepdims=True),
            gt_gaze_3d[..., :2] / np.linalg.norm(gt_gaze_3d[..., :2], axis=-1, keepdims=True))

    # 返回指标字典
    return metric_dict


# %% Get performance metrics 计算模型在椭圆分割任务（EllSeg）中的评估指标
def get_metrics_ellseg(pd_dict, gt_dict, metric_dict, args):
    B = args['batch_size']  # 获取批次大小
    F = args['frames']      # 获取帧数

    # 获取高度和宽度
    height, width = gt_dict['mask'].shape[-2:]
    scale = min([height, width])  # 计算最小的尺寸

    # 分割IoU指标
    metric_dict['iou'] = get_seg_metrics(gt_dict['mask'],
                                         pd_dict['mask'],
                                         gt_dict['mask_available'], B, F)

    # 重构mask的分割IoU指标
    metric_dict['iou_recon'] = get_seg_metrics(gt_dict['mask'],
                                               pd_dict['mask_recon'],
                                               gt_dict['mask_available'], B, F)

    # 瞳孔中心距离指标
    metric_dict['pupil_c_dst'] = get_distance(gt_dict['pupil_center'],
                                              pd_dict['pupil_ellipse'][:, :2],
                                              gt_dict['pupil_center_available'])

    # 虹膜中心距离指标
    metric_dict['iris_c_dst'] = get_distance(gt_dict['iris_ellipse'][:, :2],
                                             pd_dict['iris_ellipse'][:, :2],
                                             gt_dict['iris_ellipse_available'])

    # 瞳孔轴距指标
    metric_dict['pupil_axes_dst'] = get_distance(gt_dict['pupil_ellipse'][:, 2:-1],
                                                 pd_dict['pupil_ellipse'][:, 2:-1],
                                                 gt_dict['pupil_ellipse_available'])

    # 虹膜轴距指标
    metric_dict['iris_axes_dst'] = get_distance(gt_dict['iris_ellipse'][:, 2:-1],
                                                pd_dict['iris_ellipse'][:, 2:-1],
                                                gt_dict['iris_ellipse_available'])

    # 瞳孔角度距离指标
    metric_dict['pupil_ang_dst'] = getAng_metric(gt_dict['pupil_ellipse'][:, -1],
                                                 pd_dict['pupil_ellipse'][:, -1],
                                                 gt_dict['pupil_ellipse_available'])

    # 虹膜角度距离指标
    metric_dict['iris_ang_dst'] = getAng_metric(gt_dict['iris_ellipse'][:, -1],
                                                pd_dict['iris_ellipse'][:, -1],
                                                gt_dict['iris_ellipse_available'])

    # 评估指标计算
    term_A = metric_dict['iou'][...,-2:].mean(axis=1) \
        if np.any(gt_dict['mask_available']) \
        else np.nan*np.zeros((gt_dict['mask_available'].shape[0], ))
    term_A[~gt_dict['mask_available']] = np.nan

    term_B = metric_dict['iou_recon'][...,-2:].mean(axis=1) \
        if np.any(gt_dict['mask_available']) \
        else np.nan*np.zeros((gt_dict['mask_available'].shape[0], ))
    term_B[~gt_dict['mask_available']] = np.nan

    term_C = 1 - (1/scale)*metric_dict['pupil_c_dst'] \
        if np.any(gt_dict['pupil_center_available']) \
        else np.nan*np.zeros((gt_dict['mask_available'].shape[0], ))
    term_C[~gt_dict['pupil_center_available']] = np.nan

    term_D = 1 - (1/scale)*metric_dict['iris_c_dst'] \
        if np.any(gt_dict['iris_ellipse_available']) \
        else np.nan*np.zeros((gt_dict['mask_available'].shape[0], ))
    term_D[~gt_dict['iris_ellipse_available']] = np.nan

    term_mat = np.stack([term_A, term_B, term_C, term_D], axis=1)

    # 计算最终得分
    metric_dict['score'] = np.nanmean(term_mat, axis=1)

    # IoU得分
    metric_dict['iou_score'] = term_A

    return metric_dict  # 返回更新后的指标字典

# 对一组模型评估指标进行聚合，以便进行更全面的性能分析和报告
def aggregate_metrics(list_metric_dicts, csv_save_path=None):
    keys_list = list_metric_dicts[0].keys()  # 获取所有指标字典的键列表
    agg_dict = {}  # 聚合后的指标字典
    raw_dict = {}  # 原始指标字典

    for key_entry in keys_list:
        try:
            if 'loss' in key_entry:
                # 对于损失指标，直接转换为数组
                raw_dict[key_entry] = np.array([ele[key_entry] for ele in list_metric_dicts])
            else:
                # 对于其他指标，合并为一个大数组
                # raw_dict[key_entry] = np.concatenate([np.atleast_1d(ele[key_entry]) for ele in list_metric_dicts], axis=0)
                raw_dict[key_entry] = np.concatenate([ele[key_entry] for ele in list_metric_dicts], axis=0)
            if 'iou' in key_entry:
                # 对于IoU指标，计算每个类别的均值
                agg_dict[key_entry] = np.nanmean(raw_dict[key_entry], axis=0)
            # 计算每个指标的总体均值
            agg_dict[key_entry + '_mean'] = np.nanmean(raw_dict[key_entry], axis=0)
        except Exception as e:
            print(f"Exception occurred while processing key: {key_entry}")
            traceback.print_exc()

    if csv_save_path:
        import pandas as pd
        values = []
        names = []
        for k in raw_dict:
            if 'loss' in k:
                continue
            if 'iou' in k:
                for j in range(raw_dict[k].shape[1]):
                    names.append(f'{k}_class{j}')
                    values.append(raw_dict[k][:, j])
            else:
                names.append(k)
                values.append(raw_dict[k])
        values = np.asarray(values).T
        df = pd.DataFrame(data=values, columns=names)
        df.to_csv(csv_save_path)  # 保存到CSV文件

    return agg_dict  # 返回聚合后的指标字典

