############################################

## This code, in Pytorch, implements our method NEPENTHE
## for a ResNet-18 on Tiny ImageNet.

############################################

# Import libraries
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch import optim
import numpy as np
from tqdm import tqdm
import copy
import random
import os
import re
from torch.utils.data import DataLoader


DATA_DIR = 'Your_Data_Location'                # The path of Tiny-Inet

if not os.path.exists('./model/'):
	os.makedirs('./model/')

# set random seed
torch.manual_seed(43)
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":16:8"
random.seed(43)
np.random.seed(43)
torch.use_deterministic_algorithms(True)

#Choose devise
device = torch.device('cuda:0')  # Device configuration


# Training parameters 
epochs = 160
learning_rate = 0.1
momentum = 0.9
gamma=0.1
weight_decay = 1e-4
milestones=[80,120]
batch_size =128
fixed_amount_of_pruning = 0.5


class Hook():
	def __init__(self, module, backward=False):
		if backward==False:
			self.hook = module.register_forward_hook(self.hook_fn)                        
		else:
			self.hook = module.register_backward_hook(self.hook_fn)													 
	def hook_fn(self, module, input, output):
		self.output = input[0]  
	def close(self):
		self.hook.remove()

# function to calculate the probability for 'ON' state
def on_probability(input_tensor):
	equal_to_on = torch.eq(input_tensor, 1).float()
	sum_along_dim1 = torch.sum(equal_to_on, dim=0)
	non_zero_count = torch.sum(equal_to_on, dim=0) + torch.sum(torch.eq(input_tensor, -1).float(), dim=0)  #count the number of non_zero
	probability = sum_along_dim1 / (non_zero_count + 1e-12)      # if all the input is 0, the probability will be 0
	return probability


def fuse_conv_and_bn(conv, bn):
	#
	# init
	fusedconv = torch.nn.Conv2d(
		conv.in_channels,
		conv.out_channels,
		kernel_size=conv.kernel_size,
		stride=conv.stride,
		padding=conv.padding,
		bias=True
	)
	# prepare filters
	w_conv = conv.weight.clone().view(conv.out_channels, -1)
	w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var))).to(device)
	fused_weight =  torch.mm(w_bn, w_conv).view(fusedconv.weight.size())
	fusedconv.weight = torch.nn.Parameter(fused_weight)
	# prepare spatial bias
	if conv.bias is not None:
		b_conv = conv.bias.to(device)
	else:
		b_conv = torch.zeros( conv.weight.size(0) ).to(device)
	b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)).to(device)
	fused_bias = torch.matmul(w_bn, b_conv) + b_bn
	fusedconv.bias = torch.nn.Parameter(fused_bias)

	return fusedconv


def train(model, epoch, optimizer):
	print('\nEpoch : %d'%epoch)    
	model.train()
	running_loss=0
	correct=0
	total=0    
	loss_fn=torch.nn.CrossEntropyLoss()
	for data in tqdm(train_loader):       
		inputs,labels=data[0].to(device),data[1].to(device)        
		outputs=model(inputs)       
		loss=loss_fn(outputs,labels)       
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()     
		running_loss += loss.item()        
		_, predicted = outputs.max(1)
		total += labels.size(0)
		correct += predicted.eq(labels).sum().item()		
	train_loss=running_loss/len(train_loader)
	accu=100.*correct/total	
	print('Train Loss: %.3f | Accuracy: %.3f'%(train_loss,accu))
	return(accu, train_loss)


def val(model):
	model.eval()	
	running_loss=0
	correct=0
	total=0    
	loss_fn=torch.nn.CrossEntropyLoss()
	with torch.no_grad():
		for data in tqdm(val_loader):
			images,labels=data[0].to(device),data[1].to(device)
			outputs=model(images)
			loss= loss_fn(outputs,labels)
			running_loss+=loss.item()     
			_, predicted = outputs.max(1)
			total += labels.size(0)
			correct += predicted.eq(labels).sum().item()   
	test_loss=running_loss/len(test_loader)
	accu=100.*correct/total	
	print('Val Loss: %.3f | Accuracy: %.3f'%(test_loss,accu))
	return(accu, test_loss)


def test(model):
	model.eval()	
	running_loss=0
	correct=0
	total=0    
	loss_fn=torch.nn.CrossEntropyLoss()
	with torch.no_grad():
		for data in tqdm(test_loader):
			images,labels=data[0].to(device),data[1].to(device)
			outputs=model(images)
			loss= loss_fn(outputs,labels)
			running_loss+=loss.item()     
			_, predicted = outputs.max(1)
			total += labels.size(0)
			correct += predicted.eq(labels).sum().item()   
	test_loss=running_loss/len(test_loader)
	accu=100.*correct/total	
	print('Test Loss: %.3f | Accuracy: %.3f'%(test_loss,accu))
	return(accu, test_loss)


# function to calculte entropy
def test_entropy(model, hooks):
	model.eval()
	layers_entropy = {}
	neurons_entropy = {}
	ave_Pon_output = {}
	entropy_per_neuron = {}
	neuron_Pon = {}
	entropy = {}
	for key in hooks.keys():
		entropy[key] = 0
		neuron_Pon[key] = 0
		entropy_per_neuron[key] = 0	
	running_loss=0
	correct=0
	total=0    
	loss_fn=torch.nn.CrossEntropyLoss()
	with torch.no_grad():
		for data in tqdm(train_loader):
			images,labels=data[0].to(device),data[1].to(device)
			outputs=model(images)
			loss= loss_fn(outputs,labels)   
			running_loss+=loss.item()     
			_, predicted = outputs.max(1)
			total += labels.size(0)
			correct += predicted.eq(labels).sum().item()  
			for key in hooks.keys():         # For different layers	
				full_p_one = torch.sign(hooks[key].output)
				p_one = on_probability(full_p_one).to(device)
				state = p_one                                       
				state = state.reshape(state.shape[0], -1).to(device)              
				state_sum = torch.mean(state*1.0 , dim=1).to(device)                        
				state_sum_num = torch.sum((state_sum!= 0) * (state_sum!= 1))
				while len(p_one.shape) > 1:					
						p_one = torch.mean(p_one,dim=1)
				if state_sum_num != 0:
					p_one = (p_one*(state_sum!= 0) * (state_sum!= 1)*1.0)
					entropy[key] -= torch.mean(p_one*torch.log2(torch.clamp(p_one, min=1e-5))+((1-p_one)*torch.log2(torch.clamp(1-p_one, min=1e-5))))
					entropy_per_neuron[key]  -= (p_one*torch.log2(torch.clamp(p_one, min=1e-5))+((1-p_one)*torch.log2(torch.clamp(1-p_one, min=1e-5))))
					neuron_Pon[key] += p_one					
				else:
					entropy[key] -= 0
					entropy_per_neuron[key] = entropy_per_neuron[key]
					neuron_Pon[key] = neuron_Pon[key]
					neuron_Pon[key] += p_one
	for key in hooks.keys():
		neurons_entropy[key] = 	entropy_per_neuron[key]	/ len(train_loader)
		layers_entropy[key] = entropy[key] / len(train_loader)
		ave_Pon_output[key] = neuron_Pon[key] / len(train_loader)

	test_loss=running_loss/len(train_loader)
	accu=100.*correct/total	

	print('Test Loss: %.3f | Accuracy: %.3f'%(test_loss,accu))
	return(accu, test_loss, layers_entropy, neurons_entropy, ave_Pon_output)





# Define training, validation, and test data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train') 
VALID_DIR = os.path.join(DATA_DIR, 'val')
TEST_DIR = os.path.join(DATA_DIR, 'test')

# Define transformation sequence for image pre-processing
transform_train = transforms.Compose([
									  transforms.RandomHorizontalFlip(),
									  transforms.ToTensor(),
									  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
transform_val = transforms.Compose([
									 transforms.ToTensor(),
									 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
transform_test = transforms.Compose([
									 transforms.ToTensor(),
									 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train_dataset = torchvision.datasets.ImageFolder(TRAIN_DIR, transform=transform_train)
train_loader = DataLoader(train_dataset,
						  batch_size=batch_size, 
						  shuffle=True, 
						  num_workers=8)
val_dataset = torchvision.datasets.ImageFolder(VALID_DIR, transform=transform_val)
val_loader = DataLoader(val_dataset,
						  batch_size=batch_size, 
						  shuffle=False, 
						  num_workers=8)
test_dataset = torchvision.datasets.ImageFolder(TEST_DIR, transform=transform_test)
test_loader = DataLoader(test_dataset,
						  batch_size=batch_size, 
						  shuffle=False, 
						  num_workers=8)



from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import ResNet

class BasicBlock_new(BasicBlock):
	def __init__(self, *args, **kwargs):
		super(BasicBlock_new, self).__init__(*args, **kwargs)
		## Instantiate two different ReLU function for the Hook.
		self.relu1 = nn.ReLU(inplace=False)
		self.relu2 = nn.ReLU(inplace=False)
		delattr(self, 'relu')

	def forward(self, x):
		identity = x

		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu1(out)

		out = self.conv2(out)
		out = self.bn2(out)

		if self.downsample is not None:
			identity = self.downsample(x)

		out += identity
		out = self.relu2(out)

		return out


class ResNet_new(ResNet):
	def __init__(self, *args, **kwargs):
		super(ResNet_new, self).__init__(*args, **kwargs)		
		
	def _replace_module(self, name, module):
		modules = name.split('.')
		curr_mod = self
		
		for mod_name in modules[:-1]:
			curr_mod = getattr(curr_mod, mod_name)
		
		setattr(curr_mod, modules[-1], module)

# define the model
model = ResNet_new(BasicBlock_new, [2, 2, 2, 2], num_classes=200)
model.to(device)

hooks = {}
for name, module in model.named_modules():
	if type(module) == torch.nn.ReLU:
		hooks[name] = Hook(module)


#vanilla training to obtain the dense model
name_of_run = 'iteration_0'
name_model = name_of_run

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
loss_fn=nn.CrossEntropyLoss().to(device)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
for epoch in range(1,epochs+1):
	train_acc, train_loss = train(model, epoch,optimizer)
	test_acc, test_loss = test(model, epoch)
	scheduler.step()   
	torch.save(model.state_dict(), './model/'+ name_model)  # path of saving model


#NEPENTHE pruning
for i in range(1, 10):
	name_of_run = 'iteration_'+str(i)
	name_model = name_of_run
	layers_to_prune=[]
	for key in hooks.keys():
		if key == 'relu':
			layers_to_prune.append('conv1')
		else:
			name = key.replace('relu', 'conv')
			layers_to_prune.append(name)

	test_entropy_acc, test_entropy_loss, layers_entropy, neurons_entropy, ave_Pon_output = test_entropy(model,hooks)  # calculate entropy


	# For layers with entropy equal to 0, do not prune them anymore, change its ReLU activation to Identity
	layers_to_replace = []
	for name, module in model.named_modules():
		if name in layers_entropy.keys():
			if layers_entropy[name] == 0:
				change_name= re.sub(r'\.(\d+)', r'[\1]', name)
				layers_to_replace.append(change_name)
				hooks[name].hook.remove()
				hooks.pop(name)				

	for name in layers_to_replace:  
		exec(f'model.{name} = nn.Identity()')


	# Link the calculated entropy number from the activation to its coresponding layer. 
	temp_layer = []
	for key in layers_entropy.keys():
		temp_layer.append(key)
	for key in temp_layer:
		if key == 'relu':
			layers_entropy['conv1'] = layers_entropy.pop("relu")      
			neurons_entropy['conv1'] = neurons_entropy.pop("relu") 
			ave_Pon_output['conv1'] = ave_Pon_output.pop("relu")     
		else:
			name = key.replace('relu', 'conv')
			layers_entropy[name] = layers_entropy.pop(key)
			neurons_entropy[name] = neurons_entropy.pop(key)
			ave_Pon_output[name] = ave_Pon_output.pop(key)


	# For layers with entropy equal to 0, prune its weight parameters whose output is always at 'OFF' state.
	for name, module in model.named_modules():
		if name in layers_entropy.keys():
			if layers_entropy[name] == 0:
				layer_mask = []
				for i in range(ave_Pon_output[name].size()[0]):
					if ave_Pon_output[name][i] !=0 :
						custom_mask = torch.ones(module.weight.data[i].size()).cpu().numpy()
						layer_mask.append(custom_mask) 
					else:
						custom_mask = torch.zeros(module.weight[i].size()).cpu().numpy()                                                     
						layer_mask.append(custom_mask) 
				layer_mask = torch.Tensor(layer_mask).to(device)
				torch.nn.utils.prune.custom_from_mask(module, name="weight", mask=layer_mask)


	#Remove the zero-entropy layer in the pruning list
	for key in layers_entropy.keys():
		if layers_entropy[key] == 0:
			layers_to_prune.remove(key)

	# calculate pruning irrelevance score for each layer
	layer_entro_magni = {}
	for name, module in model.named_modules():
		if name in layers_to_prune:
			if torch.numel(torch.abs(module.weight)[module.weight != 0])==0:
				layers_to_prune.remove(name)
			else:
				layer_entro_magni[name] = layers_entropy[name] * torch.mean(torch.abs(module.weight[module.weight!=0]))

 
	total_layers_entro_magni = 0	
	for key in layer_entro_magni.keys():
		total_layers_entro_magni += layer_entro_magni[key].item()

	entropy_magni_layer_head = {}
	entropy_layer_head_expo = {}

	# calculate total number of parameters to prune
	total_layers_weight_paras = 0
	for name, module in model.named_modules():
		if name in layers_to_prune:
			total_layers_weight_paras += torch.numel(module.weight[module.weight!=0])
	total_layers_weight_paras_to_prun = fixed_amount_of_pruning * total_layers_weight_paras


	# calculate pruning relevance score and define the entropy-weighted pruned parameter budget for each layer
	wait_distri_paras = copy.deepcopy(layers_to_prune)
	fix_prun_amount={}
	left_amount={}
	while True:
		amout_changed = False
		total_entropy_layer_head_expo  = 0
		entropy_magni_layer_head = {}
		for name, module in model.named_modules():
			if name in wait_distri_paras:
				entropy_magni_layer_head[name]  = total_layers_entro_magni/(layer_entro_magni[name])		
		max_value_entropy_magni_layer_head = max(entropy_magni_layer_head.values())

		for name, module in model.named_modules():
			if name in wait_distri_paras:
				entropy_layer_head_expo[name]  = torch.exp(entropy_magni_layer_head[name] - max_value_entropy_magni_layer_head).item()
				total_entropy_layer_head_expo += entropy_layer_head_expo[name]

		for name, module in model.named_modules():
			if name in wait_distri_paras:
				fix_prun_amount[name] = int(total_layers_weight_paras_to_prun * (entropy_layer_head_expo[name]/total_entropy_layer_head_expo))
				
		for name, module in model.named_modules():
			if name in wait_distri_paras:
				left_amount[name] = torch.numel(module.weight[module.weight!=0])
				if left_amount[name] < fix_prun_amount[name]:
					fix_prun_amount[name] = left_amount[name]
					total_layers_weight_paras_to_prun -= left_amount[name]
					total_layers_entro_magni -= layer_entro_magni[name]				
					wait_distri_paras.remove(name)
					amout_changed = True
		if not amout_changed:
			break

	# fuse conv layer and corresponding batchnorm layer. Pruning will be implemented base on the fused weight.
	fuseed_layer = {}
	bn_name = None
	for name, module in model.named_modules():
		if name in layers_to_prune:
			bn_name = name.replace("conv", "bn")
			conv_layer = module
			prun_name = name
		if name == bn_name:
			bn_layer = module
			fuseed_layer[prun_name] = fuse_conv_and_bn(conv_layer, bn_layer)

	# calculate the threshold for pruning inside each layer
	threshold = {}
	for name, module in model.named_modules():
		if name in layers_to_prune:	
			if fix_prun_amount[name] ==0:
				layers_to_prune.remove(name)
			else:
				layer_weights = torch.Tensor().to(device)
				for i in range(ave_Pon_output[name].size()[0]):
					if neurons_entropy[name][i] != 0:
						weight = fuseed_layer[name].weight.data[i].view(-1)
						layer_weights = torch.cat((layer_weights, weight), dim=0)
				all_positive_weights = torch.abs(layer_weights[layer_weights != 0])
				if all_positive_weights.size()[0] > fix_prun_amount[name]:
					threshold[name] = torch.topk(all_positive_weights, fix_prun_amount[name], largest=False)[0][-1]
				else:
					threshold[name] = 10000000000

	# implement pruning inside each layer
	for name, module in model.named_modules():
		if name in layers_to_prune:
			layer_mask = []
			for i in range(ave_Pon_output[name].size()[0]):
				if neurons_entropy[name][i] != 0:
					custom_mask = (torch.abs(fuseed_layer[name].weight.data[i])>=threshold[name]).float()
					custom_mask = custom_mask.cpu().numpy()
					layer_mask.append(custom_mask)
				else:
					custom_mask = torch.ones(fuseed_layer[name].weight[i].size()).cpu().numpy()                                                    
					layer_mask.append(custom_mask) 				
			layer_mask = torch.Tensor(layer_mask).to(device)
			torch.nn.utils.prune.custom_from_mask(module, name="weight", mask=layer_mask)

	# Training to recover the performance
	optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
	loss_fn=nn.CrossEntropyLoss().to(device)	
	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

	for epoch in range(1,epochs+1):
		train_acc, train_loss = train(model, epoch,optimizer)
		val_acc, val_loss = val(model, epoch)
		scheduler.step()

	torch.save(model.state_dict(), './model/'+ name_model)      # path of saving model