import numpy as np
from utilityScript import *
from SparseCentroidencoder import SparseCE
import pickle
import socket
import matplotlib.pyplot as plt
import datetime

def orderFeatureByFrequency(dataSetName,nFeatures):
	import glob
	path = './RawFeatures/'+dataSetName+'/*.p'
	fSet = []
	for f in glob.glob(path):
		#pdb.set_trace()
		D = pickle.load(open(f,'rb'))
		for k in D['features'].keys():
			fSet.append(D['features'][k])
		
	fea,fCnt = topNFeaturePruning(fSet,nFeatures)
	fileName = './FeatureSet/'+dataSetName+'/'+dataSetName+'_features.p'
	#pdb.set_trace()
	pickle.dump(fea,open(fileName,'wb'))
	print('File written in',fileName)

def partitionData(data,label,noFold=5):
	#This function will partition the data into k-folds stratified sampling.
	#It's an absolute requirement to make the labels numeric starting from 0.
	#label should be an [m x 1] array
	dataSegment = {}
	classData = {}
	labeledData = np.hstack((data,label))
	np.random.shuffle(labeledData)
	label = labeledData[:,-1]
	noClass = np.unique(label)
	#pdb.set_trace()
	for c in range(len(noClass)):
		classKey = 'class'+str(c)
		classData[classKey] = labeledData[np.where(label==c)[0],:]
	for fold in range(noFold):
		foldKey = 'fold'+str(fold+1)
		foldData = []
		for c in range(len(noClass)):
			classKey = 'class'+str(c)
			noSample = int(len(classData[classKey])/noFold)
			sIndex = fold*noSample
			eIndex = sIndex + noSample
			if fold+1 == noFold:
				foldData.append(classData[classKey][sIndex:,:])
			else:
				foldData.append(classData[classKey][sIndex:eIndex,:])
		dataSegment[foldKey] = np.vstack((foldData))
	return dataSegment
	
def k_foldDataPartition(D,L,noFold=5,rep=10):
	#This function will partition the data into k-folds stratified sampling.
	#It's an absolute requirement to make the labels numeric starting from 0.
	#label: [m x 1] arrayDocuments/PhD/ICLR2022_SCE/SourceCode
	#data: [m x p] matrix
	#noFold: no. of partition of data
	#rep: how many no. of times the data will be partitioned.

	dataSegment = []
	for r in range(rep):
		labeledData = np.hstack((D,L))
		np.random.shuffle(labeledData)
		label = labeledData[:,-1]
		classData = {}
		#pdb.set_trace()
		noClass = np.unique(label)
		#pdb.set_trace()
		for c in range(len(noClass)):
			classKey = 'class'+str(c)
			classData[classKey] = labeledData[np.where(label==c)[0],:]
		for fold in range(noFold):
			partitionKey = 'Pass'+str(r+1)+'_Fold'+str(fold+1)
			foldData = []
			for c in range(len(noClass)):
				classKey = 'class'+str(c)
				noSample = round(len(classData[classKey])/noFold)
				sIndex = fold*noSample
				eIndex = sIndex + noSample
				if fold+1 == noFold:
					foldData.append(classData[classKey][sIndex:,:])
				else:
					foldData.append(classData[classKey][sIndex:eIndex,:])
			tmpData = {}
			tmpData [partitionKey] = np.vstack((foldData))
			dataSegment.append(tmpData)
	#pdb.set_trace()
	return dataSegment
	
def splitRatioDataPartition(labeledData,splitRatio,rep=5):
	D = []
	
	for r in range(rep):
		#for ratio in splitRatio:
		for fold in range(len(splitRatio)):
			partitionKey = 'Pass'+str(r+1)+'_Fold'+str(fold+1)
			if splitRatio[fold] == 1:
				tmpData = {}
				tmpData [partitionKey] = labeledData
				D.append(tmpData)
				#D.append(labeledData)
			else:
				for k in range(rep):
					set1,_ = splitData(labeledData,splitRatio[fold])
					#D.append(set1)
					tmpData = {}
					tmpData [partitionKey] = set1
					D.append(tmpData)
	return D


def FeatureSelection_IP(dataSetName):
	
	beginTimeStamp = datetime.datetime.now()
	for k in range(2000):
		print('Iterative Feature Selection: Pass:',k+1)
		#load data
		nSamplePerClass = 0 #nSamplePerClass=0 means return all the samples
		X_tr,Y_tr,X_tst,Y_tst = getApplicationData(dataSetName,nSamplePerClass)
		
		#remove the test samples from the training set
		valIndices = pickle.load(open(dataSetName+'_Test_Indices.p','rb'))
		orgData = np.delete(X_tr,valIndices,axis=0)
		orgLabels = np.delete(Y_tr,valIndices,axis=0)
		nFolds = 1
		nRep = 1
		machine = socket.gethostname().split('.')[0]
		datetimeStamp = str(datetime.datetime.now())
		datetimeStamp = datetimeStamp.split(' ')[0]+datetimeStamp.split(' ')[1]
		datetimeStamp = datetimeStamp.replace(':','-').replace('.','-')
		fSetName = './RawFeatures/IndianPine/'+machine+'_'+dataSetName+'_features_partition_'+str(nFolds)+'_timestamp_'+datetimeStamp+'.p'
		nFeatures = np.shape(X_tr)[1]
		initialFIndices = np.arange(nFeatures)
		featureSet = {}
		featureSet['features'] = {}
		featureSet['featureWeights'] = {}
		print('Raw features will be written in',fSetName)
		
		featureIndexInFold = np.array([]) # variable to store feature indices in each fold.
		featureWeightsInFold = np.array([])

		#partition data into k folds
		dataPartition = partitionData(orgData,orgLabels,nFolds)
		#pdb.set_trace()
		for fold in range(nFolds):
			trSet = []
			for cnt in range(nFolds):
				if nFolds == 1:
					trSet = dataPartition['fold'+str(cnt+1)] 
				elif cnt != fold:
					trSet.append(dataPartition['fold'+str(cnt+1)])
			trSet = np.vstack((trSet))
			trData,trLabels = trSet[:,:-1],trSet[:,-1]
			initialFIndices = np.arange(nFeatures)
			featureIndex = []
			featureWeight = []
			fCnt = 0
			maxFCnt = 75
			flag = True
			itr = 0
			while fCnt < maxFCnt:

				dict2 = {}
				dict2['inputL'] = np.shape(trData)[1]
				dict2['outputL'] = np.shape(trData)[1]
				dict2['hL'] = [np.shape(trData)[1],100]
				dict2['actFunc'] = ['SPL','tanh']
				dict2['outputActivation'] = 'linear'
				dict2['l1Penalty'] = 0.01
				dict2['nItrPre'] = 10
				dict2['nItrPost'] = 40
				dict2['errorFunc']='MSE'

				#initiate an object of the model and call it's training method 
				model = SparseCE(dict2)
				model.fit(trData,trLabels)
				featureList,featuresW = returnImpFeaturesElbow(model.splWs)
				#print('Fold [',(fold+1),'/',nFolds,']. Iteration',itr+1,'is done. No. of extracted features:',len(featureList))
				#orderedWs = -1*np.sort(-1*np.abs(model.splWs))
				#pdb.set_trace()
				newFList = initialFIndices[featureList]
				fCnt += len(newFList)
				itr += 1
				#pdb.set_trace()
				featureIndex.append(newFList)
				featureWeight.append(featuresW)

				#remove the selected features from trData and valData and repeat the process
				initialFIndices = np.delete(initialFIndices,featureList)
				trData = np.delete(trData,featureList,axis=1)

			featureIndex = np.hstack((featureIndex))
			featureWeight = np.hstack((featureWeight))

			key = 'Pass'+str(k+1)+'_Fold'+str(fold+1)
			featureSet['features'][key] = featureIndex
			featureSet['featureWeights'][key] = featureWeight
			print('IFS: Fold',fold+1,'of Pass:',k+1,'is done. No. of extracted features:',len(featureIndex))
			if len(featureIndex) != len(np.unique(featureIndex)):
				print('Issue in feature selection')
				pdb.set_trace()
		print('IFS: Pass',k+1,'is complete.')
		#pdb.set_trace()
		pickle.dump(featureSet,open(fSetName,'wb'))

if __name__ == "__main__":
	dataSetName,nTopFeatures = 'IndianPine',80
	FeatureSelection_IP(dataSetName)
	orderFeatureByFrequency(dataSetName,nTopFeatures)

