import os
import torch
import pickle
import logging
import wandb

from timm.optim import Lamb as Lamb_timm
from timm.scheduler import CosineLRScheduler as CosineLRScheduler_timm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from models.models_mux import model_dict
from torch.utils.data import DataLoader
# from tensorboardX import SummaryWriter

from torch.utils.data.distributed import DistributedSampler

from scripts import forward
import helperfunctions.CurriculumLib as CurLib
from helperfunctions.CurriculumLib import DataLoader_riteyes
from helperfunctions.hfunctions import mod_scalar
from helperfunctions.utils import EarlyStopping, make_logger
from helperfunctions.utils import SpikeDetection, get_nparams
from helperfunctions.utils import move_to_single


def train(args, path_dict, validation_mode=False, test_mode=False):

    rank_cond = (args['local_rank'] == 0) or not args['do_distributed']
    rank_cond_early_stop = rank_cond 
    #deactivate tensorboard log
    rank_cond = False


    net_dict = []

    # %% Load model  加载网络模型
    if args['model'] == 'DenseElNet':
        norm = nn.BatchNorm3d
    else:
        norm = nn.BatchNorm2d
    net = model_dict[args['model']](args,
                                    norm=norm,
                                    act_func=F.leaky_relu)

    # %% Weight loaders
    # if it is pretrained, then load pretrained weights
    if args['pretrained'] or args['continue_training'] or args['weights_path']:

        if args['weights_path']:
            path_pretrained = args['weights_path']
        elif args['pretrained']:
            path_pretrained = os.path.join(path_dict['repo_root'],
                                           '..',
                                           'pretrained',
                                           'pretrained.git_ok')
        elif args['continue_training']:
            path_pretrained = os.path.join(args['continue_training'])


        net_dict = torch.load(path_pretrained,
                              map_location=torch.device('cuda'))
        state_dict_single = move_to_single(net_dict['state_dict'])
        net.load_state_dict(state_dict_single, strict=False)
        print(f'Pretrained model loaded from: {path_pretrained}')

    if test_mode:
        print('Test mode detected. Loading best model.')
        
        if args['path_model']:
            net_dict = torch.load(args['path_model'],
                                  map_location=torch.device('cuda'))
        else:
            net_dict = torch.load(os.path.join(path_dict['results'],
                                               'myPara.pt'),
                                  map_location=torch.device('cuda'))

        # 确保保存的参数与解析的参数匹配
        state_dict_single = move_to_single(net_dict['state_dict'])
        net.load_state_dict(state_dict_single, strict=False)

        # Do not initialize a writer
        writer = []
    elif validation_mode:
        print('Validation mode detected. Loading model.')
        if args['path_model']:
            net_dict = torch.load(args['path_model'],
                                  map_location=torch.device('cuda'))
        else:
            net_dict = torch.load(os.path.join(path_dict['results'],
                                               'myPara.pt'),
                                  map_location=torch.device('cuda'))

        # Ensure saved arguments match with parsed arguments
        state_dict_single = move_to_single(net_dict['state_dict'])
        net.load_state_dict(state_dict_single, strict=False)

        net_dict_ip = torch.load("cur_objs/pretrained/IP.pt",
                                 map_location=torch.device('cuda'))
        state_dict_single_ip = move_to_single(net_dict_ip['state_dict'])
        net.Ipnet.load_state_dict(state_dict_single_ip, strict=False)

        # Do not initialize a writer
        writer = []
    else:
        # Initialize tensorboard if rank 0
        if rank_cond:
            # writer = SummaryWriter(path_dict['logs'])
            writer = []
        else:
            writer = []
        net_dict = torch.load("cur_objs/pretrained/myPara.pt",
                              map_location=torch.device('cuda'))
        print('Loading model.')
        state_dict_single = move_to_single(net_dict['state_dict'])
        net.load_state_dict(state_dict_single, strict=False)

        net_dict_ip = torch.load("cur_objs/pretrained/IP.pt",
                                 map_location=torch.device('cuda'))
        state_dict_single_ip = move_to_single(net_dict_ip['state_dict'])
        net.Ipnet.load_state_dict(state_dict_single_ip, strict=False)


    if args['use_GPU']:
        net.cuda()


    # %% Initialize logger
    logger = make_logger(path_dict['logs']+'/train_log.log',
                         rank=args['local_rank'] if args['do_distributed'] else 0)
    # logger.write_summary(str(net.parameters))
    logger.write('# of parameters: {}'.format(get_nparams(net)))

    if test_mode==0 and validation_mode==0:
        logger.write('Training!')
        if args['exp_name'] != 'DEBUG':
            wandb.watch(net)
    elif validation_mode:
        logger.write('Validating!')
    else:
        logger.write('Testing!')

    # %% Training and validation loops or test only
    train_validation_loops(net,
                           net_dict,
                           logger,
                           args,
                           path_dict,
                           writer,
                           rank_cond,
                           rank_cond_early_stop,
                           validation_mode,
                           test_mode)

    # %% Closing functions and logging
    if writer:
        writer.close()


def train_validation_loops(net, net_dict, logger, args,
                           path_dict, writer, rank_cond, 
                           rank_cond_early_stop, 
                           validation_mode, test_mode):
    if args['use_pkl_for_dataload']:
        # 如果使用.pkl文件加载数据集对象
        path_cur_obj = os.path.join(path_dict['repo_root'],
                                    'cur_objs',
                                    args['mode'],
                                    'cond_' + args['cur_obj'] + 'S.pkl')
        # 组装.pkl文件的路径
        with open(path_cur_obj, 'rb') as f:
            # 打开.pkl文件并加载训练、验证和测试对象
            train_obj, valid_obj, test_obj = pickle.load(f)
            print("---use_pkl_for_dataload---")
    else:
        # 如果不使用.pkl文件加载数据集对象
        path_cur_obj = os.path.join(path_dict['repo_root'],
                                    'cur_objs',
                                    'dataset_selections' + '.pkl')
        # 组装dataset_selections.pkl文件的路径
        path2h5 = args['path_data']
        # 打开.pkl文件并加载DS_sel
        DS_sel = pickle.load(open(path_cur_obj, 'rb'))
        # 从主键路径读取所有数据集
        AllDS = CurLib.readArchives(args['path2MasterKey'])

        if (args['cur_obj'] == 'OpenEDS_S'):
            # 如果cur_obj是OpenEDS_S，选择S子集
            sel = 'S'
        else:
            # 否则选择cur_obj
            sel = args['cur_obj']

        if (args['cur_obj'] == 'Ours'):
            # 如果cur_obj是Ours，则进行以下操作

            # 训练和验证对象
            AllDS_cond = CurLib.selSubset(AllDS, DS_sel['train'][sel])
            dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='none', notest=False)
            train_obj = DataLoader_riteyes(dataDiv_obj, path2h5, 'train', True, (480, 640),
                                           scale=0.5, num_frames=args['frames'], args=args)
            valid_obj = DataLoader_riteyes(dataDiv_obj, path2h5, 'valid', False, (480, 640), sort='nothing',
                                           scale=0.5, num_frames=args['frames'], args=args)

            # 测试对象
            AllDS_cond = CurLib.selSubset(AllDS, DS_sel['test'][sel])
            dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='none', notest=False)
            test_obj = DataLoader_riteyes(dataDiv_obj, path2h5, 'test', False, (480, 640), sort='nothing',
                                          scale=0.5, num_frames=args['frames'], args=args)
        else:
            if 'none' not in args['exp_name']:
                # 如果在调试模式下，从.pkl文件加载数据
                with open('cur_objs/dataDiv_obj_train.pkl', 'rb') as f:
                    dataDiv_obj = pickle.load(f)
                    # print((dataDiv_obj.folds[0]))
                    with open('cur_objs/train_names.txt', 'w') as fhh:
                        for s in dataDiv_obj.arch:
                            fhh.write(s+'\n')
            else:
                # 训练和验证对象
                AllDS_cond = CurLib.selSubset(AllDS, DS_sel['train'][sel])
                dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='vanilla', notest=False)
            # file_name = 'dataDiv_obj_train.pkl'
            # with open(file_name, 'wb') as file:
            #     pickle.dump(dataDiv_obj, file)
            #     print(f'Pickle saved "{file_name}"')
            #     print(os.getcwd())
            train_obj = DataLoader_riteyes(dataDiv_obj, path2h5, 'train', True, (480, 640),
                                           scale=0.5, num_frames=args['frames'], args=args)
            valid_obj = DataLoader_riteyes(dataDiv_obj, path2h5, 'valid', False, (480, 640), sort='nothing',
                                           scale=0.5, num_frames=args['frames'], args=args)

            # if 'none' not in args['exp_name']:
            #     # 如果在调试模式下，从.pkl文件加载数据
            #     with open('cur_objs/dataDiv_obj_test.pkl', 'rb') as f:
            #         dataDiv_obj = pickle.load(f)
            #         # print((dataDiv_obj.folds[0]))
            #         with open('cur_objs/test_names.txt', 'w') as fhh:
            #             for s in dataDiv_obj.arch:
            #                 fhh.write(s+'\n')
            # else:
            #     # 测试对象
            #     AllDS_cond = CurLib.selSubset(AllDS, DS_sel['test'][sel])
            #     dataDiv_obj = CurLib.generate_fileList(AllDS_cond, mode='none', notest=False)
            #     # file_name = 'dataDiv_obj_test.pkl'
            # # with open(file_name, 'wb') as file:
            # #     pickle.dump(dataDiv_obj, file)
            # #     print(f'Pickle saved "{file_name}"')
            # #     print(os.getcwd())
            # test_obj = DataLoader_riteyes(dataDiv_obj, path2h5, 'test', False, (480, 640), sort='nothing',
            #                               scale=0.5, num_frames=args['frames'], args=args)
            # if 'DEBUG' not in args['exp_name']:
            #  args['batches_per_ep'] = train_obj.__len__()
            # 如果使用.pkl文件加载数据集对象
            path_cur_obj = os.path.join(path_dict['repo_root'],
                                        'cur_objs',
                                        args['mode'],
                                        'cond_' + args['test_obj'] + '.pkl')
            # 组装.pkl文件的路径
            with open(path_cur_obj, 'rb') as f:
                # 打开.pkl文件并加载训练、验证和测试对象
                test_obj, _, _ = pickle.load(f)
                print("---use_pkl_for_dataload---")

    # FIXME Remove unwanted validation-train overlap
    print(f'Starting the procedure of removing unwanted train-val video overlap...')
    # 开始移除训练集和验证集之间不需要的视频重叠
    train_vid_ids = list(np.unique(train_obj.imList[:, :, 1]))
    val_vid_ids = list(np.unique(valid_obj.imList[:, :, 1]))
    # 获取训练集中的所有唯一视频ID

    for vid_id in np.unique(valid_obj.imList[:, :, 1]):
        # 遍历验证集中的所有唯一视频ID
        if vid_id in train_vid_ids:
            # 如果视频ID在训练集中出现过
            print(f'Discarded valid overlap video_id:{vid_id}')
            # 打印被丢弃的验证集重叠视频ID
            bad_ids = ((valid_obj.imList[:, :, 1] == vid_id).sum(axis=-1) > 0)
            # 找出所有包含该视频ID的验证集样本
            valid_obj.imList = valid_obj.imList[~bad_ids]
            # 从验证集中移除这些样本
    print("train_vid_ids:", train_vid_ids)
    print("valid_vid_ids:",val_vid_ids)

    # FIXME Subselect 100k test images
    # print('Sub-selecting first 100k test frames')
    # # 选择前100,000个测试帧
    # test_cutoff = int(100000 / test_obj.imList.shape[1])
    # # 计算需要的测试视频数目
    # test_obj.imList = test_obj.imList[:test_cutoff]
    # # 保留前test_cutoff个测试视频

    print(f'')
    print(f'')
    print(f'Number of images:')
    print(f'Train images left: {train_obj.imList.shape[0] * train_obj.imList.shape[1]}')
    print(f'Valid images left: {valid_obj.imList.shape[0] * valid_obj.imList.shape[1]}')
    print(f'Test images left: {test_obj.imList.shape[0] * test_obj.imList.shape[1]}')
    # 打印训练集、验证集和测试集中剩余的图像数量
    print(f'')
    print(f'')

    # %% Specify flags of importance
    train_obj.augFlag = args['aug_flag']
    # 设置训练集数据增强标志
    valid_obj.augFlag = False
    # 验证集不进行数据增强
    test_obj.augFlag = False
    # 测试集不进行数据增强

    train_obj.equi_var = args['equi_var']
    # 设置训练集等效变量
    valid_obj.equi_var = args['equi_var']
    # 设置验证集等效变量
    test_obj.equi_var = args['equi_var']
    # 设置测试集等效变量

    # %% Modify path information
    train_obj.path2data = path_dict['path_data']
    # 设置训练集数据路径
    valid_obj.path2data = path_dict['path_data']
    # 设置验证集数据路径
    test_obj.path2data = path_dict['path_data']
    # 设置测试集数据路径

    # %% Modify scale at which we are working
    train_obj.scale = args['scale_factor']
    # 设置训练集使用的缩放比例
    valid_obj.scale = args['scale_factor']
    # 设置验证集使用的缩放比例
    test_obj.scale = args['scale_factor']
    # 设置测试集使用的缩放比例

    # %% Create distributed samplers
    # 创建分布式采样器用于训练集
    train_sampler = DistributedSampler(train_obj,
                                       rank=args['local_rank'],
                                       shuffle=False,
                                       num_replicas=args['world_size'],
                                       )

    # 创建分布式采样器用于验证集
    valid_sampler = DistributedSampler(valid_obj,
                                       rank=args['local_rank'],
                                       shuffle=False,
                                       num_replicas=args['world_size'],
                                       )

    # 创建分布式采样器用于测试集
    test_sampler = DistributedSampler(test_obj,
                                      rank=args['local_rank'],
                                      shuffle=False,
                                      num_replicas=args['world_size'],
                                      )


    # %% Define dataloaders
    logger.write('Initializing loaders')
    # 记录初始化加载器
    # 如果处于验证模式
    if validation_mode:
        # 创建验证集数据加载器
        valid_loader = DataLoader(valid_obj,
                                  shuffle=False,
                                  num_workers=args['workers'],
                                  drop_last=True,
                                  pin_memory=True,
                                  batch_size=args['batch_size'],
                                  sampler=valid_sampler if args['do_distributed'] else None,
                                  )
    # 如果处于测试模式
    elif test_mode:
        # 创建测试集数据加载器
        test_loader = DataLoader(test_obj,
                                 shuffle=False,
                                 num_workers=args['workers'],
                                 drop_last=True,
                                 pin_memory=True,
                                 batch_size=args['batch_size'],
                                 sampler=test_sampler if args['do_distributed'] else None,
                                 )

    else:
        # 如果处于训练模式
        # 创建训练集数据加载器
        train_loader = DataLoader(train_obj,
                                  shuffle=args['random_dataloader'],
                                  num_workers=args['workers'],
                                  drop_last=True,
                                  pin_memory=True,
                                  batch_size=args['batch_size'],
                                  sampler=train_sampler if args['do_distributed'] else None,
                                  )

        # 创建验证集数据加载器
        valid_loader = DataLoader(valid_obj,
                                  shuffle=False,
                                  num_workers=args['workers'],
                                  drop_last=True,
                                  pin_memory=True,
                                  batch_size=args['batch_size'],
                                  sampler=valid_sampler if args['do_distributed'] else None,
                                  )


    # %% Early stopping criterion
    if '3D' in args['early_stop_metric'] or '2D' in args['early_stop_metric']:
        # 如果早停指标是3D或2D
        early_stop = EarlyStopping(metric=args['early_stop_metric'],
                                   patience=args['early_stop'],
                                   verbose=True,
                                   delta=0.001,  # 需要0.1%的改进
                                   rank_cond=rank_cond_early_stop,
                                   mode='min',
                                   fName='Mybest_model.pt',
                                   path_save=path_dict['results'],
                                   )
        # 创建早停对象，目标是最小化指标
    else:
        early_stop = EarlyStopping(metric=args['early_stop_metric'],
                                   patience=args['early_stop'],
                                   verbose=True,
                                   delta=0.001,  # 需要0.1%的改进
                                   rank_cond=rank_cond_early_stop,
                                   mode='max',
                                   fName='Mybest_model.pt',
                                   path_save=path_dict['results'],
                                   )
        # 创建早停对象，目标是最大化指标

    # %% Define alpha and beta scalars
    # 通过动态调整标量在训练过程中的值，来逐步引入或强调某些损失项，以达到更好的训练效果
    if args['curr_learn_losses']:
        # 如果当前学习损失的标志被设置
        alpha_scalar = mod_scalar([0, args['epochs']], [0, 1])
        # 定义alpha标量，其在训练期间从0到1变化
        beta_scalar = mod_scalar([10, 20], [0, 1])
        # 定义beta标量，其在第10到20个epoch之间从0到1变化

    # %% Optimizer
    # 获取所有不包含'adv'的模型参数
    # param_list_enc = [param for name, param in net.featureExtractor.named_parameters() if 'adv' not in name]
    # param_list_eye = [param for name, param in net.eyeBallBranch.named_parameters() if 'adv' not in name]
    # param_list_gaze = [param for name, param in net.gazeBranch.named_parameters() if 'adv' not in name]
    param_list= [param for name, param in net.named_parameters() if 'adv' not in name]

    # 学习器和优化器
    if 'LAMB' in args['optimizer_type']: 
        # 如果使用LAMB优化器，创建LAMB优化器实例
        # optimizer_enc = Lamb_timm(param_list_enc, lr=args['lr'], weight_decay=args['wd'])
        # # 如果使用AdamW优化器和余弦调度器，创建AdamW优化器实例
        # warmup_epochs = 5
        # scheduler_enc = CosineLRScheduler_timm(optimizer_enc,
        #                                        t_initial=args['epochs'],
        #                                        lr_min=args['lr'] / 100.0,
        #                                        warmup_t=4,
        #                                        warmup_lr_init=args['lr'] / 100.0)

        # optimizer_gaze = Lamb_timm(param_list_gaze, lr=args['lr'], weight_decay=args['wd'])
        # # 如果使用AdamW优化器和余弦调度器，创建AdamW优化器实例
        # warmup_epochs = 5
        # scheduler_gaze = CosineLRScheduler_timm(optimizer_gaze,
        #                                         t_initial=args['epochs'],
        #                                         lr_min=args['lr'] / 100.0,
        #                                         warmup_t=4,
        #                                         warmup_lr_init=args['lr'] / 100.0)

        # optimizer_eye = Lamb_timm(param_list_eye, lr=args['lr'], weight_decay=args['wd'])
        # # 如果使用AdamW优化器和余弦调度器，创建AdamW优化器实例
        # warmup_epochs = 5
        # scheduler_eye = CosineLRScheduler_timm(optimizer_eye,
        #                                        t_initial=args['epochs'],
        #                                        lr_min=args['lr'] / 100.0,
        #                                        warmup_t=4,
        #                                        warmup_lr_init=args['lr'] / 100.0)
        optimizer = Lamb_timm(param_list, lr=args['lr'], weight_decay=args['wd'])
        # 如果使用AdamW优化器和余弦调度器，创建AdamW优化器实例
        warmup_epochs = 5
        scheduler = CosineLRScheduler_timm(optimizer,
                                               t_initial=args['epochs'],
                                               lr_min=args['lr'] / 100.0,
                                               warmup_t=4,
                                               warmup_lr_init=args['lr'] / 100.0)
        # 使用余弦学习率调度器
        use_sched = True
        # 设置调度器标志为True

    elif 'adamw_cos' in args['optimizer_type']:
        optimizer = torch.optim.AdamW(param_list, lr=args['lr'], betas=(0.9, 0.99), weight_decay=args['wd'])
        # 如果使用AdamW优化器和余弦调度器，创建AdamW优化器实例
        warmup_epochs = 5
        scheduler = CosineLRScheduler_timm(optimizer,
                                           t_initial=args['epochs'],
                                           lr_min=args['lr'] / 100.0,
                                           warmup_t=4,
                                           warmup_lr_init=args['lr'] / 100.0)

        optimizer = torch.optim.AdamW(param_list, lr=args['lr'], betas=(0.9, 0.99), weight_decay=args['wd'])
        # 使用余弦学习率调度器
        use_sched = True
        # 设置调度器标志为True

    elif 'adamw_step' in args['optimizer_type']:
        optimizer = torch.optim.AdamW(param_list, lr=args['lr'], betas=(0.9, 0.99), weight_decay=args['wd'])
        # 如果使用AdamW优化器和阶梯调度器，创建AdamW优化器实例
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
        # 使用阶梯学习率调度器
        use_sched = True
        # 设置调度器标志为True

    elif 'adam_cos' in args['optimizer_type']:
        optimizer = torch.optim.Adam(param_list, lr=args['lr'], betas=(0.9, 0.99), weight_decay=args['wd'])
        # 如果使用Adam优化器和余弦调度器，创建Adam优化器实例
        warmup_epochs = 5
        scheduler = CosineLRScheduler_timm(optimizer,
                                           t_initial=args['epochs'],
                                           lr_min=args['lr'] / 1000.0,
                                           warmup_t=4,
                                           warmup_lr_init=args['lr'] / 100.0)
        # 使用余弦学习率调度器
        use_sched = True
        # 设置调度器标志为True

    elif 'adam_step' in args['optimizer_type']:
        optimizer = torch.optim.Adam(param_list, lr=args['lr'], betas=(0.9, 0.99), weight_decay=args['wd'])
        # 如果使用Adam优化器和阶梯调度器，创建Adam优化器实例
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
        # 使用阶梯学习率调度器
        use_sched = True
        # 设置调度器标志为True

    else:
        optimizer = torch.optim.Adam(param_list, lr=args['lr'], amsgrad=False, weight_decay=args['wd'])
        # 默认情况下，使用Adam优化器
        use_sched = False
        # 不使用调度器

    if args['adv_DG']:
        param_list = [param for name, param in net.named_parameters() if 'adv' in name]
        # 如果使用对抗性领域泛化，获取所有包含'adv'的模型参数
        optimizer_disc = torch.optim.Adam(param_list, lr=args['lr'], amsgrad=True)
        # 为对抗性参数创建Adam优化器实例
    else:
        optimizer_disc = False
        # 不使用对抗性优化器

    # %% Loops and what not

    # Create a checkpoint based on current scores
    checkpoint = {}
    checkpoint['args'] = args  # Save arguments
    # 创建一个检查点并保存当前参数

    # Randomize the dataset again the next time you exit
    # to the main loop.
    args['time_to_update'] = True
    # 下次退出主循环时重新随机化数据集
    last_epoch_validation = False
    # 设置最后一个epoch验证标志为False

    # specify the mode for forward and save results

    # 测试模式
    if test_mode:
        logging.info('Entering test mode only ...')
        logger.write('Entering test mode only ...')
        # 如果处于测试模式，记录进入测试模式

        args['alpha'] = 0.5
        args['beta'] = 0.5
        # 设置alpha和beta为0.5
        # 运行测试并获取测试结果
        test_result = forward(net,
                              [],
                              logger,
                              test_loader,
                              optimizer,
                              args,
                              path_dict,
                              writer=writer,
                              rank_cond=rank_cond,
                              epoch=0,
                              mode='test',
                              batches_per_ep=len(test_loader) if 'DEBUG' not in args['exp_name'] else 10,
                              # batches_per_ep=args['batches_per_ep'],
                              last_epoch_valid=True,
                              csv_save_dir=path_dict['exp'])
        # 将测试结果保存到检查点
        checkpoint['test_result'] = test_result
        print("test_result:", test_result)


        epoch = 0
        if args['exp_name'] != 'DEBUG':
            logger.write('Test results:')
            for key, item in checkpoint['test_result'].items():
                if 'mean' in key and 'loss' not in key:
                    # 判断 item 是否为 NumPy 数组
                    if isinstance(item, np.ndarray):
                        # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                        formatted_item = [f'{val:.3f}' for val in item.tolist()]
                        # 使用逗号分隔格式化后的字符串
                        formatted_string = ', '.join(formatted_item)
                        wandb.log({'test_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                        logger.write(f'test_avg/{key}: {formatted_string}')
                    else:
                        wandb.log({'test_avg/{}'.format(key): item, 'epoch': epoch})
                        logger.write(f'test_avg/{key}: {item:.3f}')
            logger.write(' ')
            logger.write(' ')
            logger.write(' ')
            # 记录测试结果

        if args['save_results_here']:
            # 确保结果保存目录存在
            os.makedirs(os.path.dirname(args['save_results_here']),
                        exist_ok=True)

            # Save out test results here instead 将测试结果保存到指定位置
            with open(args['save_results_here'], 'wb') as f:
                pickle.dump(checkpoint, f)
        else:
            # Ensure the directory exists 确保默认结果保存目录存在
            os.makedirs(path_dict['results'], exist_ok=True)

            # Save out the test results 将测试结果保存到默认位置
            with open(path_dict['results'] + '/test_results.pkl', 'wb') as f:
                pickle.dump(checkpoint, f)

    # 验证模式
    elif validation_mode:
        logger.write('Entering validation mode only ...')
        logging.info('Entering validation mode  only...')
        args['alpha'] = 0.5
        args['beta'] = 0.5
        # 如果处于验证模式，记录进入验证模式并设置alpha和beta为0.5
        # valid_batches_per_ep = int(4608 / (args['frames'] * args['batch_size']))
        valid_result = forward(net,
                               [],
                               logger,
                               valid_loader,
                               optimizer,
                               args,
                               path_dict,
                               writer=writer,
                               rank_cond=rank_cond,
                               epoch=0,
                               mode='valid',
                               # batches_per_ep=len(valid_loader) if 'DEBUG' not in args['exp_name'] else 10,
                               # batches_per_ep=args['batches_per_ep'],
                               batches_per_ep=len(valid_loader),
                               last_epoch_valid=True,
                               csv_save_dir=path_dict['exp'])
        # 运行验证并获取验证结果
        print("valid_result:", valid_result)

        checkpoint['valid_result'] = valid_result
        # 将验证结果保存到检查点

        epoch = 0
        if args['exp_name'] != 'DEBUG':
            logger.write('Validation results:')
            for key, item in checkpoint['valid_result'].items():
                # 判断 item 是否为 NumPy 数组
                if isinstance(item, np.ndarray):
                    # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                    formatted_item = [f'{val:.3f}' for val in item.tolist()]
                    # 使用逗号分隔格式化后的字符串
                    formatted_string = ', '.join(formatted_item)
                    wandb.log({'valid_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                    logger.write(f'valid_avg/{key}: {formatted_string}')
                else:
                    wandb.log({'valid_avg/{}'.format(key): item, 'epoch': epoch})
                    logger.write(f'valid_avg/{key}: {item:.3f}')
            logger.write(' ')
            logger.write(' ')
            logger.write(' ')
            # 记录验证结果

        if args['save_results_here']:
            # Ensure the directory exists
            os.makedirs(os.path.dirname(args['save_results_here']),
                        exist_ok=True)
            # 确保结果保存目录存在

            # Save out test results here instead
            with open(args['save_results_here'], 'wb') as f:
                pickle.dump(checkpoint, f)
            # 将验证结果保存到指定位置
        else:
            # Ensure the directory exists
            os.makedirs(path_dict['results'], exist_ok=True)
            # 确保默认结果保存目录存在

            # Save out the test results
            with open(path_dict['results'] + '/valid_results.pkl', 'wb') as f:
                # 将验证结果保存到默认位置
                pickle.dump(checkpoint, f)

    # 训练模式
    else:
        # 如果remove_spikes参数为True，则创建SpikeDetection实例，否则为False
        spiker = SpikeDetection() if args['remove_spikes'] else False
        logging.info('Entering train mode ...')
        logger.write('Entering train mode ...')
        # 记录进入训练模式

        if args['continue_training']:
            optimizer.load_state_dict(net_dict['optimizer'])
            epoch = net_dict['epoch'] + 1
            # 如果继续训练，从保存的优化器状态加载并设置起始epoch
        else:
            epoch = 0
            # 否则从第0个epoch开始

        # Disable early stop and keep training until it maxes out, this allows
        # us to test at the regular best model while saving intermediate result
        # while (epoch < args['epochs']) and not early_stop.early_stop:
        # 禁用早停，保持训练直到达到最大epoch，这允许在保存中间结果的同时测试最佳模型

        while (epoch < args['epochs']):
            # 当epoch小于设定的最大epoch时，继续训练
            if args['time_to_update']:
                # 如果需要更新数据集

                # Toggle flag back to False
                args['time_to_update'] = False
                # 重置标志为False

                if args['one_by_one_ds']:
                    train_loader.dataset.sort('one_by_one_ds', args['batch_size'])
                    valid_loader.dataset.sort('one_by_one_ds', args['batch_size'])
                    # 按照one_by_one_ds排序训练和验证数据集
                else:
                    # for sequence dataset:
                    train_loader.dataset.sort('ordered')
                    valid_loader.dataset.sort('ordered')
                    # 按照ordered排序训练和验证数据集

            # Set epochs for samplers
            train_sampler.set_epoch(epoch)
            valid_sampler.set_epoch(epoch)
            # 为采样器设置当前epoch

            # %%
            logging.info('Starting epoch: %d' % epoch)
            logger.write('Starting epoch: %d' % epoch)
            # 记录当前开始的epoch

            if args['curr_learn_losses']:
                args['alpha'] = alpha_scalar.get_scalar(epoch)
                args['beta'] = beta_scalar.get_scalar(epoch)
                # 如果使用动态学习损失，获取当前epoch的alpha和beta值
            else:
                args['alpha'] = 0.5
                args['beta'] = 0.5
                # 否则将alpha和beta设置为0.5

            if args['dry_run']:
                train_batches_per_ep = len(train_loader)
                valid_batches_per_ep = len(valid_loader)
                # 如果是dry run模式，设置每个epoch的训练和验证批次数为数据集长度
            else:
                train_batches_per_ep = args['batches_per_ep']
                if args['reduce_valid_samples']:
                    valid_batches_per_ep = 10
                    # 如果减少验证样本，设置每个epoch的验证批次数为10
                else:
                    valid_batches_per_ep = len(valid_loader)
                    # 否则根据参数计算每个epoch的验证批次数

            train_result = forward(net,
                                   spiker,
                                   logger,
                                   train_loader,
                                   optimizer,
                                   args,
                                   path_dict,
                                   optimizer_disc=optimizer_disc,
                                   writer=writer,
                                   rank_cond=rank_cond,
                                   epoch=epoch,
                                   mode='train',
                                   batches_per_ep=train_batches_per_ep)
                                   # batches_per_ep=50)
            # 进行训练并获取训练结果

            if epoch == args['epochs'] - 1:
                last_epoch_validation = True
                valid_batches_per_ep = len(valid_loader)
                # 如果是最后一个epoch，设置验证批次数为验证集长度

            # incase you want to validate the whole validation set just Remove True
            if (args['reduce_valid_samples'] and (epoch % args['perform_valid'] != 0)) \
                    or ('DEBUG' in args['exp_name']) or True:
                valid_result = forward(net,
                                       spiker,
                                       logger,
                                       valid_loader,
                                       optimizer,
                                       args,
                                       path_dict,
                                       writer=writer,
                                       rank_cond=rank_cond,
                                       epoch=epoch,
                                       mode='valid',
                                       batches_per_ep=valid_batches_per_ep,
                                       # batches_per_ep=50,
                                       last_epoch_valid=last_epoch_validation)
                # 如果减少验证样本且当前epoch不是需要验证的epoch，或在DEBUG模式下，进行部分验证
            else:
                valid_result = forward(net,
                                       spiker,
                                       logger,
                                       valid_loader,
                                       optimizer,
                                       args,
                                       path_dict,
                                       writer=writer,
                                       rank_cond=rank_cond,
                                       epoch=epoch,
                                       mode='valid',
                                       batches_per_ep=len(valid_loader),
                                       # batches_per_ep=50,
                                       last_epoch_valid=last_epoch_validation)
                # 否则进行完整验证

            # Update the check point weights. VERY IMPORTANT!
            checkpoint['state_dict'] = move_to_single(net.state_dict())
            # 计算参数总数
            total_params = sum(p.numel() for p in checkpoint['state_dict'].values())

            # 打印参数量
            print(f"Total parameters in propagator: {total_params}")

            checkpoint['optimizer'] = optimizer.state_dict()
            # 更新检查点中的模型和优化器状态

            checkpoint['epoch'] = epoch
            checkpoint['train_result'] = train_result
            checkpoint['valid_result'] = valid_result
            # 保存当前epoch和训练、验证结果

            if args['exp_name'] != 'DEBUG':

                logger.write('---Train results:---')
                for key, item in checkpoint['train_result'].items():
                    # 判断 item 是否为 NumPy 数组
                    if isinstance(item, np.ndarray):
                        # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                        formatted_item = [f'{val:.3f}' for val in item.tolist()]
                        # 使用逗号分隔格式化后的字符串
                        formatted_string = ', '.join(formatted_item)
                        wandb.log({'train_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                        logger.write(f'train_avg/{key}: {formatted_string}')
                    else:
                        wandb.log({'train_avg/{}'.format(key): item, 'epoch': epoch})
                        logger.write(f'train_avg/{key}: {item:.3f}')
                logger.write(' ')
                logger.write(' ')
                logger.write(' ')

                logger.write('---Validation results:---')
                for key, item in checkpoint['valid_result'].items():
                    if 'mean' in key and 'loss' not in key:
                        # 判断 item 是否为 NumPy 数组
                        if isinstance(item, np.ndarray):
                            # 将 NumPy 数组转换为 Python 列表，并对列表中的每个元素进行格式化
                            formatted_item = [f'{val:.3f}' for val in item.tolist()]
                            # 使用逗号分隔格式化后的字符串
                            formatted_string = ', '.join(formatted_item)
                            wandb.log({'valid_avg/{}'.format(key): item.mean(), 'epoch': epoch})
                            logger.write(f'valid_avg/{key}: {formatted_string}')
                        else:
                            wandb.log({'valid_avg/{}'.format(key): item, 'epoch': epoch})
                            logger.write(f'valid_avg/{key}: {item:.3f}')
                logger.write(' ')
                logger.write(' ')
                logger.write(' ')
                # 记录验证结果并将其上传到wandb

            # Save out the best validation result and model
            early_stop(checkpoint)
            # 使用早停机制保存最佳验证结果和模型

            if args['exp_name'] != 'DEBUG':
                wandb.log({'val_score': checkpoint['valid_result']['score_mean'], 'epoch': epoch})
                wandb.log({'gaze_3D_ang_deg_mean': checkpoint['valid_result']['gaze_3D_ang_deg_mean'], 'epoch': epoch})
                # 上传验证得分和3D gaze角度平均值到wandb

            # If epoch is a multiple of args['save_every'], then write out
            if (epoch % args['save_every']) == 0:
                # 如果当前epoch是save_every的倍数，保存模型

                # Ensure that you do not update the validation score at this
                # point and simply save the model
                if '3D' in args['early_stop_metric']:
                    temp_score = checkpoint['valid_result']['gaze_3D_ang_deg_mean']
                elif '2D' in args['early_stop_metric']:
                    temp_score = checkpoint['valid_result']['gaze_ang_deg_mean']
                else:
                    temp_score = checkpoint['valid_result']['masked_rendering_iou_mean']
                early_stop.save_checkpoint(temp_score,
                                           checkpoint,
                                           update_val_score=False,
                                           use_this_name_instead='myPara.pt')
                # 保存当前模型状态为last.pt，不更新验证得分

            if use_sched:
                scheduler.step(epoch=epoch)
                if args['exp_name'] != 'DEBUG':
                    wandb.log({'lr': optimizer.param_groups[0]['lr']})
                # 如果使用学习率调度器，更新学习率并记录到wandb

            epoch += 1
            # 增加epoch计数器


if __name__ == '__main__':
    print('Entry script is run.py')
