import numpy as np
from models import *
import callbacks
import tensorflow as tf
from tensorflow.keras import models
from effective_masks import *
from utils import *
import logging

def lamp(weights,target_sparsity):
  shapes=[weight.shape for weight in weights]
  counts=[np.prod(shape) for shape in shapes]
  counts_sum=[sum(counts[:layer]) for layer in range(len(shapes)+1)]
  wsorted_squared=[sorted((weight**2).reshape(-1)) for weight in weights]
  partial_sums=[]
  for w in wsorted_squared:
    partial_sum=[0]
    for weight in w[::-1]:
      partial_sum.append(partial_sum[-1]+weight)
    partial_sums.append(partial_sum[:0:-1])
  scores=[np.array([w[i]/(p[i]+1e-8) for i in range(len(w))]) for w,p in zip(wsorted_squared,partial_sums)]
  scores=[score[np.argsort(np.argsort((weight**2).reshape(-1)))] for score,weight in zip(scores,weights)]
  scores=np.concatenate(scores)
  masks=np.ones((sum(counts)))
  masks[scores.argsort()[:int(target_sparsity*sum(counts))]]=0.
  masks=[masks[counts_sum[layer]:counts_sum[layer+1]].reshape(shapes[layer]) for layer in range(len(shapes))]
  return scores,masks

def magnitude_layerwise(weight,target_sparsity):
  shape=weight.shape
  flat_abs=np.abs(weight.reshape(-1))
  argsort=flat_abs.argsort()
  prune_count=int(round(len(flat_abs)*target_sparsity))
  mask=np.ones(flat_abs.shape)
  mask[argsort[:prune_count]]=0
  mask=mask.reshape(shape)
  return mask

def magnitude_global(weights,target_sparsity):
  shapes=[weight.shape for weight in weights]
  counts=[np.prod(shape) for shape in shapes]
  counts_sum=[sum(counts[:layer]) for layer in range(len(shapes)+1)]
  scores=np.concatenate([np.abs(weight.reshape(-1)) for weight in weights])
  masks=np.ones((sum(counts)))
  masks[scores.argsort()[:int(target_sparsity*sum(counts))]]=0.
  masks=[masks[counts_sum[layer]:counts_sum[layer+1]].reshape(shapes[layer]) for layer in range(len(shapes))]
  return scores,masks

def erk_quotas(target_sparsity,shapes,**kwargs):
  counts=[np.prod(shape) for shape in shapes]
  coeffs=[sum(shape)/counts[i] for i,shape in enumerate(shapes)]
  k=(sum(counts)*(1-target_sparsity))/sum([count*coeff for count,coeff in zip(counts,coeffs)])
  sparsities=[1-k*coeff for coeff in coeffs]
  return redistribute_invalid_quotas(sparsities,shapes)

def bs_force_igq(areas,Lengths,target_sparsity,tolerance,f_low,f_high):
  lengths_low=[Length/(f_low/area+1) for Length,area in zip(Lengths,areas)]
  overall_sparsity_low=1-sum(lengths_low)/sum(Lengths)
  if abs(overall_sparsity_low-target_sparsity)<tolerance:
    return [1-length/Length for length,Length in zip(lengths_low,Lengths)]
  lengths_high=[Length/(f_high/area+1) for Length,area in zip(Lengths,areas)]
  overall_sparsity_high=1-sum(lengths_high)/sum(Lengths)
  if abs(overall_sparsity_high-target_sparsity)<tolerance:
    return [1-length/Length for length,Length in zip(lengths_high,Lengths)]
  force=float(f_low+f_high)/2
  lengths=[Length/(force/area+1) for Length,area in zip(Lengths,areas)]
  overall_sparsity=1-sum(lengths)/sum(Lengths)
  f_low=force if overall_sparsity<target_sparsity else f_low
  f_high=force if overall_sparsity>target_sparsity else f_high
  return bs_force_igq(areas,Lengths,target_sparsity,tolerance,f_low,f_high)

def igq_quotas(target_sparsity,shapes,**kwargs):
  counts=[np.prod(shape) for shape in shapes]
  tolerance=1./sum(counts)
  areas=[1./count for count in counts]
  Lengths=[count for count in counts]
  return bs_force_igq(areas,Lengths,target_sparsity,tolerance,0,1e20)

def redistribute_invalid_quotas(sparsities,shapes,**kwargs):
  counts=[np.prod(shape) for shape in shapes]
  for layer in range(len(sparsities)-1,-1,-1):
    sparsity=sparsities[layer]
    if sparsity<0 and layer>0:
      sparsities[layer-1]=(counts[layer-1]*sparsities[layer-1]+sparsities[layer]*counts[layer])/counts[layer-1]
      sparsities[layer]=0
    elif sparsity<0 and abs(sparsity*counts[0])>2 and layer==0:
      logging.warning(f"<pruning> unable to redistribute density quotas: {sparsities}")
  return sparsities

def uniform_plus_quotas(target_sparsity,shapes,**kwargs):
  if len(shapes[0])<=2: logging.error("<pruning> uniform+ supports convolutional networks only.")
  assert len(shapes[0])>2,"<pruning> uniform+ supports convolutional networks only."
  counts=[np.prod(shape) for shape in shapes]
  sparsity=target_sparsity*sum(counts)/sum(counts[1:])
  to_distribute=max([0,(sparsity-0.8)*counts[-1]])
  additional_sparsity=to_distribute/(sum(counts[1:-1]))
  return np.concatenate([[0.],[sparsity+additional_sparsity]*(len(counts)-2),[min([sparsity,0.8])]])

def effective_correction_from_global_scores(model,tensors,scores,target_sparsity):
  shapes=[model.layers[layer].get_weights()[0].shape for layer in tensors]
  counts=[np.prod(shape) for shape in shapes]
  counts_sum=[sum(counts[:layer]) for layer in range(len(shapes)+1)]
  low,high=0,sum(counts)
  while high-low>1:
    middle=(high+low)//2
    middle_masks=np.concatenate([np.ones(shape).reshape(-1) for shape in shapes])
    middle_masks[scores.argsort()[:middle]]=0.
    middle_masks=[middle_masks[counts_sum[layer]:counts_sum[layer+1]].reshape(shapes[layer]) for layer in range(len(shapes))]
    middle_effective_masks=effective_masks_synflow(model,tensors,middle_masks)
    middle_effective_sparsity=get_overall_direct_sparsity(middle_effective_masks)
    if middle_effective_sparsity<=target_sparsity:
      low,high=middle,high
    else:
      low,high=low,middle
  low_masks=np.concatenate([np.ones(shape).reshape(-1) for shape in shapes])
  low_masks[scores.argsort()[:low]]=0.
  low_masks=[low_masks[counts_sum[layer]:counts_sum[layer+1]].reshape(shapes[layer]) for layer in range(len(shapes))]
  return low_masks

def effective_correction_layerwise_scores_magnitude_pruning(model,tensors,func,scores,target_sparsity):
  shapes=[model.layers[layer].get_weights()[0].shape for layer in tensors]
  counts=[np.prod(shape) for shape in shapes]
  low,high,flag=0,sum(counts),False
  low_val=func(target_sparsity=low,shapes=shapes)
  while ((np.array(low_val)<0).any() or (np.array(low_val)>1).any()) and high-low>1: # checking the lowest achievable sparsity 
    flag=True
    middle=(low+high)//2
    middle_val=func(target_sparsity=middle/sum(counts),shapes=shapes)
    middle_sparsities=np.array(middle_val)
    if (middle_sparsities<0).any() or (middle_sparsities>1).any():
      low,high=middle,high
    else:
      low,high=low,middle
  low=actual_low=high if flag else 0
  high_val=sum(counts)
  flag=False
  while ((np.array(high_val)<0).any() or (np.array(high_val)>1).any()) and high-low>1: # checking the highest achievable sparsity
    flag=True
    middle=(low+high)//2
    middle_val=func(target_sparsity=middle/sum(counts),shapes=shapes)
    middle_sparsities=np.array(middle_val)
    if (middle_sparsities<0).any() or (middle_sparsities>1).any():
      low,high=low,middle
    else:
      low,high=middle,high
  high=actual_high=low if flag else sum(counts)
  low=actual_low
  if low>=high:
    logging.error(f"<pruning> low: ({low}) >= high ({high}). target sparsity might be incompatible with the pruning method.")
  low_val=func(target_sparsity=low/sum(counts),shapes=shapes)
  low_sparsities=np.array(low_val)
  low_masks=[magnitude_layerwise(score,sparsity) for score,sparsity in zip(scores,low_sparsities)]
  high_masks=[np.zeros(shape) for shape in shapes]
  while high-low>1:
    middle=(high+low)//2
    middle_val=func(target_sparsity=middle/sum(counts),shapes=shapes)
    middle_sparsities=np.array(middle_val)
    middle_masks=[magnitude_layerwise(score,sparsity) for score,sparsity in zip(scores,middle_sparsities)]
    effective_middle_masks=effective_masks_synflow(model,tensors,middle_masks)
    effective_middle_sparsity=get_overall_direct_sparsity(effective_middle_masks)
    if effective_middle_sparsity<=target_sparsity:
      low,high=middle,high
      low_masks,high_masks=middle_masks,high_masks
    else:
      low,high=low,middle
      low_masks,high_masks=low_masks,middle_masks
  masks=low_masks
  sparsities=func(target_sparsity=target_sparsity,shapes=shapes)
  logging.info(f'<pruning> requested: ({target_sparsity:.6f})')
  direct_masks=[magnitude_layerwise(score,sparsity) for score,sparsity in zip(scores,sparsities)]
  logging.info(f'<pruning> direct pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,direct_masks)):.6f}')
  logging.info(f'<pruning> effective pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,masks)):.6f}')
  return low_masks

def uniform_quotas(target_sparsity,shapes,**kwargs):
  return np.full(len(shapes),target_sparsity)

class Pruner(object):
  def __init__(self,mode):
    self.mode=mode
  def prune(self,model,tensors,sparsity,pruning_type,**kwargs):
    shapes=[model.layers[layer].get_weights()[0].shape for layer in tensors]
    if self.mode=='lamp':
      try:
        final_weights=np.load(f'{kwargs["path_to_dense"]}/{kwargs["sample"]}_1_final_weights.npy',allow_pickle=True)
        inits=np.load(f'{kwargs["path_to_dense"]}/{kwargs["sample"]}_1_inits.npy',allow_pickle=True)
      except:
        logging.error(f'<pruning> required files in ({kwargs["path_to_dense"]}/{kwargs["sample"]}) do not exist.')
        raise FileNotFoundError
      set_weights_model(model,tensors,inits)
      scores,masks=lamp(final_weights,sparsity)
      corrected_masks=effective_correction_from_global_scores(model,tensors,scores,sparsity)
      if pruning_type=='effective':
        return corrected_masks
      elif pruning_type=='direct':
        return masks
    elif self.mode.split('/')[0]=='magnitude':
      model_name=model.name.replace('-','').lower()
      try:
        final_weights=np.load(f'{kwargs["path_to_dense"]}/{kwargs["sample"]}_1_final_weights.npy',allow_pickle=True)
        inits=np.load(f'{kwargs["path_to_dense"]}/{kwargs["sample"]}_1_inits.npy',allow_pickle=True)
      except:
        logging.error(f'<pruning> required files in ({kwargs["path_to_dense"]}/{kwargs["sample"]}) do not exist.')
        raise FileNotFoundError
      if self.mode=='magnitude/erk':
        corrected_masks=effective_correction_layerwise_scores_magnitude_pruning(model,tensors,erk_quotas,abs(final_weights),sparsity)
        sparsities=check_valid_sparsities(erk_quotas(sparsity,shapes))
        masks=[magnitude_layerwise(final_weight,s) for final_weight,s in zip(final_weights,sparsities)]
      if self.mode=='magnitude/igq':
        corrected_masks=effective_correction_layerwise_scores_magnitude_pruning(model,tensors,igq_quotas,abs(final_weights),sparsity)
        sparsities=check_valid_sparsities(igq_quotas(sparsity,shapes))
        masks=[magnitude_layerwise(final_weight,s) for final_weight,s in zip(final_weights,sparsities)]
      if self.mode=='magnitude/uniform':
        corrected_masks=effective_correction_layerwise_scores_magnitude_pruning(model,tensors,uniform_quotas,abs(final_weights),sparsity)
        sparsities=check_valid_sparsities(uniform_quotas(sparsity,shapes))
        masks=[magnitude_layerwise(final_weight,s) for final_weight,s in zip(final_weights,sparsities)]
      if self.mode=='magnitude/uniform_plus':
        corrected_masks=effective_correction_layerwise_scores_magnitude_pruning(model,tensors,uniform_plus_quotas,abs(final_weights),sparsity)
        sparsities=check_valid_sparsities(uniform_plus_quotas(sparsity,shapes))
        masks=[magnitude_layerwise(final_weight,s) for final_weight,s in zip(final_weights,sparsities)]
      if self.mode=='magnitude/global':
        scores,masks=magnitude_global(final_weights,sparsity)
        corrected_masks=effective_correction_from_global_scores(model,tensors,scores,sparsity)
      set_weights_model(model,tensors,inits)
      if pruning_type=='effective':
        return corrected_masks
      elif pruning_type=='direct':
        return masks
    elif self.mode.split('/')[0]=='random':
      if self.mode=='random/snip':
        func=lambda **kwargs: get_direct_sparsity(self.prune_snip(**kwargs))
      if self.mode=='random/synflow':
        func=lambda **kwargs: get_direct_sparsity(self.prune_synflow(**kwargs))
      if self.mode=='random/uniform':
        func=uniform_quotas
      if self.mode=='random/igq':
        func=igq_quotas
      if self.mode=='random/erk':
        func=erk_quotas
      if self.mode=='random/uniform_plus':
        func=uniform_plus_quotas
      if pruning_type=='effective':
        return self.prune_random(model,tensors,'effective',target_sparsity=sparsity,func=func,shapes=shapes,**kwargs)
      elif pruning_type=='direct':
        sparsities=check_valid_sparsities(func(target_sparsity=sparsity,shapes=shapes,model=model,tensors=tensors,pruning_type='direct',**kwargs))
        return self.prune_random(model,tensors,'direct',sparsities=sparsities)
    elif self.mode=='snip/iterative':
      return self.prune_iterative_snip(model,tensors,sparsity,pruning_type,**kwargs)
    elif self.mode=='snip':
      return self.prune_snip(model,tensors,sparsity,pruning_type,**kwargs)
    elif self.mode=='synflow':
      return self.prune_synflow(model,tensors,sparsity,pruning_type,**kwargs)
    elif self.mode=='dense':
      return [np.ones(shape) for shape in shapes]
    else:
      logging.error(f'<pruning> unknown pruner "{self.mode}" encountered.')
      raise NotImplementedError

  def prune_synflow(self,model,tensors,target_sparsity,pruning_type,train_X,train_y,**kwargs):
    shapes=[model.layers[layer].get_weights()[0].shape for layer in tensors]
    masks=[np.ones(shape) for shape in shapes]
    counts=[np.prod(shape) for shape in shapes]
    counts_sum=[sum(counts[:layer]) for layer in range(len(shapes)+1)]
    linear_model,linear_tensors=get_model(shape=train_X[0].shape,architecture=model.name.replace('-','').lower(),batchnorm=False,activation=False,pool='average',output_classes=len(train_y[0]))
    abs_inits=[abs(model.layers[layer].get_weights()[0]) for layer in tensors]
    already_pruned,weight_scores=0,np.zeros(counts_sum[-1])
    for iteration in range(1,101):
      set_weights_model(linear_model,linear_tensors,abs_inits,masks=masks)
      to_prune=int(counts_sum[-1]-counts_sum[-1]*(1-target_sparsity)**(float(iteration)/100))-already_pruned
      already_pruned+=to_prune
      with tf.GradientTape(persistent=False) as tape:
        output=linear_model(np.ones([1]+linear_model.inputs[0].shape[1:]))
        saliency=tf.reduce_sum(output)
        weights=[linear_model.layers[layer].trainable_weights[0] for layer in linear_tensors]
      gradients=tape.gradient(saliency,weights)
      scores=np.concatenate([(gradient.numpy()*abs_init*mask).reshape(-1) for gradient,abs_init,mask in zip(gradients,abs_inits,masks)])
      masks=np.concatenate([mask.reshape(-1) for mask in masks])
      indices_to_prune=scores.argsort()[np.isin(scores.argsort(),np.where(masks==1)[0])][:to_prune]
      masks[indices_to_prune]=0.
      if len(indices_to_prune)>0:
        weight_scores[indices_to_prune]=iteration*2+(scores[indices_to_prune]-np.min(scores[indices_to_prune]))/(np.max(scores[indices_to_prune])-np.min(scores[indices_to_prune])+1e-10)
      masks=[masks[counts_sum[layer]:counts_sum[layer+1]].reshape(shapes[layer]) for layer in range(len(linear_tensors))]
      if iteration==100:
        last_batch=np.where(weight_scores==0)[0]
        weight_scores[last_batch]=(iteration+1)*2+(scores[last_batch]-np.min(scores[last_batch]))/(np.max(scores[last_batch])-np.min(scores[last_batch])+1e-10)
    if pruning_type=='effective':
      corrected_masks=effective_correction_from_global_scores(model,tensors,weight_scores,target_sparsity)
      logging.info(f'<pruning> direct pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,masks)):.6f}')
      logging.info(f'<pruning> effective pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,corrected_masks)):.6f}')
      return corrected_masks
    elif pruning_type=='direct':
      return masks

  def prune_snip(self,model,tensors,target_sparsity,pruning_type,train_X,train_y,config,**kwargs):
    inits=[model.layers[layer].get_weights()[0] for layer in tensors]
    shapes=[model.layers[layer].get_weights()[0].shape for layer in tensors]
    counts=[np.prod(shape) for shape in shapes]
    counts_sum=[sum(counts[:layer]) for layer in range(len(shapes)+1)]
    choice=np.random.choice(range(len(train_X)),size=min([config['batch_size_snip']*300,len(train_X)]),replace=False)
    batch_X,batch_y=train_X[choice],train_y[choice]
    masks=np.ones(counts_sum[-1])
    gradients={layer:[] for layer in tensors}
    for minibatch_X,minibatch_y in zip(np.split(batch_X,range(128,len(batch_X),128)),np.split(batch_y,range(128,len(batch_y),128))):
      with tf.GradientTape(persistent=False) as tape:
        output=model(minibatch_X)
        weights=[model.layers[layer].trainable_weights[0] for layer in tensors]
        loss=tf.reduce_mean(tf.keras.losses.categorical_crossentropy(output,minibatch_y))
      gradients_minibatch=tape.gradient(loss,weights)
      for gradient_minibatch,layer in zip(gradients_minibatch,tensors):
        gradients[layer].append(gradient_minibatch.numpy())
    for layer in tensors:
      gradients[layer]=np.mean(gradients[layer],axis=0)
    cs=np.concatenate([abs(gradients[layer]*init*mask).reshape(-1) for layer,init,mask in zip(tensors,inits,masks)])
    masks[cs.argsort()[:int(target_sparsity*counts_sum[-1])]]=0.
    masks=[masks[counts_sum[layer]:counts_sum[layer+1]].reshape(shapes[layer]) for layer in range(len(tensors))]
    if pruning_type=='effective':
      corrected_masks=effective_correction_from_global_scores(model,tensors,cs,target_sparsity)
      logging.info(f'<pruning> direct pruning: overall effective sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,masks)):.6f}')
      logging.info(f'<pruning> effective pruning: overall effective sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,corrected_masks)):.6f}')
      return corrected_masks
    elif pruning_type=='direct':
      return masks

  def prune_iterative_snip(self,model,tensors,target_sparsity,pruning_type,train_X,train_y,config,**kwargs):
    inits=[model.layers[layer].get_weights()[0] for layer in tensors]
    shapes=[init.shape for init in inits]
    masks=[np.ones(shape) for shape in shapes]
    counts=[np.prod(shape) for shape in shapes]
    counts_sum=[sum(counts[:layer]) for layer in range(len(shapes)+1)]
    scores,already_pruned=np.zeros(counts_sum[-1]),0
    for iteration in range(1,301):
      choice=np.random.choice(range(len(train_X)),size=min([config['batch_size_snip'],len(train_X)]),replace=False)
      batch_X,batch_y=train_X[choice],train_y[choice]
      to_prune=int(counts_sum[-1]-counts_sum[-1]*(1-target_sparsity)**(float(iteration)/300))-already_pruned
      already_pruned+=to_prune
      gradients={layer:[] for layer in tensors}
      set_weights_model(model,tensors,inits,masks=masks)
      for minibatch_X,minibatch_y in zip(np.split(batch_X,range(128,len(batch_X),128)),np.split(batch_y,range(128,len(batch_y),128))):
        with tf.GradientTape(persistent=False) as tape:
          output=model(minibatch_X)
          weights=[model.layers[layer].trainable_weights[0] for layer in tensors]
          loss=tf.reduce_mean(tf.keras.losses.categorical_crossentropy(output,minibatch_y))
        gradients_minibatch=tape.gradient(loss,weights)
        for gradient_minibatch,layer in zip(gradients_minibatch,tensors):
          gradients[layer].append(gradient_minibatch.numpy())
      for layer in tensors:
        gradients[layer]=np.mean(gradients[layer],axis=0)
      cs=np.concatenate([abs(gradients[layer]*init*mask).reshape(-1) for layer,init,mask in zip(tensors,inits,masks)])
      masks=np.concatenate([mask.reshape(-1) for mask in masks])
      indices_to_prune=cs.argsort()[np.isin(cs.argsort(),np.where(masks==1)[0])][:to_prune]
      masks[indices_to_prune]=0.
      if len(indices_to_prune)>0:
        scores[indices_to_prune]=iteration*2+(cs[indices_to_prune]-np.min(cs[indices_to_prune]))/(np.max(cs[indices_to_prune])-np.min(cs[indices_to_prune])+1e-10)
      masks=[masks[counts_sum[layer]:counts_sum[layer+1]].reshape(shapes[layer]) for layer in range(len(tensors))]
    scores[np.where(scores==0)[0]]=2*(iteration+2)
    set_weights_model(model,tensors,inits)
    if pruning_type=='effective':
      corrected_masks=effective_correction_from_global_scores(model,tensors,scores,target_sparsity)
      logging.info(f'<pruning> direct pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,masks)):.6f}')
      logging.info(f'<pruning> effective pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,corrected_masks)):.6f}')
      return corrected_masks
    elif pruning_type=='direct':
      return masks

  def prune_random(self,model,tensors,pruning_type,**kwargs):
    if pruning_type=='direct':
      sparsities=kwargs["sparsities"]
      masks=[np.ones(model.layers[layer].get_weights()[0].shape) for layer in tensors]
      inds=[np.random.choice(range(np.prod(mask.shape)),size=int(sparsity*np.prod(mask.shape)),replace=False) for sparsity,mask in zip(sparsities,masks)]
      for ind,mask in enumerate(masks):
        mask.reshape(-1)[inds[ind]]=0.
    elif pruning_type=='effective':
      shapes=[model.layers[layer].get_weights()[0].shape for layer in tensors]
      counts=[np.prod(shape) for shape in shapes]
      target_sparsity=kwargs["target_sparsity"]
      del kwargs["target_sparsity"]
      low,high,func,flag=0,sum(counts),kwargs["func"],False
      low_val=func(model=model,tensors=tensors,pruning_type='direct',target_sparsity=low/sum(counts),**kwargs)
      while ((np.array(low_val)<0).any() or (np.array(low_val)>1).any()) and high-low>1:
        flag=True
        middle=(low+high)//2
        middle_val=func(model=model,tensors=tensors,pruning_type='direct',target_sparsity=middle/sum(counts),**kwargs)
        middle_sparsities=np.array(middle_val)
        if (middle_sparsities<0).any() or (middle_sparsities>1).any():
          low,high=middle,high
        else:
          low,high=low,middle
      low=high if flag else 0
      high=sum(counts)
      low_val=func(model=model,tensors=tensors,pruning_type='direct',target_sparsity=low/sum(counts),**kwargs)
      low_sparsities=np.array(low_val)
      low_masks=self.prune_random(model,tensors,'direct',sparsities=low_sparsities)
      high_masks=[np.zeros(shape) for shape in shapes]
      while high-low>1:
        middle=(high+low)//2
        middle_val=func(model=model,tensors=tensors,pruning_type='direct',target_sparsity=middle/sum(counts),**kwargs)
        middle_sparsities=check_valid_sparsities(middle_val)
        middle_to_prune=[round(s*count) for s,count in zip(middle_sparsities,counts)]
        indices_to_prune=[np.random.choice(np.where(np.logical_and(low_masks[layer].reshape(-1)==1,high_masks[layer].reshape(-1)==0))[0],size=min([max([0,int(middle_to_prune[layer]-np.sum(1-low_masks[layer]))]),len(np.where(np.logical_and(low_masks[layer].reshape(-1)==1,high_masks[layer].reshape(-1)==0))[0])]),replace=False) for layer in range(len(tensors))]
        middle_masks=[np.array(low_masks[layer]) for layer in range(len(tensors))]
        for i,middle_mask in enumerate(middle_masks):
          middle_mask.reshape(-1)[indices_to_prune[i]]=0
        effective_middle_masks=effective_masks_synflow(model,tensors,middle_masks)
        effective_middle_sparsity=get_overall_direct_sparsity(effective_middle_masks)
        if effective_middle_sparsity<=target_sparsity:
          low,high=middle,high
          low_masks,high_masks=middle_masks,high_masks
        else:
          low,high=low,middle
          low_masks,high_masks=low_masks,middle_masks
      masks=low_masks
      sparsities=func(model=model,tensors=tensors,pruning_type='direct',target_sparsity=target_sparsity,**kwargs)
      sparsities=check_valid_sparsities(sparsities)
      direct_masks=self.prune_random(model,tensors,'direct',sparsities=sparsities)
      logging.info(f'<pruning> direct pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,direct_masks)):.6f}')
      logging.info(f'<pruning> effective pruning overall sparsity: {get_overall_direct_sparsity(effective_masks_synflow(model,tensors,masks)):.6f}')
    return masks