import torch
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import pdb


#Error classes for breaking forward pass of model
# define Python user-defined exceptions
class ModelBreak(Exception):
	"""Base class for other exceptions"""
	pass

class TargetReached(ModelBreak):
	"""Raised when the output target for a subgraph is reached, so the model doesnt neeed to be run forward any farther"""
	pass


class dissected_Conv2d(torch.nn.Module):       #2d conv Module class that has presum activation maps as intermediate output

	def __init__(self, from_conv,name=None,store_activations=False, store_ranks = False, clear_ranks=False, target_node=None,absolute_rank = True, rank_field = 'image', device='cuda'):      # from conv is normal nn.Conv2d object to pull weights and bias from
		super(dissected_Conv2d, self).__init__()
		#self.from_conv = from_conv
		self.name = name
		self.in_channels = from_conv.weight.shape[1]
		self.out_channels = from_conv.weight.shape[0]
		self.target_node = target_node
		self.device = device
		self.store_activations = store_activations
		self.store_ranks = store_ranks
		self.clear_ranks = clear_ranks
		self.rank_field = rank_field    #'image' means average over activation map, 'max' means rank with respect to maximum activation
		self.absolute_rank = absolute_rank

		self.edge_ablations = None
		self.node_ablations = None

		self.postbias_ranks = {'act':None,'grad':None,'actxgrad':None}
		self.preadd_ranks = {'act':None,'grad':None,'actxgrad':None}
		for rank_type in ['act','grad','actxgrad']:
			self.postbias_ranks[rank_type] = torch.FloatTensor(self.out_channels).zero_().to(device)
			self.preadd_ranks[rank_type] = torch.FloatTensor(self.out_channels*self.in_channels).zero_().to(device)
			#if self.cuda:
			#	self.postbias_ranks[rank_type] = self.postbias_ranks[rank_type].cuda()
			#	self.preadd_ranks[rank_type] = self.preadd_ranks[rank_type].cuda()

		self.normalizations = {'nodes':{			
							   'act':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'grad':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'actxgrad':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'weight':{'mean':None,'std':None,'l2':None,'l1':None,'max':None}
							   },
							   'edges':{			
							   'act':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'grad':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'actxgrad':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'weight':{'mean':None,'std':None,'l2':None,'l1':None,'max':None}
							   }}

		self.images_seen = 0
		self.weight_perm,self.add_perm,self.add_indices = self.gen_inout_permutation()
		self.preadd_conv = self.make_preadd_conv(from_conv)
		self.bias = None
		if from_conv.bias is not None:
			self.bias = nn.Parameter(from_conv.bias.unsqueeze(1).unsqueeze(1))
			#if self.cuda:
			#	self.bias = self.bias.cuda()
		#generate a dict that says which indices should be added together in for 'permute_add_featuremaps'

		self.preadd_ranks['weight'],self.postbias_ranks['weight'] = self.gen_weight_ranks()
		if self.store_ranks:
			self.preadd_out_hook = None
			self.postbias_out_hook = None	


	def gen_inout_permutation(self):
		'''
		When we flatten out all the output channels not to be grouped by 'output channel', we still want the outputs sorted
		such that they can be conveniently added based on input channel later
		'''
		in_chan = self.in_channels
		out_chan = self.out_channels
		
		weight_perm = []
		for i in range(in_chan):
			for j in range(out_chan):
				weight_perm.append(i+j*in_chan)
		
		add_perm = []
		add_indices = {}
		for o in range(out_chan):
			add_indices[o] = []
			for i in range(in_chan):
				add_perm.append(o+i*out_chan)
				add_indices[o].append(o+i*out_chan)
		return torch.LongTensor(weight_perm),torch.LongTensor(add_perm),add_indices


	def make_preadd_conv(self,from_conv):
		'''
		nn.Conv2d takes in 'in_channel' number of feature maps, and outputs 'out_channel' number of maps. 
		internally it has in_channel*out_channel number of 2d conv kernels. Normally, featuremaps associated 
		with a particular output channel resultant from these kernel convolution are all added together,
		this function changes a nn.Conv2d module into a module where this final addition doesnt happen. 
		The final addition can be performed seperately with permute_add_feature_maps.
		'''
		in_chan = self.in_channels
		out_chan = self.out_channels
		
		kernel_size = from_conv.kernel_size
		padding = from_conv.padding
		stride = from_conv.stride
		new_conv = nn.Conv2d(in_chan,in_chan*out_chan,kernel_size = kernel_size,
							 bias = False, padding=padding,stride=stride,groups= in_chan)
		new_conv.weight = torch.nn.parameter.Parameter(
				from_conv.weight.view(in_chan*out_chan,1,kernel_size[0],kernel_size[1])[self.weight_perm])
		return new_conv

		
	def permute_add_featuremaps(self,feature_map):
		'''
		Perform the sum within output channels step.  (THIS NEEDS TO BE SPEED OPTIMIZED)
		'''
		x = feature_map
		x = x[:, self.add_perm, :, :]
		x = torch.split(x.unsqueeze(dim=1),self.in_channels,dim = 2)
		x = torch.cat(x,dim = 1)
		x = torch.sum(x,dim=2)
		return x
	
	def gen_weight_ranks(self):
		weight_ranks_flat = torch.abs(self.preadd_conv.weight).mean(dim=(2,3)).data.squeeze(1)
		edge_weight_ranks = []
		for o in self.add_indices:
			in_chans = []
			for i in self.add_indices[o]:
				in_chans.append(weight_ranks_flat[i])
			edge_weight_ranks.append(in_chans)
		edge_weight_ranks = torch.tensor(edge_weight_ranks)
		node_weight_ranks = edge_weight_ranks.mean(dim=1)
		return weight_ranks_flat, node_weight_ranks

	def compute_edge_rank(self,grad):
		activation = self.preadd_out
		#activation_relu = F.relu(activation)
		taylor = activation * grad 
		if self.absolute_rank:
			rank_key  = {'act':torch.abs(activation),'grad':torch.abs(grad),'actxgrad':torch.abs(taylor)}
		else:
			rank_key  = {'act':activation,'grad':grad,'actxgrad':taylor}

		for key in rank_key:
			if self.preadd_ranks[key] is None: #initialize at 0
				self.preadd_ranks[key] = torch.FloatTensor(activation.size(1)).zero_().to(device)
				#if self.cuda:
				#	self.preadd_ranks[key] = self.preadd_ranks[key].cuda()
			map_mean = rank_key[key].mean(dim=(2, 3)).data
			mean_sum = map_mean.sum(dim=0).data      
			self.preadd_ranks[key] += mean_sum    # we sum up the mean activations over all images, after all batches

		#print('length edge_rank: ', len(self.preadd_ranks_prenorm['actxgrad']))
		#print('length outdimxindim: ', self.out_channels*self.in_channels)
			#have passed through we will average by the number of images seen with self.average_ranks
		#print('edge_rank time: %s'%str(time.time() - start))



	def compute_node_rank(self,grad):
		activation = self.postbias_out
		activation_relu = F.relu(activation)
		taylor = activation * grad
		if self.absolute_rank:
			rank_key  = {'act':torch.abs(activation),'grad':torch.abs(grad),'actxgrad':taylor}
		else:
			rank_key  = {'act':activation,'grad':grad,'actxgrad':taylor}

		for key in rank_key:
			if self.postbias_ranks[key] is None: #initialize at 0
				self.postbias_ranks[key] = torch.FloatTensor(activation.size(1)).zero_().to(self.device)
				#if self.cuda:
				#	self.postbias_ranks[key] = self.postbias_ranks[key].cuda()
			map_mean = rank_key[key].mean(dim=(2, 3)).data
			mean_sum = map_mean.sum(dim=0).data      
			self.postbias_ranks[key] += mean_sum    # we sum up the mean activations over all images, after all batches
			#have passed through we will average by the number of images seen with self.average_ranks
		#print('length node_rank: ', len(self.postbias_ranks_prenorm['actxgrad']))
		#print('length outdim: ', self.out_channels)
		#print('node_rank time: %s'%str(time.time() - start))



	def format_edges(self, data= 'activations',rank_type='actxgrad',weight_rank=False):
		#fetch preadd activations as [img,out_channel, in_channel,h,w]
		#fetch preadd ranks as [out_chan,in_chan]
		if weight_rank:
			rank_types = ['weight']
		else:
			rank_types = ['act','grad','actxgrad']

		if not self.store_activations:
			print('activations arent stored, use "store_activations=True" on model init. returning None')
			return None
		out_acts_list = []
		if data == 'activations':
			for out_chan in self.add_indices:
				in_acts_list = []
				for in_chan in self.add_indices[out_chan]:
					in_acts_list.append(self.preadd_out[:,in_chan,:,:].unsqueeze(dim=1).unsqueeze(dim=1))                    
				out_acts_list.append(torch.cat(in_acts_list,dim=2))
			return torch.cat(out_acts_list,dim=1).cpu().detach().numpy().astype('float32')

		else:
			out_acts_list = []
			for out_chan in self.add_indices:
				in_acts_list = []
				for in_chan in self.add_indices[out_chan]:
					in_acts_list.append(self.preadd_ranks[rank_type][in_chan].unsqueeze(dim=0).unsqueeze(dim=0))                 
				out_acts_list.append(torch.cat(in_acts_list,dim=1))
			output = torch.cat(out_acts_list,dim=0).cpu().detach().numpy().astype('float32')
			return output
						
	def average_ranks(self):
		for rank_type in ['act','grad','actxgrad']:
			if self.images_seen > 0:
				self.preadd_ranks[rank_type] = self.preadd_ranks[rank_type]/self.images_seen
				self.postbias_ranks[rank_type] = self.postbias_ranks[rank_type]/self.images_seen

	def abs_ranks(self):
		for rank_type in ['act','grad','actxgrad']:
			self.preadd_ranks[rank_type] = torch.abs(self.preadd_ranks[rank_type])
			self.postbias_ranks[rank_type] = torch.abs(self.postbias_ranks[rank_type])       

	def gen_normalizations(self,rank_type):
		if self.images_seen > 0 or rank_type == 'weight':
			e = torch.abs(self.preadd_ranks[rank_type])
			n = torch.abs(self.postbias_ranks[rank_type])
			#std
			self.normalizations['nodes'][rank_type]['std'] = float(torch.std(n))
			self.normalizations['edges'][rank_type]['std'] = float(torch.std(e))
			#mean
			self.normalizations['nodes'][rank_type]['mean'] = float(torch.mean(n))
			self.normalizations['edges'][rank_type]['mean'] = float(torch.mean(e))
			#max
			self.normalizations['nodes'][rank_type]['max'] = float(torch.max(n))
			self.normalizations['edges'][rank_type]['max'] = float(torch.max(e))
			#l1
			self.normalizations['nodes'][rank_type]['l1'] = float(torch.sum(n))
			self.normalizations['edges'][rank_type]['l1'] = float(torch.sum(e))
			#l2
			self.normalizations['nodes'][rank_type]['l2'] = float(np.sqrt(torch.sum(n * n)))
			self.normalizations['edges'][rank_type]['l2'] = float(np.sqrt(torch.sum(e * e)))

		#self.preadd_ranks_prenorm['weight'] = self.preadd_ranks_prenorm['weight'].cpu()
		#self.postbias_ranks_prenorm['weight'] = self.postbias_ranks_prenorm['weight'].cpu()
		#self.preadd_ranks['weight'] = torch.abs(self.preadd_ranks_prenorm['weight'] )/np.sqrt(torch.sum(self.preadd_ranks_prenorm['weight'] *self.preadd_ranks_prenorm['weight'] ))
		#self.postbias_ranks['weight'] = torch.abs(self.postbias_ranks_prenorm['weight'])/np.sqrt(torch.sum(self.postbias_ranks_prenorm['weight']*self.postbias_ranks_prenorm['weight']))

	def clear_ranks_func(self): #clear ranks, info that otherwise accumulates with images
		self.images_seen = 0
		for rank_type in ['act','grad','actxgrad']:
			self.postbias_ranks[rank_type] = torch.FloatTensor(self.out_channels).zero_().to(self.device)
			self.preadd_ranks[rank_type] = torch.FloatTensor(self.out_channels*self.in_channels).zero_().to(self.device)
			#if self.cuda:
			#	self.postbias_ranks[rank_type] = self.postbias_ranks[rank_type].cuda()
			#	self.preadd_ranks[rank_type] = self.preadd_ranks[rank_type].cuda()

	def forward(self, x):
		
		if self.clear_ranks:
			self.clear_ranks_func()

		self.images_seen += x.shape[0]    #keep track of how many images weve seen so we know what to divide by when we average ranks
		if self.store_activations:
			self.input = x

		preadd_out = self.preadd_conv(x)  #get output of convolutions

		#set ablated edges to 0
		if (self.edge_ablations is not None) and (self.edge_ablations != []):
			preadd_out.index_fill_(1,torch.tensor(self.edge_ablations).to(self.device),0)

		#store values of intermediate outputs after convolution
		if self.store_activations:
			self.preadd_out = preadd_out
 
		#Set hooks for calculating rank on backward pass
		if self.store_ranks:
			self.preadd_out = preadd_out
			if self.preadd_out_hook is not None:
				self.preadd_out_hook.remove()
			self.preadd_out_hook = self.preadd_out.register_hook(self.compute_edge_rank)
			#if self.preadd_ranks is not None:
			#    print(self.preadd_ranks.shape)

		added_out = self.permute_add_featuremaps(preadd_out)    #add convolution outputs by output channel
		if self.bias is not None:  
			postbias_out = added_out + self.bias
		else:
			postbias_out = added_out

		#set ablated nodes to 0
		if (self.node_ablations is not None) and (self.node_ablations != []):
			postbias_out.index_fill_(1,torch.tensor(self.node_ablations).to(self.device),0)

		#Store values of final module output
		if self.store_activations:
			self.postbias_out = postbias_out
 
		#Set hooks for calculating rank on backward pass
		if self.store_ranks:
			self.postbias_out = postbias_out
			if self.postbias_out_hook is not None:
				self.postbias_out_hook.remove()
			self.postbias_out_hook = self.postbias_out.register_hook(self.compute_node_rank)
			#if self.postbias_ranks is not None:
			#    print(self.postbias_ranks.shape)		

		if self.target_node is not None:
			#print('target reached, breaking model forward pass in %s'%self.name)
			#print(self.target_node)
			if self.rank_field == 'image':
				avg_activations = self.postbias_out.mean(dim=(0, 2, 3))
				optim_target = avg_activations[self.target_node]
			elif self.rank_field == 'max':
				max_acts = self.postbias_out.view(self.postbias_out.size(0),self.postbias_out.size(1), self.postbias_out.size(2)*self.postbias_out.size(3)).max(dim=-1).values
				max_acts_target = max_acts[:,self.target_node]
				optim_target = max_acts_target.mean()
			elif self.rank_field == 'min':
				min_acts = self.postbias_out.view(self.postbias_out.size(0),self.postbias_out.size(1), self.postbias_out.size(2)*self.postbias_out.size(3)).min(dim=-1).values
				min_acts_target = min_acts[:,self.target_node]
				optim_target = min_acts_target.mean()
			elif isinstance(self.rank_field,list):
				if isinstance(self.rank_field[0],list):
					# we have a list of target positions, a different target position for each image in the batch
					act_targets_sum = torch.FloatTensor(1).zero_().to(self.device) 
					for i in range(len(self.rank_field)):
						act_targets_sum += self.postbias_out[i,self.target_node,int(self.rank_field[i][0]),int(self.rank_field[i][1])]		
				else:
					act_targets_sum = self.postbias_out[:,self.target_node,int(self.rank_field[0]),int(self.rank_field[1])].sum()
				optim_target = act_targets_sum/self.postbias_out.shape[0]
				#raise Exception('List type rank field not yet implemented, use "min", "max",or "image" as the rank field')
				#target_acts = 
				#optim_target = target_acts.mean()
			#print(optim_target)
			optim_target.backward()
			raise TargetReached
			
		return postbias_out






class hooked_Conv2d(torch.nn.Module):       #2d conv Module class that has presum activation maps as intermediate output

	def __init__(self, from_conv,name=None,store_activations=False, store_ranks = False, clear_ranks=False, target_node=None,absolute_rank = True, rank_field = 'image', device='cuda'):      # from conv is normal nn.Conv2d object to pull weights and bias from
		super(hooked_Conv2d, self).__init__()
		self.from_conv = from_conv
		self.name = name
		self.in_channels = from_conv.weight.shape[1]
		self.out_channels = from_conv.weight.shape[0]
		self.target_node = target_node
		self.device = device
		self.store_activations = store_activations
		self.store_ranks = store_ranks
		self.clear_ranks = clear_ranks
		self.rank_field = rank_field    #'image' means average over activation map, 'max' means rank with respect to maximum activation
		self.absolute_rank = absolute_rank

		self.edge_ablations = None
		self.node_ablations = None

		self.postbias_ranks = {'act':None,'grad':None,'actxgrad':None}
		for rank_type in ['act','grad','actxgrad']:
			self.postbias_ranks[rank_type] = torch.FloatTensor(self.out_channels).zero_().to(device)
			#if self.cuda:
			#	self.postbias_ranks[rank_type] = self.postbias_ranks[rank_type].cuda()
			#	self.preadd_ranks[rank_type] = self.preadd_ranks[rank_type].cuda()

		self.normalizations = {'nodes':{			
							   'act':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'grad':{'mean':None,'std':None,'l2':None,'l1':None,'max':None},
							   'actxgrad':{'mean':None,'std':None,'l2':None,'l1':None,'max':None}
							   }}

		self.images_seen = 0

		if self.store_ranks:
			self.postbias_out_hook = None	

	
	def compute_node_rank(self,grad):
		activation = self.postbias_out
		activation_relu = F.relu(activation)
		taylor = activation * grad
		if self.absolute_rank:
			rank_key  = {'act':torch.abs(activation),'grad':torch.abs(grad),'actxgrad':taylor}
		else:
			rank_key  = {'act':activation,'grad':grad,'actxgrad':taylor}

		for key in rank_key:
			if self.postbias_ranks[key] is None: #initialize at 0
				self.postbias_ranks[key] = torch.FloatTensor(activation.size(1)).zero_().to(self.device)
				#if self.cuda:
				#	self.postbias_ranks[key] = self.postbias_ranks[key].cuda()
			map_mean = rank_key[key].mean(dim=(2, 3)).data
			mean_sum = map_mean.sum(dim=0).data      
			self.postbias_ranks[key] += mean_sum    # we sum up the mean activations over all images, after all batches
			#have passed through we will average by the number of images seen with self.average_ranks
		#print('length node_rank: ', len(self.postbias_ranks_prenorm['actxgrad']))
		#print('length outdim: ', self.out_channels)
		#print('node_rank time: %s'%str(time.time() - start))
						
	def average_ranks(self):
		for rank_type in ['act','grad','actxgrad']:
			if self.images_seen > 0:
				self.postbias_ranks[rank_type] = self.postbias_ranks[rank_type]/self.images_seen

	def abs_ranks(self):
		for rank_type in ['act','grad','actxgrad']:
			self.postbias_ranks[rank_type] = torch.abs(self.postbias_ranks[rank_type])       

	def clear_ranks_func(self): #clear ranks, info that otherwise accumulates with images
		self.images_seen = 0
		for rank_type in ['act','grad','actxgrad']:
			self.postbias_ranks[rank_type] = torch.FloatTensor(self.out_channels).zero_().to(self.device)

	def forward(self, x):
		
		if self.clear_ranks:
			self.clear_ranks_func()

		self.images_seen += x.shape[0]    #keep track of how many images weve seen so we know what to divide by when we average ranks
		if self.store_activations:
			self.input = x


		postbias_out = self.from_conv(x)
		#set ablated nodes to 0
		if (self.node_ablations is not None) and (self.node_ablations != []):
			postbias_out.index_fill_(1,torch.tensor(self.node_ablations).to(self.device),0)

		#Store values of final module output
		if self.store_activations:
			self.postbias_out = postbias_out
 
		#Set hooks for calculating rank on backward pass
		if self.store_ranks:
			self.postbias_out = postbias_out
			if self.postbias_out_hook is not None:
				self.postbias_out_hook.remove()
			self.postbias_out_hook = self.postbias_out.register_hook(self.compute_node_rank)
			#if self.postbias_ranks is not None:
			#    print(self.postbias_ranks.shape)		

		if self.target_node is not None:
			#print('target reached, breaking model forward pass in %s'%self.name)
			#print(self.target_node)
			if self.rank_field == 'image':
				avg_activations = self.postbias_out.mean(dim=(0, 2, 3))
				optim_target = avg_activations[self.target_node]
			elif self.rank_field == 'max':
				max_acts = self.postbias_out.view(self.postbias_out.size(0),self.postbias_out.size(1), self.postbias_out.size(2)*self.postbias_out.size(3)).max(dim=-1).values
				max_acts_target = max_acts[:,self.target_node]
				optim_target = max_acts_target.mean()
			elif self.rank_field == 'min':
				min_acts = self.postbias_out.view(self.postbias_out.size(0),self.postbias_out.size(1), self.postbias_out.size(2)*self.postbias_out.size(3)).min(dim=-1).values
				min_acts_target = min_acts[:,self.target_node]
				optim_target = min_acts_target.mean()
			elif isinstance(self.rank_field,list):
				if isinstance(self.rank_field[0],list):
					# we have a list of target positions, a different target position for each image in the batch
					act_targets_sum = torch.FloatTensor(1).zero_().to(self.device) 
					for i in range(len(self.rank_field)):
						act_targets_sum += self.postbias_out[i,self.target_node,int(self.rank_field[i][0]),int(self.rank_field[i][1])]	
				else:
					act_targets_sum = self.postbias_out[:,self.target_node,int(self.rank_field[0]),int(self.rank_field[1])].sum()
				optim_target = act_targets_sum/self.postbias_out.shape[0]
				#raise Exception('List type rank field not yet implemented, use "min", "max",or "image" as the rank field')
				#target_acts = 
				#optim_target = target_acts.mean()
			#print(optim_target)
			optim_target.backward()
			raise TargetReached
			
		return postbias_out



### MODEL LEVEL FUNCTIONS ###

'''
These functions all deal with dissected_Conv2d modules across an entire model
'''

# takes a full model and replaces all conv2d instances with dissected conv 2d instances
def dissect_model(model,mod_names = [],store_activations=True,store_ranks=True,clear_ranks = False,rank_field = 'image',dissect=True,device='cuda:0'):

	for name, module in reversed(model._modules.items()):
		if len(list(module.children())) > 0:
			mod_names.append(str(name))
			# recurse
			model._modules[name] = dissect_model(module,mod_names =mod_names, store_activations=store_activations,store_ranks=store_ranks,rank_field=rank_field,clear_ranks=clear_ranks,dissect=dissect,device=device)
			mod_names.pop()

		if isinstance(module, torch.nn.modules.conv.Conv2d):    # found a 2d conv module to transform
			if dissect:
				new_module = dissected_Conv2d(module, name='_'.join(mod_names+[name]), store_activations=store_activations,store_ranks=store_ranks,rank_field=rank_field,clear_ranks=clear_ranks,device=device) 
			else:
				new_module = hooked_Conv2d(module, name='_'.join(mod_names+[name]),rank_field=rank_field, store_activations=store_activations,store_ranks=store_ranks,clear_ranks=clear_ranks,device=device) 
			model._modules[name] = new_module

		elif isinstance(module, torch.nn.modules.Dropout):    #make dropout layers not dropout  #also set batchnorm to eval
			model._modules[name].eval() 
		else:    #make activation functions not 'inplace'
			model._modules[name].inplace=False                    

	return model




def set_model_target_node(model,target_layer,within_layer_id,layer=0):
 
	for name, module in model._modules.items():
		if len(list(module.children())) > 0:
			# recurse
			model._modules[name] = set_model_target_node(module,target_layer,within_layer_id,layer)

		if isinstance(module, dissected_Conv2d) or isinstance(module, hooked_Conv2d):    # found a 2d conv module to transform
			if layer==target_layer or module.name ==target_layer:
				module.target_node = within_layer_id
				break
			layer+=1             

	return model


def get_optim_target_from_model(model,target_layer,layer=0,optim_target=None):
 
	for name, module in model._modules.items():
		if len(list(module.children())) > 0:
			# recurse
			optim_target = get_optim_target_from_model(module,target_layer,layer,optim_target)

		if isinstance(module, dissected_Conv2d) or isinstance(module, hooked_Conv2d):    # found a 2d conv module to transform
			if layer==target_layer or module.name ==target_layer:

				return module.optim_target

			layer+=1             
	return optim_target





def set_across_model(model,setting,value):
 
	for name, module in model._modules.items():
		if len(list(module.children())) > 0:
			# recurse
			model._modules[name] = set_across_model(module,setting,value)

		if isinstance(module, dissected_Conv2d) or isinstance(module, hooked_Conv2d):    # found a 2d conv module to transform
			if setting == 'target_node':
				module.target_node = value
			elif setting == 'clear_ranks':
				module.clear_ranks=value
			elif setting == 'rank_field':
				module.rank_field = value
			elif setting == 'store_activations':
				module.store_activations = value
			elif setting == 'store_ranks':
			 	module.store_ranks = value
			elif setting == 'clear_ablations':
				module.edge_ablations = []
				module.node_ablations = []
			elif setting == 'rank_field':
				module.rank_field = value
			elif setting == 'absolute_rank':
				module.absolute_rank = value
			else:
				print('Error! setting "%s" is not valid'%str(setting))


	return model


def clear_ranks_across_model(model):
 
	for name, module in model._modules.items():
		if len(list(module.children())) > 0:# recurse
			model._modules[name] = clear_ranks_across_model(module)

		if isinstance(module, dissected_Conv2d):    # found a 2d conv module to transform
			module.clear_ranks_func()
			 

	return model


def layer_2_dissected_conv2d(target_layer,module, index=0, found=None):
	for layer, (name, submodule) in enumerate(module._modules.items()):
		if isinstance(submodule, dissected_Conv2d):
			if index==target_layer:
				found = submodule
			index+=1
		elif len(list(submodule.children())) > 0:
			found, index = layer_2_dissected_conv2d(target_layer,submodule, index=index, found=found)
	return found, index

def get_activations_from_dissected_Conv2d_modules(module,layer_activations=None):     
	if layer_activations is None:    #initialize the output dictionary if we are not recursing and havent done so yet
		layer_activations = {'nodes':[],'edges_in':[],'edges_out':[]}
	for layer, (name, submodule) in enumerate(module._modules.items()):
		#print(submodule)
		if isinstance(submodule, dissected_Conv2d) or isinstance(submodule, hooked_Conv2d):
			layer_activations['nodes'].append(submodule.postbias_out.cpu().detach().numpy())
			layer_activations['edges_in'].append(submodule.input.cpu().detach().numpy())
		if isinstance(submodule, dissected_Conv2d):
			layer_activations['edges_out'].append(submodule.format_edges(data= 'activations'))
			#print(layer_activations['edges_out'][-1].shape)
		elif len(list(submodule.children())) > 0:
			layer_activations = get_activations_from_dissected_Conv2d_modules(submodule,layer_activations=layer_activations)   #module has modules inside it, so recurse on this module

	return layer_activations


def get_ranks_from_dissected_Conv2d_modules(module,layer_ranks=None,weight_rank=False):     #run through all model modules recursively, and pull the ranks stored in dissected_Conv2d modules 
	if layer_ranks is None:    #initialize the output dictionary if we are not recursing and havent done so yet
		if weight_rank:
			layer_ranks = {'nodes':{'weight':[]},'edges':{'weight':[]}}
		else:
			layer_ranks = {'nodes':{'act':[],'grad':[],'actxgrad':[]},
						   'edges':{'act':[],'grad':[],'actxgrad':[]}}

	for layer, (name, submodule) in enumerate(module._modules.items()):
		#print(submodule)
		if isinstance(submodule, dissected_Conv2d) or isinstance(submodule, hooked_Conv2d):
			if submodule.absolute_rank:
				submodule.abs_ranks()
			submodule.average_ranks()
			if weight_rank:
				rank_types = ['weight']
			else:
				rank_types = ['act','grad','actxgrad']

			for rank_type in rank_types:
				#submodule.gen_normalizations(rank_type)
				layer_ranks['nodes'][rank_type].append([submodule.name, submodule.postbias_ranks[rank_type].cpu().detach().numpy().astype('float32')])
				if isinstance(submodule, dissected_Conv2d):
					layer_ranks['edges'][rank_type].append([submodule.name, submodule.format_edges(data= 'ranks',rank_type=rank_type,weight_rank=weight_rank)])
		elif len(list(submodule.children())) > 0:
			layer_ranks = get_ranks_from_dissected_Conv2d_modules(submodule,layer_ranks=layer_ranks,weight_rank=weight_rank)   #module has modules inside it, so recurse on this module
	return layer_ranks


def get_ranklist_from_dissected_Conv2d_modules(model,structure='edges',method='actxgrad'):
	if structure== 'kernels': structure='edges'
	if structure== 'filters': structure='nodes'

	full_ranks = get_ranks_from_dissected_Conv2d_modules(model)

	rank_list = []
	for l in range(len(full_ranks[structure][method])):
		if structure == 'edges':
			if len(full_ranks[structure][method][l][1].nonzero()[1])>0:
				rank_list.append(torch.tensor(full_ranks[structure][method][l][1]).to('cpu'))
		else:
			if len(full_ranks[structure][method][l][1].nonzero()[0])>0:
				rank_list.append(torch.tensor(full_ranks[structure][method][l][1]).to('cpu'))




	return rank_list