from __future__ import print_function
import argparse

import torch
import torch.nn as nn
import torch.optim as optim

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *

# ----------------------------------

import copy
import time
import pandas as pd

import random
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

from discrepancy import *
from offline import *
from utils.trick_helpers import *
from utils.contrastive import *

from online import FeatureQueue

# os.environ["CUDA_VISIBLE_DEVICES"] = "4"

# ----------------------------------

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default=None)
parser.add_argument('--shared', default=None)
########################################################################
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--batch_size_align', default=512, type=int)
parser.add_argument('--queue_size', default=256, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--workers', default=0, type=int)
parser.add_argument('--num_sample', default=1000000, type=int)
########################################################################
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--nepoch', default=500, type=int, help='maximum number of epoch for ttt')
parser.add_argument('--bnepoch', default=2, type=int, help='first few epochs to update bn stat')
parser.add_argument('--delayepoch', default=0, type=int)
parser.add_argument('--stopepoch', default=66, type=int)
########################################################################
parser.add_argument('--outf', default='.')
########################################################################
parser.add_argument('--level', default=5, type=int)
parser.add_argument('--corruption', default='snow')
parser.add_argument('--resume', default=None, help='directory of pretrained model')
parser.add_argument('--ckpt', default=None, type=int)
parser.add_argument('--fix_ssh', action='store_true')
########################################################################
parser.add_argument('--method', default='ssl', choices=['ssl', 'align', 'both'])
parser.add_argument('--divergence', default='all', choices=['all', 'coral', 'mmd'])
parser.add_argument('--scale_ext', default=0.5, type=float, help='scale of align loss on ext')
parser.add_argument('--scale_ssh', default=0.2, type=float, help='scale of align loss on ssh')
########################################################################
parser.add_argument('--ssl', default='contrastive', help='self-supervised task')
parser.add_argument('--temperature', default=0.5, type=float)
########################################################################
parser.add_argument('--align_ext', action='store_true')
parser.add_argument('--align_ssh', action='store_true')
########################################################################
parser.add_argument('--model', default='slim_resnet50_new', help='resnet50')
parser.add_argument('--save_every', default=100, type=int)
########################################################################
parser.add_argument('--tsne', action='store_true')
########################################################################
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--app', help='slim_yml')
parser.add_argument('--lambda_a', default=1.0, type=float)
parser.add_argument('--lambda_s', default=1.0, type=float)
parser.add_argument('--lambda_d', default=1.0, type=float)
parser.add_argument('--print_loss', action='store_true')


args = parser.parse_args()

common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                    'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                    'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']

for cor in common_corruptions:
    args.corruption = cor

    print(args)

    my_makedir(args.outf)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    cudnn.benchmark = True

    # -------------------------------

    # net, ext, head, ssh, classifier = build_resnet50(args)
    net, ext, head, ssh, classifier = build_slim_resnet50(args)

    _, teloader = prepare_test_data(args)

    # -------------------------------

    args.batch_size = min(args.batch_size, args.num_sample)
    args.batch_size_align = min(args.batch_size_align, args.num_sample)

    args_align = copy.deepcopy(args)
    args_align.ssl = None
    args_align.batch_size = args.batch_size_align

    if args.method == 'align':
        _, trloader = prepare_test_data(args_align, ttt=True, num_sample=args.num_sample)
    else:
        _, trloader = prepare_train_data(args, args.num_sample)

    if args.method == 'both':
        _, trloader_extra = prepare_test_data(args_align, ttt=True, num_sample=args.num_sample)
        trloader_extra_iter = iter(trloader_extra)

        _, traugloader = prepare_aug_train_data(args, args.num_sample)
        traugloader_iter = iter(traugloader)

    # -------------------------------

    print('Resuming from %s...' %(args.resume))

    load_resnet50(net, head, ssh, classifier, args)

    if torch.cuda.device_count() > 1:
        ext = torch.nn.DataParallel(ext)

    # ----------- Offline Feature Summarization ------------

    if args.method in ['align', 'both']:

        if args.queue_size > args.batch_size_align:
            assert args.queue_size % args.batch_size_align == 0
            # reset batch size by queue size
            args_align.batch_size = args.queue_size

        _, offlineloader = prepare_train_data(args_align)

        MMD_SCALE_FACTOR = 0.5
        if args.align_ext:
            args_align.scale = args.scale_ext
            widths_train = [0.25, 0.5, 0.75, 1.0]
            for width_mult in widths_train:
                ext.apply(
                    lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
                if width_mult == 1.0:
                    cov_src_ext, coral_src_ext, mu_src_ext, mmd_src_ext = offline(offlineloader, ext, args.scale_ext)
                    scale_coral_ext = args.scale_ext / coral_src_ext
                    scale_mmd_ext = args.scale_ext / mmd_src_ext * MMD_SCALE_FACTOR
                    if args.queue_size > args.batch_size_align:
                        queue_ext = FeatureQueue(dim=mu_src_ext.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ext.max().item() / 30.
                    template_ext_cov = torch.eye(2048).cuda() * bias
                elif width_mult == 0.75:
                    cov_src_ext_75, coral_src_ext_75, mu_src_ext_75, mmd_src_ext_75 = offline(offlineloader, ext, args.scale_ext)
                    scale_coral_ext_75 = args.scale_ext / coral_src_ext_75
                    scale_mmd_ext_75 = args.scale_ext / mmd_src_ext_75 * MMD_SCALE_FACTOR
                    if args.queue_size > args.batch_size_align:
                        queue_ext_75 = FeatureQueue(dim=mu_src_ext_75.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ext_75.max().item() / 30.
                    template_ext_cov_75 = torch.eye(1536).cuda() * bias
                elif width_mult == 0.5:
                    cov_src_ext_50, coral_src_ext_50, mu_src_ext_50, mmd_src_ext_50 = offline(offlineloader, ext, args.scale_ext)
                    scale_coral_ext_50 = args.scale_ext / coral_src_ext_50
                    scale_mmd_ext_50 = args.scale_ext / mmd_src_ext_50 * MMD_SCALE_FACTOR
                    if args.queue_size > args.batch_size_align:
                        queue_ext_50 = FeatureQueue(dim=mu_src_ext_50.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ext_50.max().item() / 30.
                    template_ext_cov_50 = torch.eye(1024).cuda() * bias
                else:
                    cov_src_ext_25, coral_src_ext_25, mu_src_ext_25, mmd_src_ext_25 = offline(offlineloader, ext, args.scale_ext)
                    scale_coral_ext_25 = args.scale_ext / coral_src_ext_25
                    scale_mmd_ext_25 = args.scale_ext / mmd_src_ext_25 * MMD_SCALE_FACTOR
                    if args.queue_size > args.batch_size_align:
                        queue_ext_25 = FeatureQueue(dim=mu_src_ext_25.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ext_25.max().item() / 30.
                    template_ext_cov_25 = torch.eye(512).cuda() * bias
            # cov_src_ext, coral_src_ext, mu_src_ext, mmd_src_ext = offline(offlineloader, ext, args.scale_ext)
            # scale_coral_ext = args.scale_ext / coral_src_ext
            # scale_mmd_ext = args.scale_ext / mmd_src_ext * MMD_SCALE_FACTOR

            # construct queue
            # if args.queue_size > args.batch_size_align:
            #     queue_ext = FeatureQueue(dim=mu_src_ext.shape[0], length=args.queue_size-args.batch_size_align)

        if args.align_ssh:
            args_align.scale = args.scale_ssh
            from models.SSHead import ExtractorHead

            widths_train = [0.25, 0.5, 0.75, 1.0]
            for width_mult in widths_train:
                ext.apply(
                    lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
                head.apply(
                    lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
                if width_mult == 1.0:
                    cov_src_ssh, coral_src_ssh, mu_src_ssh, mmd_src_ssh = offline(offlineloader,
                                                                                ExtractorHead(ext, head).cuda(),
                                                                                args.scale_ssh)
                    scale_align_ssh = args.scale_ssh / coral_src_ssh
                    scale_mmd_ssh = args.scale_ssh / mmd_src_ssh * MMD_SCALE_FACTOR

                    if args.queue_size > args.batch_size_align:
                        queue_ssh = FeatureQueue(dim=mu_src_ssh.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ssh.max().item() / 30.
                    template_ssh_cov = torch.eye(128).cuda() * bias
                elif width_mult == 0.75:
                    cov_src_ssh_75, coral_src_ssh_75, mu_src_ssh_75, mmd_src_ssh_75 = offline(offlineloader,
                                                                                ExtractorHead(ext, head).cuda(),
                                                                                args.scale_ssh)
                    scale_align_ssh_75 = args.scale_ssh / coral_src_ssh_75
                    scale_mmd_ssh_75 = args.scale_ssh / mmd_src_ssh_75 * MMD_SCALE_FACTOR

                    if args.queue_size > args.batch_size_align:
                        queue_ssh_75 = FeatureQueue(dim=mu_src_ssh_75.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ssh_75.max().item() / 30.
                    template_ssh_cov_75 = torch.eye(128).cuda() * bias
                elif width_mult == 0.5:
                    cov_src_ssh_50, coral_src_ssh_50, mu_src_ssh_50, mmd_src_ssh_50 = offline(offlineloader,
                                                                                ExtractorHead(ext, head).cuda(),
                                                                                args.scale_ssh)
                    scale_align_ssh_50 = args.scale_ssh / coral_src_ssh_50
                    scale_mmd_ssh_50 = args.scale_ssh / mmd_src_ssh_50 * MMD_SCALE_FACTOR

                    if args.queue_size > args.batch_size_align:
                        queue_ssh_50 = FeatureQueue(dim=mu_src_ssh_50.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ssh_50.max().item() / 30.
                    template_ssh_cov_50 = torch.eye(128).cuda() * bias
                else:
                    cov_src_ssh_25, coral_src_ssh_25, mu_src_ssh_25, mmd_src_ssh_25 = offline(offlineloader,
                                                                                ExtractorHead(ext, head).cuda(),
                                                                                args.scale_ssh)
                    scale_align_ssh_25 = args.scale_ssh / coral_src_ssh_25
                    scale_mmd_ssh_25 = args.scale_ssh / mmd_src_ssh_25 * MMD_SCALE_FACTOR

                    if args.queue_size > args.batch_size_align:
                        queue_ssh_25 = FeatureQueue(dim=mu_src_ssh_25.shape[0], length=args.queue_size - args.batch_size_align)

                    bias = cov_src_ssh_25.max().item() / 30.
                    template_ssh_cov_25 = torch.eye(128).cuda() * bias
            # cov_src_ssh, coral_src_ssh, mu_src_ssh, mmd_src_ssh = offline(offlineloader, ExtractorHead(ext, head).cuda(), args.scale_ssh)
            # scale_align_ssh = args.scale_ssh / coral_src_ssh
            # scale_mmd_ssh = args.scale_ssh / mmd_src_ssh * MMD_SCALE_FACTOR
            #
            # if args.queue_size > args.batch_size_align:
            #     queue_ssh = FeatureQueue(dim=mu_src_ssh.shape[0], length=args.queue_size-args.batch_size_align)

    # ----------- Test ------------

    if args.tsne:
        args_src = copy.deepcopy(args)
        args_src.corruption = 'original'
        _, srcloader = prepare_test_data(args_src)
        feat_src, label_src, tsne_src = visu_feat(ext, srcloader, os.path.join(args.outf, 'original.pdf'))
        feat_tar, label_tar, tsne_tar = visu_feat(ext, teloader, os.path.join(args.outf, args.corruption + '_test_class.pdf'))
        calculate_distance(feat_src, label_src, tsne_src, feat_tar, label_tar, tsne_tar)
        # comp_feat(feat_src, label_src, feat_tar, label_tar, os.path.join(args.outf, args.corruption + '_test_marginal.pdf'))

    all_err_cls = []
    all_err_cls_75 = []
    all_err_cls_50 = []
    all_err_cls_25 = []
    all_err_ssh = []

    print('Running...')
    print('Error (%)\t\ttest')

    # err_cls = test(teloader, net)[0]
    # print(('Epoch %d/%d:' %(0, args.nepoch)).ljust(24) +
    #             '%.2f\t\t' %(err_cls*100))
    widths_train = [0.25, 0.5, 0.75, 1.0]
    for width_mult in widths_train:
        net.apply(
            lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
        err_cls = test(teloader, net)[0]
        print(width_mult)
        print(('Epoch %d/%d:' %(0, args.nepoch)).ljust(24) +
                '%.2f\t\t' %(err_cls*100))
    # -------------------------------

    if args.fix_ssh:
        optimizer = optim.SGD(ext.parameters(), lr=args.lr, momentum=0.9)
    else:
        optimizer = optim.SGD(ssh.parameters(), lr=args.lr, momentum=0.9)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
        'min', factor=0.5, patience=10, cooldown=10,
        threshold=0.0001, threshold_mode='rel', min_lr=0.0001, verbose=True)

    # criterion = SupConLoss(temperature=args.temperature).cuda()
    criterion = MultiConLoss(temperature=args.temperature).cuda()
    criterion_dis = torch.nn.KLDivLoss(reduction="batchmean").cuda()
    criterion_ce = torch.nn.CrossEntropyLoss().cuda()

    lambda_a = args.lambda_a
    lambda_s = args.lambda_s
    lambda_d = args.lambda_d
    print_loss = args.print_loss

    # ----------- Improved Test-time Training ------------

    is_both_activated = False

    for epoch in range(1, args.nepoch+1):

        tic = time.time()

        if args.fix_ssh:
            classifier.eval()
            head.eval()
        else:
            classifier.eval()
            head.train()
        ext.train()

        feature_loss = []
        feature_loss_75 = []
        feature_loss_5 = []
        feature_loss_25 = []
        aug_loss = []
        aug_loss_75 = []
        aug_loss_5 = []
        aug_loss_25 = []
        ext_align_loss = []
        ext_align_loss_75 = []
        ext_align_loss_5 = []
        ext_align_loss_25 = []
        ssh_align_loss = []
        ssh_align_loss_75 = []
        ssh_align_loss_5 = []
        ssh_align_loss_25 = []
        

        for batch_idx, (inputs, labels) in enumerate(trloader):

            optimizer.zero_grad()

            if args.method in ['ssl', 'both']:
                images = torch.cat([inputs[0], inputs[1]], dim=0)
                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)
                bsz = labels.shape[0]
                # widths_train = [0.25, 0.5, 0.75, 1.0]
                widths_train = [1.0, 0.75, 0.5, 0.25]
                for width_mult in widths_train:
                    ssh.apply(
                        lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
                    if width_mult == 1.0:
                        features = ssh(images)
                        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
                        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
                        loss = criterion(features)
                        loss = loss * lambda_s
                        feature_loss.append(loss.detach().cpu().numpy())
                        loss.backward()
                        # del loss
                    elif width_mult == 0.75:
                        features = ssh(images)
                        f3, f4 = torch.split(features, [bsz, bsz], dim=0)
                        # features3 = torch.cat([f1.unsqueeze(1).detach(), f3.unsqueeze(1)], dim=1)
                        # features4 = torch.cat([f2.unsqueeze(1).detach(), f4.unsqueeze(1)], dim=1)
                        features = torch.cat([f1.unsqueeze(1).detach(), f2.unsqueeze(1).detach(), f3.unsqueeze(1), f4.unsqueeze(1)], dim=1)
                        # loss = criterion(features3) + criterion(features4) + criterion(features)
                        loss = criterion(features)
                        loss = loss * lambda_s
                        feature_loss_75.append(loss.detach().cpu().numpy())
                        loss.backward()
                        # del loss
                    elif width_mult == 0.5:
                        features = ssh(images)
                        f5, f6 = torch.split(features, [bsz, bsz], dim=0)
                        # features5 = torch.cat([f1.unsqueeze(1).detach(), f5.unsqueeze(1)], dim=1)
                        # features6 = torch.cat([f2.unsqueeze(1).detach(), f6.unsqueeze(1)], dim=1)
                        features = torch.cat([f1.unsqueeze(1).detach(), f2.unsqueeze(1).detach(), f5.unsqueeze(1), f6.unsqueeze(1)], dim=1)
                        # loss = criterion(features5) + criterion(features6) + criterion(features)
                        loss = criterion(features)
                        loss = loss * lambda_s
                        feature_loss_5.append(loss.detach().cpu().numpy())
                        loss.backward()
                        # del loss
                    else:
                        features = ssh(images)
                        # print(width_mult)
                        # print(features.size()) 512 * 128
                        f7, f8 = torch.split(features, [bsz, bsz], dim=0)
                        # print(f1.size()) 256 * 128
                        # print(f2.size()) 256 * 128
                        # features7 = torch.cat([f1.unsqueeze(1).detach(), f7.unsqueeze(1)], dim=1)
                        # features8 = torch.cat([f2.unsqueeze(1).detach(), f8.unsqueeze(1)], dim=1)
                        features = torch.cat([f1.unsqueeze(1).detach(), f2.unsqueeze(1).detach(), f7.unsqueeze(1), f8.unsqueeze(1)], dim=1)
                        # print(features.size()) 256 * 2 * 128
                        # loss = criterion(features7) + criterion(features8) + criterion(features)
                        loss = criterion(features)
                        loss = loss * lambda_s
                        feature_loss_25.append(loss.detach().cpu().numpy())
                        # loss = criterion(features)
                        loss.backward()
                    del loss

            if args.method == 'both' and is_both_activated:

                try:
                    inputs, labels = next(traugloader_iter)
                except StopIteration:
                    del traugloader_iter
                    traugloader_iter = iter(traugloader)
                    inputs, labels = next(traugloader_iter)

                # widths_train = [0.25, 0.5, 0.75, 1.0]

                images = torch.cat([inputs[0], inputs[1]], dim=0)
                images_str = inputs[1]
                images = images.cuda(non_blocking=True)
                images_str = images_str.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)
                bsz = labels.shape[0]

                widths_train = [1.0, 0.75, 0.5, 0.25]
                for width_mult in widths_train:
                    net.apply(
                        lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
                    if width_mult == 1.0:
                        outputs = net(images)
                        weakaugout, straugout = torch.split(outputs, [bsz, bsz], dim=0)
                        # print(weakaugout.size()) 256 * 10
                        weakaugout = torch.argmax(weakaugout, dim=1)
                        # print(weakaugout.size())
                        # print(labels.size())
                        # loss = criterion_dis(F.softmax(straugout / 7, dim=1), F.softmax(weakaugout / 7, dim=1))
                        loss = criterion_ce(straugout, weakaugout)
                        loss = loss * lambda_d
                        aug_loss.append(loss.detach().cpu().numpy())
                        loss.backward()
                    else:
                        str_outputs = net(images_str)
                        # loss = criterion_dis(F.softmax(str_outputs / 7, dim=1), F.softmax(weakaugout.detach() / 7, dim=1))
                        loss = criterion_ce(str_outputs, weakaugout)
                        loss = loss * lambda_d
                        if width_mult == 0.75:
                            aug_loss_75.append(loss.detach().cpu().numpy())
                        elif width_mult == 0.5:
                            aug_loss_5.append(loss.detach().cpu().numpy())
                        elif width_mult == 0.25:
                            aug_loss_25.append(loss.detach().cpu().numpy())
                        loss.backward()
                    del loss

            if args.method == 'both' and is_both_activated:

                try:
                    inputs, _ = next(trloader_extra_iter)
                except StopIteration:
                    del trloader_extra_iter
                    trloader_extra_iter = iter(trloader_extra)
                    inputs, _ = next(trloader_extra_iter)

                if args.align_ext:

                    loss = 0
                    widths_train = [0.25, 0.5, 0.75, 1.0]
                    for width_mult in widths_train:
                        loss = 0
                        ext.apply(
                            lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
                        feat_ext = ext(inputs.cuda())
                        if width_mult == 1.0:
                            if args.queue_size > args.batch_size_align:
                                feat_queue = queue_ext.get()
                                queue_ext.update(feat_ext)
                                if feat_queue is not None:
                                    feat_ext = torch.cat([feat_ext, feat_queue.cuda()])
                            cov_ext = covariance(feat_ext)
                            mu_ext = feat_ext.mean(dim=0)
                            source_domain = torch.distributions.MultivariateNormal(mu_src_ext, cov_src_ext + template_ext_cov)
                            target_domain = torch.distributions.MultivariateNormal(mu_ext, cov_ext + template_ext_cov)
                            loss += (torch.distributions.kl_divergence(source_domain, target_domain) + torch.distributions.kl_divergence(target_domain, source_domain)) * lambda_a
                            # if args.divergence in ['coral', 'all']:
                            #     cov_ext = covariance(feat_ext)
                            #     loss += coral(cov_src_ext, cov_ext) * scale_coral_ext
                            # if args.divergence in ['mmd', 'all']:
                            #     mu_ext = feat_ext.mean(dim=0)
                            #     loss += linear_mmd(mu_src_ext, mu_ext) * scale_mmd_ext
                            ext_align_loss.append(loss.detach().cpu().numpy())
                        elif width_mult == 0.75:
                            if args.queue_size > args.batch_size_align:
                                feat_queue = queue_ext_75.get()
                                queue_ext_75.update(feat_ext)
                                if feat_queue is not None:
                                    feat_ext = torch.cat([feat_ext, feat_queue.cuda()])

                            cov_ext = covariance(feat_ext)
                            mu_ext = feat_ext.mean(dim=0)
                            source_domain = torch.distributions.MultivariateNormal(mu_src_ext_75, cov_src_ext_75 + template_ext_cov_75)
                            target_domain = torch.distributions.MultivariateNormal(mu_ext, cov_ext + template_ext_cov_75)
                            loss += (torch.distributions.kl_divergence(source_domain,
                                                                       target_domain) + torch.distributions.kl_divergence(
                                target_domain, source_domain)) * lambda_a

                            # if args.divergence in ['coral', 'all']:
                            #     cov_ext = covariance(feat_ext)
                            #     loss += coral(cov_src_ext_75, cov_ext) * scale_coral_ext_75
                            # if args.divergence in ['mmd', 'all']:
                            #     mu_ext = feat_ext.mean(dim=0)
                            #     loss += linear_mmd(mu_src_ext_75, mu_ext) * scale_mmd_ext_75
                            ext_align_loss_75.append(loss.detach().cpu().numpy())
                        elif width_mult == 0.5:
                            if args.queue_size > args.batch_size_align:
                                feat_queue = queue_ext_50.get()
                                queue_ext_50.update(feat_ext)
                                if feat_queue is not None:
                                    feat_ext = torch.cat([feat_ext, feat_queue.cuda()])

                            cov_ext = covariance(feat_ext)
                            mu_ext = feat_ext.mean(dim=0)
                            source_domain = torch.distributions.MultivariateNormal(mu_src_ext_50, cov_src_ext_50 + template_ext_cov_50)
                            target_domain = torch.distributions.MultivariateNormal(mu_ext, cov_ext + template_ext_cov_50)
                            loss += (torch.distributions.kl_divergence(source_domain,
                                                                       target_domain) + torch.distributions.kl_divergence(
                                target_domain, source_domain)) * lambda_a

                            # if args.divergence in ['coral', 'all']:
                            #     cov_ext = covariance(feat_ext)
                            #     loss += coral(cov_src_ext_50, cov_ext) * scale_coral_ext_50
                            # if args.divergence in ['mmd', 'all']:
                            #     mu_ext = feat_ext.mean(dim=0)
                            #     loss += linear_mmd(mu_src_ext_50, mu_ext) * scale_mmd_ext_50
                            ext_align_loss_5.append(loss.detach().cpu().numpy())
                        else:
                            if args.queue_size > args.batch_size_align:
                                feat_queue = queue_ext_25.get()
                                queue_ext_25.update(feat_ext)
                                if feat_queue is not None:
                                    feat_ext = torch.cat([feat_ext, feat_queue.cuda()])

                            cov_ext = covariance(feat_ext)
                            mu_ext = feat_ext.mean(dim=0)
                            source_domain = torch.distributions.MultivariateNormal(mu_src_ext_25, cov_src_ext_25 + template_ext_cov_25)
                            target_domain = torch.distributions.MultivariateNormal(mu_ext, cov_ext + template_ext_cov_25)
                            loss += (torch.distributions.kl_divergence(source_domain,
                                                                       target_domain) + torch.distributions.kl_divergence(
                                target_domain, source_domain)) * lambda_a
                            ext_align_loss_25.append(loss.detach().cpu().numpy())
                            # if args.divergence in ['coral', 'all']:
                            #     cov_ext = covariance(feat_ext)
                            #     loss += coral(cov_src_ext_25, cov_ext) * scale_coral_ext_25
                            # if args.divergence in ['mmd', 'all']:
                            #     mu_ext = feat_ext.mean(dim=0)
                            #     loss += linear_mmd(mu_src_ext_25, mu_ext) * scale_mmd_ext_25

                        # queue
                        # if args.queue_size > args.batch_size_align:
                        #     feat_queue = queue_ext.get()
                        #     queue_ext.update(feat_ext)
                        #     if feat_queue is not None:
                        #         feat_ext = torch.cat([feat_ext, feat_queue.cuda()])

                        # if args.divergence in ['coral', 'all']:
                        #     cov_ext = covariance(feat_ext)
                        #     loss += coral(cov_src_ext, cov_ext) * scale_coral_ext
                        # if args.divergence in ['mmd', 'all']:
                        #     mu_ext = feat_ext.mean(dim=0)
                        #     loss += linear_mmd(mu_src_ext, mu_ext) * scale_mmd_ext

                        loss.backward()
                        del loss

                # if args.align_ssh:

                #     loss = 0

                #     widths_train = [0.25, 0.5, 0.75, 1.0]
                #     for width_mult in widths_train:
                #         loss = 0
                #         ext.apply(
                #             lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
                #         head.apply(
                #             lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值

                #         feat_ssh = head(ext(inputs.cuda()))
                #         if width_mult == 1.0:
                #             # queue
                #             if args.queue_size > args.batch_size_align:
                #                 feat_queue = queue_ssh.get()
                #                 queue_ssh.update(feat_ssh)
                #                 if feat_queue is not None:
                #                     feat_ssh = torch.cat([feat_ssh, feat_queue.cuda()])

                #             cov_ssh = covariance(feat_ssh)
                #             mu_ssh = feat_ssh.mean(dim=0)
                #             source_domain = torch.distributions.MultivariateNormal(mu_src_ssh, cov_src_ssh + template_ssh_cov)
                #             target_domain = torch.distributions.MultivariateNormal(mu_ssh, cov_ssh + template_ssh_cov)
                #             loss += (torch.distributions.kl_divergence(source_domain,
                #                                                        target_domain) + torch.distributions.kl_divergence(
                #                 target_domain, source_domain)) * 0.5
                #             ssh_align_loss.append(loss.detach().cpu().numpy())
                #             # if args.divergence in ['coral', 'all']:
                #             #     cov_ssh = covariance(feat_ssh)
                #             #     loss += coral(cov_src_ssh, cov_ssh) * scale_align_ssh
                #             # if args.divergence in ['mmd', 'all']:
                #             #     mu_ssh = feat_ssh.mean(dim=0)
                #             #     loss += linear_mmd(mu_src_ssh, mu_ssh) * scale_mmd_ssh
                #         elif width_mult == 0.75:
                #             # queue
                #             if args.queue_size > args.batch_size_align:
                #                 feat_queue = queue_ssh_75.get()
                #                 queue_ssh_75.update(feat_ssh)
                #                 if feat_queue is not None:
                #                     feat_ssh = torch.cat([feat_ssh, feat_queue.cuda()])

                #             cov_ssh = covariance(feat_ssh)
                #             mu_ssh = feat_ssh.mean(dim=0)
                #             source_domain = torch.distributions.MultivariateNormal(mu_src_ssh_75, cov_src_ssh_75 + template_ssh_cov_75)
                #             target_domain = torch.distributions.MultivariateNormal(mu_ssh, cov_ssh + template_ssh_cov_75)
                #             loss += (torch.distributions.kl_divergence(source_domain,
                #                                                        target_domain) + torch.distributions.kl_divergence(
                #                 target_domain, source_domain)) * 0.5

                #             # if args.divergence in ['coral', 'all']:
                #             #     cov_ssh = covariance(feat_ssh)
                #             #     loss += coral(cov_src_ssh_75, cov_ssh) * scale_align_ssh_75
                #             # if args.divergence in ['mmd', 'all']:
                #             #     mu_ssh = feat_ssh.mean(dim=0)
                #             #     loss += linear_mmd(mu_src_ssh_75, mu_ssh) * scale_mmd_ssh_75
                #         elif width_mult == 0.5:
                #             # queue
                #             if args.queue_size > args.batch_size_align:
                #                 feat_queue = queue_ssh_50.get()
                #                 queue_ssh_50.update(feat_ssh)
                #                 if feat_queue is not None:
                #                     feat_ssh = torch.cat([feat_ssh, feat_queue.cuda()])

                #             cov_ssh = covariance(feat_ssh)
                #             mu_ssh = feat_ssh.mean(dim=0)
                #             source_domain = torch.distributions.MultivariateNormal(mu_src_ssh_50, cov_src_ssh_50 + template_ssh_cov_50)
                #             target_domain = torch.distributions.MultivariateNormal(mu_ssh, cov_ssh + template_ssh_cov_50)
                #             loss += (torch.distributions.kl_divergence(source_domain,
                #                                                        target_domain) + torch.distributions.kl_divergence(
                #                 target_domain, source_domain)) * 0.5

                #             # if args.divergence in ['coral', 'all']:
                #             #     cov_ssh = covariance(feat_ssh)
                #             #     loss += coral(cov_src_ssh_50, cov_ssh) * scale_align_ssh_50
                #             # if args.divergence in ['mmd', 'all']:
                #             #     mu_ssh = feat_ssh.mean(dim=0)
                #             #     loss += linear_mmd(mu_src_ssh_50, mu_ssh) * scale_mmd_ssh_50
                #         else:
                #             # queue
                #             if args.queue_size > args.batch_size_align:
                #                 feat_queue = queue_ssh_25.get()
                #                 queue_ssh_25.update(feat_ssh)
                #                 if feat_queue is not None:
                #                     feat_ssh = torch.cat([feat_ssh, feat_queue.cuda()])

                #             cov_ssh = covariance(feat_ssh)
                #             mu_ssh = feat_ssh.mean(dim=0)
                #             source_domain = torch.distributions.MultivariateNormal(mu_src_ssh_25, cov_src_ssh_25 + template_ssh_cov_25)
                #             target_domain = torch.distributions.MultivariateNormal(mu_ssh, cov_ssh + template_ssh_cov_25)
                #             loss += (torch.distributions.kl_divergence(source_domain,
                #                                                        target_domain) + torch.distributions.kl_divergence(
                #                 target_domain, source_domain)) * 0.5

                #             # if args.divergence in ['coral', 'all']:
                #             #     cov_ssh = covariance(feat_ssh)
                #             #     loss += coral(cov_src_ssh_25, cov_ssh) * scale_align_ssh_25
                #             # if args.divergence in ['mmd', 'all']:
                #             #     mu_ssh = feat_ssh.mean(dim=0)
                #             #     loss += linear_mmd(mu_src_ssh_25, mu_ssh) * scale_mmd_ssh_25

                #         # # queue
                #         # if args.queue_size > args.batch_size_align:
                #         #     feat_queue = queue_ssh.get()
                #         #     queue_ssh.update(feat_ssh)
                #         #     if feat_queue is not None:
                #         #         feat_ssh = torch.cat([feat_ssh, feat_queue.cuda()])
                #         #
                #         # if args.divergence in ['coral', 'all']:
                #         #     cov_ssh = covariance(feat_ssh)
                #         #     loss += coral(cov_src_ssh, cov_ssh) * scale_align_ssh
                #         # if args.divergence in ['mmd', 'all']:
                #         #     mu_ssh = feat_ssh.mean(dim=0)
                #         #     loss += linear_mmd(mu_src_ssh, mu_ssh) * scale_mmd_ssh

                #         loss.backward()
                #         del loss

            if epoch > args.bnepoch:
                optimizer.step()

        if print_loss:
            print('Epoch %d/%d feature_loss: %f' % (epoch, args.nepoch, np.mean(feature_loss)))
            print('Epoch %d/%d feature_75_loss: %f' % (epoch, args.nepoch, np.mean(feature_loss_75)))
            print('Epoch %d/%d feature_5_loss: %f' % (epoch, args.nepoch, np.mean(feature_loss_5)))
            print('Epoch %d/%d feature_25_loss: %f' % (epoch, args.nepoch, np.mean(feature_loss_25)))
            if is_both_activated:
                print('Epoch %d/%d aug_loss: %f' % (epoch, args.nepoch, np.mean(aug_loss)))
                print('Epoch %d/%d aug_75_loss: %f' % (epoch, args.nepoch, np.mean(aug_loss_75)))
                print('Epoch %d/%d aug_5_loss: %f' % (epoch, args.nepoch, np.mean(aug_loss_5)))
                print('Epoch %d/%d aug_25_loss: %f' % (epoch, args.nepoch, np.mean(aug_loss_25)))
                print('Epoch %d/%d ext_align_loss: %f' % (epoch, args.nepoch, np.mean(ext_align_loss)))
                print('Epoch %d/%d ext_75_align_loss: %f' % (epoch, args.nepoch, np.mean(ext_align_loss_75)))
                print('Epoch %d/%d ext_5_align_loss: %f' % (epoch, args.nepoch, np.mean(ext_align_loss_5)))
                print('Epoch %d/%d ext_25_align_loss: %f' % (epoch, args.nepoch, np.mean(ext_align_loss_25)))
            # print('Epoch %d/%d ssh_align_loss: %f' % (epoch, args.nepoch, np.mean(ssh_align_loss)))
            # print('Epoch %d/%d ssh_75_align_loss: %f' % (epoch, args.nepoch, np.mean(ssh_align_loss_75)))
            # print('Epoch %d/%d ssh_5_align_loss: %f' % (epoch, args.nepoch, np.mean(ssh_align_loss_5)))
            # print('Epoch %d/%d ssh_25_align_loss: %f' % (epoch, args.nepoch, np.mean(ssh_align_loss_25)))

        widths_train = [0.25, 0.5, 0.75, 1.0]
        for width_mult in widths_train:
            loss = 0
            net.apply(
                lambda m: setattr(m, 'width_mult', width_mult))  # 逐个遍历model的子模块，给子模块中的变量width_mult赋值
            if width_mult == 1.0:
                # err_cls = test(teloader, net)[0]
                err_cls = test_ensemble(teloader, net, width_mult)[0]
                all_err_cls.append(err_cls)
                toc = time.time()
                print(('width %f Epoch %d/%d (%.0fs):' % (width_mult, epoch, args.nepoch, toc - tic)).ljust(24) +
                    '%.2f\t\t' % (err_cls * 100))
            elif width_mult == 0.75:
                # err_cls_75 = test(teloader, net)[0]
                err_cls_75 = test_ensemble(teloader, net, width_mult)[0]
                all_err_cls_75.append(err_cls_75)
                toc = time.time()
                print(('width %f Epoch %d/%d (%.0fs):' % (width_mult, epoch, args.nepoch, toc - tic)).ljust(24) +
                    '%.2f\t\t' % (err_cls_75 * 100))
            elif width_mult == 0.5:
                # err_cls_50 = test(teloader, net)[0]
                err_cls_50 = test_ensemble(teloader, net, width_mult)[0]
                all_err_cls_50.append(err_cls_50)
                toc = time.time()
                print(('width %f Epoch %d/%d (%.0fs):' % (width_mult, epoch, args.nepoch, toc - tic)).ljust(24) +
                    '%.2f\t\t' % (err_cls_50 * 100))
            else:
                err_cls_25 = test(teloader, net)[0]
                all_err_cls_25.append(err_cls_25)
                toc = time.time()
                print(('width %f Epoch %d/%d (%.0fs):' % (width_mult, epoch, args.nepoch, toc - tic)).ljust(24) +
                    '%.2f\t\t' % (err_cls_25 * 100))
        # err_cls = test(teloader, net)[0]
        # all_err_cls.append(err_cls)

        # toc = time.time()
        # print(('width %f Epoch %d/%d (%.0fs):' %(width_mult, epoch, args.nepoch, toc-tic)).ljust(24) +
        #                 '%.2f\t\t' %(err_cls*100))

        # both components
        if args.method == 'both' and not is_both_activated and epoch > args.bnepoch + args.delayepoch:
            is_both_activated = True

        # termination
        if epoch > (args.stopepoch + 1) and all_err_cls[-args.stopepoch] < min(all_err_cls[-args.stopepoch+1:]):
            print("Termination: {:.2f}".format(all_err_cls[-args.stopepoch]*100))
            print("Termination: {:.2f}".format(min(all_err_cls_75[-args.stopepoch+1:])*100))
            print("Termination: {:.2f}".format(min(all_err_cls_50[-args.stopepoch+1:])*100))
            print("Termination: {:.2f}".format(min(all_err_cls_25[-args.stopepoch+1:])*100))
            break

        # save
        if epoch > args.bnepoch and epoch % args.save_every == 0 and all_err_cls[-1] < min(all_err_cls[:-2]):
            state = {'net': net.state_dict(), 'head': head.state_dict()}
            save_file = os.path.join(args.outf, args.corruption + '_' +  args.method + '.pth')
            torch.save(state, save_file)
            print('Save model to', save_file)

        if args.tsne and epoch > args.bnepoch and err_cls < min(all_err_cls[:-1]):
            ext_best = copy.deepcopy(ext.state_dict())

        # lr decay
        scheduler.step(err_cls)

    # -------------------------------

    if args.method == 'ssl':
        prefix = os.path.join(args.outf, args.corruption + '_ssl')
    elif args.method == 'align':
        prefix = os.path.join(args.outf, args.corruption + '_align')
    elif args.method == 'both':
        prefix = os.path.join(args.outf, args.corruption + '_tttpp')
    else:
        raise NotImplementedError

    if args.tsne:
        ext.load_state_dict(ext_best, strict=True)
        feat_tar, label_tar, tsne_tar = visu_feat(ext, teloader, prefix+'_class.pdf')
        calculate_distance(feat_src, label_src, tsne_src, feat_tar, label_tar, tsne_tar)
        # comp_feat(feat_src, label_src, feat_tar, label_tar, prefix+'_marginal.pdf')

    # -------------------------------

    # df = pd.DataFrame([all_err_cls, all_err_ssh]).T
    # df.to_csv(prefix, index=False, float_format='%.4f', header=False)
