#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""
from pprint import pprint
import argparse
import getpass
import socket
import os

# 导入getpass模块以获取当前用户名
username = getpass.getuser()
# 导入socket模块以获取当前主机名
host = socket.gethostname()
# print(username)
# print(host)
# 检查当前用户名和主机名是否分别为'nipopovic'和'archer'
if username == 'NickX' and host == 'Xiao':
    # 设置环境变量WANDB_DIR
    os.environ["WANDB_DIR"] = "D:/MyData/PDGE/GazeBranch/PD_Gaze/results/PDGaze"
    # 设置masterkey_root路径
    masterkey_root = 'D:/Xiao/DataSet/TEyeD//MasterKey'
    # 设置dataset_root路径
    dataset_root_source = 'D:/Xiao/DataSet/TEyeD/All'
    dataset_root_target = 'D:/Xiao/DataSet/TEyeD/All'
    # 设置results_root路径
    results_root = 'D:/MyData/PDGE/GazeBranch/PD_Gaze/results/PDGaze'
    baseline_results_root = 'D:/MyData/PDGE/GazeBranch/PD_Gaze/results'
    # 设置default_repo路径
    default_repo = 'D:/MyData/PDGE/GazeBranch/PD_Gaze'
    # 打印当前用户和机器信息
    print('User:NickX----machine:Xiao')

else:
    # 设置环境变量WANDB_DIR
    os.environ["WANDB_DIR"] = "/ai/mnt/Code/TEyeD/GazeBranch/ResNet/PD_Gaze/results/PDGaze"
    # 设置masterkey_root路径
    masterkey_root = '/ai/mnt/DataSet/LPW/MasterKey'
    # 设置dataset_root路径
    dataset_root_source = '/ai/mnt/DataSet/LPW/All'
    dataset_root_target = '/ai/mnt/DataSet/LPW/All'
    # 设置results_root路径
    results_root = '/ai/mnt/Code/TEyeD/GazeBranch/ResNet/PD_Gaze/results/PDGaze'
    # 设置default_repo路径
    default_repo = '/ai/mnt/Code/TEyeD/GazeBranch/ResNet/PD_Gaze'
    # 打印当前用户和机器信息
    print('User:DUDULong----machine:Xiao')


def make_args():
    # 创建参数解析器对象
    parser = argparse.ArgumentParser()
    # 添加实验名称参数
    parser.add_argument('--exp_name', type=str, default='Test',
                        help='experiment string or identifier')
    # 添加是否使用pkl文件加载数据的参数
    parser.add_argument('--use_pkl_for_dataload', type=bool, default=False,
                        help='pkl to load the data')
    # 添加每隔多少次迭代生成一次掩码的参数
    parser.add_argument('--produce_rend_mask_per_iter', type=int ,default=2000,
                        help='set the num of iteration to generate the mask')
    # 添加执行验证的参数，每隔多少次迭代进行一次验证
    parser.add_argument('--perform_valid', type=int ,default=25,
                        help='perform validation')
    # ######
    # %% 超参数
    # ######
    parser.add_argument('--lr', type=float, default=2e-3,
                        help='base learning rate')
    parser.add_argument('--wd', type=float, default=0,
                        help='weight_decay')
    # 添加种子值参数，用于确保可重复性
    parser.add_argument('--seed', type=int, default=108,
                        help='seed valexp_nameue for all packages')
    parser.add_argument('--batch_size', type=int, default=4,
                        help='batch size for training')
    parser.add_argument('--feature_save_num', type=int, default=10000,
                        help='feature_save_num')
    parser.add_argument('--prefix', type=str, default='TrainSet',
                        help='experiment string or identifier')
    # 添加本地GPU编号参数
    parser.add_argument('--local_rank', type=int, default=0,
                        help='rank to set GPU')
    parser.add_argument('--lr_decay', type=int, default=0,
                        help='learning rate decay')
    parser.add_argument('--dropout', type=float, default=0.0,
                        help='dropout? anything above 0 activates dropout')

    # ######
    # %% 模型特定参数
    # ######
    parser.add_argument('--base_channel_size', type=int, default=32,
                        help='base channel size around which model grows')  # 添加基础通道大小参数
    parser.add_argument('--growth_rate', type=float, default=1.2,
                        help='growth rate of channels in network')  # 添加网络中通道增长率参数
    parser.add_argument('--track_running_stats', type=int, default=0,
                        help='disable running stats for better transfer') # 添加禁用运行时统计信息参数，用于更好的迁移学习
    parser.add_argument('--extra_depth', type=int, default=0,
                        help='extra convolutions to the encoder') # 添加额外的卷积层参数，添加到编码器中
    parser.add_argument('--grad_rev', type=int, default=0,
                        help='gradient reversal for dataset identity') # 添加数据集身份的梯度反转参数
    parser.add_argument('--adv_DG', type=int, default=0,
                        help='enable discriminator') # 添加启用鉴别器参数
    parser.add_argument('--equi_var', type=int, default=0,
                        help='normalize data to respect image dimensions') # 添加标准化数据以适应图像维度参数
    parser.add_argument('--num_blocks', type=int, default=4,
                        help='number of encoder decoder blocks') # 添加编码器解码器块数量参数
    parser.add_argument('--use_frn_tlu', type=int, default=0,
                        help='replace BN+L.RELU with FRN+TLU') # 添加使用FRN+TLU替换BN+L.RELU的参数
    parser.add_argument('--use_instance_norm', type=int, default=0,
                        help='replace BN with IN') # 添加使用实例归一化替换批归一化的参数
    parser.add_argument('--use_group_norm', type=int, default=0,
                        help='replace BN with GN, 8 channels per group') # 添加使用组归一化替换批归一化的参数，每组8个通道
    parser.add_argument('--use_ada_instance_norm', type=int, default=0,
                        help='use adaptive instance normalization') # 添加使用自适应实例归一化的参数
    parser.add_argument('--use_ada_instance_norm_mixup', type=int, default=0,
                        help='use adaptive instance normalization with mixup') # 添加使用自适应实例归一化和mixup的参数
    parser.add_argument('--diff_threshold', type=float, default=0.3,
                        help='差分算法阈值')

    # ######
    # %% 判别统计参数
    # ######
    parser.add_argument('--disc_base_channel_size', type=int, default=8,
                        help='discriminator base channels?')
    '''
    关于判别器通道大小的注释。
    Rakshit - 根据我的实验，我发现判别器很容易区分域，增加更多通道只是占用内存，而不会增加判别能力。因此，我将其保持在较小的值。
    '''
    # ######
    # %% 实验参数
    # ######
    parser.add_argument('--path_exp_tree', type=str,
                        default=results_root,
                        help='path to all experiments result folder')
    parser.add_argument('--path_data_source', type=str,
                        default=dataset_root_source,
                        help='path to all H5 file data  train')
    parser.add_argument('--path_data_target', type=str,
                        default=dataset_root_target,
                        help='path to all H5 file data  test')
    parser.add_argument('--path2MasterKey', type=str,
                        default=masterkey_root,
                        help='path to all MasterKey file data')
    parser.add_argument('--pathResult', type=str,
                        default=baseline_results_root,
                        help='path to all MasterKey file data')
    parser.add_argument('--path_model', type=str, default=[],
                        help='path to model for test purposes')
    parser.add_argument('--repo_root', type=str,
                        default=default_repo,
                        help='path to repo root')
    parser.add_argument('--reduce_valid_samples', type=int, default=0,
                        help='reduce the number of\
                            validaton samples to speed up')    # 减少验证样本数量以加快速度
    parser.add_argument('--save_every', type=int, default=1,
                        help='save weights every 1 iterations') # 每1次迭代保存一次权重

    # ######
    # %% 训练或测试参数
    # ######
    parser.add_argument('--mode', type=str, default='one_vs_one',
                        help='training mode:\
                            one_vs_one, all_vs_one, all-one_vs_one')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of training epochs')
    parser.add_argument('--cur_obj', type=str, default='TEyeD',
                        help='which dataset to train on or remove?\
                            in all_vs_one, this flag does nothing')
    parser.add_argument('--test_obj', type=str, default='TEyeD',
                        help='which dataset to test?\
                            in all_vs_one, this flag does nothing')                        
    parser.add_argument('--aug_flag', type=int, default=1,
                        help='enable augmentations?') # 是否启用数据增强
    parser.add_argument('--one_by_one_ds', type=int, default=0,
                        help='train on a single dataset, one after the other') # 依次在个数据集上训练
    parser.add_argument('--early_stop', type=int, default=20,
                        help='early stop epoch count') # 提前停止轮数
    parser.add_argument('--mixed_precision', type=int, default=0,
                        help='enable mixed precision training and testing') # 启用混合精度训练和测试
    parser.add_argument('--batches_per_ep', type=int, default=10,
                        help='number of batches per training epoch') # 每个训练轮次的批次数
    parser.add_argument('--use_GPU', type=int, default=0,
                        help='train on GPU?')
    parser.add_argument('--remove_spikes', type=int, default=1,
                        help='remove noisy batches for smooth training') # 移除噪声批次以平滑训练
    parser.add_argument('--pseudo_labels', type=int, default=0,
                        help='generate pseudo labels on datasets with missing\
                            labels') # 在缺少标签的数据集上生成伪标签
    parser.add_argument('--frames', type=int, default=4,
                        help='number of frames that is used') # 使用的帧数

    # ######
    # %% 模型特定参数
    # ######
    parser.add_argument('--use_scSE', type=int, default=0,
                        help='在每个编码器或解码器块的末尾使用并行空间和通道激励')
    parser.add_argument('--make_aleatoric', type=int, default=0,
                        help='在潜在回归过程中添加不确定性公式')
    parser.add_argument('--scale_factor', type=float, default=0.0,
                        help='修改缩放因子')
    parser.add_argument('--make_uncertain', type=int, default=0,
                        help='激活不确定性和认知不确定性')
    parser.add_argument('--continue_training', type=str, default='',
                        help='从这些权重继续训练')
    parser.add_argument('--regression_from_latent', type=int, default=1,
                        help='禁用从潜在空间进行回归？')
    parser.add_argument('--curr_learn_losses', type=int, default=1,
                        help='添加两个斜坡器并根据需要使用它们')
    parser.add_argument('--regress_channel_grow', type=float, default=0,
                        help='在回归模块中增长通道。默认值0表示通道大小保持不变。')
    parser.add_argument('--maxpool_in_regress_mod', type=int, default=-1,
                        help='在回归模块中用最大池化替换平均池化，如果为-1，则禁用池化')
    parser.add_argument('--dilation_in_regress_mod', type=int, default=1,
                        help='在回归模块中启用膨胀')
    parser.add_argument('--groups', type=int, default=1,
                        help='组大小？默认：所有，即groups=1')
    # parser.add_argument('--Gaze_Train',  action='store_true', default=True,
    #                     help='是否训练Gaze分支')
    # parser.add_argument('--Eye_Train',  action='store_true', default=True,
    #                     help='是否训练Eye分支')

    # ######
    # %% 模型选择
    # ######
    parser.add_argument('--net_simply_head', action='store_true', default=True,
                        help='直接预测二维的眼球和瞳孔坐标')
    parser.add_argument('--net_simply_head_tanh', type=str, default=1,
                        help='为简单监督激活tanh')
    parser.add_argument('--net_ellseg_head', action='store_true', default=False,
                        help='计算分割头')
    parser.add_argument('--net_rend_head', action='store_true', default=False,
                        help='计算渲染头（3D眼模型）')
    parser.add_argument('--model', type=str, default='DenseEl0',
                        help='DenseElNet, RITNet')

    parser.add_argument('--train_data_percentage', type=float, default=1.0,
                        help='训练数据的百分比')
    parser.add_argument('--loss_w_supervise', type=float, default=0.0,
                        help='使用真实标签进行监督的损失组件')
    parser.add_argument('--loss_w_supervise_eye', type=float, default=0.0,
                        help='使用真实标签进行监督eye分支的损失组件')
    parser.add_argument('--loss_w_supervise_gaze', type=float, default=0.0,
                        help='使用真实标签进行监督gaze分支的损失组件')
    parser.add_argument('--loss_w_supervise_eyeball_center', type=float, default=0.0,
                        help='使用真实标签监督眼球中心的损失组件')
    parser.add_argument('--loss_w_supervise_pupil_center', type=float, default=0.0,
                        help='使用真实标签监督瞳孔中心的损失组件')
    parser.add_argument('--loss_w_supervise_gaze_vector_3D_L2', type=float, default=0.0,
                        help='使用真实标签监督凝视向量的L2损失组件')
    parser.add_argument('--loss_w_supervise_gaze_vector_3D_cos_sim', type=float, default=0.0,
                        help='使用真实标签监督凝视向量的余弦相似度损失组件')
    parser.add_argument('--loss_w_supervise_gaze_vector_UV', type=float, default=0.0,
                        help='使用真实标签监督凝视向量的UV损失组件')
    parser.add_argument('--loss_w_ellseg', type=float, default=0.0,
                        help='损失组件权重')
    parser.add_argument('--loss_rend_vectorized', action='store_true', default=False,
                        help='计算分割头')
    parser.add_argument('--temp_n_angles', type=int, default=100,
                        help='模板中离散角度的数量')
    parser.add_argument('--temp_n_radius', type=int, default=50,
                        help='模板中离散半径的数量')
    parser.add_argument('--loss_w_rend_gt_2_pred', type=float, default=0.0,
                        help='损失组件权重 0.15')
    parser.add_argument('--loss_w_rend_pred_2_gt', type=float, default=0.0,
                        help='损失组件权重 0.15')
    parser.add_argument('--loss_w_rend_pred_2_gt_edge', type=float, default=0.0,
                        help='损失组件权重 0.15')
    parser.add_argument('--loss_w_rend_pred_2_gt_edge_2D', type=float, default=0.0,
                        help='损失组件权重 0.15')
    parser.add_argument('--loss_w_rend_diameter', type=float, default=0.0,
                        help='损失组件直径权重 0.05')
    parser.add_argument('--random_dataloader', action='store_true', default=False,
                        help='随机打乱图像')

    parser.add_argument('--scale_bound_eye', type=str, default='version_0',
                        help='加载权重')

    # ######
    # %% 预训练条件
    # ######
    parser.add_argument('--weights_path', type=str, default=None,
                        help='加载权重的路径')
    parser.add_argument('--pretrained', type=int, default=0,
                        help='从完整数据集上预训练的模型加载权重')
    parser.add_argument('--pretrained_resnet', action='store_true', default=False,
                        help='从timm库中加载预训练的resnet模型权重')
    parser.add_argument('--optimizer_type', type=str, default='LAMB',
                        help='优化器类型，默认是LAMB')
    parser.add_argument('--only_test', type=str, default=1,
                        help='仅测试模式')
    parser.add_argument('--only_valid', type=str, default=0,
                        help='仅验证模式')
    parser.add_argument('--only_train', type=str, default=0,
                        help='仅训练模式')

    # ######
    # %% 通道参数
    # ######
    parser.add_argument('--workers', type=int, default=2,
                        help='工作线程数')
    parser.add_argument('--num_batches_to_plot', type=int, default=10,
                        help='要绘制的批次数')
    parser.add_argument('--detect_anomaly', type=int, default=0,
                        help='启用异常检测？')
    parser.add_argument('--grad_clip_norm', type=float, default=0.0,
                        help='启用梯度剪辑时输入的规范值')
    parser.add_argument('--num_samples_for_embedding', type=int, default=200,
                        help='用于t-SNE投影的批次数')
    parser.add_argument('--do_distributed', type=int, default=0,
                        help='启用分布式训练？')
    parser.add_argument('--dry_run', action='store_true', default=True,
                        help="使用整个训练/验证集运行一个单独的epoch")
    parser.add_argument('--save_test_maps', action='store_true',
                        help='保存测试地图')

    # 提前终止的度量标准选择：3D或2D凝视向量或IoU
    parser.add_argument('--early_stop_metric', type=str, default='3D',
                        help='提前终止的度量标准选择')

    # %% 仅测试条件
    parser.add_argument('--save_results_here', type=str, default='',
                        help='如果提供了路径，将覆盖保存最终测试结果的路径')

    # %% Parse arguments
    args = parser.parse_args()

    if args.groups != 1:
        # 我们需要所有卷积层中的通道数为偶数。可以通过将增长因子设置为1.5来实现
        args.growth_rate = 1.5

    if args.mode == 'one_vs_one':
        print('One vs One模式')
        args.num_sets = 1  # 设置数据集数量为1

    if args.mode == 'all_vs_one':
        print('检测到All vs One模式。忽略cur_obj标志。')
        args.cur_obj = 'allvsone'  # 将当前对象设置为allvsone
        args.num_sets = 9  # 设置数据集数量为9

    if args.mode == 'pretrained':
        print('检测到预训练模式。')
        args.cur_obj = 'pretrained'  # 将当前对象设置为预训练
        args.num_sets = 4  # 设置数据集数量为4

    if args.mode == 'all-one_vs_one':
        args.num_sets = 8  # 设置数据集数量为8

    if args.one_by_one_ds:
        print('禁用峰值移除')
        args.remove_spikes = 0  # 禁用峰值移除

    # if args.dry_run:
    #     args.epochs = 1  # 如果是干跑模式，将epochs设置为1

    if args.make_uncertain:
        args.make_aleatoric = True  # 启用aleatoric
        args.dropout = 0.2  # 设置dropout为0.2

    print('{} sets detected'.format(args.num_sets))  # 打印检测到的数据集数量

    # opt = vars(args)
    # print('---------')
    # print('解析的参数')
    # pprint(opt)  # 打印解析后的参数
    return args  # 返回解析后的参数
