import time, datetime, glob, os, re, sys, random,  pickle as pickle, collections, itertools 
import pandas as pd, numpy as np, scipy, json
from sklearn import metrics
from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import IPython.display
import matplotlib.pylab as plt
import torch
from IPython.display import clear_output
from functools import reduce
from sklearn.preprocessing import LabelBinarizer

####
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, ConcatDataset
from torch.nn import init
import torch.optim as optim
seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from model import *
from dataloader import *

####

os.environ["CUDA_VISIBLE_DEVICES"]="0" 

import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning)


ROOT_DIR = "/home/pate/"


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

        
def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """

    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming_uniform':
                init.kaiming_uniform(m.weight.data, a=0, mode='fan_in')
                init.kaiming_uniform(m.bias.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
                init.orthogonal_(m.bias.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find(
                'BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  


def cust_log_loss(y_true, y_pred, eps=1e-15): 
  lb = LabelBinarizer()
  transformed_labels = lb.fit_transform(y_true)
  transformed_labels = np.append(1 - transformed_labels, transformed_labels, axis=1)
  y_pred = np.clip(y_pred, eps, 1 - eps)
  y_pred /= y_pred.sum(axis=1)[:, np.newaxis]
  loss = -(transformed_labels * np.log(y_pred)).sum(axis=1)
  return loss


def get_partition(pd00, num_public):
    pd00 = shuffle(pd00)
    pd00['index']  = pd00.index
    public_idx = np.random.choice(range(len(pd00)), num_public, replace = False)
    public_pd = pd00[pd00['index'].isin(public_idx)]
    public_pd['partition'] = 'public'
    res_pd = pd00[~pd00['index'].isin(public_idx)]
    res_pd['partition'] = res_pd['index'].apply(lambda x:  'test' if np.random.rand()>=0.8  else 'private')
    pd00 = shuffle( pd.concat([res_pd, public_pd])).drop(columns = ['index'])
    return pd00


def rand_response(z, num_z, eps):

  """
  Implement randomized response mechanism 
  """
  main_prob = np.exp(eps)/(num_z -1 + np.exp(eps))
  aux_prob = (1 - main_prob)/float(num_z -1)
  pr = [aux_prob]  * num_z
  pr[int(z)] = main_prob
  return np.random.choice(list(range(num_z)), p= pr)