from __future__ import print_function

import argparse
import socket

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

from torchvision import datasets, transforms


from models import model_dict
from dataset.cifar100 import get_cifar100_dataloaders
from dataset.cifar10 import get_cifar10_dataloaders

from helper.loops import validate


def parse_option():
    hostname = socket.gethostname()

    parser = argparse.ArgumentParser('argument for validation')

    parser.add_argument('--print_freq', type=int, default=100, help='print frequency')
    parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
    parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
    # dataset
    parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'cifar10'], help='dataset')
    parser.add_argument('--data_dir', type=str, help='path to data directory')
    parser.add_argument('--model', type=str, choices=['CustomResNet18', 'Custom_wrn_22_8'])
    parser.add_argument('--path', type=str, default=None, help='checkpoint path')
    
    parser.add_argument('--width_mult_list', '-wml', default=[0.5, 1.0], type=float, help='supporting width multiplier values')
    opt = parser.parse_args()

    return opt

def load_model(model_path, n_cls, opt):
    print('==> loading model')
    print(opt.model)
    model_name = opt.model
    model = model_dict[model_name](num_classes=n_cls)
    model.load_state_dict(torch.load(model_path)['model'])
    print('==> done')
    return model, torch.load(model_path)['mask_epoch']

def cal_mask_relu(mask_list):
    total_relu = 0

    for i in range(len(mask_list)):
        total_relu += mask_list[i].sum()
    print('====================================================================')
    print('Total_relu:', float(total_relu))

def main():
    opt = parse_option()
    opt.od_training = True
    # dataloader
    if opt.dataset == 'cifar100':
        train_loader, val_loader, n_data = get_cifar100_dataloaders(data_folder=opt.data_dir,
                                                                    batch_size=opt.batch_size,
                                                                    num_workers=opt.num_workers,
                                                                    is_instance=True)
        n_cls = 100
    
    elif opt.dataset == 'cifar10':
        train_loader, val_loader, n_data = get_cifar10_dataloaders(data_folder=opt.data_dir,
                                                                    batch_size=opt.batch_size,
                                                                    num_workers=opt.num_workers,
                                                                    is_instance=True)
        n_cls = 10

    else:
        raise NotImplementedError(opt.dataset)

    model, mask_epoch = load_model(opt.path, n_cls, opt)
    
    criterion_cls = nn.CrossEntropyLoss()
    
    if torch.cuda.is_available():
        model.cuda()
        criterion_cls.cuda()
        cudnn.benchmark = True

    cal_mask_relu(mask_epoch)
    test_acc, _, _, mask_epoch = validate(val_loader, model, criterion_cls, opt, mask_epoch)
    print('==========================================================')
    print('The acc of {} is:{}'.format(opt.model, str(float(test_acc))))
    
if __name__ == '__main__':
    main()
