import torch
import torch.nn as nn
import torch.optim as optim
from utils.options import args
import copy

import torch_pruning as tp
from functools import partial


def get_pruner(model, example_inputs, num_classes):
    unwrapped_parameters = ([model.encoder.pos_embedding, model.class_token]
                            if "vit" in args.cfg else None)
    imp = tp.importance.GroupNormImportance(p=2)
    pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=True)
    ignored_layers = []
    ch_sparsity_dict = {}
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and m.out_features == num_classes:
            ignored_layers.append(m)
    round_to = None
    if 'vit' in args.cfg:
        round_to = model.encoder.layers[0].num_heads
    pruner = pruner_entry(
        model,
        example_inputs,
        importance=imp,
        iterative_steps=100,
        ch_sparsity=1.0,
        ch_sparsity_dict=ch_sparsity_dict,
        max_ch_sparsity=1.0,
        ignored_layers=ignored_layers,
        round_to=round_to,
        unwrapped_parameters=unwrapped_parameters,
    )
    return pruner


def prune_to_target_flops(pruner, model, example_inputs):
    model.eval()
    ori_ops, _ = tp.utils.count_ops_and_params(model,
                                               example_inputs=example_inputs)

    pruned_ops = ori_ops
    while pruned_ops > (1 - args.target_flops_PR) * ori_ops:
        pruner.step()
        if 'vit' in args.cfg:
            model.hidden_dim = model.conv_proj.out_channels
        pruned_ops, _ = tp.utils.count_ops_and_params(
            model, example_inputs=example_inputs)

    return pruned_ops


def prune_groupnorm(model, example_inputs, num_classes):
    model = copy.deepcopy(model)
    pruner = get_pruner(model, example_inputs, num_classes)
    prune_to_target_flops(pruner, model, example_inputs)
    print('model:', model)
    return model
