import os
import shutil
import time
import pickle

import numpy as np
import random
from copy import deepcopy
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.nn as nn

from .lr_schedulers import LinearWarmupMultiStepLR, LinearWarmupCosineAnnealingLR
from .postprocessing import postprocess_results
from ..modeling import MaskedConv1D, Scale, AffineDropPath, LayerNorm


################################################################################
def fix_random_seed(seed, include_cuda=True):
    rng_generator = torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    if include_cuda:
        # training: disable cudnn benchmark to ensure the reproducibility
        cudnn.enabled = True
        cudnn.benchmark = False
        cudnn.deterministic = True
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # this is needed for CUDA >= 10.2
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
        torch.use_deterministic_algorithms(True, warn_only=True)
    else:
        cudnn.enabled = True
        cudnn.benchmark = True
    return rng_generator


def save_checkpoint(state, is_best, file_folder,
                    file_name='checkpoint.pth.tar'):
    """save checkpoint to file"""
    if not os.path.exists(file_folder):
        os.mkdir(file_folder)
    torch.save(state, os.path.join(file_folder, file_name))
    if is_best:
        # skip the optimization / scheduler state
        state.pop('optimizer', None)
        state.pop('scheduler', None)
        torch.save(state, os.path.join(file_folder, 'model_best.pth.tar'))


def print_model_params(model):
    for name, param in model.named_parameters():
        print(name, param.min().item(), param.max().item(), param.mean().item())
    return


def make_optimizer(model, optimizer_config):
    """create optimizer
    return a supported optimizer
    """
    # separate out all parameters that with / without weight decay
    # see https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134
    decay = set()#函数将参数分为两个集合：decay 和 no_decay。decay 集合包含需要进行权重衰减的参数，而 no_decay 集合包含不需要进行权重衰减的参数。
    no_decay = set()
    # todo
    whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d,torch.nn.Conv2d, MaskedConv1D)
    blacklist_weight_modules = (LayerNorm,torch.nn.BatchNorm1d ,torch.nn.GroupNorm)

    # loop over all modules / params
    for mn, m in model.named_modules():#对于每个模块和参数，根据其类型和名称进行分类
        for pn, p in m.named_parameters():
            fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
            if pn.endswith('bias'):#所有偏置（bias）参数都不会进行权重衰减
                # all biases will not be decayed
                no_decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                # weights of whitelist modules will be weight decayed 特定模块类型的权重参数将进行权重衰减，而另一些模块类型的权重参数则不会进行权重衰减。
                decay.add(fpn)
            elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                # weights of blacklist modules will NOT be weight decayed
                no_decay.add(fpn)
            elif pn.endswith('scale') and isinstance(m, (Scale, AffineDropPath)):
                # corner case of our scale layer
                no_decay.add(fpn)
            elif pn.endswith('rel_pe'):
                # corner case for relative position encoding
                no_decay.add(fpn)


    # validate that we considered every parameter
    param_dict = {pn: p for pn, p in model.named_parameters()}#一个字典，将模型的参数名称映射到实际的参数对象上
    inter_params = decay & no_decay#是 decay 和 no_decay 集合的交集，表示同时被标记为需要和不需要进行权重衰减的参数。这个集合应该是空集
    union_params = decay | no_decay#是 decay 和 no_decay 集合的并集，表示所有被考虑的参数。这个集合应该包含所有模型的参数。
    assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
    assert len(param_dict.keys() - union_params) == 0, \
        "parameters %s were not separated into either decay/no_decay set!" \
        % (str(param_dict.keys() - union_params), )

    # create the pytorch optimizer object
    optim_groups = [
        {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": optimizer_config['weight_decay']},
        {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]

    if optimizer_config["type"] == "SGD":#根据优化器配置选择使用 SGD 还是 AdamW 优化器
        optimizer = optim.SGD(
            optim_groups,
            lr=optimizer_config["learning_rate"],
            momentum=optimizer_config["momentum"]
        )
    elif optimizer_config["type"] == "AdamW":
        optimizer = optim.AdamW(
            optim_groups,
            lr=optimizer_config["learning_rate"]
        )
    else:
        raise TypeError("Unsupported optimizer!")

    return optimizer


def make_scheduler(
    optimizer,
    optimizer_config,
    num_iters_per_epoch,
    last_epoch=-1
):
    """create scheduler
    return a supported scheduler
    All scheduler returned by this function should step every iteration
    """
    if optimizer_config["warmup"]:
        max_epochs = optimizer_config["epochs"] + optimizer_config["warmup_epochs"]
        max_steps = max_epochs * num_iters_per_epoch

        # get warmup params
        warmup_epochs = optimizer_config["warmup_epochs"]
        warmup_steps = warmup_epochs * num_iters_per_epoch

        # with linear warmup: call our custom schedulers
        if optimizer_config["schedule_type"] == "cosine":
            # Cosine
            scheduler = LinearWarmupCosineAnnealingLR(
                optimizer,
                warmup_steps,
                max_steps,
                last_epoch=last_epoch
            )

        elif optimizer_config["schedule_type"] == "multistep":
            # Multi step
            steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]]
            scheduler = LinearWarmupMultiStepLR(
                optimizer,
                warmup_steps,
                steps,
                gamma=optimizer_config["schedule_gamma"],
                last_epoch=last_epoch
            )
        else:
            raise TypeError("Unsupported scheduler!")

    else:
        max_epochs = optimizer_config["epochs"]
        max_steps = max_epochs * num_iters_per_epoch

        # without warmup: call default schedulers
        if optimizer_config["schedule_type"] == "cosine":
            # step per iteration
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                max_steps,
                last_epoch=last_epoch
            )

        elif optimizer_config["schedule_type"] == "multistep":
            # step every some epochs
            steps = [num_iters_per_epoch * step for step in optimizer_config["schedule_steps"]]
            scheduler = optim.lr_scheduler.MultiStepLR(
                optimizer,
                steps,
                gamma=schedule_config["gamma"],
                last_epoch=last_epoch
            )
        else:
            raise TypeError("Unsupported scheduler!")

    return scheduler


class AverageMeter(object):
    """Computes and stores the average and current value.
    Used to compute dataset stats from mini-batches
    """
    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = 0.0

    def initialize(self, val, n):
        self.val = val
        self.avg = val
        self.sum = val * n
        self.count = n
        self.initialized = True

    def update(self, val, n=1):
        if not self.initialized:
            self.initialize(val, n)
        else:
            self.add(val, n)

    def add(self, val, n):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class ModelEma(torch.nn.Module):
    def __init__(self, model, decay=0.999, device=None):
        super().__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)#使用 deepcopy 函数创建了模型 model 的深层副本 self.module，用于累积权重的移动平均值。
        self.module.eval()#将副本模型设置为评估模式，这意味着在进行前向传播时不会更新 BatchNormalization 和 Dropout 层的统计信息。
        self.decay = decay# 存储衰减因子
        self.device = device  # perform ema on different device from model if set 存储设备信息，指定进行 EMA 计算的设备。如果为 None，则默认与模型在同一设备上。
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):#用于更新移动平均值
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):#遍历副本模型的参数和原始模型的参数。
                if self.device is not None:#将原始模型的参数移动到指定的设备上。
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))# 使用给定的更新函数 update_fn 更新移动平均值。这里的更新函数是将移动平均值乘以衰减因子，然后加上当前模型参数的新值。

    def update(self, model):#使用给定的更新函数更新移动平均值。这个更新函数是将移动平均值乘以衰减因子，然后加上当前模型参数的新值乘以（1 - 衰减因子）。
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):#将移动平均值设置为与原始模型参数相等。
        self._update(model, update_fn=lambda e, m: m)

################################################################################
def train_one_epoch(
    train_loader,
    train_loader_target,
    model,
    optimizer,
    scheduler,
    curr_epoch,
    max_epochs=20,
    model_ema = None,
    clip_grad_l2norm = -1,
    tb_writer = None,
    print_freq = 20,
    finetune=False,
):
    """Training the model for one epoch"""
    # set up meters 记录时间
    batch_time = AverageMeter()
    losses_tracker = {}#创建一个空字典用于跟踪损失值
    # number of iterations per epoch 获取训练数据加载器的迭代次数，即每个epoch中的批次数。
    num_iters = len(train_loader)
    # switch to train mode 将模型设置为训练模式
    model.train()
    # todo
    train_loader_target_iter=iter(train_loader_target)
    # main training loop
    print("\n[Train]: Epoch {:d} started".format(curr_epoch))
    start = time.time()
    for iter_idx, video_list in enumerate(train_loader, 0):
        # todo
        if finetune==False:
            max_batches=len(train_loader)
            p = float(iter_idx + curr_epoch * max_batches) / (max_epochs * max_batches)
            grl_lambda = 2. / (1. + np.exp(-10 * p)) - 1
            video_list_target=next(train_loader_target_iter,0)
            if video_list_target==0:
                train_loader_target_iter=iter(train_loader_target)
                video_list_target=next(train_loader_target_iter,0)
            # zero out optim
            optimizer.zero_grad(set_to_none=True)
            # forward / backward the model
            losses ,domain_pred = model(video_list,grl_lambda,domain='source')
            # losses  = model(video_list,domain='source')
            # todo source domain
            y_s_domain = torch.zeros(len(video_list), dtype=torch.long)
            loss_fn_domain = torch.nn.NLLLoss().cuda()
            # loss_s_domain=[]
            loss_s_domain_sum=0
            for i in range(len(domain_pred)):
                # loss_s_domain.append(loss_fn_domain(domain_pred[i], y_s_domain.cuda()))
                loss_s_domain_sum+=loss_fn_domain(domain_pred[i], y_s_domain.cuda())
            # todo target domain
            # domain_pred = model(video_list_target, grl_lambda, domain='target')
            # y_t_domain = torch.ones(len(video_list), dtype=torch.long)
            # loss_t_domain = []
            # loss_t_domain_sum=0
            # for i in range(len(domain_pred)):
            #     # loss_t_domain.append(loss_fn_domain(domain_pred[i], y_t_domain.cuda()))
            #     loss_t_domain_sum+=loss_fn_domain(domain_pred[i], y_t_domain.cuda())
            # loss_s_domain_avg=loss_s_domain_sum/len(domain_pred)
            # loss_t_domain_avg=loss_t_domain_sum/len(domain_pred)
            # loss_domain=loss_s_domain_avg+loss_t_domain_avg
            losses['final_loss']+=(0*loss_s_domain_sum)
            losses['final_loss'].backward()
        else:
            losses  = model(video_list,finetune=True)
            losses['final_loss'].backward()
        # gradient cliping (to stabilize training if necessary)
        if clip_grad_l2norm > 0.0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                clip_grad_l2norm
            )
        # step optimizer / scheduler
        optimizer.step()
        scheduler.step()

        if model_ema is not None:
            model_ema.update(model)

        # printing (only check the stats when necessary to avoid extra cost)
        if (iter_idx != 0) and (iter_idx % print_freq) == 0:
            # measure elapsed time (sync all kernels)
            torch.cuda.synchronize()
            batch_time.update((time.time() - start) / print_freq)
            start = time.time()

            # track all losses
            for key, value in losses.items():
                # init meter if necessary
                if key not in losses_tracker:
                    losses_tracker[key] = AverageMeter()
                # update
                losses_tracker[key].update(value.item())

            # log to tensor board
            lr = scheduler.get_last_lr()[0]
            global_step = curr_epoch * num_iters + iter_idx
            if tb_writer is not None:
                # learning rate (after stepping)
                tb_writer.add_scalar(
                    'train/learning_rate',
                    lr,
                    global_step
                )
                # all losses
                tag_dict = {}
                for key, value in losses_tracker.items():
                    if key != "final_loss":
                        tag_dict[key] = value.val
                tb_writer.add_scalars(
                    'train/all_losses',
                    tag_dict,
                    global_step
                )
                # final loss
                tb_writer.add_scalar(
                    'train/final_loss',
                    losses_tracker['final_loss'].val,
                    global_step
                )

            # print to terminal
            block1 = 'Epoch: [{:03d}][{:05d}/{:05d}]'.format(
                curr_epoch, iter_idx, num_iters
            )
            block2 = 'Time {:.2f} ({:.2f})'.format(
                batch_time.val, batch_time.avg
            )
            block3 = 'Loss {:.2f} ({:.2f})\n'.format(
                losses_tracker['final_loss'].val,
                losses_tracker['final_loss'].avg
            )
            block4 = ''
            for key, value in losses_tracker.items():
                if key != "final_loss":
                    block4  += '\t{:s} {:.2f} ({:.2f})'.format(
                        key, value.val, value.avg
                    )

            print('\t'.join([block1, block2, block3, block4]))

    # finish up and print
    lr = scheduler.get_last_lr()[0]
    print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr))
    return
# todo
def train_one_epoch_mean_teacher(
    train_loader_target,
    model,
    optimizer,
    scheduler,
    curr_epoch,
    model_ema = None,
    clip_grad_l2norm = -1,
    tb_writer = None,
    print_freq = 20,
    Pseudo_label_dict=None,
    kmeans=None,
    pca=None,
    sum_score=None
):
    """Training the model for one epoch"""
    # set up meters 记录时间
    batch_time = AverageMeter()
    losses_tracker = {}#创建一个空字典用于跟踪损失值
    # number of iterations per epoch 获取训练数据加载器的迭代次数，即每个epoch中的批次数。
    num_iters = len(train_loader_target)
    # switch to train mode 将模型设置为训练模式
    model.train()
    # main training loop
    print("\n[Train]: Epoch {:d} started".format(curr_epoch))
    start = time.time()
    for iter_idx, video_list_target in enumerate(train_loader_target, 0):
        # todo 将伪标签写入video_list_target中，作为groundtruth
        for i in range(len(video_list_target)):
            video_id=video_list_target[i]['video_id']
            v_fps=video_list_target[i]['fps']
            v_stride=video_list_target[i]['feat_stride']
            v_num_frames=video_list_target[i]['feat_num_frames']
            feat_offset = 0.5 * v_num_frames / v_stride
            # 伪标签选择：按照分数排序
            gt=Pseudo_label_dict[video_id][:2]
            # 伪标签选择，设置分数阈值
            # gt=[]
            # for per_gt in Pseudo_label_dict[video_id]:
            #     if per_gt[-1]>0.25:
            #         gt.append(per_gt)
            #     else:
            #         break
            # todo 此处需要填补一些选择合适伪标签的算法——可参考主动学习
            video_list_target[i]['segments_real']=video_list_target[i]['segments']
            video_list_target[i]['segments']=[]
            video_list_target[i]['labels_real'] = video_list_target[i]['labels']
            video_list_target[i]['labels'] = []
            for t in range(len(gt)):
                video_list_target[i]['segments'].append(np.array(gt[t][:2])*v_fps/v_stride-feat_offset)
                video_list_target[i]['labels'].append(gt[t][2])
            video_list_target[i]['segments']=torch.Tensor(video_list_target[i]['segments'])
            video_list_target[i]['labels']=torch.LongTensor(video_list_target[i]['labels'])

        # zero out optim
        optimizer.zero_grad(set_to_none=True)
        # forward / backward the model
        losses ,domain_pred = model(video_list_target,
                                    freeMatch=False,
                                    # freeMatch=False,
                                    kmeans=kmeans,
                                    pca=pca,
                                    sum_score=sum_score)
        losses['final_loss'].backward()
        # gradient cliping (to stabilize training if necessary)
        if clip_grad_l2norm > 0.0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                clip_grad_l2norm
            )
        # step optimizer / scheduler
        optimizer.step()
        scheduler.step()

        if model_ema is not None:
            model_ema.update(model)

        # printing (only check the stats when necessary to avoid extra cost)
        if (iter_idx != 0) and (iter_idx % print_freq) == 0:
            # measure elapsed time (sync all kernels)
            torch.cuda.synchronize()
            batch_time.update((time.time() - start) / print_freq)
            start = time.time()

            # track all losses
            for key, value in losses.items():
                # init meter if necessary
                if key not in losses_tracker:
                    losses_tracker[key] = AverageMeter()
                # update
                losses_tracker[key].update(value.item())

            # log to tensor board
            lr = scheduler.get_last_lr()[0]
            global_step = curr_epoch * num_iters + iter_idx
            if tb_writer is not None:
                # learning rate (after stepping)
                tb_writer.add_scalar(
                    'train/learning_rate',
                    lr,
                    global_step
                )
                # all losses
                tag_dict = {}
                for key, value in losses_tracker.items():
                    if key != "final_loss":
                        tag_dict[key] = value.val
                tb_writer.add_scalars(
                    'train/all_losses',
                    tag_dict,
                    global_step
                )
                # final loss
                tb_writer.add_scalar(
                    'train/final_loss',
                    losses_tracker['final_loss'].val,
                    global_step
                )

            # print to terminal
            block1 = 'Epoch: [{:03d}][{:05d}/{:05d}]'.format(
                curr_epoch, iter_idx, num_iters
            )
            block2 = 'Time {:.2f} ({:.2f})'.format(
                batch_time.val, batch_time.avg
            )
            block3 = 'Loss {:.2f} ({:.2f})\n'.format(
                losses_tracker['final_loss'].val,
                losses_tracker['final_loss'].avg
            )
            block4 = ''
            for key, value in losses_tracker.items():
                if key != "final_loss":
                    block4  += '\t{:s} {:.2f} ({:.2f})'.format(
                        key, value.val, value.avg
                    )

            print('\t'.join([block1, block2, block3, block4]))

    # finish up and print
    lr = scheduler.get_last_lr()[0]
    print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr))
    return
def train_one_epoch_aa(
    train_loader,
    train_loader_target,
    model,
    optimizer,
    scheduler,
    curr_epoch,
    max_epochs=20,
    model_ema = None,
    clip_grad_l2norm = -1,
    tb_writer = None,
    print_freq = 20,
    Pseudo_label_dict=None,
    finetune=False,
):
    """Training the model for one epoch"""
    # set up meters 记录时间
    batch_time = AverageMeter()
    losses_tracker = {}#创建一个空字典用于跟踪损失值
    # number of iterations per epoch 获取训练数据加载器的迭代次数，即每个epoch中的批次数。
    num_iters = len(train_loader)
    # switch to train mode 将模型设置为训练模式
    model.train()
    # todo
    train_loader_target_iter=iter(train_loader_target)
    # main training loop
    print("\n[Train]: Epoch {:d} started".format(curr_epoch))
    start = time.time()
    for iter_idx, video_list in enumerate(train_loader, 0):
        # todo 将伪标签写入video_list_target中，作为groundtruth
        for i in range(len(video_list)):
            video_id=video_list[i]['video_id']
            v_fps=video_list[i]['fps']
            v_stride=video_list[i]['feat_stride']
            v_num_frames=video_list[i]['feat_num_frames']
            feat_offset = 0.5 * v_num_frames / v_stride
            # 伪标签选择：按照分数排序
            gt=Pseudo_label_dict[video_id][:2]
            # 伪标签选择，设置分数阈值
            # gt=[]
            # for per_gt in Pseudo_label_dict[video_id]:
            #     if per_gt[-1]>0.15:
            #         gt.append(per_gt)
            #     else:
            #         break
            # TODO 此处需要填补一些选择合适伪标签的算法——可参考主动学习
            video_list[i]['segments_real']=video_list[i]['segments']
            video_list[i]['segments']=[]
            video_list[i]['labels_real'] = video_list[i]['labels']
            video_list[i]['labels'] = []
            for t in range(len(gt)):
                video_list[i]['segments'].append(np.array(gt[t][:2])*v_fps/v_stride-feat_offset)
                if train_loader.dataset.num_classes==1:
                    video_list[i]['labels'].append(0)
                else:
                    video_list[i]['labels'].append(gt[t][2])
            video_list[i]['segments']=torch.Tensor(video_list[i]['segments'])
            video_list[i]['labels']=torch.LongTensor(video_list[i]['labels'])

        # todo
        max_batches=len(train_loader)
        p = float(iter_idx + curr_epoch * max_batches) / (max_epochs * max_batches)
        grl_lambda = 2. / (1. + np.exp(-10 * p)) - 1
        video_list_target=next(train_loader_target_iter,0)
        if video_list_target==0:
            train_loader_target_iter=iter(train_loader_target)
            video_list_target=next(train_loader_target_iter,0)
        if len(video_list_target)!=len(video_list):
            if len(video_list_target)>len(video_list):
                video_list_target=video_list_target[:len(video_list)]
            if len(video_list_target)<len(video_list):
                video_list=video_list[:len(video_list_target)]
        # zero out optim
        optimizer.zero_grad(set_to_none=True)
        # forward / backward the model
        losses ,domain_pred = model(video_list,grl_lambda,domain='source')
        # losses  = model(video_list,domain='source')
        # todo source domain
        y_s_domain = torch.zeros(len(video_list), dtype=torch.long)
        loss_fn_domain = torch.nn.NLLLoss().cuda()
        loss_s_domain=[]
        loss_s_domain_sum=0
        for i in range(len(domain_pred)):
            # loss_s_domain.append(loss_fn_domain(domain_pred[i], y_s_domain.cuda()))
            loss_s_domain_sum+=loss_fn_domain(domain_pred[i], y_s_domain.cuda())
        # todo target domain
        domain_pred = model(video_list_target, grl_lambda, domain='target')
        y_t_domain = torch.ones(len(video_list), dtype=torch.long)
        loss_t_domain = []
        loss_t_domain_sum=0
        for i in range(len(domain_pred)):
            # loss_t_domain.append(loss_fn_domain(domain_pred[i], y_t_domain.cuda()))
            loss_t_domain_sum+=loss_fn_domain(domain_pred[i], y_t_domain.cuda())
        loss_s_domain_avg=loss_s_domain_sum/len(domain_pred)
        loss_t_domain_avg=loss_t_domain_sum/len(domain_pred)
        loss_domain=loss_s_domain_avg+loss_t_domain_avg
        losses['final_loss']+=(0.01*loss_domain)
        losses['final_loss'].backward()
        # gradient cliping (to stabilize training if necessary)
        if clip_grad_l2norm > 0.0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                clip_grad_l2norm
            )
        # step optimizer / scheduler
        optimizer.step()
        scheduler.step()

        if model_ema is not None:
            model_ema.update(model)

        # printing (only check the stats when necessary to avoid extra cost)
        if (iter_idx != 0) and (iter_idx % print_freq) == 0:
            # measure elapsed time (sync all kernels)
            torch.cuda.synchronize()
            batch_time.update((time.time() - start) / print_freq)
            start = time.time()

            # track all losses
            for key, value in losses.items():
                # init meter if necessary
                if key not in losses_tracker:
                    losses_tracker[key] = AverageMeter()
                # update
                losses_tracker[key].update(value.item())

            # log to tensor board
            lr = scheduler.get_last_lr()[0]
            global_step = curr_epoch * num_iters + iter_idx
            if tb_writer is not None:
                # learning rate (after stepping)
                tb_writer.add_scalar(
                    'train/learning_rate',
                    lr,
                    global_step
                )
                # all losses
                tag_dict = {}
                for key, value in losses_tracker.items():
                    if key != "final_loss":
                        tag_dict[key] = value.val
                tb_writer.add_scalars(
                    'train/all_losses',
                    tag_dict,
                    global_step
                )
                # final loss
                tb_writer.add_scalar(
                    'train/final_loss',
                    losses_tracker['final_loss'].val,
                    global_step
                )

            # print to terminal
            block1 = 'Epoch: [{:03d}][{:05d}/{:05d}]'.format(
                curr_epoch, iter_idx, num_iters
            )
            block2 = 'Time {:.2f} ({:.2f})'.format(
                batch_time.val, batch_time.avg
            )
            block3 = 'Loss {:.2f} ({:.2f})\n'.format(
                losses_tracker['final_loss'].val,
                losses_tracker['final_loss'].avg
            )
            block4 = ''
            for key, value in losses_tracker.items():
                if key != "final_loss":
                    block4  += '\t{:s} {:.2f} ({:.2f})'.format(
                        key, value.val, value.avg
                    )

            print('\t'.join([block1, block2, block3, block4]))

    # finish up and print
    lr = scheduler.get_last_lr()[0]
    print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr))
    return

def train_one_epoch_test(
    train_loader,
    train_loader_target,
    model,
    optimizer,
    scheduler,
    curr_epoch,
    max_epochs=20,
    model_ema = None,
    clip_grad_l2norm = -1,
    tb_writer = None,
    print_freq = 20,
    Pseudo_label_dict=None,
    finetune=False,
):
    """Training the model for one epoch"""
    # set up meters 记录时间
    batch_time = AverageMeter()
    losses_tracker = {}#创建一个空字典用于跟踪损失值
    # number of iterations per epoch 获取训练数据加载器的迭代次数，即每个epoch中的批次数。
    num_iters = len(train_loader)
    # switch to train mode 将模型设置为训练模式
    model.train()
    # todo
    train_loader_target_iter=iter(train_loader_target)
    # main training loop
    print("\n[Train]: Epoch {:d} started".format(curr_epoch))
    start = time.time()
    for iter_idx, video_list in enumerate(train_loader, 0):
        # todo 将伪标签写入video_list_target中，作为groundtruth
        for i in range(len(video_list)):
            video_id=video_list[i]['video_id']
            v_fps=video_list[i]['fps']
            v_stride=video_list[i]['feat_stride']
            v_num_frames=video_list[i]['feat_num_frames']
            feat_offset = 0.5 * v_num_frames / v_stride
            # 伪标签选择：按照分数排序
            gt=Pseudo_label_dict[video_id][:2]
            # 伪标签选择，设置分数阈值
            # gt=[]
            # for per_gt in Pseudo_label_dict[video_id]:
            #     if per_gt[-1]>0.15:
            #         gt.append(per_gt)
            #     else:
            #         break
            # TODO 此处需要填补一些选择合适伪标签的算法——可参考主动学习
            video_list[i]['segments_real']=video_list[i]['segments']
            video_list[i]['segments']=[]
            video_list[i]['labels_real'] = video_list[i]['labels']
            video_list[i]['labels'] = []
            for t in range(len(gt)):
                video_list[i]['segments'].append(np.array(gt[t][:2])*v_fps/v_stride-feat_offset)
                video_list[i]['labels'].append(gt[t][2])
            video_list[i]['segments']=torch.Tensor(video_list[i]['segments'])
            video_list[i]['labels']=torch.LongTensor(video_list[i]['labels'])

        # todo
        max_batches=len(train_loader)
        p = float(iter_idx + curr_epoch * max_batches) / (max_epochs * max_batches)
        grl_lambda = 2. / (1. + np.exp(-10 * p)) - 1
        video_list_target=next(train_loader_target_iter,0)
        # video_list_target=video_list
        if video_list_target==0:
            train_loader_target_iter=iter(train_loader_target)
            video_list_target=next(train_loader_target_iter,0)
        # if len(video_list_target)!=len(video_list):
        #     if len(video_list_target)>len(video_list):
        #         video_list_target=video_list_target[:len(video_list)]
        #     if len(video_list_target)<len(video_list):
        #         video_list=video_list[:len(video_list_target)]
        # zero out optim
        optimizer.zero_grad(set_to_none=True)
        # forward / backward the model
        losses ,domain_pred = model(video_list,grl_lambda,domain='source')
        # losses  = model(video_list,domain='source')
        # todo source domain
        y_s_domain = torch.zeros(len(video_list), dtype=torch.long)
        loss_fn_domain = torch.nn.NLLLoss().cuda()
        loss_s_domain=[]
        loss_s_domain_sum=0
        for i in range(len(domain_pred)):
            # loss_s_domain.append(loss_fn_domain(domain_pred[i], y_s_domain.cuda()))
            loss_s_domain_sum+=loss_fn_domain(domain_pred[i], y_s_domain.cuda())
        # todo target domain
        domain_pred = model(video_list_target, grl_lambda, domain='target')
        y_t_domain = torch.ones(len(video_list), dtype=torch.long)
        loss_t_domain = []
        loss_t_domain_sum=0
        for i in range(len(domain_pred)):
            # loss_t_domain.append(loss_fn_domain(domain_pred[i], y_t_domain.cuda()))
            loss_t_domain_sum+=loss_fn_domain(domain_pred[i], y_t_domain.cuda())
        loss_s_domain_avg=loss_s_domain_sum/len(domain_pred)
        loss_t_domain_avg=loss_t_domain_sum/len(domain_pred)
        loss_domain=loss_s_domain_avg+loss_t_domain_avg
        losses['final_loss']=losses['final_loss']+(0*loss_domain)
        losses['final_loss'].backward()
        # gradient cliping (to stabilize training if necessary)
        if clip_grad_l2norm > 0.0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                clip_grad_l2norm
            )
        # step optimizer / scheduler
        optimizer.step()
        scheduler.step()

        # if model_ema is not None:
        #     model_ema.update(model)

        # printing (only check the stats when necessary to avoid extra cost)
        if (iter_idx != 0) and (iter_idx % print_freq) == 0:
            # measure elapsed time (sync all kernels)
            torch.cuda.synchronize()
            batch_time.update((time.time() - start) / print_freq)
            start = time.time()

            # track all losses
            for key, value in losses.items():
                # init meter if necessary
                if key not in losses_tracker:
                    losses_tracker[key] = AverageMeter()
                # update
                losses_tracker[key].update(value.item())

            # log to tensor board
            lr = scheduler.get_last_lr()[0]
            global_step = curr_epoch * num_iters + iter_idx
            if tb_writer is not None:
                # learning rate (after stepping)
                tb_writer.add_scalar(
                    'train/learning_rate',
                    lr,
                    global_step
                )
                # all losses
                tag_dict = {}
                for key, value in losses_tracker.items():
                    if key != "final_loss":
                        tag_dict[key] = value.val
                tb_writer.add_scalars(
                    'train/all_losses',
                    tag_dict,
                    global_step
                )
                # final loss
                tb_writer.add_scalar(
                    'train/final_loss',
                    losses_tracker['final_loss'].val,
                    global_step
                )

            # print to terminal
            block1 = 'Epoch: [{:03d}][{:05d}/{:05d}]'.format(
                curr_epoch, iter_idx, num_iters
            )
            block2 = 'Time {:.2f} ({:.2f})'.format(
                batch_time.val, batch_time.avg
            )
            block3 = 'Loss {:.2f} ({:.2f})\n'.format(
                losses_tracker['final_loss'].val,
                losses_tracker['final_loss'].avg
            )
            block4 = ''
            for key, value in losses_tracker.items():
                if key != "final_loss":
                    block4  += '\t{:s} {:.2f} ({:.2f})'.format(
                        key, value.val, value.avg
                    )

            print('\t'.join([block1, block2, block3, block4]))

    # finish up and print
    lr = scheduler.get_last_lr()[0]
    print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr))
    return


def train_one_epoch_freeMatch(
    train_loader_target,
    model,
    optimizer,
    scheduler,
    curr_epoch,
    # tau_t,
    # p_t,
    # label_hist,
    # sat_ema,
    model_ema = None,
    clip_grad_l2norm = -1,
    tb_writer = None,
    print_freq = 20,
    Pseudo_label_dict=None
):
    """Training the model for one epoch"""
    # set up meters 记录时间
    batch_time = AverageMeter()
    losses_tracker = {}#创建一个空字典用于跟踪损失值
    # number of iterations per epoch 获取训练数据加载器的迭代次数，即每个epoch中的批次数。
    num_iters = len(train_loader_target)
    # switch to train mode 将模型设置为训练模式
    model.train()
    # main training loop
    print("\n[Train]: Epoch {:d} started".format(curr_epoch))
    start = time.time()
    for iter_idx, video_list_target in enumerate(train_loader_target, 0):

        # todo 将伪标签写入video_list_target中，作为groundtruth
        for i in range(len(video_list_target)):
            video_id=video_list_target[i]['video_id']
            # 伪标签选择：按照分数排序
            gt=Pseudo_label_dict[video_id][:1]
            # 伪标签选择，设置分数阈值
            # gt=[]
            # for per_gt in Pseudo_label_dict[video_id]:
            #     if per_gt[-1]>0.25:
            #         gt.append(per_gt)
            #     else:
            #         break
            # todo 此处需要填补一些选择合适伪标签的算法——可参考主动学习
            video_list_target[i]['segments_real']=video_list_target[i]['segments']
            video_list_target[i]['segments']=[]
            video_list_target[i]['labels_real'] = video_list_target[i]['labels']
            video_list_target[i]['labels'] = []
            for t in range(len(gt)):
                video_list_target[i]['segments'].append(gt[t][:2])
                video_list_target[i]['labels'].append(gt[t][2])
            video_list_target[i]['segments']=torch.Tensor(video_list_target[i]['segments'])
            video_list_target[i]['labels']=torch.LongTensor(video_list_target[i]['labels'])

        # zero out optim
        optimizer.zero_grad(set_to_none=True)
        # forward / backward the model
        losses ,domain_pred = model(video_list_target)
        losses['final_loss'].backward()
        # gradient cliping (to stabilize training if necessary)
        if clip_grad_l2norm > 0.0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                clip_grad_l2norm
            )
        # step optimizer / scheduler
        optimizer.step()
        scheduler.step()

        if model_ema is not None:
            model_ema.update(model)

        # printing (only check the stats when necessary to avoid extra cost)
        if (iter_idx != 0) and (iter_idx % print_freq) == 0:
            # measure elapsed time (sync all kernels)
            torch.cuda.synchronize()
            batch_time.update((time.time() - start) / print_freq)
            start = time.time()

            # track all losses
            for key, value in losses.items():
                # init meter if necessary
                if key not in losses_tracker:
                    losses_tracker[key] = AverageMeter()
                # update
                losses_tracker[key].update(value.item())

            # log to tensor board
            lr = scheduler.get_last_lr()[0]
            global_step = curr_epoch * num_iters + iter_idx
            if tb_writer is not None:
                # learning rate (after stepping)
                tb_writer.add_scalar(
                    'train/learning_rate',
                    lr,
                    global_step
                )
                # all losses
                tag_dict = {}
                for key, value in losses_tracker.items():
                    if key != "final_loss":
                        tag_dict[key] = value.val
                tb_writer.add_scalars(
                    'train/all_losses',
                    tag_dict,
                    global_step
                )
                # final loss
                tb_writer.add_scalar(
                    'train/final_loss',
                    losses_tracker['final_loss'].val,
                    global_step
                )

            # print to terminal
            block1 = 'Epoch: [{:03d}][{:05d}/{:05d}]'.format(
                curr_epoch, iter_idx, num_iters
            )
            block2 = 'Time {:.2f} ({:.2f})'.format(
                batch_time.val, batch_time.avg
            )
            block3 = 'Loss {:.2f} ({:.2f})\n'.format(
                losses_tracker['final_loss'].val,
                losses_tracker['final_loss'].avg
            )
            block4 = ''
            for key, value in losses_tracker.items():
                if key != "final_loss":
                    block4  += '\t{:s} {:.2f} ({:.2f})'.format(
                        key, value.val, value.avg
                    )

            print('\t'.join([block1, block2, block3, block4]))

    # finish up and print
    lr = scheduler.get_last_lr()[0]
    print("[Train]: Epoch {:d} finished with lr={:.8f}\n".format(curr_epoch, lr))
    return


def valid_one_epoch(
    val_loader,
    model,
    curr_epoch,
    ext_score_file = None,
    evaluator = None,
    output_file = None,
    tb_writer = None,
    print_freq = 20,
    domain=None,
    source=None,
    target=None
):
    """Test the model on the validation set"""
    # either evaluate the results or save the results
    assert (evaluator is not None) or (output_file is not None) or domain=='target'

    # set up meters
    batch_time = AverageMeter()
    # switch to evaluate mode
    model.eval()
    # def apply_dropout(m):
    #     if type(m) == nn.Dropout:
    #         m.train()
    # model.apply(apply_dropout)
    # dict for results (for our evaluation code)
    results = {
        'video-id': [],
        't-start' : [],
        't-end': [],
        'label': [],
        'score': []
    }

    # loop over validation set
    start = time.time()
    for iter_idx, video_list in tqdm(enumerate(val_loader, 0)):
        # forward the model (wo. grad)
        with torch.no_grad():
            output = model(video_list)
            # unpack the results into ANet format
            num_vids = len(output)
            for vid_idx in range(num_vids):
                if output[vid_idx]['segments'].shape[0] > 0:
                    results['video-id'].extend(
                        [output[vid_idx]['video_id']] *
                        output[vid_idx]['segments'].shape[0]
                    )
                    results['t-start'].append(output[vid_idx]['segments'][:, 0])
                    results['t-end'].append(output[vid_idx]['segments'][:, 1])
                    results['label'].append(output[vid_idx]['labels'])
                    results['score'].append(output[vid_idx]['scores'])
        # printing
        if (iter_idx != 0) and iter_idx % (print_freq) == 0:
            # measure elapsed time (sync all kernels)
            torch.cuda.synchronize()
            batch_time.update((time.time() - start) / print_freq)
            start = time.time()

            # print timing
            # print('Test: [{0:05d}/{1:05d}]\t'
            #       'Time {batch_time.val:.2f} ({batch_time.avg:.2f})'.format(
            #       iter_idx, len(val_loader), batch_time=batch_time))

    # gather all stats and evaluate
    results['t-start'] = torch.cat(results['t-start']).numpy()
    results['t-end'] = torch.cat(results['t-end']).numpy()
    results['label'] = torch.cat(results['label']).numpy()
    results['score'] = torch.cat(results['score']).numpy()



    if evaluator is not None:
        if ext_score_file is not None and isinstance(ext_score_file, str):
            results = postprocess_results(results, ext_score_file)
        # call the evaluator
        _, mAP, _ = evaluator.evaluate(results, verbose=True,source=source,target=target,curr_epoch=curr_epoch)
    else:
        # dump to a pickle file that can be directly used for evaluation
        with open(output_file, "wb") as f:
            pickle.dump(results, f)
        mAP = 0.0

    # log mAP to tb_writer
    if tb_writer is not None:
        tb_writer.add_scalar('validation/mAP', mAP, curr_epoch)
    # todo
    if domain == 'target':
        return results,mAP
    return mAP