'''
load lottery tickets and evaluation 
support datasets: cifar10, Fashionmnist, cifar100
'''

import os
import time 
import random
import shutil
import argparse
import numpy as np  
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.optim
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import torchvision.models as models
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import SubsetRandomSampler
from advertorch.utils import NormalizeByChannelMeanStd

from utils import *
from pruning_utils_2 import *
from pruning_utils_unprune import *
from pruning_utils import prune_model_custom_fillback
parser = argparse.ArgumentParser(description='PyTorch Evaluation Tickets')

##################################### general setting #################################################
parser.add_argument('--data', type=str, default='../../data', help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='cifar10', help='dataset')
parser.add_argument('--arch', type=str, default='res18', help='model architecture')
parser.add_argument('--seed', default=None, type=int, help='random seed')
parser.add_argument('--save_dir', help='The directory used to save the trained models', default=None, type=str)
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--save_model', action="store_true", help="whether saving model")

##################################### training setting #################################################
parser.add_argument('--optim', type=str, default='sgd', help='optimizer')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')
parser.add_argument('--epochs', default=182, type=int, help='number of total epochs to run')
parser.add_argument('--warmup', default=0, type=int, help='warm up epochs')
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--decreasing_lr', default='91,136', help='decreasing strategy')

##################################### Pruning setting #################################################
parser.add_argument('--pretrained', default=None, type=str, help='pretrained weight for pt')
parser.add_argument('--mask_dir', default=None, type=str, help='mask direction for ticket')
parser.add_argument('--conv1', action="store_true", help="whether pruning&rewind conv1")
parser.add_argument('--fc', action="store_true", help="whether rewind fc")

parser.add_argument('--type', type=str, default=None, choices=['ewp', 'random_path', 'betweenness', 'hessian_abs', 'taylor1_abs','intgrads','identity', 'omp'])
parser.add_argument('--add-back', action="store_true", help="add back weights")
parser.add_argument('--prune-type', type=str, choices=["lt", 'pt', 'st', 'mt', 'trained', 'transfer'])
parser.add_argument('--num-paths', default=50000, type=int)
parser.add_argument('--evaluate', action="store_true")
parser.add_argument('--evaluate-p', type=float, default=0.00)
parser.add_argument('--evaluate-random', action="store_true")
parser.add_argument('--evaluate-full', action="store_true")

parser.add_argument('--checkpoint', type=str)
parser.add_argument('--fillback-rate', type=float)



best_sa = 0


def prune_model_custom_fillback_time(model, mask_dict, conv1=False, criteria="remain", train_loader=None, fillback_rate = 0.0):

    feature_maps = []
    channels = []
    def hook(module, input, output):
        feature_maps.append(output.detach().cpu())
    
    image, label = next(iter(train_loader))
    handles = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            if (name == 'conv1' and conv1) or (name != 'conv1'):
                handles.append(m.register_forward_hook(hook))
                device = m.weight.data.device
    model(image.to(device))
    counter = 0
    for i, (name,m) in enumerate(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            if (name == 'conv1' and conv1) or (name != 'conv1'):
                mask = mask_dict[name+'.weight_mask']
                mask = mask.view(mask.shape[0], -1)
                count = torch.sum(mask != 0, 1) # [C]
                #sparsity = torch.sum(mask) / mask.numel()
                num_channel = count.sum().float() / mask.shape[1]
                num_channel = num_channel + (mask.shape[0] - num_channel) * fillback_rate
                print(num_channel)
                int_channel = int(num_channel)
                frac_channel = num_channel - int_channel
                channels.append(int(num_channel) + 1)
                
                if criteria == 'remain':
                    print(mask.shape[0] - int_channel)
                    threshold, _ = torch.kthvalue(count, max(mask.shape[0] - int_channel, 1))
                
                    mask[torch.where(count >= threshold)[0]] = 1
                    mask[torch.where(count < threshold)[0]] = 0
                    
                elif criteria == 'magnitude':
                    mask = mask_dict[name+'.weight_mask']
                    count = m.weight.data.view(mask.shape[0], -1).abs().sum(1)
                    threshold, _ = torch.kthvalue(count, mask.shape[0] - int_channel)
                
                    mask[torch.where(count > threshold)[0]] = 1
                    mask[torch.where(count < threshold)[0]] = 0
                    mask[torch.where(count == threshold)[0],:int(frac_channel * mask.shape[1])] = 1
                    mask[torch.where(count == threshold)[0],int(frac_channel * mask.shape[1]):] = 0
                
                elif criteria == 'l1':
                    mask = mask_dict[name+'.weight_mask']
                    count = feature_maps[counter].view(mask.shape[0], -1).abs().sum(1)
                    threshold, _ = torch.kthvalue(count, mask.shape[0] - int_channel)
                
                    mask[torch.where(count > threshold)[0]] = 1
                    mask[torch.where(count < threshold)[0]] = 0
                    mask[torch.where(count == threshold)[0],:int(frac_channel * mask.shape[1])] = 1
                    mask[torch.where(count == threshold)[0],int(frac_channel * mask.shape[1]):] = 0
                    counter += 1

                elif criteria == 'l2':
                    mask = mask_dict[name+'.weight_mask']
                    count = (feature_maps[counter].view(mask.shape[0], -1).abs() ** 2).sum(1)
                    threshold, _ = torch.kthvalue(count, mask.shape[0] - int_channel)
                
                    mask[torch.where(count > threshold)[0]] = 1
                    mask[torch.where(count < threshold)[0]] = 0
                    mask[torch.where(count == threshold)[0],:int(frac_channel * mask.shape[1])] = 1
                    mask[torch.where(count == threshold)[0],int(frac_channel * mask.shape[1]):] = 0
                    counter += 1
                
                mask = mask.view(*mask_dict[name+'.weight_mask'].shape)
                print('pruning layer with custom mask:', name)
                #prune.CustomFromMask.apply(m, 'weight', mask=mask.to(m.weight.device))

    for h in handles:
        h.remove()
    return channels

def main():
    global args, best_sa
    args = parser.parse_args()
    args.use_sparse_conv = False
    args.batch_size=32
    print(args)

    print('*'*50)
    print('conv1 included for prune and rewind: {}'.format(args.conv1))
    print('fc included for rewind: {}'.format(args.fc))
    print('*'*50)

    torch.cuda.set_device(int(args.gpu))
    os.makedirs(args.save_dir, exist_ok=True)
    if args.seed:
        setup_seed(args.seed)

    # prepare dataset 
    model, train_loader, val_loader, test_loader = setup_model_dataset(args)
    from models.resnet50_cfg import resnet50_official
    model = resnet50_official(num_classes=200,imagenet=True)
    criterion = nn.CrossEntropyLoss()
    try:
        state_dict = torch.load(args.checkpoint, map_location="cpu")['state_dict']
    except:
        state_dict = torch.load(args.checkpoint, map_location="cpu")

    start = time.time()
    current_mask = extract_mask(state_dict)
    print(current_mask.keys())
    
    from models.conv import SparseConv2D
    new_model = copy.deepcopy(model)
    combined_state_dict = {}
    for key in state_dict:
        if key in current_mask:
            combined_state_dict[key[:-5]] = current_mask[key] * state_dict[key[:-5] + "_orig"]
            print((combined_state_dict[key[:-5]].abs() != 0).float().sum() / combined_state_dict[key[:-5]].numel())
        elif not 'orig' in key:
            if key.startswith("module."):
                new_key = key[7:]
            else:
                new_key = key
            combined_state_dict[new_key] = state_dict[key]
            if 'conv' in key:
                print((combined_state_dict[new_key].abs() == 0).float().sum() / combined_state_dict[new_key].numel())

    new_model.load_state_dict(combined_state_dict, strict=False)
    
    from models.conv import SparseConv2D
    
    def replace_conv(m, name):
        print(name)
        for attr_str, _ in m.named_children():
            print(attr_str)
            target_attr = getattr(m, attr_str)
            if isinstance(target_attr, nn.Conv2d):
                record = copy.deepcopy(getattr(m, attr_str))
                new_conv = SparseConv2D(target_attr.weight.shape[1], target_attr.weight.shape[0], target_attr.weight.shape[2], target_attr.stride, target_attr.padding, target_attr.dilation, False)
                flag = new_conv.load(record.weight, None)
                flag = 1
                if (flag > 0):
                    setattr(m, attr_str, new_conv)
                    print(f"DENSE BLOCKS GREATER THAN 0 in {name}")
                else:
                    print(f"NO DENSE BLOCK WAS FOUND in {name}")

            replace_conv(_, attr_str)
            
    replace_conv(new_model, "new_model")

    new_model.cuda()
    new_model.eval()
    times = []

    import torchprof
    x = torch.randn((32, 3, 64, 64)).cuda()
    new_model(x)
    with torchprof.Profile(new_model, use_cuda=True, profile_memory=True) as prof:
        for _ in range(100):
            new_model(x)
    
    print(prof.display(show_events=False))
    print(prof.display(show_events=True))


def validate(val_loader, model, criterion):
    """
    Run evaluation
    """
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()
    run_t = 0
    for i, (image, target) in enumerate(val_loader):
        
        image = image.cuda()
        target = target.cuda()
        
        # compute output
        
        output = model(image)
        break
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        top1.update(prec1.item(), image.size(0))

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i, len(val_loader), top1=top1))

def save_checkpoint(state, is_SA_best, save_path, filename='checkpoint.pth.tar'):
    filepath = os.path.join(save_path, filename)
    torch.save(state, filepath)
    if is_SA_best:
        shutil.copyfile(filepath, os.path.join(save_path, 'model_SA_best.pth.tar'))

def load_ticket(model, args):
    # weight 
    if args.pretrained:

        initalization = torch.load(args.pretrained, map_location = torch.device('cuda:'+str(args.gpu)))
        
        if 'init_weight' in initalization.keys():
            print('loading from init_weight')
            initalization = initalization['init_weight']
        elif 'state_dict' in initalization.keys():
            print('loading from state_dict')
            initalization = initalization['state_dict']
        
        loading_weight = extract_main_weight(initalization, fc=True, conv1=True)
        new_initialization = model.state_dict()
        if not 'normalize.std' in loading_weight:
            loading_weight['normalize.std'] = new_initialization['normalize.std']
            loading_weight['normalize.mean'] = new_initialization['normalize.mean']

        if not (args.prune_type == 'lt' or args.prune_type == 'trained'):
            keys = list(loading_weight.keys()) 
            for key in keys:
                if key.startswith('fc') or key.startswith('conv1'):
                    del loading_weight[key]

            loading_weight['fc.weight'] = new_initialization['fc.weight']
            loading_weight['fc.bias'] = new_initialization['fc.bias']
            loading_weight['conv1.weight'] = new_initialization['conv1.weight']

        print('*number of loading weight={}'.format(len(loading_weight.keys())))
        print('*number of model weight={}'.format(len(model.state_dict().keys())))
        model.load_state_dict(loading_weight)
        


    # mask 
    if args.mask_dir:
        print('loading mask')
        current_mask_weight = torch.load(args.mask_dir, map_location = torch.device('cuda:'+str(args.gpu)))
        if 'state_dict' in current_mask_weight.keys():
            current_mask_weight = current_mask_weight['state_dict']
        current_mask = extract_mask(current_mask_weight)
        #check_sparsity(model, conv1=args.conv1)
        if args.arch == 'res18':
            downsample = 100
        else:
            downsample = 1000
        
        prune_model_custom(model, current_mask)
        check_sparsity(model, conv1=args.conv1)

def warmup_lr(epoch, step, optimizer, one_epoch_step):

    overall_steps = args.warmup*one_epoch_step
    current_steps = epoch*one_epoch_step + step 

    lr = args.lr * current_steps/overall_steps
    lr = min(lr, args.lr)

    for p in optimizer.param_groups:
        p['lr']=lr

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def setup_seed(seed): 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed) 
    random.seed(seed) 
    torch.backends.cudnn.deterministic = True 

if __name__ == '__main__':
    main()


