from __future__ import division

import os, time, random, math
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
from thop import profile
import copy

def get_resnet_honey(model, channel_PR):
    honey = []
    total = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            if 'bn3' not in name and 'downsample' not in name:
                total += module.weight.data.shape[0]
    
    bn = torch.zeros(total)
    index = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            if 'bn3' not in name and 'downsample' not in name:
                size = module.weight.data.shape[0]
                bn[index:(index+size)] = module.weight.data.abs().clone()
                index += size

    y, i = torch.sort(bn)
    thre_index = int(total * channel_PR)
    thre = y[thre_index]
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            if 'bn3' not in name and 'downsample' not in name:
                weight_copy = module.weight.data.abs().clone()
                mask = weight_copy.gt(thre).float().cuda()
                honey.append(int(torch.sum(mask)) if int(torch.sum(mask))>0 else 1)

    return honey