from __future__ import print_function

import argparse
import os
import random
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from spiking_resnet import *
from snnwrn import *
#from snnresnet6 import *
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
from spikingjelly.clock_driven import functional
from torch.autograd import Variable



parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training')
# Datasets
parser.add_argument('-d', '--dataset', default='cifar10', type=str)
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
# Optimization options
parser.add_argument('--epochs', default=300, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--train-batch', default=128, type=int, metavar='N',
                    help='train batchsize')
parser.add_argument('--test-batch', default=100, type=int, metavar='N',
                    help='test batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--drop', '--dropout', default=0, type=float,
                    metavar='Dropout', help='Dropout ratio')
parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225],
                        help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
# Architecture

parser.add_argument('--depth', type=int, default=29, help='Model depth.')
parser.add_argument('--cardinality', type=int, default=8, help='Model cardinality (group).')
parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...')
parser.add_argument('--growthRate', type=int, default=12, help='Growth rate for DenseNet.')
parser.add_argument('--compressionRate', type=int, default=1, help='Compression Rate (theta) for DenseNet.')
# Miscs
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')

parser.add_argument('--save_dir', default='test_checkpoint/', type=str)
parser.add_argument('--percent', default=0.1, type=float, help='percentage of weight to prune')

args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

# Validate dataset
assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.'

# Use CUDA
use_cuda = torch.cuda.is_available()

# Random seed
if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if use_cuda:
    torch.cuda.manual_seed_all(args.manualSeed)

best_acc = 0  # best test accuracy
device="cuda:0"
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.save_dir):
        mkdir_p(args.save_dir)

    # Data
    print('==> Preparing dataset %s' % args.dataset)
    '''transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    if args.dataset == 'cifar10':
        dataloader = datasets.CIFAR10
        num_classes = 10
    else:
        dataloader = datasets.CIFAR100
        num_classes = 100


    trainset = dataloader(root='./data', train=True, download=True, transform=transform_train)
    trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers)

    testset = dataloader(root='./data', train=False, download=False, transform=transform_test)
    testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers)
    '''
    mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(
        root='./', train=True, download=True,
        transform=transform_train)
    train_loader = torch.utils.data.DataLoader(cifar100_training, batch_size=64, shuffle=True)

    cifar100_testing = torchvision.datasets.CIFAR100(
        root='./', train=False, download=True,
        transform=transform_test)
    testloader = torch.utils.data.DataLoader(cifar100_testing, batch_size=64, shuffle=False)
    # Model
    #print("==> creating model '{}'".format(args.arch))
    #model = models.__dict__[args.arch](dataset=args.dataset)

    #model = model.cuda()
    #model.cuda()
    model = spiking_resnet18().to(device)  #the model will be pruned
    #modules = list(model.modules())


    model.to(device)

    checkpoint = torch.load('./best.pth.tar',map_location='cpu')    #the model will be pruned
    model.load_state_dict(checkpoint['state_dict'])

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    # Resume
    #title = 'cifar-10-' + args.arch
   

    print('\nEvaluation only')
    test_acc0 = test(testloader, model, criterion, start_epoch, use_cuda)
    print('Before pruning:Test Acc:  %.2f' % ( test_acc0))

# -------------------------------------------------------------
    #pruning 
    total = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            if  m.kernel_size==(1,1):
                continue
            total += m.weight.data.numel()
    conv_weights = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            if  m.kernel_size==(1,1):
                    continue
            size = m.weight.data.numel()
            conv_weights[index:(index+size)] = m.weight.data.view(-1).abs().clone()
            index += size

    y, i = torch.sort(conv_weights)
    thre_index = int(total * args.percent)
    thre = y[thre_index]
    pruned = 0
    print('Pruning threshold: {}'.format(thre))
    zero_flag = False
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.Conv2d):
            if  m.kernel_size==(1,1):
                    continue
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre).float().to(device)
            pruned = pruned + mask.numel() - torch.sum(mask)
            m.weight.data.mul_(mask)
            if int(torch.sum(mask)) == 0:
                zero_flag = True
            print('layer index: {:d} \t total params: {:d} \t remaining params: {:d}'.
                format(k, mask.numel(), int(torch.sum(mask))))
    print('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}'.format(total, pruned, pruned/total))
# -------------------------------------------------------------

    print('\nTesting')
    test_acc1 = test(testloader, model, criterion, start_epoch, use_cuda)
    print('After Pruning:  Test Acc:  %.2f' % ( test_acc1))
    save_checkpoint({
            'epoch': 0,
            'state_dict': model.state_dict(),
            'acc': test_acc1,
            'best_acc': 0.,
        }, False, checkpoint=args.save_dir)


    with open(os.path.join(args.save_dir, 'prune.txt'), 'w') as f:
        f.write('Before pruning: Test Acc:  %.2f\n' % (test_acc0))
        f.write('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}\n'.format(total, pruned, pruned/total))
        #f.write('After Pruning: Test Loss:  %.8f, Test Acc:  %.2f\n' % (test_loss1, test_acc1))

        if zero_flag:
            f.write("There exists a layer with 0 parameters left.")
    return

def test(dataloader, model, criterion, epoch, use_cuda):
    model.eval()

    summ = []
    correct_sum = 0
    test_sum = 0

    for data_batch, labels_batch in dataloader:


        data_batch, labels_batch = data_batch.to(device), labels_batch.to(device)

        data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)

        output_batch = model(data_batch)
        correct_sum += (output_batch.max(1)[1] == labels_batch.to(device)).float().sum().item()
        test_sum += labels_batch.numel()
        functional.reset_net(model)

    print('bbb')
    print('\nTest set: Accuracy: {}/{} ({:.2f}%)\n'.format(
        correct_sum, len(dataloader.dataset), 100. * correct_sum/ len(dataloader.dataset)))
    test_accuracy = correct_sum / test_sum

    #torch.save(model.state_dict(), "./resstudent.pt")
    #torch.save(model, "./resstudent.pth")
    #print('savedpyres')
    # print(test_accuracy)
    return test_accuracy

def save_checkpoint(state, is_best, checkpoint, filename='pruned.pth.tar'):
    print("aaa")
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)

if __name__ == '__main__':
    main()
