import time, datetime, glob, os, re, sys, random,  pickle as pickle, collections, itertools 
import pandas as pd, numpy as np, scipy, sklearn
from sklearn import metrics
from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import IPython.display
import matplotlib.pylab as plt
import torch
from IPython.display import clear_output
from functools import reduce

####
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, ConcatDataset
from torch.nn import init
import torch.optim as optim
seed = 3
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


####

pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 2000)
pd.set_option('display.max_rows', 500)
pd.set_option('display.precision', 4)
pd.set_option('display.max_colwidth', 2000)
os.environ["CUDA_VISIBLE_DEVICES"]="0" 

import warnings
warnings.filterwarnings(action='ignore', category=FutureWarning)



class CustomDataset(Dataset):
    def __init__(self, meta_path, indices,  mode='train'):
        self.features = pd.read_csv(meta_path)
        dataset = meta_path.split("/")[-1].split('_')[0]
        if dataset == 'creditcard': 
            self.features['constant'] = 1.0
        self.indices = indices 
        self.cols = np.setdiff1d(self.features.columns, ['label', 'y','z', 'is_train','race', 'is_train2']) 
        #print(self.cols)
        self.mode = mode
        self.meta_path = meta_path
        self.preprocess()

    def preprocess(self):
        random.seed(1234)
        if self.mode =='train': 
            self.features.sample(frac=1, replace=False).reset_index(drop=True, inplace=True)
    
    def __getitem__(self, idx):
        cur_data = self.features.loc[self.indices[idx]]
        x = cur_data[self.cols].values
        y = cur_data.y
        x = torch.tensor(x, dtype=torch.float)
        y = torch.tensor(y)
        z = torch.tensor(cur_data.z)
        return x, y ,z

    def __len__(self):
        return len(self.indices)


def get_loader(meta_path, num_teachers):
    train_data = pd.read_csv(meta_path)
    teacher_loaders = []
    data_size = np.ceil((len(train_data))/num_teachers)
    print('total train_data', len(train_data))
    print('Total data per teacher:', data_size)
    skf = StratifiedKFold(n_splits=num_teachers, random_state=123, shuffle=True)
    private_x = train_data[[c for c in train_data.columns if c not in ['is_train', 'y', 'label', 'z']]]
    private_y = train_data.y.values
    for _, test_index in skf.split(private_x, private_y):
        subset_data = CustomDataset(meta_path, test_index)
        teacher_loaders.append(subset_data)
    return teacher_loaders


def get_student_data(dataset): 
    test_meta_path =  "/home/pate/data/{}_test.csv".format(dataset)
    print(test_meta_path)
    test_data = pd.read_csv(test_meta_path)
    num_student_train = 200
    batch_size = 32
    print('Total test data: {}. Total student train student data: {}'.format(len(test_data), num_student_train))
    student_train_data = CustomDataset(test_meta_path, test_data[test_data['is_train2'] ==1].index.values)
    student_test_data = CustomDataset(test_meta_path,test_data[test_data['is_train2'] !=1].index.values )
    student_train_loader = DataLoader(student_train_data, batch_size=batch_size)
    student_test_loader = DataLoader(student_test_data, batch_size=batch_size)

    df_student_train = student_train_data.features.loc[student_train_data.indices]
    df_student_test = student_test_data.features.loc[student_test_data.indices]
    x_train = df_student_train[student_train_data.cols].values
    x_test = df_student_test[student_test_data.cols].values
    y_train = df_student_train.y.values
    z_train = df_student_train.z.values
    y_test = df_student_test.y.values
    z_test = df_student_test.z.values
    return x_train, y_train, z_train, x_test, y_test, z_test, student_train_loader

