#main Top-Down pruning

import os
import torch
import random 
import pickle
import argparse
import numpy as np 
import torch.optim
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from model import PreActResNet18 as ResNet18  
from utils import *

parser = argparse.ArgumentParser(description='PyTorch Cifar10_100 CIL Top-Down pruning')

#################### base setting #########################
parser.add_argument('--data_dir', help='The directory for data', default='CIL_data', type=str)
parser.add_argument('--dataset', type=str, default='cifar10', help='default dataset')
parser.add_argument('--save_dir', help='The directory used to save the trained models', default='TD_cifar10', type=str)
parser.add_argument('--save_data_path', help='The directory used to save the data', default='TD_cifar10/data', type=str)
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=None, help='random seed')

################## training setting ###########################
parser.add_argument('--epochs', default=100, type=int, help='number of total epochs to run')
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--decreasing_lr', default='60,80', help='decreasing strategy')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')

################## CIL setting ##################################
parser.add_argument('--classes_per_classifier', type=int, default=2, help='number of classes per classifier')
parser.add_argument('--classifiers', type=int, default=5, help='number of classifiers')
parser.add_argument('--unlabel_num', type=int, default=45, help='number of unlabel images')

################## pruning setting ##################################
parser.add_argument('--iter_epochs', default=30, type=int, help='number of total epochs to run')
parser.add_argument('--percent', default=0.2, type=float, help='pruning rate')
parser.add_argument('--rewind', type=str, default='zero', help='rewind_type')
parser.add_argument('--flag', default='1,1,2,3,5', help='pruning times for each task')


best_prec1 = 0

def main():

    global args, best_prec1
    args = parser.parse_args()
    print(args)

    #pre-define pruning schedule
    prune_steps = [x-1 for x in list(map(int, args.flag.split(',')))]
    assert len(prune_steps) == args.classifiers
    pruning_flag = False 
    prune_stage = 0

    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
    overall_result = {}

    if args.dataset == 'cifar10':

        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
        # dataset transform
        train_trans = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize
            ])

        val_trans = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                normalize
            ])

        path_head = args.data_dir
        train_path = os.path.join(path_head,'cifar10_train.pkl')
        val_path = os.path.join(path_head,'cifar10_val.pkl')
        old_img_path = os.path.join(path_head,'cifar10_save_100.pkl') # saved images for learned tasks
        test_path = os.path.join(path_head,'cifar10_test.pkl')
        unlabel_path = os.path.join(path_head,'cifar10_80m_150k.pkl')

        sequence = [9,8,7,1,5,0,3,4,6,2]

        ## with another random task sequence
        # if os.path.isfile('cifar10_class_order.txt'):
        #     sequence = np.loadtxt('cifar10_class_order.txt')
        # else:
        #     sequence = np.random.permutation(10)
        #     np.savetxt('cifar10_class_order.txt', sequence)
        # print('cifar10 incremental task sequence:', sequence)

    else:
        print('do not support dataset of '+args.dataset)
        assert 0

    all_states = args.classifiers
    class_per_state = args.classes_per_classifier

    torch.cuda.set_device(int(args.gpu))

    if args.seed:
        setup_seed(args.seed)

    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(args.save_data_path, exist_ok=True)

    #setup logger
    log_result = Logger(os.path.join(args.save_dir, 'log_results.txt'))
    name_list = ['Task{}'.format(i+1) for i in range(all_states)]
    name_list.append('Mean Acc')
    log_result.append(['current state = 1'])
    log_result.append(name_list)

    criterion = nn.CrossEntropyLoss()

    model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
    model.cuda()
    print('*****************************************************************************')
    print('start training task1 (only the first classifier is trained)')
    print('*****************************************************************************')

    torch.save({
        'state_dict': model.state_dict(),
    }, os.path.join(args.save_dir, 'task0_checkpoint_weight.pt'))

    #dataset 
    train_dataset = Labeled_dataset(train_path, train_trans, sequence[:class_per_state], offset=0)
    val_dataset = Labeled_dataset(val_path , val_trans, sequence[:class_per_state], offset=0)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=256, shuffle=True,
        num_workers=2, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=256, shuffle=False,
        num_workers=2, pin_memory=True)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)

    train_acc = []
    ta=[]
    for epoch in range(args.epochs):
        print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

        train_accuracy = train(train_loader, model, criterion, optimizer, epoch)

        prec1 = validate(val_loader, model, criterion, fc_num=1, if_main=True)

        train_acc.append(train_accuracy)
        ta.append(prec1)

        scheduler.step()

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer,
        }, is_best, args.save_dir, filename='task1_checkpoint.pt', best_name='task1_best_model.pt')

    for current_state in range(1, all_states+1):

        best_prec1 = 0
        model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
        model.cuda()
        model_path = os.path.join(args.save_dir, 'task'+str(current_state)+'_best_model.pt')
        new_dict = torch.load(model_path, map_location=torch.device('cuda:'+str(args.gpu)))
        
        if pruning_flag:
            print('pruning with custom mask')
            current_mask = extract_mask(new_dict['state_dict'])
            prune_model_custom(model, current_mask)

        remain_weight = check_sparsity(model)
        model.load_state_dict(new_dict['state_dict'])
        
        print('*****************************************************************************')
        print('start training task'+str(current_state+1))
        print('best epoch', new_dict['epoch'])
        print('model loaded', model_path)
        print('remain weight size = {}'.format(remain_weight))
        print('*****************************************************************************')


        # testing accuracy & generate feature of unlabeled data using original model
        bal_acc = []
        log_acc = ['None' for i in range(all_states+1)]
        
        for test_iter in range(current_state):
            val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
            val_loader_test = torch.utils.data.DataLoader(
                val_dataset_test,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
            ta_bal = validate(val_loader_test, model, criterion, fc_num = current_state, if_main= True)
            ta_rand = validate(val_loader_test, model, criterion, fc_num = current_state, if_main= False)

            bal_acc.append(ta_bal)
            log_acc[test_iter] = ta_bal
            print('TA_balance = ', ta_bal)
            print('TA_random = ', ta_rand)

        mean_acc = np.mean(np.array(bal_acc))
        log_acc[-1] = mean_acc
        print('******************************************************')
        print('mean_acc = ', mean_acc, current_state)
        print('******************************************************')
        log_result.append(log_acc)
        log_result.append(['remain weight size = {}'.format(100-remain_weight)])
        log_result.append(['*'*50])
        log_result.append(['current state = {}'.format(current_state+1)])
        log_result.append(name_list)

        old_dataset =  Labeled_dataset(old_img_path, val_trans, sequence[:current_state*class_per_state], offset=0)
        k15_dataset = k150_dataset(unlabel_path, val_trans)
        old_loader = torch.utils.data.DataLoader(
            old_dataset,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        k15_loader = torch.utils.data.DataLoader(
            k15_dataset,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        all_feature = feature_extract_old(k15_loader, model, criterion)
        target_feature = feature_extract_old(old_loader, model, criterion)

        index_select = select_knn(target_feature, all_feature, args.unlabel_num)
        np.save(os.path.join(args.save_data_path,'task'+str(current_state)+'_select_index.npy'), np.array(index_select))
        all_data = pickle.load(open(unlabel_path,'rb'))
        all_image = all_data['data']
        all_label = all_data['label']
        new_data = {}
        new_data['data'] = all_image[index_select,:,:,:]
        new_data['label'] = all_label[index_select]
        print(new_data['data'].shape)
        print(new_data['label'].shape)
        pickle.dump(new_data, open(os.path.join(args.save_data_path,'selected_unlabel_task'+str(current_state)+'.pkl'), 'wb'))
        
        extract_dataset = k150_dataset(os.path.join(args.save_data_path,'selected_unlabel_task'+str(current_state)+'.pkl'), train_trans)
        extract_loader = torch.utils.data.DataLoader(
            extract_dataset,
            batch_size=256, shuffle=True,
            num_workers=2, pin_memory=True)
        label_extract(extract_loader, model, criterion, args.save_data_path, fc_num=current_state)


        #pruning stage (loading last epoch model from last task) 
        last_checkpoint_weight = torch.load(os.path.join(args.save_dir, 'task'+str(current_state)+'_checkpoint.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
        model.load_state_dict(last_checkpoint_weight['state_dict'])

        if prune_steps[prune_stage]>0:

            if current_state == 1:
                train_dataset = Labeled_dataset(train_path, train_trans, sequence[:class_per_state], offset=0)
                train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=256, shuffle=True,
                    num_workers=2, pin_memory=True)
                optimizer = torch.optim.SGD(model.parameters(), args.lr/100,
                                            momentum=args.momentum,
                                            weight_decay=args.weight_decay)

                for prune_iter in range(prune_steps[prune_stage]):
                    print('starting pruning',prune_steps[prune_stage])
                    pruning_model(model,args.percent)
                    check_sparsity(model)
                    pruning_flag = True

                    for epoch in range(args.iter_epochs):
                        print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))
                        train_accuracy = train(train_loader, model, criterion, optimizer, epoch)

            else:
                current_state -= 1
                # training mode 
                train_dataset = Labeled_dataset(train_path, train_trans, sequence[current_state*class_per_state:(current_state+1)*class_per_state], offset=current_state*class_per_state)
                train_old_dataset = Labeled_dataset(old_img_path, train_trans, sequence[:current_state*class_per_state], offset=0)
                train_random_dataset = torch.utils.data.dataset.ConcatDataset((train_dataset,train_old_dataset))

                unlabel_dataset = unlabel_feature_dataset(os.path.join(args.save_data_path,'task'+str(current_state)))
                train_loader_random = torch.utils.data.DataLoader(
                    train_random_dataset,
                    batch_size=64, shuffle=True,
                    num_workers=2, pin_memory=True)

                train_loader_balance_new = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=int(64/(1+current_state)), shuffle=True,
                    num_workers=2, pin_memory=True)

                train_loader_balance_old = torch.utils.data.DataLoader(
                    train_old_dataset,
                    batch_size=int(64*current_state/(1+current_state)), shuffle=True,
                    num_workers=2, pin_memory=True)

                unlabel_loader = torch.utils.data.DataLoader(
                    unlabel_dataset,
                    batch_size=64, shuffle=True,
                    num_workers=2, pin_memory=True)

                optimizer = torch.optim.SGD(model.parameters(), args.lr/100,
                                            momentum=args.momentum,
                                            weight_decay=args.weight_decay)
                                            
                for prune_iter in range(prune_steps[prune_stage]):
                    print('starting pruning',prune_steps[prune_stage])
                    pruning_model(model,args.percent)
                    check_sparsity(model)
                    pruning_flag = True

                    for epoch in range(args.iter_epochs):
                        print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))
                        train_accuracy = train_lwf(train_loader_random, train_loader_balance_new, train_loader_balance_old, unlabel_loader, model, criterion, optimizer, epoch, current_state+1)

                current_state+=1
            
        if prune_steps[prune_stage] >= 0:
            pruning_model(model,args.percent)
            check_sparsity(model)
            pruning_flag = True
            
        prune_stage += 1

        #rewinding stage
        if args.rewind == 'best':
            print('rewind best weight')
            model_path = os.path.join(args.save_dir, 'task'+str(current_state)+'_best_model.pt')
            new_dict = torch.load(model_path, map_location=torch.device('cuda:'+str(args.gpu)))['state_dict']
            weight_orig_dict = rewind(model, new_dict, pruning_flag)
            model.load_state_dict(weight_orig_dict, strict=False)

        elif args.rewind == 'zero':
            print('rewind zero weight')
            model_path = os.path.join(args.save_dir, 'task0_checkpoint_weight.pt')
            new_dict = torch.load(model_path, map_location=torch.device('cuda:'+str(args.gpu)))['state_dict']
            weight_orig_dict = rewind(model, new_dict, pruning_flag)
            model.load_state_dict(weight_orig_dict, strict=False)

        elif args.rewind == 'rand':
            random_seed = np.random.randint(1000)
            setup_seed(random_seed)
            print('seed = ', random_seed)
            random_model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
            new_dict = random_model.state_dict()
            weight_orig_dict = rewind(model, new_dict, pruning_flag)
            model.load_state_dict(weight_orig_dict, strict=False)

        check_sparsity(model)
        

        # test after rewinding
        bal_acc = []
        for test_iter in range(current_state):
            val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
            val_loader_test = torch.utils.data.DataLoader(
                val_dataset_test,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
            ta_bal = validate(val_loader_test, model, criterion, fc_num = current_state, if_main= True)
            ta_rand = validate(val_loader_test, model, criterion, fc_num = current_state, if_main= False)
            bal_acc.append(ta_bal)
            print('TA_balance = ', ta_bal)
            print('TA_random = ', ta_rand)
        print('******************************************************')
        print('prun_mean_acc = ', np.mean(np.array(bal_acc)), current_state)
        print('******************************************************')


        if current_state == all_states:
            print('re-train task{}'.format(current_state))
            current_state -= 1
            save_state = current_state+2
        else:
            save_state = current_state+1

        # training stage
        train_dataset = Labeled_dataset(train_path, train_trans, sequence[current_state*class_per_state:(current_state+1)*class_per_state], offset=current_state*class_per_state)
        val_dataset = Labeled_dataset(val_path, val_trans, sequence[:(current_state+1)*class_per_state], offset=0)

        train_old_dataset = Labeled_dataset(old_img_path, train_trans, sequence[:current_state*class_per_state], offset=0)
        train_random_dataset = torch.utils.data.dataset.ConcatDataset((train_dataset,train_old_dataset))

        unlabel_dataset = unlabel_feature_dataset(os.path.join(args.save_data_path,'task'+str(current_state)))
        train_loader_random = torch.utils.data.DataLoader(
            train_random_dataset,
            batch_size=64, shuffle=True,
            num_workers=2, pin_memory=True)

        train_loader_balance_new = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=int(64/(1+current_state)), shuffle=True,
            num_workers=2, pin_memory=True)

        train_loader_balance_old = torch.utils.data.DataLoader(
            train_old_dataset,
            batch_size=int(64*current_state/(1+current_state)), shuffle=True,
            num_workers=2, pin_memory=True)

        unlabel_loader = torch.utils.data.DataLoader(
            unlabel_dataset,
            batch_size=64, shuffle=True,
            num_workers=2, pin_memory=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)


        train_acc = []
        ta=[]
        for epoch in range(args.epochs):
            print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

            train_accuracy = train_lwf(train_loader_random, train_loader_balance_new, train_loader_balance_old, unlabel_loader, model, criterion, optimizer, epoch, current_state+1)

            prec1 = validate(val_loader, model, criterion, fc_num=current_state+1, if_main=True)

            train_acc.append(train_accuracy)
            ta.append(prec1)

            scheduler.step()

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer,
            }, is_best, args.save_dir, filename='task{}_checkpoint.pt'.format(save_state), best_name='task{}_best_model.pt'.format(save_state))

    # test
    current_state = all_states+1
    model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
    model.cuda()    
    model_path = os.path.join(args.save_dir, 'task'+str(current_state)+'_best_model.pt')
    new_dict = torch.load(model_path, map_location=torch.device('cuda:'+str(args.gpu)))

    if pruning_flag:
        print('pruning with custom mask')
        current_mask = extract_mask(new_dict['state_dict'])
        prune_model_custom(model, current_mask)

    model.load_state_dict(new_dict['state_dict'])
    remain_weight = check_sparsity(model)

    print('*****************************************************************************')
    print('start testing task'+str(current_state))
    print('remain weight size = {}'.format(100-remain_weight))
    print('*****************************************************************************')


    # testing accuracy & generate feature of unlabeled data using original model
    bal_acc = []
    log_acc = ['None' for i in range(all_states+1)]
    
    for test_iter in range(all_states):
        val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
        val_loader_test = torch.utils.data.DataLoader(
            val_dataset_test,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
        ta_bal = validate(val_loader_test, model, criterion, fc_num = all_states, if_main= True)
        ta_rand = validate(val_loader_test, model, criterion, fc_num = all_states, if_main= False)

        bal_acc.append(ta_bal)
        log_acc[test_iter] = ta_bal
        print('TA_balance = ', ta_bal)
        print('TA_random = ', ta_rand)

    mean_acc = np.mean(np.array(bal_acc))
    log_acc[-1] = mean_acc
    print('******************************************************')
    print('mean_acc = ', mean_acc, current_state)
    print('******************************************************')
    log_result.append(log_acc)
    log_result.append(['remain weight size = {}'.format(100-remain_weight)])
    log_result.append(['*'*50])



def train(train_loader, model, criterion, optimizer, epoch):
    
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    for i, (input, target) in enumerate(train_loader):


        input = input.cuda()
        target = target.long().cuda()

        optimizer.zero_grad()

        input_data_main = {'x': input, 'out_idx':1, 'main_fc': True}
        input_data = {'x': input, 'out_idx':1, 'main_fc': False}

        output_gt = model(**input_data)
        loss_rand = criterion(output_gt, target)

        output_gt_main = model(**input_data_main)
        loss_balance = criterion(output_gt_main, target)

        loss = loss_rand+loss_balance

        loss.backward()
        optimizer.step()

        output = output_gt_main.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                    epoch, i, len(train_loader), loss=losses, top1=top1))

    print('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg

def validate(val_loader, model, criterion, fc_num=1, if_main=False):

    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        input = input.cuda()
        target = target.long().cuda()

        input_data = {'x': input, 'out_idx':fc_num, 'main_fc': if_main}        

        # compute output
        with torch.no_grad():
            output = model(**input_data)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i, len(val_loader), loss=losses, top1=top1))

    print('valid_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg

def train_lwf(rand_loader, new_balance_loader, old_balance_loader, unlabel_loader, model, criterion, optimizer, epoch, fc_num):
    
    losses = AverageMeter()
    top1 = AverageMeter()

    coef_old = int(64*(fc_num-1)/fc_num)/64
    coef_new = int(64/fc_num)/64

    # switch to train mode
    model.train()

    new_balance = iter(new_balance_loader)
    old_balance = iter(old_balance_loader)
    unlabel = iter(unlabel_loader)

    for i, (input, target) in enumerate(rand_loader):

        try:
            bal_new_img, bal_new_target = next(new_balance)
        except StopIteration:
            new_balance = iter(new_balance_loader)
            bal_new_img, bal_new_target = next(new_balance)

        try:
            bal_old_img, bal_old_target = next(old_balance)
        except StopIteration:
            old_balance = iter(old_balance_loader)
            bal_old_img, bal_old_target = next(old_balance)

        try:
            unlab_img, unlab_target, unlab_target_main = next(unlabel)
        except StopIteration:
            unlabel = iter(unlabel_loader)
            unlab_img, unlab_target, unlab_target_main = next(unlabel)
        
        bal_new_img = bal_new_img.cuda()
        bal_old_img = bal_old_img.cuda()
        unlab_img = unlab_img.cuda()
        input = input.cuda()

        bal_new_target = bal_new_target.long().cuda()
        bal_old_target = bal_old_target.long().cuda()
        target = target.long().cuda()

        unlab_target = unlab_target.cuda()
        unlab_target_main = unlab_target_main.cuda()

        inputs_random = {'x': input, 'out_idx': fc_num, 'main_fc': False}
        inputs_balance_new = {'x': bal_new_img, 'out_idx': fc_num, 'main_fc': True}
        inputs_balance_old = {'x': bal_old_img, 'out_idx': fc_num, 'main_fc': True}
        inputs_unlabel_random = {'x': unlab_img, 'out_idx': fc_num-1, 'main_fc': False}
        inputs_unlabel_balance = {'x': unlab_img, 'out_idx': fc_num-1, 'main_fc': True}
        
        optimizer.zero_grad()

        # random input
        output_gt = model(**inputs_random)
        loss_rand = criterion(output_gt, target)

        # balance inputs
        output_bal_new = model(**inputs_balance_new)
        output_bal_old = model(**inputs_balance_old)
        loss_balance = criterion(output_bal_new, bal_new_target)*coef_new + criterion(output_bal_old, bal_old_target)*coef_old

        #unlabel 
        unlab_output_rand = model(**inputs_unlabel_random)
        loss_unlabel_rand = loss_fn_kd(unlab_output_rand, unlab_target, T=2) 
        unlab_output_bal = model(**inputs_unlabel_balance)
        loss_unlabel_balance = loss_fn_kd(unlab_output_bal, unlab_target_main, T=2) 

        all_rand_loss = loss_rand + loss_unlabel_rand
        all_bal_loss = loss_balance + loss_unlabel_balance
        
        loss = all_rand_loss + all_bal_loss*0.5

        loss.backward()
        optimizer.step()

        output = output_gt.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                    epoch, i, len(rand_loader), loss=losses, top1=top1))

    print('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg


if __name__ == '__main__':
    main()