from __future__ import print_function, division

import sys
import time
import torch
from torch import nn
from itertools import cycle

from .util import AverageMeter, accuracy, param_dist


def train_negrad(epoch, train_loader, delete_loader, model, criterion, optimizer, alpha, opt, quiet=False):
    """vanilla training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for idx, ((input, target), (del_input, del_target)) in enumerate(zip(train_loader, cycle(delete_loader))):
        #del_input, del_target = next(cycle(delete_loader))
        data_time.update(time.time() - end)

        input = input.float()
        del_input = del_input.float()
        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()
            del_input = del_input.cuda()
            del_target = del_target.cuda()

        # ===================forward=====================
        output = model(input)
        del_output = model(del_input)
        r_loss = criterion(output, target)
        del_loss = criterion(del_output, del_target)

        loss = alpha*r_loss - (1-alpha)*del_loss

        if not quiet:
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()


        if not quiet:
            # print info
            if idx % opt.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       epoch, idx, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses, top1=top1, top5=top5))
                sys.stdout.flush()

    if not quiet:
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg, losses.avg

def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt, quiet=False):
    """vanilla training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for idx, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)

        input = input.float()
        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()

        # ===================forward=====================
        output = model(input)
        loss = criterion(output, target)

        if not quiet:
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(acc1[0], input.size(0))
            top5.update(acc5[0], input.size(0))

        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()

        # tensorboard logger
        pass

        if not quiet:
            if idx % opt.print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       epoch, idx, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses, top1=top1, top5=top5))
                sys.stdout.flush()

    if not quiet:
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg, losses.avg

def train_distill(epoch, train_loader, module_list, swa_model, criterion_list, optimizer, opt, split, quiet=False, trans=False):
    """One epoch distillation"""
    # set modules as train()
    for module in module_list:
        module.train()
    # set teacher as eval()
    module_list[-1].eval()


    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]
    criterion_kd = criterion_list[2]

    model_s = module_list[0]
    model_t = module_list[-1]

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    kd_losses = AverageMeter()
    top1 = AverageMeter()


    end = time.time()
    for idx, data in enumerate(train_loader):
        if opt.distill in ['crd']:
            input, target, index, contrast_idx = data
        else:
            input, target = data
        data_time.update(time.time() - end)

        if trans:
            input = {k: v.float() for k, v in input.items()}
            for idx_key in ("input_ids", "token_type_ids", "labels"):
                    if idx_key in input:
                        input[idx_key] = input[idx_key].long()
        else:
            input = input.float()

        if torch.cuda.is_available():
            if trans:
                input = {k: v.cuda() for k, v in input.items()}
            else:
                input = input.cuda()

            target = target.cuda()
            
            if opt.distill in ['crd']:
                contrast_idx = contrast_idx.cuda()
                index = index.cuda()

        # ===================forward=====================
        #feat_s, logit_s = model_s(input, is_feat=True, preact=False)
        if trans:
            logit_s = model_s(**input)
        else:
            logit_s = model_s(input)
        
        with torch.no_grad():
            #feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
            #feat_t = [f.detach() for f in feat_t]
            if trans:
                logit_t = model_t(**input)
            else:
                logit_t = model_t(input)


        # cls + kl div
        loss_cls = criterion_cls(logit_s, target)
        loss_div = criterion_div(logit_s, logit_t)

        # other kd beyond KL divergence
        if opt.distill == 'kd':
            loss_kd = 0
        elif opt.distill == 'hint':
            f_s = module_list[1](feat_s[opt.hint_layer])
            f_t = feat_t[opt.hint_layer]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'crd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
        elif opt.distill == 'attention':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'nst':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'similarity':
            g_s = [feat_s[-2]]
            g_t = [feat_t[-2]]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'rkd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'pkt':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'kdsvd':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'correlation':
            f_s = module_list[1](feat_s[-1])
            f_t = module_list[2](feat_t[-1])
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'vid':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
            loss_kd = sum(loss_group)

        else:
            raise NotImplementedError(opt.distill)

        if split == "minimize":
            loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_kd
        elif split == "maximize":
            loss = -loss_div

        loss = loss + param_dist(model_s, swa_model, opt.smoothing)

        if split == "minimize" and not quiet:
            acc1, _ = accuracy(logit_s, target, topk=(1,1))
            losses.update(loss.item(), target.size(0))
            top1.update(acc1[0], target.size(0))
        elif split == "maximize" and not quiet:
            kd_losses.update(loss.item(), target.size(0))
        elif split == "linear" and not quiet:
            acc1, _ = accuracy(logit_s, target, topk=(1, 1))
            losses.update(loss.item(), target.size(0))
            top1.update(acc1[0], target.size(0))
            kd_losses.update(loss.item(), target.size(0))


        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        #nn.utils.clip_grad_value_(model_s.parameters(), clip)
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()

        if not quiet:
            if split == "mainimize":
                if idx % opt.print_freq == 0:
                    print('Epoch: [{0}][{1}/{2}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                        epoch, idx, len(train_loader), batch_time=batch_time,
                        data_time=data_time, loss=losses, top1=top1))
                    sys.stdout.flush()

    
    if split == "minimize":
        if not quiet:
            print(' * Acc@1 {top1.avg:.3f} '
                  .format(top1=top1))

        return top1.avg, losses.avg
    else:
        return kd_losses.avg

def train_distill_hide(epoch, train_dataset, test_dataset, module_list, swa_model, criterion_list, optimizer, opt):
    """One epoch distillation"""
    # set modules as train()
    for module in module_list:
        module.train()
    # set teacher as eval()
    module_list[-1].eval()
    

    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]
    criterion_kd = criterion_list[2]

    model_s = module_list[0]
    model_t = module_list[-1]

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    kd_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    idx = -1
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16,num_workers=0,pin_memory=True,shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16,num_workers=0,pin_memory=True,shuffle=True)
    for data, data_t in zip(train_loader,cycle(test_loader)):
        idx += 1
        if opt.distill in ['crd']:
            input, target, index, contrast_idx = data
            input_t, target_t, index_t, contrast_idx_t = data_t
        else:
            input, target = data
            input_t, target_t = data_t
        data_time.update(time.time() - end)


        if len(input) > len(input_t):
            input = input[:len(input_t)]
            target = target[:len(target_t)]
        elif len(input_t) > len(input):
            input_t = input[:len(input)]
            target_t = target_t[:len(target_t)]

        input = input.float()
        input_t = input_t.float()
        if torch.cuda.is_available():
            input = input.cuda()
            input_t = input_t.cuda()
            target = target.cuda()
            target_t = target_t.cuda()
            if opt.distill in ['crd']:
                contrast_idx = contrast_idx.cuda()
                index = index.cuda()

        # ===================forward=====================
        #feat_s, logit_s = model_s(input, is_feat=True, preact=False)
        logit_s = model_s(input)
        logit_s_t = model_s(input_t)
        with torch.no_grad():
            #feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
            #feat_t = [f.detach() for f in feat_t]
            logit_t = model_t(input)
            logit_t_t = model_t(input_t)


        loss_div = criterion_div(logit_s, logit_t_t)


        # other kd beyond KL divergence
        if opt.distill == 'kd':
            loss_kd = 0
        elif opt.distill == 'hint':
            f_s = module_list[1](feat_s[opt.hint_layer])
            f_t = feat_t[opt.hint_layer]
            loss_kd = criterion_kd(f_s, f_t_t)
        elif opt.distill == 'crd':
            f_s = feat_s[-1]
            f_t = feat_t_t[-1]
            loss_kd = criterion_kd(f_s, f_t_t, index, contrast_idx)
        elif opt.distill == 'attention':
            g_s = feat_s[1:-1]
            g_t = feat_t_t[1:-1]
            loss_group = criterion_kd(g_s, g_t_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'nst':
            g_s = feat_s[1:-1]
            g_t = feat_t_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'similarity':
            g_s = [feat_s[-2]]
            g_t = [feat_t_t[-2]]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'rkd':
            f_s = feat_s[-1]
            f_t = feat_t_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'pkt':
            f_s = feat_s[-1]
            f_t = feat_t_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'kdsvd':
            g_s = feat_s[1:-1]
            g_t = feat_t_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'correlation':
            f_s = module_list[1](feat_s[-1])
            f_t = module_list[2](feat_t_t[-1])
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'vid':
            g_s = feat_s[1:-1]
            g_t = feat_t_t[1:-1]
            loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
            loss_kd = sum(loss_group)

        else:
            raise NotImplementedError(opt.distill)


        loss = loss_div#+ param_dist(model_s, swa_model, opt.smoothing)


        kd_losses.update(loss.item(), input.size(0))



        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        #nn.utils.clip_grad_value_(model_s.parameters(), clip)
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()



    return kd_losses.avg

def train_distill_linear(epoch, train_loader, delete_loader, module_list, swa_model, criterion_list, optimizer, opt):
    """One epoch distillation"""
    # set modules as train()
    for module in module_list:
        module.train()
    # set teacher as eval()
    module_list[-1].eval()

    if opt.distill == 'abound':
        module_list[1].eval()
    elif opt.distill == 'factor':
        module_list[2].eval()

    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]
    criterion_kd = criterion_list[2]

    model_s = module_list[0]
    model_t = module_list[-1]

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for idx, (data, data_del) in enumerate(zip(train_loader, cycle(delete_loader))):
        if opt.distill in ['crd']:
            input, target, index, contrast_idx = data
            input_del, target_del, index_del, contrast_idx_del = data_del
        else:
            input, target = data
            input_del, target_del = data_del

        data_time.update(time.time() - end)

        input = input.float()
        input_del = input_del.float()

        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()
            input_del = input_del.cuda()
            target_del = target_del.cuda()
            if opt.distill in ['crd']:
                contrast_idx = contrast_idx.cuda()
                index = index.cuda()
                contrast_idx_del = contrast_idx_del.cuda()
                index_del = index_del.cuda()

        # ===================forward=====================
        preact = False
        if opt.distill in ['abound']:
            preact = True
        feat_s, logit_s = model_s(input, is_feat=True, preact=preact)
        feat_s_del, logit_s_del = model_s(input_del, is_feat=True, preact=preact)
        with torch.no_grad():
            feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
            feat_t = [f.detach() for f in feat_t]
            feat_t_del, logit_t_del = model_t(input_del, is_feat=True, preact=preact)
            feat_t_del = [f.detach() for f in feat_t_del]

        # cls + kl div
        loss_cls = criterion_cls(logit_s, target)
        loss_div = criterion_div(logit_s, logit_t)
        loss_div_del = criterion_div(logit_s_del, logit_t_del)

        # other kd beyond KL divergence
        if opt.distill == 'kd':
            loss_kd = 0
        elif opt.distill == 'hint':
            f_s = module_list[1](feat_s[opt.hint_layer])
            f_t = feat_t[opt.hint_layer]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'crd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
        elif opt.distill == 'attention':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'nst':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'similarity':
            g_s = [feat_s[-2]]
            g_t = [feat_t[-2]]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'rkd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'pkt':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'kdsvd':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'correlation':
            f_s = module_list[1](feat_s[-1])
            f_t = module_list[2](feat_t[-1])
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'vid':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
            loss_kd = sum(loss_group)
        elif opt.distill == 'abound':
            # can also add loss to this stage
            loss_kd = 0
        elif opt.distill == 'fsp':
            # can also add loss to this stage
            loss_kd = 0
        elif opt.distill == 'factor':
            factor_s = module_list[1](feat_s[-2])
            factor_t = module_list[2](feat_t[-2], is_factor=True)
            loss_kd = criterion_kd(factor_s, factor_t)
        else:
            raise NotImplementedError(opt.distill)

        loss = opt.gamma * loss_cls + opt.alpha * loss_div - opt.beta*loss_div_del

        loss = loss + param_dist(model_s, swa_model)



        acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))


        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()


        if idx % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, idx, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))
            sys.stdout.flush()


    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, losses.avg

def train_bad_teacher(epoch, train_loader, delete_loader, module_list, criterion_list, optimizer, opt):
    """One epoch distillation"""
    # set modules as train()
    for module in module_list:
        module.train()
    # set teacher as eval()
    module_list[1].eval()
    module_list[2].eval()

    if opt.distill == 'abound':
        module_list[1].eval()
    elif opt.distill == 'factor':
        module_list[2].eval()

    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]
    criterion_kd = criterion_list[2]

    model_s = module_list[0]
    model_gt = module_list[1]
    model_bt = module_list[2]

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    end = time.time()
    for idx, (data, data_del) in enumerate(zip(train_loader, cycle(delete_loader))):
        if opt.distill in ['crd']:
            input, target, index, contrast_idx = data
            input_del, target_del, index_del, contrast_idx_del = data_del
        else:
            input, target = data
            input_del, target_del = data_del

        data_time.update(time.time() - end)

        input = input.float()
        input_del = input_del.float()

        if torch.cuda.is_available():
            input = input.cuda()
            target = target.cuda()
            input_del = input_del.cuda()
            target_del = target_del.cuda()
            if opt.distill in ['crd']:
                contrast_idx = contrast_idx.cuda()
                index = index.cuda()
                contrast_idx_del = contrast_idx_del.cuda()
                index_del = index_del.cuda()

        # ===================forward=====================
        preact = False
        if opt.distill in ['abound']:
            preact = True
        logit_s = model_s(input)
        logit_s_del = model_s(input_del)
        with torch.no_grad():
            logit_gt = model_gt(input)
            logit_gt_del = model_gt(input_del)

            logit_bt = model_bt(input)
            logit_bt_del = model_bt(input_del)

        # cls + kl div
        loss_cls = criterion_cls(logit_s, target)
        loss_div = criterion_div(logit_s, logit_gt)
        loss_div_del = criterion_div(logit_s_del, logit_bt_del)

        # other kd beyond KL divergence
        if opt.distill == 'kd':
            loss_kd = 0
        elif opt.distill == 'hint':
            f_s = module_list[1](feat_s[opt.hint_layer])
            f_t = feat_t[opt.hint_layer]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'crd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
        elif opt.distill == 'attention':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'nst':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'similarity':
            g_s = [feat_s[-2]]
            g_t = [feat_t[-2]]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'rkd':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'pkt':
            f_s = feat_s[-1]
            f_t = feat_t[-1]
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'kdsvd':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = criterion_kd(g_s, g_t)
            loss_kd = sum(loss_group)
        elif opt.distill == 'correlation':
            f_s = module_list[1](feat_s[-1])
            f_t = module_list[2](feat_t[-1])
            loss_kd = criterion_kd(f_s, f_t)
        elif opt.distill == 'vid':
            g_s = feat_s[1:-1]
            g_t = feat_t[1:-1]
            loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
            loss_kd = sum(loss_group)
        elif opt.distill == 'abound':
            # can also add loss to this stage
            loss_kd = 0
        elif opt.distill == 'fsp':
            # can also add loss to this stage
            loss_kd = 0
        elif opt.distill == 'factor':
            factor_s = module_list[1](feat_s[-2])
            factor_t = module_list[2](feat_t[-2], is_factor=True)
            loss_kd = criterion_kd(factor_s, factor_t)
        else:
            raise NotImplementedError(opt.distill)

        #loss = opt.gamma * loss_cls + opt.alpha * loss_div - opt.beta*loss_div_del

        loss = opt.alpha*loss_div + opt.beta*loss_div_del


        acc1, acc5 = accuracy(logit_s, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))


        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()


        if idx % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, idx, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))
            sys.stdout.flush()


    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, losses.avg

def train_bcu(epoch, train_loader, delete_loader, module_list, criterion_list, optimizer, bin_cls_optimizer, opt):

    module_list[0].train()
    module_list[-1].eval()

    if opt.distill == 'abound':
        module_list[2].eval()
    elif opt.distill == 'factor':
        module_list[3].eval()


    model_s = module_list[0]
    model_t = module_list[-1]

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    bcu_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    bcu_accuracy = AverageMeter()

    end = time.time()
    idx = 0
    
    for (input, target), (del_input, del_target) in zip(train_loader, cycle(delete_loader)):
        #for counter, (del_input, del_target) in enumerate(delete_loader):
            #del_input, del_target = next(cycle(delete_loader))
        data_time.update(time.time() - end)

        input = input.float()
        del_input = del_input.float()
        if torch.cuda.is_available():
            input = input.cuda()
            del_input = del_input.cuda()
            target = target.cuda()
            del_target = del_target.cuda()

        # ===================forward=====================
        feat_t_r, logit_t_r = model_t(input, is_feat=True)
        feat_t_d, logit_t_d = model_t(del_input, is_feat=True)
        feat_s_r, logit_s_r = model_s(input, is_feat=True)
        feat_s_d, logit_s_d = model_s(del_input, is_feat=True)


        f_s_r = feat_s_r[-1]
        f_s_d = feat_s_d[-1]
        f_t_r = feat_t_r[-1]
        f_t_d = feat_t_d[-1]

        if opt.bcu_vec == "logits":
            loss1 = criterion_list[0](logit_s_r, target)
            loss2, bcu_acc = criterion_list[1](logit_s_r, logit_s_d, logit_t_r, logit_t_d)
        else:
            loss1 = criterion_list[0](logit_s_r, target)
            loss2, bcu_acc = criterion_list[1](f_s_r, f_s_d, f_t_r, f_t_d)


        bin_cls_optimizer.zero_grad()
        loss2.backward(retain_graph=True)
        bin_cls_optimizer.step()

        loss = opt.gamma*loss1 + (1-opt.gamma)*loss2.item()

        acc1, acc5 = accuracy(logit_s_r, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        bcu_losses.update(loss2.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))
        bcu_accuracy.update(bcu_acc, input.size(0))

        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()

        # tensorboard logger
        pass

        # print info
        if idx % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'BCULoss {loss2.val:.4f} ({loss2.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                  'BCUAcc {bcu_accuracy.val:.3f} ({bcu_accuracy.avg:.3f})'.format(
                   epoch, idx, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, loss2=bcu_losses, top1=top1, top5=top5, bcu_accuracy=bcu_accuracy))
            sys.stdout.flush()


        idx += 1

    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, losses.avg, bcu_losses.avg, bcu_accuracy.avg

def train_bcu_distill(epoch, train_loader, delete_loader, module_list, criterion_list, optimizer, bin_cls_optimizer, opt):

    module_list[0].train()
    module_list[-1].eval()

    if opt.distill == 'abound':
        module_list[2].eval()
    elif opt.distill == 'factor':
        module_list[3].eval()


    model_s = module_list[0]
    model_t = module_list[-1]

    criterion_cls = criterion_list[0]
    criterion_div = criterion_list[1]
    criterion_bcu = criterion_list[-1]

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    bcu_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    bcu_accuracy = AverageMeter()

    end = time.time()
    idx = 0
    
    for (input, target), (del_input, del_target) in zip(train_loader, cycle(delete_loader)):
        #for counter, (del_input, del_target) in enumerate(delete_loader):
            #del_input, del_target = next(cycle(delete_loader))
        data_time.update(time.time() - end)

        input = input.float()
        del_input = del_input.float()
        if torch.cuda.is_available():
            input = input.cuda()
            del_input = del_input.cuda()
            target = target.cuda()
            del_target = del_target.cuda()

        # ===================forward=====================
        feat_t_r, logit_t_r = model_t(input, is_feat=True)
        feat_t_d, logit_t_d = model_t(del_input, is_feat=True)
        feat_s_r, logit_s_r = model_s(input, is_feat=True)
        feat_s_d, logit_s_d = model_s(del_input, is_feat=True)


        f_s_r = feat_s_r[-1]
        f_s_d = feat_s_d[-1]
        f_t_r = feat_t_r[-1]
        f_t_d = feat_t_d[-1]

        if opt.bcu_vec == "logits":
            loss_cls = criterion_cls(logit_s_r, target)
            loss_bc, bcu_acc = criterion_bcu(logit_s_r, logit_s_d, logit_t_r, logit_t_d)
        else:
            loss_cls = criterion_cls(logit_s_r, target)
            loss_bc, bcu_acc = criterion_bcu(f_s_r, f_s_d, f_t_r, f_t_d)

        loss_div = criterion_div(logit_s_r, logit_t_r)

        bin_cls_optimizer.zero_grad()
        loss_bc.backward(retain_graph=True)
        bin_cls_optimizer.step()


        loss = opt.gamma * loss_cls + opt.alpha * loss_div + opt.beta * loss_bc.item()

        acc1, acc5 = accuracy(logit_s_r, target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        bcu_losses.update(loss_bc.item(), input.size(0))
        top1.update(acc1[0], input.size(0))
        top5.update(acc5[0], input.size(0))
        bcu_accuracy.update(bcu_acc, input.size(0))

        # ===================backward=====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ===================meters=====================
        batch_time.update(time.time() - end)
        end = time.time()

        # tensorboard logger
        pass

        # print info
        if idx % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'BCULoss {bcu_losses.val:.4f} ({bcu_losses.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                  'BCUAcc {bcu_accuracy.val:.3f} ({bcu_accuracy.avg:.3f})'.format(
                   epoch, idx, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, bcu_losses=bcu_losses, top1=top1, top5=top5, bcu_accuracy=bcu_accuracy))
            sys.stdout.flush()


        idx += 1

    print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
          .format(top1=top1, top5=top5))

    return top1.avg, losses.avg, bcu_losses.avg, bcu_accuracy.avg

def validate(val_loader, model, criterion, opt, quiet=False, trans=False):
    """validation"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for idx, (input, target) in enumerate(val_loader):

            if trans:
                input = {k: v.float() for k, v in input.items()}
                for idx_key in ("input_ids", "token_type_ids", "labels"):
                    if idx_key in input:
                        input[idx_key] = input[idx_key].long()
            else:
                input = input.float()
                
            if torch.cuda.is_available():
                if trans:
                    input = {k: v.cuda() for k, v in input.items()}
                else:
                    input = input.cuda()

                target = target.cuda()

            # compute output
            if trans:
                output = model(**input)
            else:
                output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 2))
            losses.update(loss.item(), target.size(0))
            top1.update(acc1[0], target.size(0))
            top5.update(acc5[0], target.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            
            if not quiet:
                if idx % opt.print_freq == 0:
                    print('Test: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                           idx, len(val_loader), batch_time=batch_time, loss=losses,
                           top1=top1, top5=top5))
        if not quiet:
            print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                  .format(top1=top1, top5=top5))

    return top1.avg, top5.avg, losses.avg
