# -*- coding: UTF-8 -*-
'''
@Project ：PD_Gaze 
@File    ：EyeModel.py
@Author  ：xyf
@Date    ：2025/6/10 19:22 
'''
import gc
import logging
import math
import pickle
import sys
import time

from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from helperfunctions.CurriculumLib import DataLoader_riteyes
from helperfunctions.hfunctions import create_experiment_folder_tree, fix_batch, merge_two_dicts, generate_rend_masks
import helperfunctions.CurriculumLib as CurLib
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from pprint import pprint

from rendering.rendered_semantics_loss import SobelFilter, loss_fn_rend_sprvs, rendered_semantics_loss_vectorized3, \
    rendered_semantics_loss
from rendering.rendering import render_semantics, eyeball_center
from scripts import reshape_gt, get_metrics_simple, move_gpu, detach_cpu_numpy, aggregate_metrics, log_wandb, \
    reshape_ellseg_out, send_to_device, get_metrics_rend, save_out
from timm.optim import Lamb as Lamb_timm
from timm.scheduler import CosineLRScheduler as CosineLRScheduler_timm
import torch
import random
import warnings
import numpy as np
import wandb
import torch.nn.functional as F
from args_maker import make_args
from helperfunctions.utils import move_to_single, make_logger, get_nparams, EarlyStopping, SpikeDetection
from models.models_mux import model_dict

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["WANDB_API_KEY"] = '+++++++++++'  # 将引号内的+替换成自己在wandb上的一串值
os.environ["WANDB_MODE"] = "offline"  # 离线  （此行代码不用修改）

class EyeModelTrainer:
    def __init__(self, args):
        self.args = args
        model_name = args['model']
        cur_objs = args['cur_obj']
        print(f'[Trainer Log] Model Name: \033[0;32;40m\t{model_name}\033[0m')
        # 检查是否有可用的 GPU
        if torch.cuda.is_available():
            self.device = torch.device('cuda')  # 使用 CUDA 设备
            self.GPU_num = torch.cuda.device_count()  # 获取 GPU 数量
            print(f'[Trainer Log] \033[0;32;40m\t{self.GPU_num} GPUs Detected \033[0m')
        else:
            self.device = torch.device('cpu')  # 使用 CPU
            self.GPU_num = 0
            print(f'[Trainer Log] \033[0;32;40m\tNo GPU Detected. Run on CPU \033[0m')

        self.save_pt_name = f'{model_name}_{cur_objs}_myPara.pt'
        self.path_dict, self.exp_name_str = create_experiment_folder_tree(args['repo_root'],
                                                                args['path_exp_tree'],
                                                                args['exp_name'],
                                                                args['only_test'],
                                                                create_tree=args['local_rank'] == 0 if args[
                                                                    'do_distributed'] else True)
        if 'DEBUG' not in args['exp_name']:
            wandb.init(project="GazeModel",
                       entity='xyfdxb',
                       config=args, name=self.exp_name_str)



    def run(self):
        self.path_dict['repo_root'] = args['repo_root']
        self.path_dict['path_data_source'] = args['path_data_source']
        self.path_dict['path_data_target'] = args['path_data_target']
        # %%
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True

        # Set seeds
        torch.manual_seed(args['seed'])
        np.random.seed(args['seed'])
        random.seed(args['seed'])

        print('---------')
        print('解析的参数')
        pprint(args)  # 打印解析后的参数
        print('---------')

        # Train and save validated model
        if not args['only_test']:
            if not args['only_valid']:
                print('train mode')
                self.train(args, self.path_dict, validation_mode=False, test_mode=False)

                print('validation mode')
                self.train(args, self.path_dict, validation_mode=True, test_mode=False)

                # Test out best model and save results
                print("test mode")
                self.train(args, self.path_dict, validation_mode=False, test_mode=True)
        elif args['only_valid']:
            print('validation mode')
            self.train(args, self.path_dict, validation_mode=True, test_mode=False)
        elif args['only_test']:
            print('test mode')
            self.train(args, self.path_dict, validation_mode=False, test_mode=True)

        print("run done!")
        wandb.finish()


    def train(self, args, path_dict, validation_mode=False, test_mode=False):

        net_dict = []

        # %% Load model  加载网络模型
        if args['model'] == 'DenseElNet':
            norm = nn.BatchNorm3d
        else:
            norm = nn.BatchNorm2d

        # 构建模型
        print(f'[Trainer Log] Model Building......', end='')
        sys.stdout.flush()  # 刷新标准输出缓冲，立即打印上面内容

        # 如果使用多个 GPU，则使用 nn.DataParallel 包装模型
        if self.GPU_num > 1:
            self.model = nn.DataParallel(model_dict[args['model']](args,
                                        norm=norm,
                                        act_func=F.leaky_relu)).to(self.device)
        else:
            self.model = model_dict[args['model']](args,
                                        norm=norm,
                                        act_func=F.leaky_relu)  # 使用单 GPU 或 CPU

        print(f' Complete!')  # 模型构建完成


        # %% Weight loaders
        # if it is pretrained, then load pretrained weights
        if args['pretrained'] or args['continue_training'] or args['weights_path']:

            if args['weights_path']:
                path_pretrained = args['weights_path']
            elif args['pretrained']:
                path_pretrained = os.path.join(path_dict['repo_root'],
                                               '..',
                                               'pretrained',
                                               'pretrained.git_ok')
            elif args['continue_training']:
                path_pretrained = os.path.join(args['continue_training'])

            net_dict = torch.load(path_pretrained,
                                  map_location=torch.device('cuda'))
            state_dict_single = move_to_single(net_dict['state_dict'])
            self.model.load_state_dict(state_dict_single, strict=False)
            print(f'Pretrained model loaded from: {path_pretrained}')

        if test_mode:
            print('Test mode detected. Loading best model.')

            if args['path_model']:
                net_dict = torch.load(args['path_model'],
                                      map_location=torch.device('cuda'))
            else:
                net_dict = torch.load(os.path.join(path_dict['results'], self.save_pt_name),
                                      map_location=torch.device('cuda'))

            # 确保保存的参数与解析的参数匹配
            state_dict_single = move_to_single(net_dict['state_dict'])
            self.model.load_state_dict(state_dict_single, strict=False)

            # Do not initialize a writer
            writer = []
        elif validation_mode:
            print('Validation mode detected. Loading model.')
            if args['path_model']:
                net_dict = torch.load(args['path_model'],
                                      map_location=torch.device('cuda'))
            else:
                net_dict = torch.load(os.path.join(path_dict['results'], self.save_pt_name),
                                      map_location=torch.device('cuda'))

            # Ensure saved arguments match with parsed arguments
            state_dict_single = move_to_single(net_dict['state_dict'])
            self.model.load_state_dict(state_dict_single, strict=False)

            # Do not initialize a writer
            writer = []
        else:
            writer = []


        if args['use_GPU']:
            self.model.cuda()

        # %% Initialize logger
        logger = make_logger(path_dict['logs'] + '/train_log.log',
                             rank=args['local_rank'] if args['do_distributed'] else 0)
        logger.write(f'[Experiment args: ] {str(args)}')
        logger.write('# num of parameters: {}'.format(get_nparams(self.model)))

        if test_mode == 0 and validation_mode == 0:
            logger.write('Training!')
            if args['exp_name'] != 'DEBUG':
                wandb.watch(self.model)
        elif validation_mode:
            logger.write('Validating!')
        else:
            logger.write('Testing!')

        # %% Training and validation loops or test only
        self.train_validation_loops(self.model,
                               net_dict,
                               logger,
                               args,
                               path_dict,
                               writer,
                               validation_mode,
                               test_mode)

        # %% Closing functions and logging
        if writer:
            writer.close()

    def train_validation_loops(self, net, net_dict, logger, args,
                               path_dict, writer,
                               validation_mode, test_mode):

        print("---use_pkl_for_dataload---")
        # 如果使用.pkl文件加载数据集对象
        path_cur_obj_train = os.path.join(path_dict['repo_root'],
                                    'cur_objs',
                                    args['mode'],
                                    'cond_' + args['cur_obj'] + '.pkl')
        # 组装.pkl文件的路径
        with open(path_cur_obj_train, 'rb') as f:
            # 打开.pkl文件并加载训练、验证和测试对象
            train_obj, valid_obj, _ = pickle.load(f)

        path_cur_obj_test = os.path.join(path_dict['repo_root'],
                                    'cur_objs',
                                    args['mode'],
                                    'cond_' + args['test_obj'] + '.pkl')
        # 组装.pkl文件的路径
        with open(path_cur_obj_test, 'rb') as f:
            # 打开.pkl文件并加载训练、验证和测试对象
            test_obj, _, _ = pickle.load(f)


        # FIXME Remove unwanted validation-train overlap
        print(f'Starting the procedure of removing unwanted train-val video overlap...')
        # 开始移除训练集和验证集之间不需要的视频重叠
        train_vid_ids = list(np.unique(train_obj.imList[:, :, 1]))
        val_vid_ids = list(np.unique(valid_obj.imList[:, :, 1]))
        # 获取训练集中的所有唯一视频ID

        for vid_id in np.unique(valid_obj.imList[:, :, 1]):
            # 遍历验证集中的所有唯一视频ID
            if vid_id in train_vid_ids:
                # 如果视频ID在训练集中出现过
                # 打印被丢弃的验证集重叠视频ID
                bad_ids = ((valid_obj.imList[:, :, 1] == vid_id).sum(axis=-1) > 0)
                # 找出所有包含该视频ID的验证集样本
                valid_obj.imList = valid_obj.imList[~bad_ids]
                # 从验证集中移除这些样本

        print(f'')
        print(f'')
        print(f'Number of images:')
        print(f'Train images left: {train_obj.imList.shape[0] * train_obj.imList.shape[1]}')
        print(f'Valid images left: {valid_obj.imList.shape[0] * valid_obj.imList.shape[1]}')
        print(f'Test images left: {test_obj.imList.shape[0] * test_obj.imList.shape[1]}')
        # 打印训练集、验证集和测试集中剩余的图像数量
        print(f'')
        print(f'')

        # %% Specify flags of importance
        train_obj.augFlag = args['aug_flag']
        # 设置训练集数据增强标志
        valid_obj.augFlag = False
        # 验证集不进行数据增强
        test_obj.augFlag = False
        # 测试集不进行数据增强

        # %% Modify path information
        train_obj.path2data = path_dict['path_data_source']
        # 设置训练集数据路径
        valid_obj.path2data = path_dict['path_data_source']
        # 设置验证集数据路径
        test_obj.path2data = path_dict['path_data_target']

        train_obj.scale = args['scale_factor']
        # 设置训练集使用的缩放比例
        valid_obj.scale = args['scale_factor']
        # 设置验证集使用的缩放比例
        test_obj.scale = args['scale_factor']
        # 设置测试集数据路径

        # %% Define dataloaders
        logger.write('Initializing loaders')
        # 记录初始化加载器
        # 如果处于验证模式
        if validation_mode:
            # 创建验证集数据加载器
            valid_loader = DataLoader(valid_obj,
                                      shuffle=False,
                                      num_workers=args['workers'],
                                      drop_last=True,
                                      pin_memory=True,
                                      batch_size=args['batch_size'],
                                      sampler= None,
                                      )
        # 如果处于测试模式
        elif test_mode:
            # 创建测试集数据加载器
            test_loader = DataLoader(test_obj,
                                     shuffle=False,
                                     num_workers=args['workers'],
                                     drop_last=True,
                                     pin_memory=True,
                                     batch_size=args['batch_size'],
                                     sampler= None,
                                     )

        else:
            # 如果处于训练模式
            # 创建训练集数据加载器
            train_loader = DataLoader(train_obj,
                                      shuffle=args['random_dataloader'],
                                      num_workers=args['workers'],
                                      drop_last=True,
                                      pin_memory=True,
                                      batch_size=args['batch_size'],
                                      sampler= None,
                                      )

            # 创建验证集数据加载器
            valid_loader = DataLoader(valid_obj,
                                      shuffle=False,
                                      num_workers=args['workers'],
                                      drop_last=True,
                                      pin_memory=True,
                                      batch_size=args['batch_size'],
                                      sampler= None,
                                      )

        # %% Early stopping criterion
        if '3D' in args['early_stop_metric'] or '2D' in args['early_stop_metric']:
            # 如果早停指标是3D或2D
            early_stop = EarlyStopping(metric=args['early_stop_metric'],
                                       patience=args['early_stop'],
                                       verbose=True,
                                       delta=0.001,  # 需要0.1%的改进
                                       rank_cond=True,
                                       mode='min',
                                       fName='Mybest_model.pt',
                                       path_save=path_dict['results'],
                                       )
            # 创建早停对象，目标是最小化指标
        else:
            early_stop = EarlyStopping(metric=args['early_stop_metric'],
                                       patience=args['early_stop'],
                                       verbose=True,
                                       delta=0.001,  # 需要0.1%的改进
                                       rank_cond=True,
                                       mode='max',
                                       fName='Mybest_model.pt',
                                       path_save=path_dict['results'],
                                       )
            # 创建早停对象，目标是最大化指标

        # %% Optimizer
        param_list = [param for name, param in net.named_parameters() if 'adv' not in name]
        # 学习器和优化器
        if 'LAMB' in args['optimizer_type']:
            # 如果使用LAMB优化器，创建LAMB优化器实例
            optimizer = Lamb_timm(param_list, lr=args['lr'], weight_decay=args['wd'])
            # 如果使用AdamW优化器和余弦调度器，创建AdamW优化器实例
            warmup_epochs = 5
            scheduler = CosineLRScheduler_timm(optimizer,
                                               t_initial=args['epochs'],
                                               lr_min=args['lr'] / 100.0,
                                               warmup_t=4,
                                               warmup_lr_init=args['lr'] / 100.0)
            # 使用余弦学习率调度器
            use_sched = True
            # 设置调度器标志为True

        else:
            optimizer = torch.optim.Adam(param_list, lr=args['lr'], amsgrad=False, weight_decay=args['wd'])
            # 默认情况下，使用Adam优化器
            use_sched = False
            # 不使用调度器
            # 不使用对抗性优化器

        # Create a checkpoint based on current scores
        checkpoint = {}
        checkpoint['args'] = args  # Save arguments
        # 创建一个检查点并保存当前参数

        args['time_to_update'] = True
        # 下次退出主循环时重新随机化数据集
        last_epoch_validation = False
        # 设置最后一个epoch验证标志为False

        # 测试模式
        if test_mode:
            logging.info('Entering test mode only ...')
            logger.write('Entering test mode only ...')
            # 如果处于测试模式，记录进入测试模式

            # 运行测试并获取测试结果
            test_result = self.Netforward(net,
                                  [],
                                  logger,
                                  test_loader,
                                  optimizer,
                                  args,
                                  path_dict,
                                  writer=writer,
                                  epoch=0,
                                  mode='test',
                                  batches_per_ep=len(test_loader) if 'DEBUG' not in args['exp_name'] else 10,
                                  # batches_per_ep=args['batches_per_ep'],
                                  last_epoch_valid=True,
                                  csv_save_dir=path_dict['exp'])
            # 将测试结果保存到检查点
            checkpoint['test_result'] = test_result
            print("test_result:", test_result)

            epoch = 0
            if args['exp_name'] != 'DEBUG':
                logger.write('Test results:')
                for key, item in checkpoint['test_result'].items():
                    if 'mean' in key and 'loss' not in key:
                        # 判断 item 是否为 NumPy 数组
                        if isinstance(item, np.ndarray):
                            # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                            formatted_item = [f'{val:.3f}' for val in item.tolist()]
                            # 使用逗号分隔格式化后的字符串
                            formatted_string = ', '.join(formatted_item)
                            wandb.log({'test_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                            logger.write(f'test_avg/{key}: {formatted_string}')
                        else:
                            wandb.log({'test_avg/{}'.format(key): item, 'epoch': epoch})
                            logger.write(f'test_avg/{key}: {item:.3f}')
                logger.write(' ')
                logger.write(' ')
                logger.write(' ')
                # 记录测试结果

            if args['save_results_here']:
                # 确保结果保存目录存在
                os.makedirs(os.path.dirname(args['save_results_here']),
                            exist_ok=True)

                # Save out test results here instead 将测试结果保存到指定位置
                with open(args['save_results_here'], 'wb') as f:
                    pickle.dump(checkpoint, f)
            else:
                # Ensure the directory exists 确保默认结果保存目录存在
                os.makedirs(path_dict['results'], exist_ok=True)

                # Save out the test results 将测试结果保存到默认位置
                with open(path_dict['results'] + '/test_results.pkl', 'wb') as f:
                    pickle.dump(checkpoint, f)

        # 验证模式
        elif validation_mode:
            logger.write('Entering validation mode only ...')
            logging.info('Entering validation mode  only...')

            # 如果处于验证模式，记录进入验证模式并设置alpha和beta为0.5
            # valid_batches_per_ep = int(4608 / (args['frames'] * args['batch_size']))
            valid_result = self.Netforward(net,
                                   [],
                                   logger,
                                   valid_loader,
                                   optimizer,
                                   args,
                                   path_dict,
                                   writer=writer,
                                   epoch=0,
                                   mode='valid',
                                   # batches_per_ep=len(valid_loader) if 'DEBUG' not in args['exp_name'] else 10,
                                   # batches_per_ep=args['batches_per_ep'],
                                   batches_per_ep=len(valid_loader),
                                   last_epoch_valid=True,
                                   csv_save_dir=path_dict['exp'])
            # 运行验证并获取验证结果
            print("valid_result:", valid_result)

            checkpoint['valid_result'] = valid_result
            # 将验证结果保存到检查点

            epoch = 0
            if args['exp_name'] != 'DEBUG':
                logger.write('Validation results:')
                for key, item in checkpoint['valid_result'].items():
                    # 判断 item 是否为 NumPy 数组
                    if isinstance(item, np.ndarray):
                        # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                        formatted_item = [f'{val:.3f}' for val in item.tolist()]
                        # 使用逗号分隔格式化后的字符串
                        formatted_string = ', '.join(formatted_item)
                        wandb.log({'valid_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                        logger.write(f'valid_avg/{key}: {formatted_string}')
                    else:
                        wandb.log({'valid_avg/{}'.format(key): item, 'epoch': epoch})
                        logger.write(f'valid_avg/{key}: {item:.3f}')
                logger.write(' ')
                logger.write(' ')
                logger.write(' ')
                # 记录验证结果

            if args['save_results_here']:
                # Ensure the directory exists
                os.makedirs(os.path.dirname(args['save_results_here']),
                            exist_ok=True)
                # 确保结果保存目录存在

                # Save out test results here instead
                with open(args['save_results_here'], 'wb') as f:
                    pickle.dump(checkpoint, f)
                # 将验证结果保存到指定位置
            else:
                # Ensure the directory exists
                os.makedirs(path_dict['results'], exist_ok=True)
                # 确保默认结果保存目录存在

                # Save out the test results
                with open(path_dict['results'] + '/valid_results.pkl', 'wb') as f:
                    # 将验证结果保存到默认位置
                    pickle.dump(checkpoint, f)

        # 训练模式
        else:
            # 如果remove_spikes参数为True，则创建SpikeDetection实例，否则为False
            spiker = SpikeDetection() if args['remove_spikes'] else False
            logging.info('Entering train mode ...')
            logger.write('Entering train mode ...')
            # 记录进入训练模式

            if args['continue_training']:
                optimizer.load_state_dict(net_dict['optimizer'])
                epoch = net_dict['epoch'] + 1
                # 如果继续训练，从保存的优化器状态加载并设置起始epoch
            else:
                epoch = 0
                # 否则从第0个epoch开始

            while (epoch < args['epochs']):
                # 当epoch小于设定的最大epoch时，继续训练
                if args['time_to_update']:
                    # 如果需要更新数据集

                    # Toggle flag back to False
                    args['time_to_update'] = False
                    # 重置标志为False

                    if args['one_by_one_ds']:
                        train_loader.dataset.sort('one_by_one_ds', args['batch_size'])
                        valid_loader.dataset.sort('one_by_one_ds', args['batch_size'])
                        # 按照one_by_one_ds排序训练和验证数据集
                    else:
                        # for sequence dataset:
                        train_loader.dataset.sort('ordered')
                        valid_loader.dataset.sort('ordered')
                        # 按照ordered排序训练和验证数据集

                # 为采样器设置当前epoch

                # %%
                logging.info('Starting epoch: %d' % epoch)
                logger.write('Starting epoch: %d' % epoch)
                # 记录当前开始的epoch



                if args['dry_run']:
                    train_batches_per_ep = len(train_loader)
                    valid_batches_per_ep = len(valid_loader)
                    # 如果是dry run模式，设置每个epoch的训练和验证批次数为数据集长度
                else:
                    train_batches_per_ep = args['batches_per_ep']
                    if args['reduce_valid_samples']:
                        valid_batches_per_ep = 10
                        # 如果减少验证样本，设置每个epoch的验证批次数为10
                    else:
                        valid_batches_per_ep = len(valid_loader)
                        # 否则根据参数计算每个epoch的验证批次数

                train_result = self.Netforward(net,
                                       spiker,
                                       logger,
                                       train_loader,
                                       optimizer,
                                       args,
                                       path_dict,
                                       writer=writer,
                                       epoch=epoch,
                                       mode='train',
                                       batches_per_ep=train_batches_per_ep)
                # 进行训练并获取训练结果

                if epoch == args['epochs'] - 1:
                    last_epoch_validation = True
                    valid_batches_per_ep = len(valid_loader)
                    # 如果是最后一个epoch，设置验证批次数为验证集长度

                # incase you want to validate the whole validation set just Remove True
                if (args['reduce_valid_samples'] and (epoch % args['perform_valid'] != 0)) \
                        or ('DEBUG' in args['exp_name']) or True:
                    valid_result = self.Netforward(net,
                                           spiker,
                                           logger,
                                           valid_loader,
                                           optimizer,
                                           args,
                                           path_dict,
                                           writer=writer,
                                           epoch=epoch,
                                           mode='valid',
                                           batches_per_ep=valid_batches_per_ep,
                                           # batches_per_ep=50,
                                           last_epoch_valid=last_epoch_validation)
                    # 如果减少验证样本且当前epoch不是需要验证的epoch，或在DEBUG模式下，进行部分验证
                else:
                    valid_result = self.Netforward(net,
                                           spiker,
                                           logger,
                                           valid_loader,
                                           optimizer,
                                           args,
                                           path_dict,
                                           writer=writer,
                                           epoch=epoch,
                                           mode='valid',
                                           batches_per_ep=len(valid_loader),
                                           # batches_per_ep=50,
                                           last_epoch_valid=last_epoch_validation)
                    # 否则进行完整验证

                # Update the check point weights. VERY IMPORTANT!
                checkpoint['state_dict'] = move_to_single(net.state_dict())
                # 计算参数总数
                total_params = sum(p.numel() for p in checkpoint['state_dict'].values())

                # 打印参数量
                print(f"Total parameters in propagator: {total_params}")

                checkpoint['optimizer'] = optimizer.state_dict()
                # 更新检查点中的模型和优化器状态

                checkpoint['epoch'] = epoch
                checkpoint['train_result'] = train_result
                checkpoint['valid_result'] = valid_result
                # 保存当前epoch和训练、验证结果

                if args['exp_name'] != 'DEBUG':

                    logger.write('---Train results:---')
                    for key, item in checkpoint['train_result'].items():
                        # 判断 item 是否为 NumPy 数组
                        if isinstance(item, np.ndarray):
                            # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                            formatted_item = [f'{val:.3f}' for val in item.tolist()]
                            # 使用逗号分隔格式化后的字符串
                            formatted_string = ', '.join(formatted_item)
                            wandb.log({'train_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                            logger.write(f'train_avg/{key}: {formatted_string}')
                        else:
                            wandb.log({'train_avg/{}'.format(key): item, 'epoch': epoch})
                            logger.write(f'train_avg/{key}: {item:.3f}')
                    logger.write(' ')
                    logger.write(' ')
                    logger.write(' ')

                    logger.write('---Validation results:---')
                    for key, item in checkpoint['valid_result'].items():
                        if 'mean' in key and 'loss' not in key:
                            # 判断 item 是否为 NumPy 数组
                            if isinstance(item, np.ndarray):
                                # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                                formatted_item = [f'{val:.3f}' for val in item.tolist()]
                                # 使用逗号分隔格式化后的字符串
                                formatted_string = ', '.join(formatted_item)
                                wandb.log({'valid_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                                logger.write(f'valid_avg/{key}: {formatted_string}')
                            else:
                                wandb.log({'valid_avg/{}'.format(key): item, 'epoch': epoch})
                                logger.write(f'valid_avg/{key}: {item:.3f}')
                    logger.write(' ')
                    logger.write(' ')
                    logger.write(' ')
                    # 记录验证结果并将其上传到wandb

                # Save out the best validation result and model
                early_stop(checkpoint)
                # 使用早停机制保存最佳验证结果和模型

                if args['exp_name'] != 'DEBUG':
                    wandb.log({'val_score': checkpoint['valid_result']['score_mean'], 'epoch': epoch})
                    wandb.log(
                        {'gaze_3D_ang_deg_mean': checkpoint['valid_result']['gaze_3D_ang_deg_mean'], 'epoch': epoch})
                    # 上传验证得分和3D gaze角度平均值到wandb

                # If epoch is a multiple of args['save_every'], then write out
                if (epoch % args['save_every']) == 0:
                    # 如果当前epoch是save_every的倍数，保存模型

                    # Ensure that you do not update the validation score at this
                    # point and simply save the model
                    if '3D' in args['early_stop_metric']:
                        temp_score = checkpoint['valid_result']['gaze_3D_ang_deg_mean']
                    elif '2D' in args['early_stop_metric']:
                        temp_score = checkpoint['valid_result']['gaze_ang_deg_mean']
                    else:
                        temp_score = checkpoint['valid_result']['masked_rendering_iou_mean']
                    early_stop.save_checkpoint(temp_score,
                                               checkpoint,
                                               update_val_score=False,
                                               use_this_name_instead=self.save_pt_name)
                    # 保存当前模型状态为last.pt，不更新验证得分

                if use_sched:
                    scheduler.step(epoch=epoch)
                    if args['exp_name'] != 'DEBUG':
                        wandb.log({'lr': optimizer.param_groups[0]['lr']})
                    # 如果使用学习率调度器，更新学习率并记录到wandb

                epoch += 1
                # 增加epoch计数器

    def Netforward(self,
                net,
                spiker,
                logger,
                loader,
                optimizer,
                args,
                path_dict,
                epoch=0,
                mode='test',
                writer=[],
                batches_per_ep=2000,
                last_epoch_valid=False,
                csv_save_dir=None):
        # 定义前向传播函数，接收网络模型、尖峰检测器、日志记录器、数据加载器、优化器、参数字典、路径字典等作为参数
        net_param_tmp = next(net.parameters())
        device = net_param_tmp.device
        print(device)

        # 初始化 Sobel 滤波器
        sobel_filter = SobelFilter(device)  # .get_device()

        logger.write('----{}. Epoch: {}----'.format(mode, epoch))

        # 根据模式设置网络模型为训练模式或评估模式
        if mode == 'train':
            net.train()
        else:
            net.eval()
        # 初始化 IO 时间列表和数据加载器迭代器
        io_time = []
        loader_iter = iter(loader)
        # 初始化度量、数据集 ID、嵌入和预测掩码标志
        metrics = []
        available_predicted_mask = False

        train_with_mask = args['loss_w_rend_pred_2_gt_edge'] or args['loss_w_rend_gt_2_pred'] \
                          or args['loss_w_rend_pred_2_gt']

        for bt in tqdm(range(batches_per_ep), desc=f"Epoch {epoch} - Batches", leave=True):
            # 开始每个批次的循环
            start_time = time.time()
            # 记录批次开始时间
            try:
                data_dict = next(loader_iter)
                print('----epoch:{}----batch:{}/{}'.format(epoch, bt, batches_per_ep))
                # 尝试从迭代器中获取下一个数据批次
            except:
                print('Loader reset')
                # 如果发生异常（迭代器耗尽），重置迭代器
                loader_iter = iter(loader)
                data_dict = next(loader_iter)
                # 再次获取数据批次
                args['time_to_update'] = True
                # 标记需要更新

            if torch.any(data_dict['is_bad']):
                logger.write('Bad batch found!', do_warn=True)
                data_dict = fix_batch(data_dict)

            end_time = time.time()
            # 记录批次结束时间
            io_time.append(end_time - start_time)
            # 计算并记录IO时间

            with torch.autograd.set_detect_anomaly(bool(args['detect_anomaly'])):
                with torch.cuda.amp.autocast(enabled=bool(args['mixed_precision'])):
                    # 使用自动混合精度和异常检测
                    # 初始化批次结果字典
                    batch_results_rend = {}
                    batch_results_gaze = {}
                    batch_size = args['batch_size']
                    frames = args['frames']

                    if mode == 'train':
                        print("train net begin")
                        out_dict_gaze, out_dict_eye = net(data_dict, args)
                        print("train net finish", net.device)
                        # 如果是训练模式，执行前向传播
                    else:
                        with torch.no_grad():
                            out_dict_gaze, out_dict_eye = net(data_dict, args)
                            print("valid net finish")
                            # 如果是测试模式，执行无梯度的前向传播

                    if torch.all(data_dict['image'] == 0):
                        optimizer.zero_grad()
                        net.zero_grad()
                        print('invalid input image')
                        continue
                        # 如果输入图像全为零，跳过该批次

                    data_dict = reshape_gt(data_dict, args)
                    out_dict_eye = reshape_ellseg_out(out_dict_eye, args)
                    # 重塑张量，将批次和帧合并为一个维度
                    H = data_dict['image'].shape[1]
                    W = data_dict['image'].shape[2]
                    # 获取图像的高度和宽度

                    image_resolution_diagonal = math.sqrt(H ** 2 + W ** 2)
                    # 计算图像分辨率对角线
                    # 初始化渲染损失、损失值、损失字典和渲染字典
                    if args['net_rend_head']:
                        # 如果使用渲染头，执行以下操作
                        # 检查预测值是否在合理范围内，若不在，则跳过该批次
                        if torch.any(torch.any(out_dict_eye['T'] < -1) or torch.any(out_dict_eye['T'] > 1)) or \
                                torch.any(torch.any(out_dict_eye['R'] < -1) or torch.any(out_dict_eye['R'] > 1)) or \
                                torch.any(torch.any(out_dict_eye['L'] < -1) or torch.any(out_dict_eye['L'] > 1)) or \
                                torch.any(
                                    torch.any(out_dict_eye['focal'] < -1) or torch.any(out_dict_eye['focal'] > 1)) or \
                                torch.any(torch.any(out_dict_eye['r_pupil'] < -1) or torch.any(
                                    out_dict_eye['r_pupil'] > 1)) or \
                                torch.any(
                                    torch.any(out_dict_eye['r_iris'] < -1) or torch.any(out_dict_eye['r_iris'] > 1)):
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('invalid predicted values from rend head')
                            continue

                        # 检查预测值是否为 NaN 或 Inf，若是，则跳过该批次
                        if torch.isnan(out_dict_eye['gaze_vector_3D']).any():
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('NaN gaze_vector_3D BEFORE FUNCTION')
                            continue
                        if torch.isinf(out_dict_eye['T']).any():
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('inf problem T inf before function')
                            continue

                        if torch.isnan(out_dict_eye['R']).any():
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('NaN problem R BEFORE FUNCTION')
                            continue
                        if torch.isinf(out_dict_eye['R']).any():
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('inf problem R inf before function')
                            continue

                        # 渲染语义
                        out_dict_eye, rend_dict = render_semantics(out_dict_eye, H=H, W=W, args=args,
                                                                   data_dict=data_dict)

                        # 检查渲染后的值是否为 NaN 或 Inf，若是，则跳过该批次
                        if (torch.isnan(rend_dict['pupil_UV']).any() or torch.isinf(rend_dict['pupil_UV']).any()):
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('invalid pupil from rendering points')
                            continue

                        if (torch.isnan(rend_dict['iris_UV']).any() or torch.isinf(rend_dict['iris_UV']).any()):
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('invalid iris from rendering points')
                            continue

                        if (torch.isnan(rend_dict['pupil_c_UV']).any() or torch.isinf(rend_dict['pupil_c_UV']).any()):
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('invalid pupil center')
                            continue

                        if (torch.isnan(rend_dict['eyeball_c_UV']).any() or torch.isinf(
                                rend_dict['eyeball_c_UV']).any()):
                            optimizer.zero_grad()
                            net.zero_grad()
                            print('invalid eyeball center')
                            continue

                        # 将数据字典发送到设备（如 GPU）
                        # data_dict['mask'][data_dict['mask'] == 3] = 2  # 将虹膜移到2
                        # # 定义类别颜色映射（示例：3个类别）
                        # colormap = np.array([
                        #     [0, 0, 0],  # 类别0：黑色（背景）
                        #     [255, 0, 0],  # 类别1：红色
                        #     [0, 255, 0]  # 类别2：绿色
                        # ])
                        #
                        # # 将类别索引转换为 RGB 图像
                        # mask_rgb = colormap[data_dict['mask'][0]]  # 形状变为 [H, W, 3]
                        #
                        # plt.imshow(mask_rgb)
                        # plt.axis('off')
                        # plt.show()

                        data_dict = send_to_device(data_dict, device)

                        if train_with_mask:
                            # 根据参数选择损失函数
                            if args['loss_rend_vectorized']:
                                loss_fn_rend = rendered_semantics_loss_vectorized3
                            else:
                                loss_fn_rend = rendered_semantics_loss

                            # 计算渲染的瞳孔和虹膜损失
                            total_loss_rend, loss_dict_rend = loss_fn_rend(data_dict['mask'],
                                                                           rend_dict,
                                                                           sobel_filter,
                                                                           None,
                                                                           None,
                                                                           args)

                            iterations = args['batch_size'] * args['frames']
                            # 如果满足条件，生成渲染掩码
                            if (bt % args['produce_rend_mask_per_iter'] == 0 or last_epoch_valid \
                                    or (mode == 'test')):
                                available_predicted_mask = True

                                # 生成渲染掩码
                                rend_dict['eyeball_circle'] = eyeball_center(out_dict_eye,
                                                                             H=H,
                                                                             W=W,
                                                                             args=args)

                                rend_dict = generate_rend_masks(rend_dict, H, W, iterations)

                                rend_dict['mask'] = torch.argmax(rend_dict['predict'], dim=1)
                                rend_dict['gaze_img'] = rend_dict['mask_gaze']

                                rend_dict['mask'] = torch.clamp(rend_dict['mask'], min=0, max=255)
                                rend_dict['gaze_img'] = torch.clamp(rend_dict['gaze_img'], min=0, max=255)

                                rend_dict['mask'] = rend_dict['mask'].detach().cpu().numpy()
                                rend_dict['gaze_img'] = rend_dict['gaze_img'].detach().cpu().numpy()
                                # 将渲染掩码转换为 NumPy 数组并限制其范围

                                # mask = rend_dict['gaze_img']
                                # # 如果你选择的是第一个方法，显示单个通道
                                # mask_channel = mask[0, :, :]  # 选择第一个通道

                                # # 绘制掩码图像
                                # plt.figure(figsize=(6, 6))
                                # plt.imshow(mask_channel, cmap='gray')  # 使用灰度色彩映射显示掩码
                                # plt.colorbar()  # 添加颜色条，用于显示掩码值的范围
                                # plt.title('Mask Visualization')  # 设置标题
                                # plt.axis('off')  # 关闭坐标轴显示
                                # plt.show()  # 显示掩码图像

                            # 如果不满足条件，渲染掩码不可用
                            else:
                                available_predicted_mask = False

                            if torch.is_tensor(total_loss_rend):
                                total_loss_rend_value = total_loss_rend.item()
                                is_spike = spiker.update(total_loss_rend_value) if spiker else False
                            else:
                                # 记录渲染损失值
                                total_loss_rend_value = total_loss_rend

                            # 获取渲染结果的指标
                            batch_results_rend = get_metrics_rend(detach_cpu_numpy(rend_dict),
                                                                  detach_cpu_numpy(data_dict),
                                                                  batch_results_rend,
                                                                  image_resolution_diagonal,
                                                                  args,
                                                                  available_predicted_mask)
                            # 记录渲染总损失
                            batch_results_rend['loss/rend_total'] = total_loss_rend_value
                            for k in loss_dict_rend:
                                # 记录渲染损失字典中的各项损失
                                batch_results_rend[f'loss/rend_{k}'] = loss_dict_rend[k].item()
                        else:
                            total_loss_rend = 0.0
                    else:
                        # 初始化渲染损失、损失值、损失字典和渲染字典
                        total_loss_rend = 0.0
                        total_loss_rend_value = 0.0
                        loss_dict_rend = {}
                        rend_dict = {}
                    model_name = args['model']
                    #     #add loss in case we want to supervise the 3D Eye model or directly the UV point
                    if args['loss_w_supervise']:
                        if args['net_rend_head']:
                            # 计算注视向量的损失
                            total_supervised_loss_eye, loss_dict_supervised_eye = loss_fn_rend_sprvs(data_dict,
                                                                                                     rend_dict,
                                                                                                     args)
                            # 计算监督损失
                            batch_results_rend[f'loss/{model_name}_total'] = total_supervised_loss_eye.item()
                            # print("loss/eye_total", total_supervised_loss.item())
                            for k in loss_dict_supervised_eye:
                                # 记录监督损失
                                batch_results_rend[f'loss/{model_name}_{k}'] = loss_dict_supervised_eye[k].item()

                    # 初始化椭圆分割损失、损失值、对抗损失和损失字典
                    is_spike = False

                # take metrics of the simply head
                if args['loss_w_supervise']:
                    if args['net_rend_head'] and not train_with_mask:
                        # 获取渲染头的指标（不使用掩码）
                        batch_results_rend = get_metrics_simple(rend_dict,
                                                                move_gpu(data_dict, rend_dict['pupil_c_UV'].device),
                                                                batch_results_rend,
                                                                image_resolution_diagonal,
                                                                args, "rend")

                # define losses
                loss_eye = total_loss_rend
                if args['loss_w_supervise']:
                    # 如果只有渲染头分支，计算总损失
                    loss_eye += args['loss_w_supervise_eye'] * total_supervised_loss_eye
                    # 更新损失尖峰检测器
                    is_spike = spiker.update(loss_eye.item()) if spiker else False


                if mode == 'train':
                    # 反向传播总损失
                    loss_eye.backward()

                    if not is_spike:
                        # Gradient clipping, if needed, goes here
                        if args['grad_clip_norm'] > 0:
                            # 如果设置了梯度裁剪，执行梯度裁剪
                            grad_norm = torch.nn.utils.clip_grad_norm_(net.parameters(),
                                                                       max_norm=args['grad_clip_norm'],
                                                                       norm_type=2)
                        print("Gaze loss:", loss_eye.item())
                        optimizer.step()
                    else:
                        # 如果检测到损失尖峰，打印警告信息
                        total_norm = np.inf
                        print('-------------')
                        print('Spike detected! Loss: {}'.format(loss_eye.item()))
                        print('-------------')

                # Zero out gradients no matter what
                # 清零优化器和网络的梯度
                optimizer.zero_grad()
                net.zero_grad()

            # Merge metrics
            # 合并批次结果
            batch_results = merge_two_dicts(detach_cpu_numpy(batch_results_rend),
                                            detach_cpu_numpy(batch_results_gaze))

            # 记录损失和批次指标
            batch_results['loss'] = loss_eye.item()
            batch_metrics = aggregate_metrics([batch_results])
            metrics.append(batch_results)

            if args['exp_name'] != 'DEBUG':
                # 记录到wandb
                log_wandb(batch_metrics, rend_dict, data_dict, out_dict_gaze, loss_eye,
                          available_predicted_mask, mode, epoch, bt, H, W, args)

            # 如果满足条件，保存输出
            if (available_predicted_mask and args['net_rend_head'] and \
                    bt % args['produce_rend_mask_per_iter'] == 0):
                gt_dict = {}
                gt_dict['image'] = data_dict['image']
                gt_dict['pupil_center_available'] = data_dict['pupil_center_available']
                gt_dict['pupil_center'] = data_dict['pupil_center']
                gt_dict['pupil_ellipse_available'] = data_dict['pupil_ellipse_available']
                gt_dict['pupil_ellipse'] = data_dict['pupil_ellipse']
                gt_dict['mask'] = data_dict['mask']
                gt_dict['iris_ellipse_available'] = data_dict['iris_ellipse_available']
                gt_dict['iris_ellipse'] = data_dict['iris_ellipse']
                # Saving spiky conditions unnecessarily bloats the drive
                save_out(gt_dict, out_dict_eye, rend_dict, data_dict['image'], path_dict, mode,
                         is_spike, args, epoch, bt)
                # 如果满足条件且训练时使用掩码，保存输出
            elif (bt % args['produce_rend_mask_per_iter'] == 0 and train_with_mask):
                gt_dict = {}
                gt_dict['image'] = data_dict['image']
                gt_dict['pupil_center_available'] = data_dict['pupil_center_available']
                gt_dict['pupil_center'] = data_dict['pupil_center']
                gt_dict['pupil_ellipse_available'] = data_dict['pupil_ellipse_available']
                gt_dict['pupil_ellipse'] = data_dict['pupil_ellipse']
                gt_dict['mask'] = data_dict['mask']
                gt_dict['iris_ellipse_available'] = data_dict['iris_ellipse_available']
                gt_dict['iris_ellipse'] = data_dict['iris_ellipse']
                save_out(gt_dict, out_dict_eye, None, data_dict['image'], path_dict, mode,
                         is_spike, args, epoch, bt)

            del out_dict_gaze  # Explicitly free up memory 释放内存
            del out_dict_eye  # Explicitly free up memory 释放内存
            del rend_dict  # Explicitly free up memory 释放内存
            del batch_results_gaze
            del batch_results_rend
            torch.cuda.empty_cache()
        # 如果提供了csv保存目录且目录存在，则生成保存路径，否则将保存路径设置为None
        if csv_save_dir is not None:
            if os.path.isdir(csv_save_dir):
                csv_save_path = os.path.join(csv_save_dir, f'{mode}_raw_results.csv')
        else:
            csv_save_path = None
        # 调用aggregate_metrics函数聚合所有批次的指标并保存到csv文件中（如果csv_save_path不为None）
        results_dict = aggregate_metrics(metrics, csv_save_path)

        # 清除RAM中积累的数据
        del loader_iter
        # 清除CUDA缓存和RAM缓存
        torch.cuda.empty_cache()
        gc.collect()
        # 返回聚合后的结果字典
        return results_dict


if __name__ == '__main__':
    args = vars(make_args())
    Eyetrainer = EyeModelTrainer(args)
    Eyetrainer.run()

