import os
import glob
import copy
import pdb
import json
import random 
import itertools

import pandas as pd
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torch.utils.data.sampler import Sampler
import torch.nn.functional as F


def split_clients(input_feature, sensitive_feature, target_feature, num_users):
    num_items = int(len(input_feature)/num_users)
        
    dict_users = {}

    group_1 = (target_feature == 1) * (sensitive_feature == 1)
    group_2 = (target_feature == 1) * (sensitive_feature == 0)
    group_3 = (target_feature == 0) * (sensitive_feature == 1)
    group_4 = (target_feature == 0) * (sensitive_feature == 0)

    group_1_idxs = np.where(group_1==1)[0]
    group_2_idxs = np.where(group_2==1)[0]
    group_3_idxs = np.where(group_3==1)[0]
    group_4_idxs = np.where(group_4==1)[0]

    group_idx = [group_1_idxs, group_2_idxs, group_3_idxs, group_4_idxs]

    for i in range(num_users):
        # set a random seed for each user to determine the distribution
        np.random.seed(i)
        group_prob = np.random.choice(range(1,11), 4, replace=False)
        group_prob = group_prob / group_prob.sum()
        user_data = []
        user_sensitive = []
        user_label = []
        for j in range(4):
            group_j_prob = group_prob[j]
            num_samples = int(num_items * group_j_prob)

            # print("group_idx[j], num_samples: ", len(group_idx[j]), num_samples)
            if num_samples > len(group_idx[j]):
                num_samples = len(group_idx[j]) // 4

            group_j_idx_selected = set(np.random.choice(group_idx[j], num_samples, replace=False))
            group_j_idx = set(group_idx[j]) - group_j_idx_selected
            group_idx[j] = list(group_j_idx)
            
            group_j_data = input_feature[list(group_j_idx_selected)]
            group_j_sensitive = sensitive_feature[list(group_j_idx_selected)]
            group_j_target = target_feature[list(group_j_idx_selected)]

            user_data += list(group_j_data)
            user_sensitive += list(group_j_sensitive)
            user_label += list(group_j_target)
            # print("user: {}, group id: {}, num_samples: {}".format(i, j, len(group_j_data)))
        # pdb.set_trace()

        dict_users[i] = [user_data, user_sensitive, user_label]

    return dict_users

# define custom dataloader from torch
class CustomizedDataset(Dataset):
    def __init__(self, inputs_rows=None, sensitive_rows=None, target_rows=None):
        self.inputs_rows = inputs_rows
        self.sensitive_rows = sensitive_rows
        self.target_rows = target_rows

    def __len__(self):
        return len(self.inputs_rows)
        
    def __getitem__(self, idx):
        inputs = torch.from_numpy(self.inputs_rows[idx]).float()

        sensitive_feature = torch.tensor(self.sensitive_rows[idx]).long()

        targets = torch.tensor(self.target_rows[idx]).long()

        return inputs, sensitive_feature, targets

def get_dataset(dataset, num_users, label_y=100, label_a=100):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    clients_train_dataset = []
    clients_test_dataset = []

    if dataset == 'compas':
        compas_data = pd.read_csv('data/compas/propublica_data_for_fairml.csv')
        sensitive_feature = compas_data.iloc[:, [label_a]].values
    
        num_samples = len(sensitive_feature)
        np.random.seed(0) # must be 0 seed
        ids = np.random.choice(num_samples, num_samples, replace=False)
        train_ids = ids[:int(num_samples * 0.8)]
        test_ids = ids[int(num_samples * 0.8):]
        train_sensitive = sensitive_feature[train_ids].squeeze()
        test_sensitive = sensitive_feature[test_ids].squeeze()

        train_data_file = 'data/compas_new/train_data.npy'
        train_label_file = 'data/compas_new/train_label.npy'
        train_x = np.load(train_data_file)
        train_y = np.load(train_label_file)

        user_groups = split_clients(train_x, train_sensitive, train_y, num_users)

        test_data_file = 'data/compas_new/test_data.npy'
        test_label_file = 'data/compas_new/test_label.npy'
        test_x = np.load(test_data_file)
        test_y = np.load(test_label_file)   

        test_dataset = CustomizedDataset(inputs_rows=test_x, sensitive_rows=test_sensitive, target_rows=test_y)

    elif dataset == 'income':
        clients = []
        groups = []
        train_data = {}
        test_data = {}

        train_data_file = 'data/adult_new/train_data.npy'
        train_label_file = 'data/adult_new/train_label.npy'
        train_gender_file = 'data/adult_new/train_gender.npy'

        train_x = np.load(train_data_file)
        train_sensitive = np.load(train_gender_file) # 1 = female
        train_y = np.load(train_label_file)

        # this is for sanity check
        # train_id = np.where(train_x[:, 44]>0)
        # gender_id = np.where(train_sensitive==1)

        test_data_file = 'data/adult_new/test_data.npy'
        test_label_file = 'data/adult_new/test_label.npy'
        test_gender_file = 'data/adult_new/test_gender.npy'

        test_x = np.load(test_data_file)
        test_sensitive = np.load(test_gender_file) # 1 = female
        test_y = np.load(test_label_file)

        user_groups = split_clients(train_x, train_sensitive, train_y, num_users)
        test_dataset = CustomizedDataset(inputs_rows=test_x, sensitive_rows=test_sensitive, target_rows=test_y)

    return user_groups, test_dataset

def Gaussian_stats(user_groups):
    num_users = len(user_groups)
    
    num_data_list = []
    mean_list = []
    std_list = []
    for i in range(num_users):
        user_data = np.array(user_groups[i][0])
        
        mean = np.mean(user_data, axis=0)
        std = np.std(user_data, axis=0)
        
        num_data = len(user_data)

        num_data_list.append(num_data)
        mean_list.append(mean)
        std_list.append(std)
    
    num_data_list = np.array(num_data_list).reshape(num_users, 1)
    mean_list = np.array(mean_list)
    std_list = np.array(std_list)

    weighted_mean = np.sum(mean_list * num_data_list, axis=0) / np.sum(num_data_list)
    weighted_std = np.sum(std_list * num_data_list, axis=0) / np.sum(num_data_list)

    return weighted_mean, weighted_std
        
if __name__ == '__main__':

    num_users = 10
    user_groups, test_dataset = get_dataset(dataset='compas', num_users=num_users, label_a=5)
    # for i in range(num_users):
    #     current_user = user_groups[i]
    #     current_dataset = CustomizedDataset(inputs_rows=current_user[0], 
    #                                         sensitive_rows=current_user[1], 
    #                                         target_rows=current_user[2])
        
    #     current_dataloader = torch.utils.data.DataLoader(current_dataset, batch_size=4, shuffle=True)
    #     num_step = len(current_dataloader)
    #     for epoch in range(10):
    #         for j, (inputs, sensitive_attributes, targets) in enumerate(current_dataloader):
    #             if j == 122:
    #                 print(inputs)
            
    #         pdb.set_trace()
        
    #     break
    # pdb.set_trace()
    # global_data_stats(dataset='compas', label_a=5)