#main Top-Down pruning

import os
import torch
import random 
import pickle
import argparse
import numpy as np 
import torch.optim
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from model import PreActResNet18 as ResNet18  
from utils import *

parser = argparse.ArgumentParser(description='PyTorch Cifar10_100 CIL Top-Down pruning')

#################### base setting #########################
parser.add_argument('--data_dir', help='The directory for data', default='CIL_data', type=str)
parser.add_argument('--dataset', type=str, default='cifar10', help='default dataset')
parser.add_argument('--pretrained', help='pretrained models', default=None, type=str)
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--pruned', action='store_true', help='whether the checkpoint has been pruned')

################## CIL setting ##################################
parser.add_argument('--classes_per_classifier', type=int, default=2, help='number of classes per classifier')
parser.add_argument('--classifiers', type=int, default=5, help='number of classifiers')


best_prec1 = 0

def main():

    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.dataset == 'cifar10':

        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
        # dataset transform
        train_trans = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
                transforms.ToTensor(),
                normalize
            ])

        val_trans = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                normalize
            ])

        path_head = args.data_dir
        test_path = os.path.join(path_head,'cifar10_test.pkl')

        sequence = [9,8,7,1,5,0,3,4,6,2]

        ## with another random task sequence
        # if os.path.isfile('cifar10_class_order.txt'):
        #     sequence = np.loadtxt('cifar10_class_order.txt')
        # else:
        #     sequence = np.random.permutation(10)
        #     np.savetxt('cifar10_class_order.txt', sequence)
        # print('cifar10 incremental task sequence:', sequence)

    else:
        print('do not support dataset of '+args.dataset)
        assert 0

    all_states = args.classifiers
    class_per_state = args.classes_per_classifier

    torch.cuda.set_device(int(args.gpu))

    #setup logger
    log_result = Logger('test.txt')
    name_list = ['Task{}'.format(i+1) for i in range(all_states)]
    name_list.append('Mean Acc')
    log_result.append(name_list)

    criterion = nn.CrossEntropyLoss()


    model = ResNet18(num_classes_per_classifier=class_per_state, num_classifier=all_states)
    model.cuda()    
    new_dict = torch.load(args.pretrained, map_location=torch.device('cuda:'+str(args.gpu)))

    if args.pruned:
        print('pruning with custom mask')
        current_mask = extract_mask(new_dict)
        prune_model_custom(model, current_mask)

    model.load_state_dict(new_dict)
    remain_weight = check_sparsity(model)

    print('*****************************************************************************')
    print('start testing ')
    print('remain weight size = {}'.format(100-remain_weight))
    print('*****************************************************************************')


    # testing accuracy & generate feature of unlabeled data using original model
    bal_acc = []
    log_acc = ['None' for i in range(all_states+1)]
    
    for test_iter in range(all_states):
        val_dataset_test = Labeled_dataset(test_path, val_trans, sequence[test_iter*class_per_state:(test_iter+1)*class_per_state], offset=test_iter*class_per_state)
        val_loader_test = torch.utils.data.DataLoader(
            val_dataset_test,
            batch_size=256, shuffle=False,
            num_workers=2, pin_memory=True)

        print('dataset'+str(test_iter+1), 'classes = ', sequence[test_iter*class_per_state:(test_iter+1)*class_per_state])
        ta_bal = validate(val_loader_test, model, criterion, fc_num = all_states, if_main= True)

        bal_acc.append(ta_bal)
        log_acc[test_iter] = ta_bal
        print('TA = ', ta_bal)

    mean_acc = np.mean(np.array(bal_acc))
    log_acc[-1] = mean_acc
    print('******************************************************')
    print('mean_acc = ', mean_acc)
    print('******************************************************')
    log_result.append(log_acc)
    log_result.append(['remain weight size = {}'.format(100-remain_weight)])
    log_result.append(['*'*50])


def validate(val_loader, model, criterion, fc_num=1, if_main=False):

    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(val_loader):
        input = input.cuda()
        target = target.long().cuda()

        input_data = {'x': input, 'out_idx':fc_num, 'main_fc': if_main}        

        # compute output
        with torch.no_grad():
            output = model(**input_data)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i, len(val_loader), loss=losses, top1=top1))

    print('valid_accuracy {top1.avg:.3f}'.format(top1=top1))

    return top1.avg


if __name__ == '__main__':
    main()