from sparselearning.pruning_utils import *
# from utils import *
import torch
import os 
import numpy as np 
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
from models.preactivate_resnet import *
from models.vgg import *
from models.wideresnet import *
import torch.nn as nn
import argparse
from utils import setup_seed

def prune_loop(model, loss, pruner, dataloader, device, sparsity, scope, epochs, schedule = 'exponential', train_mode=False):

    # Set model to train or eval mode
    model.train()
    if not train_mode:
        model.eval()

    # Prune model
    for epoch in range(epochs):
        pruner.score(model, loss, dataloader, device)
        if schedule == 'exponential':
            sparse = sparsity**((epoch + 1) / epochs)
        elif schedule == 'linear':
            sparse = 1.0 - (1.0 - sparsity)*((epoch + 1) / epochs)
        pruner.mask(sparse, scope)

def getPruneDataloader(batch_size, length):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    data_dir = 'data/cifar10'
    train_set = Subset(CIFAR10(data_dir, train=True, transform=train_transform, download=True), list(range(45000)))

    # Dataloader
    if length is not None:
        indices = torch.randperm(len(train_set))[:length]
        dataset = torch.utils.data.Subset(train_set, indices)

    dataloader = torch.utils.data.DataLoader(dataset=dataset, 
                                             batch_size=batch_size, 
                                             shuffle=True, num_workers=2, pin_memory=True)

    return dataloader

def save_mask(model, sparsity, pruner, iteration_number, save_flag, seed, train_mode = True, class_num = 10):
    print(save_flag)
    prune_loader = getPruneDataloader(batch_size = 128, length = class_num * 10)
    criterion = nn.CrossEntropyLoss()

    save_dir = "./omp_mask"
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device('cuda:{}'.format(0))
    # model.cuda()

    prune_loop(model, criterion, pruner, prune_loader, device,
            sparsity, scope='global', epochs=iteration_number, train_mode = train_mode)  
    
    pruner.apply_mask()

    print('sparsity = {}'.format(sparsity))
    check_sparsity(model) 
    check_sparsity_layers(model)
    current_mask = extract_mask(model)
    check_sparsity_dict(current_mask)
    torch.save(current_mask, os.path.join(save_dir, '{0}-{1}-seed{2}-mask.pt'.format(save_flag, sparsity, seed)))

parser = argparse.ArgumentParser(description='Init Sparse Training Mask')
parser.add_argument('--sparsity', type=float, help='the density of mask', required=True)
parser.add_argument('--seed', type=int, help='random seed', required= True)

if __name__ == '__main__':
    args = parser.parse_args()
    sparsity = args.sparsity
    setup_seed(args.seed)

    class_num = 100

    #RP, OMP, GMP, TP, SNIP,
    # 'rand' : Rand,
    # 'mag' : Mag,
    # 'snip' : SNIP,
    # 'grasp': GraSP,
    # 'synflow' : SynFlow,
    # 'taylor1scorerabs':Taylor1ScorerAbs,

    #model, sparsity, pruner, iteration_number, save_flag
    '''
    ####################### rp ###########################
    model = ResNet18(num_classes = class_num)
    # model = WideResNet(34, class_num, 10, dropRate=0.0)

    prune_conv_linear(model)
    pruner_type = "rand"
    device = torch.device('cuda:{}'.format(0))
    model.to(device)
    pruner = get_pruner(pruner_type)(masked_parameters(model)) 
    iteration_number = 1
    save_flag = "rp"

    save_mask(model, sparsity, pruner, iteration_number, save_flag, args.seed)
    
    ####################### snip ###########################
    model = vgg16_bn(num_classes = class_num)
    # model = WideResNet(34, class_num, 10, dropRate=0.0)

    prune_conv_linear(model)
    # for key in model.state_dict():
    #     print(key, " : ", model.state_dict()[key].shape)

    pruner_type = "snip"
    device = torch.device('cuda:{}'.format(0))
    model.to(device)
    pruner = get_pruner(pruner_type)(masked_parameters(model)) 
    iteration_number = 1
    save_flag = "snip"

    save_mask(model, sparsity, pruner, iteration_number, save_flag, args.seed)

    

    ####################### Grasp ##############################
    model = vgg16_bn(num_classes = class_num)
    # model = WideResNet(34, class_num, 10, dropRate=0.0)

    prune_conv_linear(model)
    # for key in model.state_dict():
    #     print(key, " : ", model.state_dict()[key].shape)

    pruner_type = "grasp"
    device = torch.device('cuda:{}'.format(0))
    model.to(device)
    pruner = get_pruner(pruner_type)(masked_parameters(model)) 
    iteration_number = 1
    save_flag = "grasp"

    save_mask(model, sparsity, pruner, iteration_number, save_flag, args.seed)
    
    ####################### SynFlow ##############################
    # model = vgg16_bn(num_classes = class_num)
    model = ResNet18(num_classes = class_num)

    # model = WideResNet(34, class_num, 10, dropRate=0.0)

    prune_conv_linear(model)
    # for key in model.state_dict():
    #     print(key, " : ", model.state_dict()[key].shape)

    pruner_type = "synflow"
    device = torch.device('cuda:{}'.format(0))
    model.to(device)
    pruner = get_pruner(pruner_type)(masked_parameters(model)) 
    iteration_number = 100
    save_flag = "synflow"

    save_mask(model, sparsity, pruner, iteration_number, save_flag, args.seed, train_mode = False)

    '''

    ####################### omp ###########################
    model = ResNet18(num_classes = class_num)
    prune_conv_linear(model)
    checkpoint = torch.load(os.path.join(r'model_dense/dense_resnet18_cifar100_b128_ce05', 'checkpoint.pth.tar'), map_location = torch.device('cuda:0'))
    model.load_state_dict(checkpoint['state_dict'], strict = False)
    pruner_type = "mag"
    device = torch.device('cuda:{}'.format(0))
    model.to(device)
    pruner = get_pruner(pruner_type)(masked_parameters(model))  # args.pruner in [Mag, SynFlow, Taylor1ScorerAbs, Rand, SNIP, GraSP]
    iteration_number = 1
    save_flag = "resnet18-100-omp"

    save_mask(model, sparsity, pruner, iteration_number, save_flag, args.seed, train_mode = False, class_num = class_num)

'''
    ####################### tp ###########################
    model = vgg16_bn(num_classes = class_num)
    prune_conv_linear(model)
    #tp 使用训练的dense model
    checkpoint = torch.load(os.path.join(r'/home/pengjun/code_workspace/dst-adv/new_res_dense/vgg16_bn_cifar10_dense_b128_9fc4', 'checkpoint.pth.tar'), map_location = torch.device('cuda:0'))
    model.load_state_dict(checkpoint['state_dict'], strict = False)
    pruner_type = "taylor1scorerabs"
    device = torch.device('cuda:{}'.format(0))
    model.to(device)
    pruner = get_pruner(pruner_type)(masked_parameters(model)) 
    iteration_number = 1
    save_flag = "tp"

    save_mask(model, sparsity, pruner, iteration_number, save_flag)
    
'''
 

    