import time, datetime, glob, os, re, sys, random,  pickle as pickle, collections, itertools 
import pandas as pd, numpy as np, scipy, sklearn, argparse
from sklearn import metrics
from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import euclidean_distances, manhattan_distances
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelBinarizer
from scipy.stats import spearmanr

import IPython.display
import matplotlib.pylab as plt
import torch
from IPython.display import clear_output
from functools import reduce

####
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, ConcatDataset
from torch.nn import init

seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from dataloader import *
from utils import *
from model import *
####

pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 2000)
pd.set_option('display.max_rows', 500)
pd.set_option('display.precision', 4)
pd.set_option('display.max_colwidth', 2000)
os.environ["CUDA_VISIBLE_DEVICES"]="0" 

import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning)




def run_exp(args):
	#load data
	dataset = args.dataset

	lambda_reg = args.lambda_reg
	max_iter = args.max_iter	
	num_teach = args.num_teacher
	ROOT_DIR = "/home/pate/"
	sigma = args.sigma
	num_seed = 1000 
	res_stack= {}
	res_stack2 = {}
	x_train, y_train, z_train, x_test, y_test, z_test, _ = get_student_data(dataset)
	####################################################################################################################################
	##################teacher training######################
	teacher_loaders = get_loader(os.path.join(ROOT_DIR,"data/{}_train.csv".format(dataset)), num_teach)
	teacher_models = train_models(num_teach, teacher_loaders)
	####################################################################################################################################
	##################non private training######################

	votes, _ = get_teacher_votes(teacher_models, x_train)

	student_model_non_priv = LogisticRegression(max_iter=max_iter, fit_intercept=False, C=1/lambda_reg).fit(x_train, np.argmax(votes, axis=1))
	non_priv_param = student_model_non_priv.coef_
	y_test_pred_non_priv = student_model_non_priv.predict_proba(x_test)
	ind_acc_test_non_priv = np.argmax(y_test_pred_non_priv, axis=1) == y_test
	ind_loss_test = cust_log_loss(y_test, y_test_pred_non_priv)
	loss_test_non_priv = np.mean(ind_loss_test)
	acc_test_non_priv = {}
	acc_test_non_priv2 = student_model_non_priv.score(x_test, y_test)
	for i in range(2): 
		acc_test_non_priv['group_' + str(i)] = student_model_non_priv.score(x_test[z_test==i], y_test[z_test==i])
	####################################################################################################################################
	################# private training ######################

	seed_log_pred_train = np.zeros((len(y_train), num_seed), dtype=int) #store prediction of one run to N x B time run
	seed_log_pred_test = np.zeros((len(y_test), num_seed), dtype=int) #store prediction of one run to N x B time run
	lhs, loss_test_priv, acc_test_priv = 0, {}, {}
	acc_test_priv2 = 0
	for i in range(2): 
		loss_test_priv['group_' + str(i)] = 0 
		acc_test_priv['group_' + str(i)] = 0
	res_stack = {}
	for seed in range(num_seed): 
		y_train_noisy, _  = aggregate_noisy_votes(votes, sigma, num_teach) 
		student_model_priv = train_student(student_train_loader, y_train_noisy, num_epoch, x_train.shape[1], lambda_reg)
		y_test_pred = student_model_priv.predict_proba(x_test)
		seed_log_pred_train[:, seed] = y_train_noisy
		seed_log_pred_test[:, seed] = np.argmax(y_test_pred, axis=1)
		ind_loss_test_priv = cust_log_loss(y_test, y_test_pred)
		acc_test_priv2 += student_model_priv.score(x_test, y_test)
		for i in range(2): 
			loss_test_priv['group_' + str(i)] += np.mean(ind_loss_test_priv[z_test==i]) # loss private
			acc_test_priv['group_' + str(i)] += student_model_priv.score(x_test[z_test==i], y_test[z_test==i])
		priv_param = student_model_priv.coef_ #param private
		lhs += np.linalg.norm(non_priv_param - priv_param)
	seed_log_pred_test =  np.array([np.bincount(seed_log_pred_test[i], minlength=2) for i in range(seed_log_pred_test.shape[0])]) # count # times predicted for class 0/1
	ind_acc_test_priv = (seed_log_pred_test[np.arange(len(y_test)), y_test.astype(int)])/num_seed
	er_acc = ind_acc_test_non_priv - ind_acc_test_priv	

	seed_log_pred_train =  np.array([np.bincount(seed_log_pred_train[i], minlength=2) for i in range(seed_log_pred_train.shape[0])]) # count # times predicted for class 0/1
	flip_prob_vec = (seed_log_pred_train[np.arange(len(seed_log_pred_train)),(1- np.argmax(votes,axis=1))])/num_seed
	flip_prob_vec_test = (seed_log_pred_test[np.arange(len(seed_log_pred_test)),(1- y_test.astype(int))])/num_seed

	rhs = flip_prob_vec * np.linalg.norm(x_train, axis=1)
	avg_rhs = np.mean(rhs)
	res_stack2['input_norm'] = np.linalg.norm(x_test[ind_acc_test_non_priv], axis=1)
	res_stack2['flip_prob'] = flip_prob_vec_test[ind_acc_test_non_priv]
	res_stack2['acc_gap'] = er_acc[ind_acc_test_non_priv]
	priv_param /= num_seed
	res_stack['sigma'] = sigma
	res_stack['max_iter'] = max_iter
	res_stack['num_teacher'] = num_teach
	res_stack['rhs'] = avg_rhs/lambda_reg
	res_stack['lhs'] = lhs/num_seed
	res_stack['lambda_reg'] = lambda_reg
	res_stack['acc_priv'] = acc_test_priv2/num_seed
	res_stack['acc_nonpriv'] = acc_test_non_priv2
	for i in range(2): 
		res_stack['err_loss_group_{}'.format(i)] = loss_test_priv['group_' + str(i)]/num_seed - np.mean(ind_loss_test[z_test==i])
		res_stack['err_acc_group_{}'.format(i)] = acc_test_non_priv['group_' + str(i)] - acc_test_priv['group_' + str(i)]/num_seed
	result_fp = os.path.join(ROOT_DIR, 'results/{}_bound_parameter.csv'.format(dataset))
	header_flag = not os.path.isfile(result_fp)
	df_private = pd.DataFrame(res_stack, index=[0])
	df_private.to_csv(result_fp, index=None,header=header_flag, mode='a')
	with open(os.path.join(ROOT_DIR, 'results/{}_bound_parameter_all_{}_{}_{}.txt'.format(dataset, args.num_teacher,args.sigma, args.lambda_reg)), 'w') as fp:
		json.dump(res_stack2, fp,  cls=NumpyEncoder)

def main():
	starttime = time.time()
	parser = argparse.ArgumentParser(description='PATE')
	parser.add_argument('--dataset', type=str)
	parser.add_argument('--max_iter', type=int, default=5)
	parser.add_argument('--lambda_reg', type=float, default=0.1)
	parser.add_argument('--num_teacher', type= int)
	parser.add_argument('--sigma', type= float)

	args = parser.parse_args()
	run_exp(args)
	print('That took {} seconds'.format(time.time() - starttime))

if __name__ == "__main__":
	main()
