import argparse
import time
from copy import deepcopy
from math import log, sqrt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import json
import sys
from Utils import load
from Utils import generator
from Utils import metrics
from Layers import layers
from train import *
from prune import *
import warnings


def jls_extract_def(args, optimizer, milestones, gamma, train_loader, iters, last_epoch):
    if args.lr_scheduler == 'drop' or args.post_epochs == 0:
        if args.lr_step_size == 0:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)
        else:
            lr_drops = []
            for epoch in range(args.post_epochs):
                if epoch % args.lr_step_size == 0:
                    lr_drops.append(epoch)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_drops, gamma=args.lr_drop_rate)
    elif args.lr_scheduler == 'linear':
        total_iters = train_loader.__len__() * args.post_epochs
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda iters: (1.0 - iters / total_iters) if iters <= total_iters else 0,
                                                      last_epoch=-1)

    return scheduler


def run(args):
    ## Random Seed and Device ##
    warnings.filterwarnings("ignore")
    if args.seed is not None:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    device = load.device(args.gpu)

    ## Data ##
    # print('Loading {} prune and validation set.'.format(args.dataset))
    input_shape, num_classes = load.dimension(args.dataset)
    prune_loader = load.dataloader(args.dataset,
                                   args.prune_batch_size,
                                   True,
                                   args.workers,
                                   args.prune_dataset_ratio * num_classes,
                                   data_dir=args.data_dir)
    train_loader = load.dataloader(args.dataset, args.train_batch_size, True, args.workers, data_dir=args.data_dir)
    test_loader = load.dataloader(args.dataset, args.test_batch_size, False, args.workers, data_dir=args.data_dir)

    ## Model, Loss, Optimizer ##
    if args.verbose:
        print('Creating {}-{} model.'.format(args.dataset, args.model))
    model = load.model(args.model, args.dataset)(groups=args.groups, width_factor=args.width_factor)
    model = model.to(device)

    # args.prune_pw_only = False
    # if args.mask_scope == 'global':
    #     args.no_prune_linear = True
    #     if args.pruner == 'synflow':
    #         # args.prune_epochs = 100
    #         args.train_mode = False
    #     else:
    #         args.prune_epochs = 1
    #         args.train_mode = True
    # else:
    #     args.prune_epochs = 1
    #     args.no_prune_linear = False
    #     if 'resnet' not in args.model and 'vgg' not in args.model:
    #         args.prune_pw_only = True

    conv1x1, conv3x3, batchnorm, other_conv = metrics.stats(model, skip_last=args.skip_last)
    sparsity = args.compression
    assert sparsity > 0 and sparsity <= 1
    if args.pruner in ['opt_flops', 'opt_both']:
        if args.compression_flops is None:
            flops_sparsity = args.compression
        else:
            flops_sparsity = args.compression_flops
        assert flops_sparsity > 0 and flops_sparsity <= 1
    loss = nn.CrossEntropyLoss().to(device)
    opt_class, opt_kwargs = load.optimizer(args.optimizer)
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = load.scheduler(args, optimizer, train_loader)

    # Prune ##
    if args.prune_pw_only:
        params_generator = generator.pointwise_parameters
    else:
        params_generator = generator.masked_parameters
    params = params_generator(model, args.prune_bias, args.prune_batchnorm, args.prune_residual, args.no_prune_linear, args.prune_shortcut)
    pruner = load.pruner(args.pruner)(masked_parameters=params, skip_last=args.skip_last, prune_pw_only=args.prune_pw_only)
    all_pruner = load.pruner('synflow')(masked_parameters=generator.all_parameters(model), skip_last=False, prune_pw_only=False)
    all_pruner.dataloader = prune_loader
    all_pruner.device = device
    all_pruner.model = model
    if args.pruner in ['opt_params', 'opt_flops', 'opt_both']:
        if args.pruner != 'opt_both':
            if args.pruner == 'opt_params':
                alpha = load.get_params(model, args.prune_pw_only)
            elif args.pruner == 'opt_flops':
                alpha = load.get_flops(model, args.prune_pw_only, input_shape, device)
            sparsity_list = single_solve(alpha, sparsity)
        else:
            alpha = load.get_params(model, args.prune_pw_only, args.no_prune_linear)
            beta = load.get_flops(model, args.prune_pw_only, args.no_prune_linear, input_shape, device)
            assert len(alpha) == len(beta)
            expand_limit = 4.0 if args.expand else 1.0
            sparsity_list = double_solve(alpha, beta, sparsity, flops_sparsity, args.expand, expand_limit)
            # prevent extreme expansion
        # rule form sanity-check paper, with L = 21
        # sparsity_list = np.array([(L - l)**2 + (L - l) for l in range(L)], dtype=np.float)
        # sparsity_list /= max(sparsity_list)
        # sparsity_list = np.clip(sparsity_list * (sparsity / 0.14)**2, 0, 1)
        if args.verbose:
            print('solved sparsity = ', sparsity_list)
            print('length of modules = ', len(model.pruned_types))
        # set parameters for pruning detached weights
        if args.mask_scope == 'global':
            pruner.sparsity_list = sparsity_list
            pruner.mask(sparsity, 'random_weight', args.verbose)
            all_pruner.prune_detatched()
        elif args.mask_scope == 'filter':
            pruner.sparsity_list = sparsity_list
            pruner.mask(sparsity, 'random_filter', args.verbose)
            all_pruner.prune_detatched()
        elif args.mask_scope == 'precropping':
            # use sparsity_list[:-1] to skip the FC layer
            pruner.score(model, loss, prune_loader, device)
            pruner.precropping(sparsity_list[:-1], sparsity, verbose=args.verbose)
        else:
            raise NotImplementedError
    elif args.pruner != 'lottery':
        prune_loop(model, loss, pruner, prune_loader, device, sparsity, args.compression_schedule, args.mask_scope, args.prune_epochs,
                   args.reinitialize, args.reinitialize_sparse, args.prune_train_mode, args.shuffle, args.uniform_shuffle, args.invert, args.verbose)
        # all_pruner.prune_detatched()
    else:
        new_model = deepcopy(model)
        new_params = params_generator(new_model, args.prune_bias, args.prune_batchnorm, args.prune_residual, args.no_prune_linear,
                                      args.prune_shortcut)
        new_pruner = load.pruner(args.pruner)(masked_parameters=new_params, skip_last=args.skip_last, prune_pw_only=args.prune_pw_only)
        model = lt_prune(model, new_model, pruner, new_pruner, sparsity, args.mask_scope, loss, optimizer, scheduler, train_loader, test_loader,
                         device, args.post_epochs, args.verbose, args.result_dir, args.resume)
        pruner = new_pruner
        if args.shuffle:
            new_pruner.shuffle()
        if args.reinitialize:
            model._initialize_weights(sparse_init=args.reinitialize_sparse)

    all_pruner.score(model, None, prune_loader, device, False)
    sparsity_list = [m.sum().item() / m.numel() for m, p in all_pruner.masked_parameters]
    if args.verbose:
        print(sparsity_list)
    if sum(sparsity_list) == 0:
        sys.exit()

    # norm_list = []
    # for name, module in model.named_modules():
    #     if isinstance(module, layers.Conv2d):
    #         norm_list.append((module.weight * module.weight_mask).norm().item())
    # print(norm_list)
    # assert False

    if args.mask_scope == 'pregrouping':
        if not hasattr(pruner, 'model'):
            pruner.model = model
        pruner.set_group_number()
        params = generator.masked_parameters(model, args.prune_bias, args.prune_batchnorm, args.prune_residual, args.no_prune_linear,
                                             args.prune_shortcut)
        pruner = load.pruner(args.pruner)(params, skip_last=args.skip_last, prune_pw_only=False)
        model._initialize_weights()
        model.to(device)
        sparsity = 1

    if args.parallel:
        model = nn.DataParallel(model).to(device)
    print(model, file=open(f'{args.result_dir}/model_def', 'w'))

    if args.speed_test:
        print('Speed testing')
        torch.backends.cudnn.benchmark = True
        print(sum([p.numel() for p in model.parameters()]))
        print(model)
        if args.no_cuda:
            model.cpu()
        if args.speed_test == 'inference':
            backward = False
        elif args.speed_test == 'training':
            backward = True
        else:
            raise NotImplementedError()
        start = time.time()
        run_iter = 100
        dummy_input = next(iter(test_loader))
        avg_time = speed_test(model, loss, dummy_input, device, run_iter=run_iter, backward=backward, cuda=not args.no_cuda)
        print(f'avg time per batch = {avg_time}')
        sys.exit()

    # Post-Train ##
    if args.print_model:
        print(model)
    torch.cuda.empty_cache()
    optimizer = opt_class(generator.parameters(model), lr=args.lr, weight_decay=args.weight_decay, **opt_kwargs)
    scheduler = load.scheduler(args, optimizer, train_loader)
    post_result, best_acc = train_eval_loop(model, loss, optimizer, scheduler, train_loader, test_loader, device, args.post_epochs, args.verbose)
    remaining_params, total_params = pruner.stats()

    ## Display Results ##
    frames = [post_result.head(1), post_result.tail(1)]
    train_result = pd.concat(frames, keys=['Post-Prune', 'Final'])
    prune_result = metrics.summary(model, all_pruner.scores, metrics.flop(model, input_shape, device, pw_only=False),
                                   lambda p: generator.prunable(p, batchnorm=True, residual=True, linear=True, shortcut=True))
    # lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual, args.no_prune_linear, args.prune_shortcut))
    total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
    possible_params = prune_result['size'].sum()
    total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
    possible_flops = prune_result['flops'].sum()

    result_string = f"Best validation accuracy = {best_acc} \n"
    result_string += f"Train results:{train_result}\n"
    result_string += f"Prune results:{prune_result}\n"
    result_string += "Parameter Sparsity: {}/{} ({:.4f})\n".format(total_params, possible_params, total_params / possible_params)
    result_string += "FLOP Sparsity: {}/{} ({:.4f})\n".format(total_flops, possible_flops, total_flops / possible_flops)

    if args.verbose:
        print(result_string)
    print(result_string, file=open(f'{args.result_dir}/prune_result', 'a'))

    if args.gather_result_path is not None:
        gather_result_string = f"{best_acc, total_params, total_flops}"
        print(gather_result_string, file=open(f'Results/data/{args.gather_result_path}', 'a'))

    ## Save Results and Model ##
    if args.save:
        # print('Saving results.')
        post_result.to_pickle("{}/post-train.pkl".format(args.result_dir))
        prune_result.to_pickle("{}/compression.pkl".format(args.result_dir))
        torch.save(model.state_dict(), "{}/model.pt".format(args.result_dir))
        torch.save(optimizer.state_dict(), "{}/optimizer.pt".format(args.result_dir))
        torch.save(scheduler.state_dict(), "{}/scheduler.pt".format(args.result_dir))
