#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""
import os
from pprint import pprint

import torch
import random
import warnings
import numpy as np
import wandb
import string

from datetime import datetime   
from distutils.dir_util import copy_tree

from main import train
from args_maker import make_args

# Suppress warnings
warnings.filterwarnings('ignore')
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["WANDB_API_KEY"] = '+++++++++++' # 将引号内的+替换成自己在wandb上的一串值
os.environ["WANDB_MODE"] = "offline"   # 离线  （此行代码不用修改）

def create_experiment_folder_tree(repo_root,
                                  path_exp_records,
                                  exp_name,
                                  is_test=False,
                                  create_tree=True):
    """
    创建实验文件夹结构的函数。
    参数:
    repo_root (str): 仓库的根目录。
    path_exp_records (str): 存放实验记录的目录路径。
    exp_name (str): 实验名称。
    is_test (bool): 是否是测试模式。默认为False。
    create_tree (bool): 是否创建文件夹结构。默认为True。
    返回:
    tuple: 包含路径字典和实验名称字符串的元组。
    """
    if is_test:
        # 如果是测试模式，使用提供的实验名称。
        exp_name_str = exp_name
    else:
        # 否则，生成带有日期时间和随机字符串的实验名称。
        now = datetime.now()
        date_time_str = now.strftime('%y_%m_%d_%H_%M_%S')
        exp_name_str = exp_name + '_' + date_time_str
    # 构建实验文件夹的完整路径。
    path_exp = os.path.join(path_exp_records, exp_name_str)
    # 创建路径字典以存放结果、图像、日志和源代码文件夹的路径。
    path_dict = {}
    for ele in ['results', 'figures', 'logs', 'src']:
        path_dict[ele] = os.path.join(path_exp, ele)
        os.makedirs(path_dict[ele], exist_ok=True)  # 创建目录，如果目录存在则忽略。
    # 添加实验文件夹的路径到字典中。
    path_dict['exp'] = path_exp
    # TODO: 不包括隐藏文件，然后重新启用
    # if (not is_test) and create_tree:
    #     # 如果不是测试模式并且需要创建树结构，
    #     # 从仓库根目录复制文件到实验文件夹的'src'目录中。
    #     copy_tree(repo_root, os.path.join(path_exp, 'src'))

    return path_dict, exp_name_str  # 返回路径字典和实验名称字符串。



def cleanup():
    torch.distributed.destroy_process_group()


if __name__ == '__main__':
    args = vars(make_args())
    #args['exp_name'] = 'DEBU'

    path_dict, exp_name_str = create_experiment_folder_tree(args['repo_root'],
                                              args['path_exp_tree'],
                                              args['exp_name'],
                                              args['only_test'],
                                              create_tree=args['local_rank']==0 if args['do_distributed'] else True)
    if 'DEBUG' not in args['exp_name']:
        wandb.init(project="GazeModel",
                   entity='xyfdxb',
                   config=args, name=exp_name_str)

    path_dict['repo_root'] = args['repo_root']
    path_dict['path_data'] = args['path_data']

    #change the number of frames to predifine 10
    #to load the pkl file with 10 images
    if args['use_pkl_for_dataload']:
        args['frames'] = 1

    # %% DDP essentials

    if args['do_distributed']:
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        world_size = torch.distributed.get_world_size()

    else:
        world_size = 1

    global batch_size_global
    batch_size_global = int(args['batch_size']*world_size)
 
    #torch.cuda.set_device(args['local_rank'])
    args['world_size'] = world_size
    args['batch_size'] = int(args['batch_size']/world_size)

    # %%
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True

    # Set seeds
    torch.manual_seed(args['seed'])
    np.random.seed(args['seed'])
    random.seed(args['seed'])

    print('---------')
    print('解析的参数')
    pprint(args)  # 打印解析后的参数
    print('---------')

    # Train and save validated model
    if not args['only_test']:
        if not args['only_valid']:
            print('train mode')
            train(args, path_dict, validation_mode=False, test_mode=False)

            print('validation mode')
            train(args, path_dict, validation_mode=True, test_mode=False)

            # Test out best model and save results
            print("test mode")
            train(args, path_dict, validation_mode=False, test_mode=True)


    # Close process group
    if args['do_distributed']:
        torch.distributed.barrier()
        cleanup()
    elif args['only_valid']:
        print('validation mode')
        train(args, path_dict, validation_mode=True, test_mode=False)
    elif args['only_test']:
        print('test mode')
        train(args, path_dict, validation_mode=False, test_mode=True)

        # print("test mode")
        # # Test out best model and save results
        # train(args, path_dict, validation_mode=False, test_mode=True)

    print("run done!")
    wandb.finish()