
import numpy as	np
import time
import sys
import matplotlib.pyplot as	plt
import os
import torch

#--------------------------------#
from FwdBwdNeuralEqV3 import *
from Tx	import *
from eq	import *
from neuralEQ import *
from utils import *
import device



if __name__	== "__main__":
	#*************************HEADER***********************#
	startTime =	time.time()
	np.random.seed(1)
	args = parsing_def()
	sys.path.insert(0, './config')
	config_module =	__import__('config_{}'.format(args.config))
	cfg= config_module.config


	simName = args.name
	#******************************************************#


	row = 2
	col = 12
	sumRedundantNeuronArr = np.zeros((row,col,2))	# only for 2 stacked-nnUnit
	for idx, modelFile in enumerate(cfg['pruneAnalysis']['modelFileList']):
		nEqLoad = torch.load(modelFile)
		nEqLoad = nEqLoad.to(device.device)

		print(f"Modelfile: {modelFile}")
		for k in [1]:
			for i in range(col):
				for j in range(len(nEqLoad.nnUnit[k][k])):
					if isinstance(nEqLoad.nnUnit[k][i][j], torch.nn.Linear):
						sumRedundantNeuron = 0
						for m in range(len(nEqLoad.nnUnit[k][i][j].weight_mask)):
							maskSum = torch.sum(nEqLoad.nnUnit[k][i][j].weight_mask[m])
							if (maskSum<=cfg['pruneAnalysis']['threshold']):
								sumRedundantNeuron = sumRedundantNeuron+1
								#print(f"nEqLoad.nnUnit[{k}][{i}][{j}]")
								#print(f"{nEqLoad.nnUnit[k][i][j]}", end=" ")
								#print(f"mask_sum: {torch.sum(nEqLoad.nnUnit[k][i][j].weight_mask[m])} / {len(nEqLoad.nnUnit[k][i][j].weight_mask[m])}")
						print(f"Redundant neurons in nEqLoad.nnUnit[{k}][{i}][{j}]", end=": ")
						print(f"{sumRedundantNeuron}")
						sumRedundantNeuronArr[k][i][int(j/2)] += sumRedundantNeuron
						
						#print(f"mask: {(nEqLoad.nnUnit[k][i][j].weight_mask[m])}")
	print(f"sumRedundantNeuronArray[1]: \n{ sumRedundantNeuronArr[1]/ (idx+1) }")
