import time, datetime, glob, os, re, sys, random,  pickle as pickle, collections, itertools 
import pandas as pd, numpy as np, scipy, sklearn, xgboost as xgb
from sklearn import metrics
from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import IPython.display
import matplotlib.pylab as plt
import torch
from IPython.display import clear_output
from functools import reduce

####
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, ConcatDataset
from torch.nn import init
import torch.optim as optim
seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from dataloader import *
from utils import *
from model import *
####

pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 2000)
pd.set_option('display.max_rows', 500)
pd.set_option('display.precision', 4)
pd.set_option('display.max_colwidth', 2000)
os.environ["CUDA_VISIBLE_DEVICES"]="0" 

import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning)


ROOT_DIR = "/home/pate/"



def run_exp(args):
	#load data
	num_teachers = args.num_teachers
	batch_size = 32    

	teacher_loaders = get_loader(os.path.join(ROOT_DIR,"data/income_train.csv"), num_teachers)

	test_meta_path = os.path.join(ROOT_DIR,"data/income_test.csv")
	test_data = pd.read_csv(test_meta_path)
	num_student_train = int(np.ceil(len(test_data) * 0.75))

	print('Total test data: {}. Total student train student data: {}'.format(len(test_data), num_student_train))
	student_train_data = CustomDataset(test_meta_path, np.arange(0, num_student_train))
	student_test_data = CustomDataset(test_meta_path, np.arange(num_student_train, len(test_data)))
	student_train_loader = DataLoader(student_train_data, batch_size=batch_size)
	student_test_loader = DataLoader(student_test_data, batch_size=batch_size)

	df_student_test = student_test_data.features.iloc[student_test_data.indices]
	x_test = df_student_test.iloc[:, :-2].values
	y_test = df_student_test.label.values
	z_test = df_student_test.z.values

	x_test = torch.Tensor(x_test)
	z_test = torch.Tensor(z_test)
	y_test = torch.Tensor(y_test).reshape(-1, 1)


	#train teacher
	teacher_models = train_models(teacher_loaders,num_teachers, len(feats))

	#get label for students train data
	epsilon = args.eps

	preds, student_labels = aggregated_teacher(teacher_models, student_train_loader, epsilon, num_student_train)

	res_stack = {}
	res_stack['train_loss'] = []
	res_stack['test_acc'] = []
	res_stack['test_loss'] = []
	res_stack['group_0_acc'] = []
	res_stack['group_1_acc'] = []
	res_stack['group_0_loss'] = []
	res_stack['group_1_loss'] = []
	student_model = Classifier(len(feats))
	criterion = nn.BCELoss()
	optimizer = optim.Adam(student_model.parameters(), lr=1e-3)
	epochs = 20
	steps = 0
	running_loss = 0
	running_test_loss = 0
	for e in range(epochs):
		cur_training_epoch_loss = 0
		student_model.train()
		train_loader = student_loader(student_train_loader, student_labels)
		for images, labels in train_loader:
			optimizer.zero_grad()
			output = student_model.forward(images)
			loss = criterion(torch.sigmoid(output), labels.float().unsqueeze(1))
			loss.backward()
			optimizer.step()
			cur_training_epoch_loss += loss.item()
		running_loss += cur_training_epoch_loss/len(student_train_loader)
		res_stack['train_loss'].append(running_loss/(e+1))

		print("Epoch: {}/{}.. ".format(e+1, epochs),"Train Loss: {:.3f}.. ".format(running_loss/(e+1)))
		test_loss = 0
		accuracy = 0
		student_model.eval()
		with torch.no_grad():
			all_output = student_model(x_test)
			test_loss = criterion(torch.sigmoid(all_output), y_test).item()  
			accuracy =  bin_acc(all_output, y_test) 
			for i in range(2):
				y_group_pred = student_model(x_test[z_test == i])
				y_group_true = y_test[z_test == i]
				group_loss = criterion(torch.sigmoid(y_group_pred), y_group_true.float())
				acc = bin_acc(y_group_pred, y_group_true)
				res_stack['group_{}_acc'.format(i)].append(acc.item())
				running_test_loss += group_loss.item()
				res_stack['group_{}_loss'.format(i)].append(running_test_loss/(e+1))
		student_model.train()
		print("Test Loss: {:.3f}.. ".format(test_loss),
		    "Accuracy: {:.3f}".format(accuracy))
		res_stack['test_loss'].append(test_loss)
		res_stack['test_acc'].append(accuracy.item())
		        
		steps += 1

	df_result = pd.DataFrame(res_stack)
	df_result['epoch']= df_result.index.values
	df_result['epsilon'] = epsilon
	df_result['num_teachers'] = num_teachers
	df_result.to_csv(os.path.join(ROOT_DIR, 'results/pate_result_xavier_{}_{}'.format(num_teachers, epsilon)), index=False)

def main():
	starttime = time.time()
	parser = argparse.ArgumentParser(description='Test')
	parser.add_argument('--num_teachers', type= int)
	parser.add_argument('--eps', type=float, default=0.1)

	args = parser.parse_args()
	print(args)
	run_exp(args)
	print('That took {} seconds'.format(time.time() - starttime))

if __name__ == "__main__":
	main()