import argparse
import os
import torch
import random 
import numpy as np 
import torch.optim
import pickle
import torch.nn as nn
import torch.utils.data
from copy import deepcopy
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from utils import *
from model import PreActResNet18 as ResNet18  

parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training')

#################### base setting #########################
parser.add_argument('--data_dir', help='The directory for data', default='CIL_data', type=str)
parser.add_argument('--dataset', type=str, default='cifar10', help='default dataset')
parser.add_argument('--save_dir', help='The directory used to save the trained models', default='BU_cifar10', type=str)
parser.add_argument('--save_data_path', help='The directory used to save the data', default='BU_cifar10/data', type=str)
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=None, help='random seed')

################## training setting ###########################
parser.add_argument('--epochs', default=100, type=int, help='number of total epochs to run')
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--decreasing_lr', default='60,80', help='decreasing strategy')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')

################## CIL setting ##################################
parser.add_argument('--classes_per_classifier', type=int, default=2, help='number of classes per classifier')
parser.add_argument('--classifiers', type=int, default=5, help='number of classifiers')
parser.add_argument('--unlabel_num', type=int, default=45, help='number of unlabel images')

################## pruning setting ##################################
parser.add_argument('--iter_epochs', default=30, type=int, help='number of total epochs to run')
parser.add_argument('--percent', default=0.2, type=float, help='pruning rate')
parser.add_argument('--accept_decay', default=1, type=float, help='accuracy decay which can be accepted')
parser.add_argument('--deacc', default=2, type=float, help='decay accuracy threshold')
parser.add_argument('--max_iter_prun', default=26, type=int, help='maximum times for iterative pruning during each task')
parser.add_argument('--base_sparsity', default=90, type=int, help='base sparsity during iterative pruning to escape the impact of randomness')


best_prec1 = 0

def main():

    global args, best_prec1
    args = parser.parse_args()
    print(args)

    decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
    overall_result = {}

    if args.dataset == 'cifar10':

        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
        # dataset transform
        train_trans = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize
            ])

        val_trans = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                normalize
            ])

        path_head = args.data_dir
        train_path = os.path.join(path_head,'cifar10_train.pkl')
        val_path = os.path.join(path_head,'cifar10_val.pkl')
        old_img_path = os.path.join(path_head,'cifar10_save_100.pkl') # saved images for learned tasks
        test_path = os.path.join(path_head,'cifar10_test.pkl')
        unlabel_path = os.path.join(path_head,'cifar10_80m_150k.pkl')

        sequence = [9,8,7,1,5,0,3,4,6,2]

        ## with another random task sequence
        # if os.path.isfile('cifar10_class_order.txt'):
        #     sequence = np.loadtxt('cifar10_class_order.txt')
        # else:
        #     sequence = np.random.permutation(10)
        #     np.savetxt('cifar10_class_order.txt', sequence)
        # print('cifar10 incremental task sequence:', sequence)

    else:
        print('do not support dataset of '+args.dataset)
        assert 0

    all_states = args.classifiers
    class_per_state = args.classes_per_classifier

    torch.cuda.set_device(int(args.gpu))

    if args.seed:
        setup_seed(args.seed)

    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(args.save_data_path, exist_ok=True)

    #setup logger
    log_result = Logger(os.path.join(args.save_dir, 'log_results.txt'))
    name_list = ['Task{}'.format(i+1) for i in range(all_states)]
    name_list.append('Mean Acc')
    log_result.append(['current state = 1'])
    

    criterion = nn.CrossEntropyLoss()

    #define model 
    model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
    model.cuda()

    #if need to retrain base_task 
    print('*****************************************************************************')
    print('start training task1 (only the first classifier is trained)')
    print('*****************************************************************************')

    torch.save({
        'state_dict': model.state_dict(),
    }, os.path.join(args.save_dir, 'task0_checkpoint_weight.pt'))

    #dataset 
    train_dataset = Labeled_dataset(train_path, train_trans, sequence[:class_per_state], offset=0)
    val_dataset = Labeled_dataset(val_path, val_trans, sequence[:class_per_state], offset=0)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=256, shuffle=True,
        num_workers=2, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=256, shuffle=False,
        num_workers=2, pin_memory=True)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)

    train_acc = []
    ta=[]
    for epoch in range(args.epochs):
        print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

        train_accuracy = train(train_loader, model, criterion, optimizer, epoch)

        prec1 = validate(val_loader, model, criterion, fc_num=1, if_main=True)

        train_acc.append(train_accuracy)
        ta.append(prec1)

        scheduler.step()

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer,
        }, is_best, args.save_dir, filename='task1_checkpoint.pt', best_name='task1_best_model.pt')

    # loading best weight to test
    best_weight = torch.load(os.path.join(args.save_dir, 'task1_best_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(best_weight['state_dict'])
    print('loading the best weight of task1 (base_task)')
    print('starting test the accuracy of test_dataset ')

    # testing accuracy & generate feature of unlabeled data using original model
    num_test_dataset = 1
    test_result = np.zeros((2,num_test_dataset))
    log_acc = ['None' for i in range(all_states+1)]
    
    for test_iter in range(num_test_dataset):
        val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
        val_loader_test = torch.utils.data.DataLoader(
            val_dataset_test,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
        ta_bal = validate(val_loader_test, model, criterion, fc_num = 1, if_main= True)
        ta_rand = validate(val_loader_test, model, criterion, fc_num = 1, if_main= False)

        print('TA_balance = %.2f'%ta_bal, 'TA_random = %.2f'%ta_rand)
        test_result[0,test_iter] = ta_rand
        test_result[1,test_iter] = ta_bal
        log_acc[test_iter] = ta_bal

    mean_acc = np.mean(test_result[1,test_iter])
    log_acc[-1] = mean_acc
    log_result.append(name_list)
    log_result.append(log_acc)
    
    overall_result['task1'] = test_result
    pickle.dump(overall_result, open(os.path.join(args.save_dir, 'all_result.pkl'),'wb'))

    # generate unlabel softlogits according to best model
    best_weight = torch.load(os.path.join(args.save_dir, 'task1_best_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(best_weight['state_dict'])
    print('loading the best weight of task1 (base_task)')

    print('starting generate unlabel dataset according to task1 best model')
    old_dataset =  Labeled_dataset(old_img_path, val_trans, sequence[:class_per_state], offset=0)
    k15_dataset = k150_dataset(unlabel_path, val_trans)
    old_loader = torch.utils.data.DataLoader(
        old_dataset,
        batch_size=256, shuffle=False,
        num_workers=2, pin_memory=True)

    k15_loader = torch.utils.data.DataLoader(
        k15_dataset,
        batch_size=256, shuffle=False,
        num_workers=2, pin_memory=True)

    all_feature = feature_extract_old(k15_loader, model, criterion)
    target_feature = feature_extract_old(old_loader, model, criterion)

    index_select = select_knn(target_feature, all_feature, args.unlabel_num)
    np.save(os.path.join(args.save_data_path,'task1_select_index.npy'), np.array(index_select))

    all_data = pickle.load(open(unlabel_path,'rb'))
    all_image = all_data['data']
    all_label = all_data['label']
    new_data = {}
    new_data['data'] = all_image[index_select,:,:,:]
    new_data['label'] = all_label[index_select]
    print(new_data['data'].shape)
    print(new_data['label'].shape)
    pickle.dump(new_data, open(os.path.join(args.save_data_path,'selected_unlabel_task1.pkl'), 'wb'))
    
    extract_dataset = k150_dataset(os.path.join(args.save_data_path,'selected_unlabel_task1.pkl'), train_trans)
    extract_loader = torch.utils.data.DataLoader(
        extract_dataset,
        batch_size=256, shuffle=False,
        num_workers=2, pin_memory=True)
    label_extract(extract_loader, model, criterion, args.save_data_path, fc_num=1)

    best_weight = torch.load(os.path.join(args.save_dir, 'task1_best_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(best_weight['state_dict'])
    print('loading the last weight of task1 (checkpoint) (base_task)')

    overall_result = pickle.load(open(os.path.join(args.save_dir, 'all_result.pkl'),'rb'))
    baseline_result = overall_result['task1']
    baseline_acc = np.mean(baseline_result[1,:])
    print(baseline_result)
    print('baseline acc (balance branch) = ', baseline_acc)

    acc_decay = 0
    pruning_times = 0
    zero_rate = 0
    train_dataset = Labeled_dataset(train_path, train_trans, sequence[:class_per_state], offset=0)
    test_dataset = Labeled_dataset(test_path, val_trans, sequence[:class_per_state], offset=0)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=256, shuffle=True,
        num_workers=2, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=256, shuffle=False,
        num_workers=2, pin_memory=True)

    log_result.append(['*'*50])
    log_result.append(['Pruning record'])
    log_result.append(['pruning_times', 'best_ta', 'baseline', 'zero_rate'])     

    while(acc_decay < args.deacc or zero_rate < args.base_sparsity):

        best_test_acc = 0

        # maybe share memory
        save_model_dict = deepcopy(model.state_dict())

        print('starting pruning')
        pruning_model(model,args.percent)
        zero_rate = check_sparsity(model)
        pruning_times +=1

        optimizer = torch.optim.SGD(model.parameters(), args.lr/100,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        for epoch in range(args.iter_epochs):
            print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))
            train_accuracy = train(train_loader, model, criterion, optimizer, epoch)
            test_acc = validate(test_loader, model, criterion, fc_num=1, if_main=True)

            best_test_acc = max(best_test_acc, test_acc)

        acc_decay = baseline_acc - best_test_acc
        print('********************************************')
        print('pruning_times, best_ta, baseline, zero_rate')
        print(pruning_times, best_test_acc, baseline_acc, zero_rate)
        print('********************************************')

        log_result.append([pruning_times, best_test_acc, baseline_acc, zero_rate])

        if pruning_times > args.max_iter_prun:
            break

    torch.save(save_model_dict, os.path.join(args.save_dir, 'task1_prune_weight.pt'))
    mask_dict = extract_mask(save_model_dict)
    print('*************current mask*******************')
    check_sparsity_mask(mask_dict)
    print('********************************************')
    torch.save(mask_dict, os.path.join(args.save_dir, 'current_mask.pt'))

    print('*****************************************************************************')
    print('start re-training task1')
    print('*****************************************************************************')
    model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
    model.cuda()
    init_weight = torch.load(os.path.join(args.save_dir, 'task0_checkpoint_weight.pt'), map_location=torch.device('cuda:'+str(args.gpu)))['state_dict']
    init_mask = torch.load(os.path.join(args.save_dir, 'current_mask.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(init_weight)
    prune_model_custom(model, init_mask)
    remain_weight = check_sparsity(model,True)

    #dataset 
    train_dataset = Labeled_dataset(train_path, train_trans, sequence[:class_per_state], offset=0)
    val_dataset = Labeled_dataset(val_path, val_trans, sequence[:class_per_state], offset=0)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=256, shuffle=True,
        num_workers=2, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=256, shuffle=False,
        num_workers=2, pin_memory=True)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)
    
    best_prec1 = 0
    train_acc = []
    ta=[]
    for epoch in range(args.epochs):
        print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

        train_accuracy = train(train_loader, model, criterion, optimizer, epoch)

        prec1 = validate(val_loader, model, criterion, fc_num=1, if_main=True)

        train_acc.append(train_accuracy)
        ta.append(prec1)

        scheduler.step()

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer,
        }, is_best, args.save_dir, filename='re_task1_checkpoint.pt', best_name='re_task1_best_model.pt')

    print('testing after re-train task1')
    # retrain pruned model
    best_weight = torch.load(os.path.join(args.save_dir, 're_task1_best_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(best_weight['state_dict'])
    remain_weight = check_sparsity(model)

    # testing accuracy & generate feature of unlabeled data using original model
    log_acc = ['None' for i in range(all_states+1)]
    test_result = []
    for test_iter in range(1):
        val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
        val_loader_test = torch.utils.data.DataLoader(
            val_dataset_test,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
        ta_bal = validate(val_loader_test, model, criterion, fc_num = 1, if_main= True)
        ta_rand = validate(val_loader_test, model, criterion, fc_num = 1, if_main= False)

        print('TA_balance = %.2f'%ta_bal, 'TA_random = %.2f'%ta_rand)
        test_result.append(ta_bal)
        log_acc[test_iter] = ta_bal

    mean_acc = np.mean(np.array(test_result))
    log_acc[-1] = mean_acc

    log_result.append(['*'*50])
    log_result.append(['re-train task1 result'])
    log_result.append(['remain weight size = {}'.format(100-remain_weight)])
    log_result.append(name_list)
    log_result.append(log_acc)

    for current_state in range(1, all_states):

        #start training next task 
        model_path = os.path.join(args.save_dir, 'task'+str(current_state)+'_best_model.pt')
        new_dict = torch.load(model_path, map_location=torch.device('cuda:'+str(args.gpu)))['state_dict'] 

        full_model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
        full_model.load_state_dict(new_dict)
        full_model.cuda()

        prun_model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
        prun_model.load_state_dict(new_dict)
        prun_model.cuda()
        current_mask = torch.load(os.path.join(args.save_dir, 'current_mask.pt'), map_location=torch.device('cuda:'+str(args.gpu)))

        prune_model_custom(prun_model, current_mask)
        check_sparsity(prun_model)

        
        # training mode 
        train_dataset = Labeled_dataset(train_path, train_trans, sequence[current_state*class_per_state:(current_state+1)*class_per_state], offset=current_state*class_per_state)
        val_dataset = Labeled_dataset(val_path, val_trans, sequence[:(current_state+1)*class_per_state], offset=0)

        train_old_dataset = Labeled_dataset(old_img_path, train_trans, sequence[:current_state*class_per_state], offset=0)
        train_random_dataset = torch.utils.data.dataset.ConcatDataset((train_dataset,train_old_dataset))

        unlabel_dataset = unlabel_feature_dataset(os.path.join(args.save_data_path,'task'+str(current_state)))
        test_dataset = Labeled_dataset(test_path , val_trans, sequence[:(current_state+1)*class_per_state], offset=0)

        train_loader_random = torch.utils.data.DataLoader(
            train_random_dataset,
            batch_size=64, shuffle=True,
            num_workers=2, pin_memory=True)

        train_loader_balance_new = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=int(64/(1+current_state)), shuffle=True,
            num_workers=2, pin_memory=True)

        train_loader_balance_old = torch.utils.data.DataLoader(
            train_old_dataset,
            batch_size=int(64*current_state/(1+current_state)), shuffle=True,
            num_workers=2, pin_memory=True)

        unlabel_loader = torch.utils.data.DataLoader(
            unlabel_dataset,
            batch_size=64, shuffle=True,
            num_workers=2, pin_memory=True)

        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)


        #training for full model     
        optimizer = torch.optim.SGD(full_model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)

        best_prec1 = 0
        train_acc = []
        ta=[]
        for epoch in range(args.epochs):
            print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

            train_accuracy = train_lwf(train_loader_random, train_loader_balance_new, train_loader_balance_old, unlabel_loader, full_model, criterion, optimizer, epoch, current_state+1)

            prec1 = validate(val_loader, full_model, criterion, fc_num=current_state+1, if_main=True)

            train_acc.append(train_accuracy)
            ta.append(prec1)

            scheduler.step()

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': full_model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer,
            }, is_best, args.save_dir, filename='task{}_checkpoint.pt'.format(current_state+1), best_name='task{}_best_model.pt'.format(current_state+1))

        #training for prune model     
        optimizer = torch.optim.SGD(prun_model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)

        best_prec1 = 0
        train_acc = []
        ta=[]
        for epoch in range(args.epochs):
            print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

            train_accuracy = train_lwf(train_loader_random, train_loader_balance_new, train_loader_balance_old, unlabel_loader, prun_model, criterion, optimizer, epoch, current_state+1)

            prec1 = validate(val_loader, prun_model, criterion, fc_num=current_state+1, if_main=True)

            train_acc.append(train_accuracy)
            ta.append(prec1)

            scheduler.step()

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': prun_model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer,
            }, is_best, args.save_dir, filename='task{}_prun_checkpoint.pt'.format(current_state+1), best_name='task{}_best_prun_model.pt'.format(current_state+1))


        #compare full model with prun model 
        #rewind to best weight 
        full_model.load_state_dict(torch.load(os.path.join(args.save_dir, 'task'+str(current_state+1)+'_best_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))['state_dict'])
        prun_model.load_state_dict(torch.load(os.path.join(args.save_dir, 'task'+str(current_state+1)+'_best_prun_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))['state_dict'])


        log_result.append(['*'*50])
        log_result.append(['full model acc state{}'.format(current_state+1)])
        log_acc = ['None' for i in range(all_states+1)]
        
        num_test_dataset = current_state+1
        test_result = np.zeros((2,num_test_dataset))
        for test_iter in range(num_test_dataset):
            val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
            val_loader_test = torch.utils.data.DataLoader(
                val_dataset_test,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
            ta_bal = validate(val_loader_test, full_model, criterion, fc_num = current_state+1, if_main= True)
            ta_rand = validate(val_loader_test, full_model, criterion, fc_num = current_state+1, if_main= False)

            print('TA_balance = %.2f'%ta_bal, 'TA_random = %.2f'%ta_rand)
            test_result[0,test_iter] = ta_rand
            test_result[1,test_iter] = ta_bal
            log_acc[test_iter] = ta_bal
        full_model_mean_acc = np.mean(test_result[1,:])
        log_acc[-1] = full_model_mean_acc
        log_result.append(name_list)
        log_result.append(log_acc)
        print('******************************************************')
        print('full_model_mean_acc = ', full_model_mean_acc)
        print('******************************************************')
        overall_result = pickle.load(open(os.path.join(args.save_dir, 'all_result.pkl'),'rb'))
        overall_result['task'+str(current_state+1)+'_full'] = test_result
        pickle.dump(overall_result, open(os.path.join(args.save_dir, 'all_result.pkl'),'wb'))


        log_result.append(['*'*50])
        log_result.append(['prune model acc state{}'.format(current_state+1)])
        log_acc = ['None' for i in range(all_states+1)]

        test_result = np.zeros((2,num_test_dataset))
        for test_iter in range(num_test_dataset):
            val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
            val_loader_test = torch.utils.data.DataLoader(
                val_dataset_test,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
            ta_bal = validate(val_loader_test, prun_model, criterion, fc_num = current_state+1, if_main= True)
            ta_rand = validate(val_loader_test, prun_model, criterion, fc_num = current_state+1, if_main= False)

            print('TA_balance = %.2f'%ta_bal, 'TA_random = %.2f'%ta_rand)
            test_result[0,test_iter] = ta_rand
            test_result[1,test_iter] = ta_bal
            log_acc[test_iter] = ta_bal
        prun_model_mean_acc = np.mean(test_result[1,:])
        log_acc[-1] = prun_model_mean_acc
        log_result.append(name_list)
        log_result.append(log_acc)
        print('******************************************************')
        print('prun_model_mean_acc = ', prun_model_mean_acc)
        print('******************************************************')
        overall_result = pickle.load(open(os.path.join(args.save_dir, 'all_result.pkl'),'rb'))
        overall_result['task'+str(current_state+1)+'_prun'] = test_result
        pickle.dump(overall_result, open(os.path.join(args.save_dir, 'all_result.pkl'),'wb'))

        if current_state < all_states-1:
            #using best full model to generate data for next task 
            old_dataset =  Labeled_dataset(old_img_path, val_trans, sequence[:(current_state+1)*class_per_state], offset=0)
            k15_dataset = k150_dataset(unlabel_path, val_trans)
            old_loader = torch.utils.data.DataLoader(
                old_dataset,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            k15_loader = torch.utils.data.DataLoader(
                k15_dataset,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            all_feature = feature_extract_old(k15_loader, full_model, criterion)
            target_feature = feature_extract_old(old_loader, full_model, criterion)

            index_select = select_knn(target_feature, all_feature, args.unlabel_num)
            np.save(os.path.join(args.save_data_path,'task'+str(current_state+1)+'_select_index.npy'), np.array(index_select))
            all_data = pickle.load(open(unlabel_path,'rb'))
            all_image = all_data['data']
            all_label = all_data['label']
            new_data = {}
            new_data['data'] = all_image[index_select,:,:,:]
            new_data['label'] = all_label[index_select]
            print(new_data['data'].shape)
            print(new_data['label'].shape)
            pickle.dump(new_data, open(os.path.join(args.save_data_path,'selected_unlabel_task'+str(current_state+1)+'.pkl'), 'wb'))
            
            extract_dataset = k150_dataset(os.path.join(args.save_data_path,'selected_unlabel_task'+str(current_state+1)+'.pkl'), train_trans)
            extract_loader = torch.utils.data.DataLoader(
                extract_dataset,
                batch_size=256, shuffle=True,
                num_workers=2, pin_memory=True)
            label_extract(extract_loader, full_model, criterion, args.save_data_path, fc_num=current_state+1)


        if full_model_mean_acc-prun_model_mean_acc < args.accept_decay:
            print('current prun model is ok!', current_state+1)
            continue
        
        else:
            print('need to re_prune from full model', current_state+1)
            model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
            model.cuda()            
            model_start_dict = torch.load(os.path.join(args.save_dir, 'task'+str(current_state+1)+'_best_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
            model.load_state_dict(model_start_dict['state_dict'])

            overall_result = pickle.load(open(os.path.join(args.save_dir, 'all_result.pkl'),'rb'))
            baseline_result = overall_result['task'+str(current_state+1)+'_full']
            baseline_acc = np.mean(baseline_result[1,:])
            print('baseline acc(balance branch) = ', baseline_acc)

            current_mask = torch.load(os.path.join(args.save_dir, 'current_mask.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
            reverse_current_mask = reverse_mask(current_mask)

            # pruning during fix
            print('starting pruning for task', current_state+1)
            prune_model_custom(model, reverse_current_mask)
            check_sparsity(model)
            pruning_model(model, args.percent)
            check_sparsity(model)     

            new_mask = extract_mask(model.state_dict())
            update_mask = concat_mask(new_mask, current_mask)
            model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
            model_start_dict = torch.load(os.path.join(args.save_dir, 'task'+str(current_state+1)+'_best_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))
            model.load_state_dict(model_start_dict['state_dict'])
            model.cuda()
            prune_model_custom(model, update_mask)
            check_sparsity(model)   

            check_mask(current_mask, model.state_dict())  

            # maybe share memory
            save_model_dict =  deepcopy(model.state_dict())

            acc_decay = 0
            pruning_times = 1
            zero_rate = 0

            log_result.append(['*'*50])
            log_result.append(['Pruning record'])
            log_result.append(['pruning_times', 'best_ta', 'baseline', 'zero_rate'])     


            while(acc_decay < args.deacc or zero_rate < args.base_sparsity):

                best_test_acc = 0
                optimizer = torch.optim.SGD(model.parameters(), args.lr/100,
                                            momentum=args.momentum,
                                            weight_decay=args.weight_decay)

                for epoch in range(args.iter_epochs):
                    print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))
                    train_accuracy = train_lwf(train_loader_random, train_loader_balance_new, train_loader_balance_old, unlabel_loader, model, criterion, optimizer, epoch, current_state+1)
                    test_acc = validate(test_loader, model, criterion, fc_num=current_state+1, if_main=True)

                    best_test_acc = max(best_test_acc, test_acc)

                acc_decay = baseline_acc - best_test_acc
                print('********************************************')
                print('pruning_times, best_ta, baseline, zero_rate')
                print(pruning_times, best_test_acc, baseline_acc, zero_rate)
                print('********************************************')
                
                log_result.append([pruning_times, best_test_acc, baseline_acc, zero_rate])

                if acc_decay > args.deacc and zero_rate > args.base_sparsity:
                    break
                else:
                    # maybe share memory
                    save_model_dict =  deepcopy(model.state_dict())

                    last_model_dict = deepcopy(model.state_dict())
                    new_weight = extract_weight(last_model_dict)
                    no_orig_new_weight = reverse_rewind(new_weight)
                    model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
                    model.load_state_dict(no_orig_new_weight)
                    model.cuda()
                    prune_model_custom(model, new_mask)
                    check_sparsity(model)
                    pruning_model(model,args.percent)
                    check_sparsity(model)   
                    new_mask = extract_mask(model.state_dict())
                    update_mask = concat_mask(new_mask, current_mask)

                    model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
                    model.load_state_dict(no_orig_new_weight)
                    model.cuda()
                    prune_model_custom(model, update_mask)
                    zero_rate = check_sparsity(model)   
                    check_mask(current_mask, model.state_dict())        
                    pruning_times+=1

                if pruning_times > args.max_iter_prun:
                    break     

            torch.save(save_model_dict, os.path.join(args.save_dir, 'task'+str(current_state+1)+'_prune_weight.pt'))
            mask_dict = extract_mask(save_model_dict)
            print('*************current mask*******************')
            check_sparsity_mask(mask_dict)
            print('********************************************')
            torch.save(mask_dict, os.path.join(args.save_dir, 'current_mask.pt'))


            #re-train
            model_path = os.path.join(args.save_dir, 'task'+str(current_state)+'_best_model.pt')
            new_dict = torch.load(model_path, map_location=torch.device('cuda:'+str(args.gpu)))['state_dict'] 

            prun_model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
            prun_model.load_state_dict(new_dict)
            prun_model.cuda()
            current_mask = torch.load(os.path.join(args.save_dir, 'current_mask.pt'), map_location=torch.device('cuda:'+str(args.gpu)))

            prune_model_custom(prun_model, current_mask)
            remain_weight = check_sparsity(prun_model)

            
            # training mode 
            train_dataset = Labeled_dataset(train_path, train_trans, sequence[current_state*class_per_state:(current_state+1)*class_per_state], offset=current_state*class_per_state)
            val_dataset = Labeled_dataset(val_path, val_trans, sequence[:(current_state+1)*class_per_state], offset=0)

            train_old_dataset = Labeled_dataset(old_img_path, train_trans, sequence[:current_state*class_per_state], offset=0)
            train_random_dataset = torch.utils.data.dataset.ConcatDataset((train_dataset,train_old_dataset))

            unlabel_dataset = unlabel_feature_dataset(os.path.join(args.save_data_path,'task'+str(current_state)))
            test_dataset = Labeled_dataset(test_path , val_trans, sequence[:(current_state+1)*class_per_state], offset=0)

            train_loader_random = torch.utils.data.DataLoader(
                train_random_dataset,
                batch_size=64, shuffle=True,
                num_workers=2, pin_memory=True)

            train_loader_balance_new = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=int(64/(1+current_state)), shuffle=True,
                num_workers=2, pin_memory=True)

            train_loader_balance_old = torch.utils.data.DataLoader(
                train_old_dataset,
                batch_size=int(64*current_state/(1+current_state)), shuffle=True,
                num_workers=2, pin_memory=True)

            unlabel_loader = torch.utils.data.DataLoader(
                unlabel_dataset,
                batch_size=64, shuffle=True,
                num_workers=2, pin_memory=True)

            val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            test_loader = torch.utils.data.DataLoader(
                test_dataset,
                batch_size=256, shuffle=False,
                num_workers=2, pin_memory=True)

            optimizer = torch.optim.SGD(prun_model.parameters(), args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, gamma=0.1)

            best_prec1 = 0
            train_acc = []
            ta=[]
            for epoch in range(args.epochs):
                print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))

                train_accuracy = train_lwf(train_loader_random, train_loader_balance_new, train_loader_balance_old, unlabel_loader, prun_model, criterion, optimizer, epoch, current_state+1)

                prec1 = validate(val_loader, prun_model, criterion, fc_num=current_state+1, if_main=True)

                train_acc.append(train_accuracy)
                ta.append(prec1)

                scheduler.step()

                # remember best prec@1 and save checkpoint
                is_best = prec1 > best_prec1
                best_prec1 = max(prec1, best_prec1)

                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': prun_model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer,
                }, is_best, args.save_dir, filename='re_task{}_prun_checkpoint.pt'.format(current_state+1), best_name='re_task{}_best_prun_model.pt'.format(current_state+1))

            prun_model.load_state_dict(torch.load(os.path.join(args.save_dir, 're_task'+str(current_state+1)+'_best_prun_model.pt'), map_location=torch.device('cuda:'+str(args.gpu)))['state_dict'])
            
            remain_weight = check_sparsity(model)
            
            log_result.append(['*'*50])
            log_result.append(['remain weight size = {}'.format(100-remain_weight)])
            log_result.append(['re_prune model acc state{}'.format(current_state+1)])
            log_acc = ['None' for i in range(all_states+1)]

            test_result = np.zeros((2,num_test_dataset))
            for test_iter in range(num_test_dataset):
                val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
                val_loader_test = torch.utils.data.DataLoader(
                    val_dataset_test,
                    batch_size=256, shuffle=False,
                    num_workers=2, pin_memory=True)

                print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
                ta_bal = validate(val_loader_test, prun_model, criterion, fc_num = current_state+1, if_main= True)
                ta_rand = validate(val_loader_test, prun_model, criterion, fc_num = current_state+1, if_main= False)

                print('TA_balance = %.2f'%ta_bal, 'TA_random = %.2f'%ta_rand)
                test_result[0,test_iter] = ta_rand
                test_result[1,test_iter] = ta_bal
                log_acc[test_iter] = ta_bal
            prun_model_mean_acc = np.mean(test_result[1,:])
            log_acc[-1] = prun_model_mean_acc
            log_result.append(['*'*50])
            log_result.append(['re-train task{} result'.format(current_state+1)])
            log_result.append(['remain weight size = {}'.format(100-remain_weight)])
            log_result.append(name_list)
            log_result.append(log_acc)
            print('******************************************************')
            print('re_prun_model_mean_acc = ', prun_model_mean_acc)
            print('******************************************************')



def train(train_loader, model, criterion, optimizer, epoch):
    
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    for i, (input, target) in enumerate(train_loader):


        input = input.cuda()
        target = target.long().cuda()

        optimizer.zero_grad()

        input_data_main = {'x': input, 'out_idx':1, 'main_fc': True}
        input_data = {'x': input, 'out_idx':1, 'main_fc': False}

        output_gt = model(**input_data)
        loss_rand = criterion(output_gt, target)

        output_gt_main = model(**input_data_main)
        loss_balance = criterion(output_gt_main, target)

        loss = loss_rand+loss_balance

        loss.backward()
        optimizer.step()

        output = output_gt_main.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                    epoch, i, len(train_loader), loss=losses, top1=top1))

    print('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg

def validate(val_loader, model, criterion, fc_num=1, if_main=False):

    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        input = input.cuda()
        target = target.long().cuda()

        input_data = {'x': input, 'out_idx':fc_num, 'main_fc': if_main}        

        # compute output
        with torch.no_grad():
            output = model(**input_data)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                      i, len(val_loader), loss=losses, top1=top1))

    print('valid_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg

def train_lwf(rand_loader, new_balance_loader, old_balance_loader, unlabel_loader, model, criterion, optimizer, epoch, fc_num):
    
    losses = AverageMeter()
    top1 = AverageMeter()

    coef_old = int(64*(fc_num-1)/fc_num)/64
    coef_new = int(64/fc_num)/64

    # switch to train mode
    model.train()

    new_balance = iter(new_balance_loader)
    old_balance = iter(old_balance_loader)
    unlabel = iter(unlabel_loader)

    for i, (input, target) in enumerate(rand_loader):

        try:
            bal_new_img, bal_new_target = next(new_balance)
        except StopIteration:
            new_balance = iter(new_balance_loader)
            bal_new_img, bal_new_target = next(new_balance)

        try:
            bal_old_img, bal_old_target = next(old_balance)
        except StopIteration:
            old_balance = iter(old_balance_loader)
            bal_old_img, bal_old_target = next(old_balance)

        try:
            unlab_img, unlab_target, unlab_target_main = next(unlabel)
        except StopIteration:
            unlabel = iter(unlabel_loader)
            unlab_img, unlab_target, unlab_target_main = next(unlabel)
        
        bal_new_img = bal_new_img.cuda()
        bal_old_img = bal_old_img.cuda()
        unlab_img = unlab_img.cuda()
        input = input.cuda()

        bal_new_target = bal_new_target.long().cuda()
        bal_old_target = bal_old_target.long().cuda()
        target = target.long().cuda()

        unlab_target = unlab_target.cuda()
        unlab_target_main = unlab_target_main.cuda()

        inputs_random = {'x': input, 'out_idx': fc_num, 'main_fc': False}
        inputs_balance_new = {'x': bal_new_img, 'out_idx': fc_num, 'main_fc': True}
        inputs_balance_old = {'x': bal_old_img, 'out_idx': fc_num, 'main_fc': True}
        inputs_unlabel_random = {'x': unlab_img, 'out_idx': fc_num-1, 'main_fc': False}
        inputs_unlabel_balance = {'x': unlab_img, 'out_idx': fc_num-1, 'main_fc': True}
        
        optimizer.zero_grad()

        # random input
        output_gt = model(**inputs_random)
        loss_rand = criterion(output_gt, target)

        # balance inputs
        output_bal_new = model(**inputs_balance_new)
        output_bal_old = model(**inputs_balance_old)
        loss_balance = criterion(output_bal_new, bal_new_target)*coef_new + criterion(output_bal_old, bal_old_target)*coef_old

        #unlabel 
        unlab_output_rand = model(**inputs_unlabel_random)
        loss_unlabel_rand = loss_fn_kd(unlab_output_rand, unlab_target, T=2) 
        unlab_output_bal = model(**inputs_unlabel_balance)
        loss_unlabel_balance = loss_fn_kd(unlab_output_bal, unlab_target_main, T=2) 

        all_rand_loss = loss_rand + loss_unlabel_rand
        all_bal_loss = loss_balance + loss_unlabel_balance
        
        loss = all_rand_loss + all_bal_loss*0.5

        loss.backward()
        optimizer.step()

        output = output_gt.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(rand_loader), loss=losses, top1=top1))

    print('train_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg


if __name__ == '__main__':
    main()