import time, datetime, glob, os, re, sys, random,  pickle as pickle, collections, itertools 
import pandas as pd, numpy as np, scipy, json
from sklearn import metrics
from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import IPython.display
import matplotlib.pylab as plt
import torch
from IPython.display import clear_output
from functools import reduce
from sklearn.preprocessing import LabelBinarizer

####
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, ConcatDataset
from torch.nn import init
import torch.optim as optim
seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from model import *
from dataloader import *

####

os.environ["CUDA_VISIBLE_DEVICES"]="0" 

import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning)


ROOT_DIR = "/home/pate/"




def aggregated_teacher(models, dataloader, epsilon, num_student_train):

    preds = torch.torch.zeros((len(models), num_student_train), dtype=torch.long)
    for i, model in enumerate(models):
        results = predict(model, dataloader)
        preds[i] = results.squeeze()
    
    labels = np.array([]).astype(int)
    for image_preds in np.transpose(preds):
        label_counts = np.bincount(image_preds, minlength=2)
        beta = 1 / epsilon
        for i in range(len(label_counts)):
            label_counts[i] += np.random.laplace(0, beta, 1)
        new_label = np.argmax(label_counts)
        labels = np.append(labels, new_label)
    return preds.numpy(), labels



def train(cur_teacher_data):
    df_student_teacher = cur_teacher_data.features.loc[cur_teacher_data.indices]
    x_teacher = df_student_teacher[cur_teacher_data.cols].values
    y_teacher = df_student_teacher.y.values
    z_teacher = df_student_teacher.z.values
    cur_model = LogisticRegression(max_iter=200).fit(X=x_teacher, y=y_teacher)
    return cur_model


def train_teacher_models(num_teachers, teacher_loaders):
    models = []
    for i in range(num_teachers):
        cur_model = train(teacher_loaders[i])
        models.append(cur_model)
    return models


def train_student(x_loader, new_label, num_epoch, num_feat, lambda_reg,verbose=True):
  torch.manual_seed(seed)
  criterion = nn.BCELoss()
  student_model = Classifier(num_feat)
  init_weights(student_model, 'xavier')
  optimizer = optim.Adam(student_model.parameters(), lr=1e-4, weight_decay= lambda_reg)
  student_model.train()
  for e in range(num_epoch):
    train_loader = student_loader(x_loader, new_label)
    cur_training_epoch_loss = 0 
    for x, y in train_loader:
      optimizer.zero_grad()
      output = student_model(x)
      loss = criterion(torch.sigmoid(output), y.float().unsqueeze(1))
      loss.backward()
      optimizer.step()
      cur_training_epoch_loss += loss.item()
    if verbose: 
      if (e+1) %3 == 0: 
        print("Epoch: {}/{}.. ".format(e, num_epoch),"Train Loss: {} ".format(cur_training_epoch_loss/len(x_loader)))
  return student_model 


def predict(model, x_array):
    #outputs = torch.zeros(0, dtype=torch.long)
    output = model.predict(x_array)
    output_prob = model.predict_proba(x_array)
    return output, output_prob


def get_teacher_votes(models, student_x_train):
    preds = np.zeros((len(models), len(student_x_train)))
    dist_bound = 0
    for i, model in enumerate(models):
        pred, pred_prob = predict(model, student_x_train)
        dist_bound += pred_prob * (1-pred_prob)
        preds[i] = pred
    preds2 = preds.astype(int).T
    preds2 = np.array([np.bincount(preds2[i], minlength=2) for i in range(preds2.shape[0])])
    dist_bound = dist_bound /len(models)
    return preds2, dist_bound[:, 0]


def aggregate_noisy_votes(votes, sigma, n_teachers, orders=None):
    #votes_log = []
    rdp_eps_by_order = 0
    final_votes = []
    for i in range(len(votes)):
        vote = votes[i, :]
        #rdp_eps_by_order += rdp_gaussian(compute_logq_gaussian(vote, sigma), sigma, orders)
        noise_votes = vote + np.random.normal(np.zeros_like(vote), sigma)
        cur_final_vote = np.argmax(noise_votes)
        final_votes.append(cur_final_vote)
        #votes_log.append((vote, thresh_votes, release_votes))
    return final_votes, rdp_eps_by_order


