import pdb
from copy import copy
import torch
import numpy as np
import pickle
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torch.autograd import Variable
from utilityScript import *
from SoftmaxClassifierPyTorch import *
from simpleANNClassifierPyTorch import *
from SupervisedCentroidencodeVisualizerPyTorch import *
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score

def returnBottleneckArc(dataSetName):
	arch=None
	if dataSetName.upper() in ['INDIANPINE']:
		arch=[500,400,250]
	return arch
	
def returnActFunc(dataSetName):
	arch=None
	if dataSetName.upper() in ['INDIANPINE']:
		actF=['relu','relu','relu']
	return actF

def prob_of_class_greater_label(tstLabels,tstPredProbs):
	#for auc calculation using sklearn.metrics.roc_auc_score
	probs_auc = np.ones(len(tstLabels))
	for i in range (len(tstLabels)):
		if tstLabels[i] == 0:
			probs_auc[i] = probs_auc[i] - tstPredProbs[i]
		else:
			probs_auc[i] = tstPredProbs[i]
	return probs_auc

def returnHyperParams(dataSetName):
	
	hyperParams = {}
	if dataSetName == 'IndianPine':
		hyperParams['normalizationFlag'] = True
		hyperParams['preTrFlag'] = True
		hyperParams['epochsPre'] = 10
		hyperParams['epochsPost'] = 300
		hyperParams['epochsFinetune'] = 150
		hyperParams['miniBatchSize'] = 50
		hyperParams['lrnCE'] = 0.001
		hyperParams['lrnSoftmax'] = 0.001
		hyperParams['lrnFinetune'] = 0.0001
		hyperParams['multiCenter'] = False
	
	return hyperParams

def CEClassifier(conf_ce,trData,trLabels,tstData,tstLabels,dataSetName,param,gpuId):
	
	nClass = len(np.unique(trLabels))
	#initialize an object of CE
	ce = SCEVisualizer(conf_ce)
	ce.bottleneckCons = False
	#train CE
	ce.fit(trData, trLabels, preTraining=param['preTrFlag'], learningRate=param['lrnCE'], miniBatchSize=param['miniBatchSize'], 
		numEpochsPreTrn=param['epochsPre'], numEpochsPostTrn=param['epochsPost'], standardizeFlag=param['normalizationFlag'], 
		multiClass=param['multiCenter'],cudaDeviceId=gpuId, verbose=False)	
	ce = ce.to('cpu')
	
	#store the weight and bias of CE. The weight and bias will be used in finetuning
	preW = [ce.hidden[i].weight for i in range(len(ce.hLayer))]
	preB = [ce.hidden[i].bias for i in range(len(ce.hLayer))]

	#store bottleneck data to pre-train softmax layer
	bottleneckData = ce.predict(trData)[len(conf_ce['hL'])].to('cpu').numpy()
	
	#define a softmax layer with the bottleneck output
	#pdb.set_trace()
	softmaxModel = Softmax(bottleneckData.shape[1],nClass)

	#prepare data for softmax
	softmaxDataTorch = Data.TensorDataset(torch.from_numpy(bottleneckData).float(),torch.from_numpy(trLabels.flatten().astype(int)))
	softmaxLoader = Data.DataLoader(dataset=softmaxDataTorch,batch_size=param['miniBatchSize'],shuffle=True)
	
	#train softmax classifier
	#print('Training softmax layer')
	softmaxModel.fit(softmaxLoader,'Adam',learningRate=param['lrnSoftmax'],numEpochs=10)

	#store the weight and bias of softmax model
	preW.append(softmaxModel.softmaxLayer.weight)
	preB.append(softmaxModel.softmaxLayer.bias)
	#print('Training of softmax layer is done')

	# now do fine-tuning with a fully connected ANN with those pre-trained weight and bias.
	
	#print('Fine tuning starts')
	ann = NeuralNet(conf_ce['inputDim'], conf_ce['hL'] , nClass)
	
	#assign the pre-trained weight and bia
	ann.setHiddenWeight(preW[:-1],preB[:-1])
	ann.setOutputWeight(preW[-1],preB[-1])

	
	#trData,trLabels,standardizeFlag,batchSize,optimizationFunc='Adam',learningRate=0.001,m=0,numEpochs=100,cudaDeviceId=0,verbose=False
	ann.fit(trData,trLabels,param['normalizationFlag'],param['miniBatchSize'],'Adam',param['lrnFinetune'],
		numEpochs=param['epochsFinetune'],cudaDeviceId=gpuId)
	ann = ann.to('cpu')
	#pdb.set_trace()
	tstPredProb,tstPredLabel = ann.predict(tstData)
	#calculation of test statictics
	accuracy = 100 * accuracy_score(tstLabels.flatten(), tstPredLabel)
	return accuracy

def classificationOnTestData(dataSetName,nSamplePerClass,nFeatures):
	#load feature file
	feaFile = 'FeatureSet/'+dataSetName+'/'+dataSetName+'_features.p'
	fea = pickle.load(open(feaFile,'rb'))
	tstIndices = pickle.load(open(dataSetName+'_Test_Indices.p','rb'))
	D,L,_,_ = getApplicationData(dataSetName,nSamplePerClass)
	D = D[:,fea[:nFeatures]]
	D = np.hstack((D,L)) #attach label

	#now split D into training and test 
	X_tst = D[tstIndices,:]
	X_tr = np.delete(D,tstIndices,axis=0)
	trData,trLabels = X_tr[:,:-1],X_tr[:,-1]
	tstData,tstLabels = X_tst[:,:-1],X_tst[:,-1]
	print('Data is loaded. No of training samples',len(trData),'No of test samples',len(tstData))

	accuracy = []
	for itr in range(10):
		conf_ce={}
		conf_ce['inputDim'] = np.shape(trData)[1]
		conf_ce['hL'] = returnBottleneckArc(dataSetName)
		conf_ce['hActFunc'] = returnActFunc(dataSetName)
		conf_ce['oActFunc'] = 'linear'
		conf_ce['errorFunc'] = 'MSE'
		conf_ce['l2Penalty'] = 0.0001

		param = returnHyperParams(dataSetName)
		gpuId = 0
		acc = CEClassifier(conf_ce,trData,trLabels,tstData,tstLabels,dataSetName,param,gpuId)
		accuracy.append(acc)
	print('No of features',nFeatures,'Test accuracy:{:.2f}%'.format(np.mean(accuracy)))

#load application data
dataSetName,nSamplePerClass='IndianPine',0#nSamplePerClass=0 means return all the samples

for nFeatures in [1,2,3,4,5,10,20,40,60,80]:
	classificationOnTestData(dataSetName,nSamplePerClass,nFeatures)

