# python imports
import argparse
import copy
import os
import sys
import time
import datetime
from pprint import pprint
import csv
import pickle

# torch imports
import torch
import torch.nn as nn
import torch.utils.data
# for visualization
from torch.utils.tensorboard import SummaryWriter

# our code
from libs.core import load_config
from libs.datasets import make_dataset, make_data_loader
from libs.modeling import make_meta_arch
from libs.utils import (train_one_epoch,  valid_one_epoch,train_one_epoch_mean_teacher,train_one_epoch_aa,train_one_epoch_test,
                        ANETdetection,   save_checkpoint, make_optimizer,
                        make_scheduler,  fix_random_seed, ModelEma)
from libs.utils.weighted_boxes_fusion import weighted_boxes_fusion
import torch.distributed as dist
from tqdm import tqdm
from kmeans_pytorch import kmeans, kmeans_predict
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans  
from sklearn.metrics import silhouette_score  
import matplotlib.pyplot as plt  
import time,datetime

################################################################################
def main(args,my_local_rank,my_port):
    # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    """main function that handles training / inference"""
    time1 = datetime.datetime.now()
    time1 = time1.strftime('%Y-%m-%d %H:%M:%S')#只取年月日，时分秒
    # type_encode = sys.getfilesystemencoding()
    sys.stdout = Logger(f"/fengfangming/TAD/actionformer/logs/fact_anet_logs/{time1}.txt")
    """0. setup gpus"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = my_port
    # local_rank = int(os.environ["LOCAL_RANK"])
    local_rank = my_local_rank

    # local_rank=2
    dist.init_process_group(backend='nccl',init_method='env://')
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    """1. setup parameters / folders"""
    # parse args
    args.start_epoch = 0
    print(args.config)
    if os.path.isfile(args.config):
        cfg = load_config(args.config)
    else:
        raise ValueError("Config file does not exist.")

    pprint(cfg)

    # prep for output folder (based on time stamp)
    if not os.path.exists(cfg['output_folder']):
        os.mkdir(cfg['output_folder'])
    cfg_filename = os.path.basename(args.config).replace('.yaml', '')
    if len(args.output) == 0:
        ts = datetime.datetime.fromtimestamp(int(time.time()))
        ckpt_folder = os.path.join(
            cfg['output_folder'], cfg_filename + '_' + str(ts))
    else:
        ckpt_folder = os.path.join(
            cfg['output_folder'], cfg_filename + '_' + str(args.output))
    if not os.path.exists(ckpt_folder):
        os.mkdir(ckpt_folder)
    # tensorboard writer
    # tb_writer = SummaryWriter(os.path.join(ckpt_folder, 'logs'))
    # todo tensorboard
    tb_writer = SummaryWriter('logs/tf_logs')

    # fix the random seeds (this will fix everything)
    rng_generator = fix_random_seed(cfg['init_rand_seed'], include_cuda=True)

    # re-scale learning rate / # workers based on number of GPUs
    cfg['opt']["learning_rate"] *= len(cfg['devices'])
    cfg['loader']['num_workers'] *= len(cfg['devices'])

    """2. create dataset / dataloader"""
    #源域
    train_dataset = make_dataset(
        cfg['dataset_name'], True, cfg['train_split'], **cfg['dataset']
    )
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    # update cfg based on dataset attributes (fix to epic-kitchens)
    train_db_vars = train_dataset.get_attributes()
   
    val_dataset = make_dataset(
        cfg['dataset_name'], True, cfg['val_split'], **cfg['dataset']
    )
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_db_vars = val_dataset.get_attributes()
    val_det_eval = ANETdetection(
        val_dataset.json_file,
        val_dataset.split[0],
        tiou_thresholds=val_db_vars['tiou_thresholds']
    )

    cfg['model']['train_cfg']['head_empty_cls'] = train_db_vars['empty_label_ids']

    # data loaders
    train_loader = make_data_loader(
        train_dataset, True, rng_generator,train_sampler, **cfg['loader'])
    val_loader = make_data_loader(
        val_dataset, False, None,val_sampler,1,  cfg['loader']['num_workers'])
    # todo
    if os.path.isfile(args.config_target):
        cfg_target = load_config(args.config_target)
    else:
        raise ValueError("Config_target file does not exist.")


    #目标域
    org_feat_folder = cfg_target['dataset']['feat_folder']
    # cfg_target['dataset']['feat_folder'] = cfg_target['dataset']['feat_folder']+'_aug'
    train_dataset_target = make_dataset(
        cfg_target['dataset_name'], True, cfg_target['train_split'], **cfg_target['dataset']
    )
    # cfg_target['dataset']['feat_folder'] = org_feat_folder
    val_dataset_target = make_dataset(
        cfg_target['dataset_name'], False, cfg_target['train_split'], **cfg_target['dataset']
    )

    train_dataset_target_db_vars = train_dataset_target.get_attributes()
    val_dataset_target_db_vars = val_dataset_target.get_attributes()

    train_dataset_target_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset_target)
    val_dataset_target_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset_target)    

    train_loader_target = make_data_loader(
        train_dataset_target, False, None,train_dataset_target_sampler,cfg_target['loader']['batch_size'],  cfg_target['loader']['num_workers'])
    val_loader_target=make_data_loader(
        val_dataset_target, False, None,val_dataset_target_sampler,1,  cfg_target['loader']['num_workers'])
    
    val_dataset_target_det_eval = ANETdetection(
        val_dataset_target.json_file,
        val_dataset_target.split[0],
        tiou_thresholds=val_dataset_target_db_vars['tiou_thresholds']
    )

    cfg_target['model']['train_cfg']['head_empty_cls'] = train_dataset_target_db_vars['empty_label_ids']

    """3. create model, optimizer, and scheduler"""
    # model
    model = make_meta_arch(cfg['model_name'], **cfg['model'])
    # not ideal for multi GPU training, ok for now
    # model = nn.DataParallel(model, device_ids=cfg['devices'])
    model.to(device)
    #TODO 若要在源域训练数据，由于域判别器需要前向传播两次（计算源域分类损失和目标域分类损失），会出现报错，需要将find_unused_parameters参数设为False
    #     而若要在目标域训练，由于域判别器用不上了，所以网络中会出现一些无用的参数，需要将find_unused_parameters设为True
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],find_unused_parameters=False)

    # optimizer
    optimizer = make_optimizer(model, cfg['opt'])
    # optimizer = optim.SGD(model.parameters())
    # schedule
    num_iters_per_epoch = len(train_loader)
    scheduler = make_scheduler(optimizer, cfg['opt'], num_iters_per_epoch)

    # enable model EMA
    print("Using model EMA ...")
    model_ema = ModelEma(model)

    """4. Resume from model / Misc"""
    # resume from a checkpoint?
    if args.resume:
        if os.path.isfile(args.resume):
            # load ckpt, reset epoch / best rmse
            checkpoint = torch.load(args.resume,
                map_location = lambda storage, loc: storage.cuda(
                    cfg['devices'][0]))
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            model_ema.module.load_state_dict(checkpoint['state_dict_ema'],strict=False)
            # also load the optimizer / scheduler if necessary
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{:s}' (epoch {:d}".format(
                args.resume, checkpoint['epoch']
            ))
            del checkpoint
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return

    # save the current config
    with open(os.path.join(ckpt_folder, 'config.txt'), 'w') as fid:
        pprint(cfg, stream=fid)
        fid.flush()

    """4. training / validation loop"""
    print("\nStart training model {:s} ...".format(cfg['model_name']))

    # start training todo
    # max_epochs = cfg['opt'].get(
    #     'early_stop_epochs',
    #     cfg['opt']['epochs'] + cfg['opt']['warmup_epochs']
    # )
    if args.resume and args.finetune==False:
        max_epochs=0
    else:
        max_epochs=100
    val_freq=1
    for epoch in range(args.start_epoch, max_epochs):
        # train for one epoch
        train_one_epoch(
            train_loader,
            train_loader_target,
            model,
            optimizer,
            scheduler,
            epoch,
            model_ema = model_ema,
            clip_grad_l2norm = cfg['train_cfg']['clip_grad_l2norm'],
            tb_writer=tb_writer,
            print_freq=args.print_freq,
            max_epochs=max_epochs,
            finetune=args.finetune
        )
        if (epoch+1)%val_freq==0:
                print('test in source train dataset')
                mAP1 = valid_one_epoch(
                    val_loader,
                    model,
                    curr_epoch=epoch,
                    evaluator=val_det_eval,
                    output_file=None,
                    ext_score_file=cfg['test_cfg']['ext_score_file'],
                    tb_writer=None,
                )
                mAP = valid_one_epoch(
                    val_loader_target,
                    model,
                    -1,
                    evaluator=val_dataset_target_det_eval,
                    output_file=None,
                    ext_score_file=cfg_target['test_cfg']['ext_score_file'],
                    tb_writer=None,
                    # domain='target'
                )
                # writer.writerow([mAP1,mAP2,mAP2])
        # save ckpt once in a while
        if (
            ((epoch + 1) == max_epochs) or
            ((args.ckpt_freq > 0) and ((epoch + 1) % 1 == 0))
        ):
            save_states = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            }

            save_states['state_dict_ema'] = model_ema.module.state_dict()
            save_checkpoint(
                save_states,
                False,
                file_folder=ckpt_folder,
                file_name='epoch_{:03d}.pth.tar'.format(epoch + 1)
            )
    if not args.resume or args.finetune:
        print("Source domain all done!")
        return

    """1.5 蒙特卡洛模拟分割源相似集源不相似集训练域判别器"""
    if args.update_target_self_division:
        model.eval()
        def apply_dropout(m):
            if type(m) == nn.Dropout:
                m.train()
        model.apply(apply_dropout)
        output_dict={}

        for iter_idx, video_list in tqdm(enumerate(val_loader_target, 0)):
            output_dict[video_list[0]['video_id']]=[]
            with torch.no_grad():   
                for mtkl_i in range(20):
                    output = model(video_list)
                    output[0]['segments']=output[0]['segments'][:5]
                    output[0]['scores']=output[0]['scores'][:5]
                    output[0]['labels']=output[0]['labels'][:5]
                    output_dict[video_list[0]['video_id']].append(output[0])
        
        uncertainty_dict = {}
        for video_name in tqdm(output_dict.keys()):
            data_list=output_dict[video_name]
            segments_list = []
            scores_list = []
            labels_list = []
            for data in data_list:
                segments_list.append(data["segments"])
                scores_list.append(data["scores"])
                labels_list.append(data["labels"])

            segments_tensor = torch.stack(segments_list)
            scores_tensor = torch.stack(scores_list)

            # 计算均值和方差
            segments_mean = segments_tensor.mean(dim=0)
            segments_diff = segments_tensor - segments_mean
            segments_diff_squared = segments_diff.pow(2)
            segments_distance_var = segments_diff_squared.mean(dim=2)
            
            scores_mean = scores_tensor.mean(dim=0)
            scores_var = scores_tensor.var(dim=0)
            uncertainty_dict[video_name]=segments_distance_var.mean()*scores_var.mean()
        
        #保存不确定分数，以后不需要重新计算了
        with open(f'/fengfangming/TAD/actionformer/{cfg_filename.split("_")[0]}2{cfg_filename.split("_")[-2]}_uncertainty_dict.pkl', 'wb') as f:
            pickle.dump(uncertainty_dict, f)
        return
    # 不对不确定性分数进行计算，加载保存的分数文件 /fengfangming/TAD/actionformer/data/thumos/uncertainty_dict.pkl 
    if args.aa:
        if not args.update_target_self_division:
            with open(f'/fengfangming/TAD/actionformer/{cfg_filename.split("_")[0]}2{cfg_filename.split("_")[-2]}_uncertainty_dict.pkl', 'rb') as f:
                uncertainty_dict = pickle.load(f)
        source_sim_lst=[]
        source_unsim_lst=[]
        sum_score=0
        for video_name in uncertainty_dict.keys():
            sum_score+=uncertainty_dict[video_name]
        for video_name in uncertainty_dict.keys():
            if uncertainty_dict[video_name]>(sum_score/len(uncertainty_dict))/8:
                source_sim_lst.append(video_name)
            else:
                source_unsim_lst.append(video_name)
        # 重新构建dataloader，测试最原始的伪标签
        # tmp_lst=source_sim_lst+source_unsim_lst
        # source_sim_lst=tmp_lst
        # source_unsim_lst=tmp_lst

        # 对抗对齐
        # 构建源相似集与源不相似集的dataloader
        cfg_target['dataset']['target_div_lst']=source_sim_lst
        source_sim_dataset = make_dataset(
            cfg_target['dataset_name'], True, cfg_target['train_split'], **cfg_target['dataset']
        )
        cfg_target['dataset']['target_div_lst']=source_unsim_lst
        source_unsim_dataset = make_dataset(
            cfg_target['dataset_name'], True, cfg_target['train_split'], **cfg_target['dataset']
        )
        cfg_target['dataset']['target_div_lst']=None

        source_sim_dataset_sampler = torch.utils.data.distributed.DistributedSampler(source_sim_dataset)
        source_unsim_dataset_sampler = torch.utils.data.distributed.DistributedSampler(source_unsim_dataset)

        source_sim_loader = make_data_loader(
                source_sim_dataset, False, None,source_sim_dataset_sampler,cfg_target['loader']['batch_size'],  cfg_target['loader']['num_workers'])
        source_unsim_loader = make_data_loader(
                source_unsim_dataset, False, None,source_unsim_dataset_sampler,cfg_target['loader']['batch_size'],  cfg_target['loader']['num_workers'])
        
        max_aa_epoch=30
        aa_val_freq=1
        aa_ckpt_floder='/fengfangming/TAD/actionformer/aa_ckpt'
        for aa_epoch in range(max_aa_epoch):
            # 先生成伪标签
            Pseudo_label,mAP = valid_one_epoch(
                    val_loader_target,
                    model,
                    aa_epoch,
                    evaluator=val_dataset_target_det_eval,
                    output_file=None,
                    ext_score_file=cfg_target['test_cfg']['ext_score_file'],
                    tb_writer=None,
                    domain='target',
                    source=cfg_filename.split('_')[0],
                    target=cfg_filename.split('_')[-2],
                )
            Pseudo_label_dict={}
            for i in range(len(Pseudo_label['video-id'])):
                video_id=Pseudo_label['video-id'][i]
                t_start = Pseudo_label['t-start'][i]
                t_end = Pseudo_label['t-end'][i]
                label = Pseudo_label['label'][i]
                score=Pseudo_label['score'][i]
                if video_id not in Pseudo_label_dict:
                    Pseudo_label_dict[video_id]=[]
                else:
                    Pseudo_label_dict[video_id].append([t_start,t_end,label,score])
            # aa
            train_one_epoch_aa(
                source_unsim_loader,
                source_sim_loader,
                model,
                optimizer,
                scheduler,
                curr_epoch=aa_epoch,
                max_epochs=max_aa_epoch,
                model_ema=model_ema,
                clip_grad_l2norm=cfg['train_cfg']['clip_grad_l2norm'],
                tb_writer=tb_writer,
                print_freq=args.print_freq,
                Pseudo_label_dict=Pseudo_label_dict,
                # kmeans=kmeans,
                # pca=pca,
                # sum_score=sum_score
            )
            # save ckpt once in a while
            if (
                ((aa_epoch + 1) == max_epochs) or
                ((args.ckpt_freq > 0) and ((aa_epoch + 1) % args.ckpt_freq == 0))
            ):
                save_states = {
                    'epoch': aa_epoch + 1,
                    'state_dict': model.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }

                save_states['state_dict_ema'] = model_ema.module.state_dict()
                save_checkpoint(
                    save_states,
                    False,
                    file_folder=aa_ckpt_floder,
                    file_name='epoch_{:03d}.pth.tar'.format(aa_epoch + 1)
                )

    """2. Mean Teacher"""
    teacher_model=copy.deepcopy(model)
    # model_state = model.state_dict()
    # teacher_static_model=copy.deepcopy(model)
    # teacher_static_model_state = teacher_static_model.state_dict()
    tmp_loader=copy.deepcopy(train_loader_target)
    max_ema_epoch=30
    update_Pseudo_label_freq=1
    for ema_epoch in range(max_ema_epoch):
        # 交换两个模型的权重
        # tmp_model=copy.deepcopy(model)
        # model=copy.deepcopy(teacher_static_model)
        # teacher_static_model=copy.deepcopy(tmp_model)
        # tmp_model=model
        # model=teacher_static_model
        # teacher_static_model=tmp_model
        # model_state = model.state_dict()
        # teacher_static_state = teacher_static_model.state_dict()

        # # 临时存储student_model的权重
        # temp_state_dict = model_state.copy()

        # # 将teacher_model的权重赋值给student_model
        # model.load_state_dict(teacher_static_state)

        # # 将student_model的权重（现在是temp_state_dict）赋值给teacher_model
        # teacher_static_model.load_state_dict(temp_state_dict)
        """使用两个教师模型一次性生成所有伪标签"""
        if ema_epoch% update_Pseudo_label_freq==0:
            print('valid on Teacher model')
            Pseudo_label1,mAP = valid_one_epoch(
                val_loader_target,
                teacher_model,
                -1,
                evaluator=val_dataset_target_det_eval,
                output_file=None,
                ext_score_file=cfg_target['test_cfg']['ext_score_file'],
                tb_writer=None,
                domain='target'
            )
            print('valid on Student model')
            Pseudo_label2,mAP = valid_one_epoch(
                val_loader_target,
                model,
                -1,
                evaluator=val_dataset_target_det_eval,
                output_file=None,
                ext_score_file=cfg_target['test_cfg']['ext_score_file'],
                tb_writer=None,
                domain='target'
            )
            Pseudo_label_dict={}
            for i in range(len(Pseudo_label1['video-id'])):
                video_id=Pseudo_label1['video-id'][i]
                t_start = Pseudo_label1['t-start'][i]
                t_end = Pseudo_label1['t-end'][i]
                label = Pseudo_label1['label'][i]
                score=Pseudo_label1['score'][i]
                if video_id not in Pseudo_label_dict:
                    Pseudo_label_dict[video_id]=[]
                # if score>0.1:
                else:
                    Pseudo_label_dict[video_id].append([t_start,t_end,label,score])
                video_id=Pseudo_label2['video-id'][i]
                t_start = Pseudo_label2['t-start'][i]
                t_end = Pseudo_label2['t-end'][i]
                label = Pseudo_label2['label'][i]
                score=Pseudo_label2['score'][i]
                if video_id not in Pseudo_label_dict:
                    Pseudo_label_dict[video_id]=[]
                # if score>0.1:
                else:
                    Pseudo_label_dict[video_id].append([t_start,t_end,label,score])
            # 使用weighted box fusion 让两个教师模型共同生成伪标签
            Consensus_pseudo_label_dict={}
            for video_id in Pseudo_label_dict:
                wbf=weighted_boxes_fusion()
                annos=Pseudo_label_dict[video_id]
                fusions=[]
                for anno in annos:
                    fusions=wbf.add_boxes(anno[:2],anno[2],anno[3])
                Consensus_pseudo_label_dict[video_id]=fusions
                
        train_one_epoch_mean_teacher(
            train_loader_target,
            model,
            optimizer,
            scheduler,
            curr_epoch=ema_epoch,
            model_ema=model_ema,
            clip_grad_l2norm=cfg['train_cfg']['clip_grad_l2norm'],
            tb_writer=tb_writer,
            print_freq=args.print_freq,
            Pseudo_label_dict=Consensus_pseudo_label_dict,
            # Pseudo_label_dict=Pseudo_label_dict,
            # kmeans=kmeans,
            # pca=pca,
            # sum_score=sum_score
        )
        # if ema_epoch%2==1:
        #     train_one_epoch_test(
        #             train_loader_target,
        #             tmp_loader,
        #             model,
        #             optimizer,
        #             scheduler,
        #             curr_epoch=ema_epoch,
        #             max_epochs=max_ema_epoch,
        #             model_ema=model_ema,
        #             clip_grad_l2norm=cfg['train_cfg']['clip_grad_l2norm'],
        #             tb_writer=tb_writer,
        #             print_freq=args.print_freq,
        #             # Pseudo_label_dict=Consensus_pseudo_label_dict,
        #             Pseudo_label_dict=Pseudo_label_dict,
        #             # kmeans=kmeans,
        #             # pca=pca,
        #             # sum_score=sum_score
        #         )
        #     # model_state = model.state_dict()
        #     print('valid on Student model')
        #     Pseudo_label_test,mAP = valid_one_epoch(
        #         val_loader_target,
        #         model,
        #         -1,
        #         evaluator=val_dataset_target_det_eval,
        #         output_file=None,
        #         ext_score_file=cfg_target['test_cfg']['ext_score_file'],
        #         tb_writer=None,
        #         domain='target'
        #         )
        #     update_ema_variables(model,teacher_model)
        # else:
        #     train_one_epoch_test(
        #             train_loader_target,
        #             tmp_loader,
        #             teacher_model,
        #             optimizer,
        #             scheduler,
        #             curr_epoch=ema_epoch,
        #             max_epochs=max_ema_epoch,
        #             model_ema=model_ema,
        #             clip_grad_l2norm=cfg['train_cfg']['clip_grad_l2norm'],
        #             tb_writer=tb_writer,
        #             print_freq=args.print_freq,
        #             # Pseudo_label_dict=Consensus_pseudo_label_dict,
        #             Pseudo_label_dict=Pseudo_label_dict,
        #             # kmeans=kmeans,
        #             # pca=pca,
        #             # sum_score=sum_score
        #         )
            # model_state = model.state_dict()
            print('valid on Teacher model')
            Pseudo_label_test,mAP = valid_one_epoch(
                val_loader_target,
                teacher_model,
                -1,
                evaluator=val_dataset_target_det_eval,
                output_file=None,
                ext_score_file=cfg_target['test_cfg']['ext_score_file'],
                tb_writer=None,
                domain='target'
                )
            update_ema_variables(teacher_model,model)
        # teacher_model_state = teacher_model.state_dict()

        # save ckpt once in a while
        if (
                ((ema_epoch + 1) == max_epochs) or
                ((args.ckpt_freq > 0) and ((ema_epoch + 1) % args.ckpt_freq == 0))
        ):
            save_states = {
                'epoch': ema_epoch + 1,
                'state_dict': teacher_model.state_dict(),
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            }

            save_states['state_dict_ema'] = model_ema.module.state_dict()
            save_checkpoint(
                save_states,
                False,
                file_folder=ckpt_folder,
                file_name='epoch_{:03d}.pth.tar'.format(ema_epoch + 1)
            )
    print('really all done!')
    return
    
def update_ema_variables(model, ema_model, alpha=0.99):
    # Use the true average until the exponential average is more correct
    # alpha = min(1 - 1 / (global_step + 1), alpha)
    with torch.no_grad():
        model_state_dict = model.state_dict()
        ema_model_state_dict = ema_model.state_dict()
        for entry in ema_model_state_dict.keys():
            ema_param = ema_model_state_dict[entry].clone().detach()
            param = model_state_dict[entry].clone().detach()
            new_param = (ema_param * alpha) + (param * (1. - alpha))
            ema_model_state_dict[entry] = new_param
        ema_model.load_state_dict(ema_model_state_dict)

class Logger(object):
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass

################################################################################
if __name__ == '__main__':
    """Entry Point"""
    # the arg parser
    parser = argparse.ArgumentParser(
      description='Train a point-based transformer for action localization')
    parser.add_argument('config', metavar='DIR',
                        help='path to a config file')
    parser.add_argument('config_target', metavar='DIR',
                        help='path to a config_target file')
    parser.add_argument('-p', '--print-freq', default=10, type=int,
                        help='print frequency (default: 10 iterations)')
    parser.add_argument('-c', '--ckpt-freq', default=5, type=int,
                        help='checkpoint frequency (default: every 5 epochs)')
    parser.add_argument('--output', default='', type=str,
                        help='name of exp folder (default: none)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to a checkpoint (default: none)')
    parser.add_argument("--update_target_self_division", action="store_true", help="Run or not.")
    parser.add_argument("--aa", action="store_true", help="adversarial_alignment.")
    parser.add_argument("--finetune", action="store_true", help="Run or not.")
    # parser.add_argument('--local_rank', default=-1, type=int,
    #                 help='node rank for distributed training')
    # args = parser.parse_args('./configs/anet_i3d_thumosModel.yaml ./configs/thumos_i3d.yaml --output reproduce'.split())
    
    # 在源域训练模型
    # args = parser.parse_args('./configs/fact_i3d_thumosModel_12class.yaml ./configs/anet_i3d_thumosModel_12class.yaml --output fact&anet'.split())
    # args = parser.parse_args('./configs4aa/fact_split_anet_model.yaml \
    #                         ./configs4aa/anet_i3d.yaml --output fact2anet'.split())
    # main(args,my_local_rank=1,my_port='3111')
    # args = parser.parse_args('./configs4aa/fact_th14_model.yaml \
    #                         ./configs4aa/thumos_i3d.yaml --output fact2th14'.split())
    # main(args,my_local_rank=2,my_port='3222')
    # args = parser.parse_args('./configs4aa/anet_th14_model.yaml \
    #                         ./configs4aa/thumos_i3d.yaml --output anet2th14'.split())
    # main(args,my_local_rank=3,my_port='3333')


    # 在目标域分割源相似集和源不相似集并训练域判别器
    # args = parser.parse_args('./configs/anet_i3d_thumosModel.yaml ./configs/thumos_i3d.yaml --update_target_self_division --resume ckpt/anet_i3d_thumosModel_reproduce/epoch_025.pth.tar'.split())
    # args = parser.parse_args('./configs4aa/fact_split_anet_model.yaml \
    #                         ./configs4aa/anet_i3d.yaml \
    #                         --update_target_self_division \
    #                         --resume ckpt/fact_split_anet_model_fact2anet/epoch_001.pth.tar'.split())
    # main(args,my_local_rank=0,my_port='3111')
    # args = parser.parse_args('./configs4aa/fact_th14_model.yaml \
    #                         ./configs4aa/thumos_i3d.yaml \
    #                         --update_target_self_division \
    #                         --resume ckpt/fact_th14_model_fact2th14/epoch_003.pth.tar'.split())
    # main(args,my_local_rank=0,my_port='3222')
    # args = parser.parse_args('./configs4aa/anet_th14_model.yaml \
    #                         ./configs4aa/thumos_i3d.yaml \
    #                         --update_target_self_division \
    #                         --resume ckpt/anet_th14_model_anet2th14/epoch_005.pth.tar'.split())
    # main(args,my_local_rank=3,my_port='3333')

    # 根据不确定分数，在目标域划分源相似集和源不相似集，用伪标签进行训练
    # args = parser.parse_args('./configs/anet_i3d_thumosModel.yaml ./configs/thumos_i3d.yaml --output reproduce --resume ckpt_debug/anet_i3d_thumosModel_debug_aa/epoch_010.pth.tar --aa'.split())
    # args = parser.parse_args('./configs4aa/fact_split_anet_model.yaml \
    #                         ./configs4aa/anet_i3d.yaml \
    #                         --output aa_fact2anet \
    #                         --resume ckpt/fact_split_anet_model_fact2anet/epoch_001.pth.tar --aa'.split())
    # main(args,my_local_rank=1,my_port='3111')
    args = parser.parse_args('./configs4aa/fact_th14_model.yaml \
                            ./configs4aa/thumos_i3d.yaml \
                            --output aa_fact2th14 \
                            --resume ckpt/fact_th14_model_fact2th14/epoch_003.pth.tar --aa'.split())
    main(args,my_local_rank = 2,my_port='3222')
    # args = parser.parse_args('./configs4aa/fact_th14_model.yaml \
    #                         ./configs4aa/thumos_i3d.yaml \
    #                         --output aa_anet2th14 \
    #                         --resume ckpt/anet_th14_model_anet2th14/epoch_005.pth.tar --aa'.split())
    # main(args,my_local_rank=3,my_port='3333')
    # 得到再源域训练的模型后，放到目标域进行mean teacher
    # args = parser.parse_args('./configs/anet_i3d_thumosModel.yaml ./configs/thumos_i3d.yaml --output MT --resume ckpt_debug/anet_i3d_thumosModel_debug_aa/epoch_010.pth.tar'.split())

    # 将源域训练的模型在辅助域微调
    # args = parser.parse_args('./configs/mixed.yaml ./configs/thumos_i3d.yaml --output reproduce --resume ckpt/anet_i3d_thumosModel_reproduce/epoch_025.pth.tar --finetune'.split())

    # main(args)
    # torchrun --nproc_per_node=1  --nnodes=1 --standalone train_SFOD.py
