import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import copy
from torch.utils.data import Dataset



#---------------------------------------------------------data processing---------------------------------------------------------

def data_processing(data_name, data_path,num=100, train_rate=0.8,random_seed=0,generate_method='random', dir_alpha=1):
    if data_name == 'adult':
        if data_path is None:
            data=pd.read_csv('data/adult/adult.data',header=None,sep=', ',engine='python')
            test=pd.read_csv('data/adult/adult.test',header=None,sep=', ',engine='python')
        else:
            data=pd.read_csv(data_path+'adult.data',header=None,sep=', ',engine='python')
            test=pd.read_csv(data_path+'adult.test',header=None,sep=', ',engine='python')
        for col in data.columns:
            data=data[data[col]!='?']
        print(data.shape)
        for col in test.columns:
            test=test[test[col]!='?']
        print(test.shape)
        data=pd.concat([data,test])
        data.index=range(len(data))
        print(data.shape)
        column_list=['age','workclass','fnlwgt','education','education-num','marital-status','occupation','relationship','race','sex','capital-gain','capital-loss','hours-per-week','native-country','income']
        column_cate=['workclass','marital-status','occupation','relationship','race','native-country']
        data.columns=column_list
        data.drop(['education'],axis=1,inplace=True)
        Y=data['income'].copy()
        Sensitive_attribute=data['sex'].copy()
        for Y_i in Y.index:
            if Y[Y_i][:5]=='<=50K':
                Y[Y_i]=0
            else:
                Y[Y_i]=1

        for S_i in Sensitive_attribute.index:
            if Sensitive_attribute[S_i]=='Male':
                Sensitive_attribute[S_i]=0
            else:
                Sensitive_attribute[S_i]=1
        X=data.drop(['income'],axis=1)
        
        np.random.seed(random_seed)
        if generate_method=='random':
            def dataSplit_by_random(data, n):
                #divide data index into n parts randomly
                index_list = []
                index = data.index
                index = np.array(index)
                np.random.shuffle(index)
                index_list = np.array_split(index, n)
                return index_list
            #divide data into n parts by label
            label='workclass'
            #data_list, index_list = dataSplit(data, label)
            index_list=dataSplit_by_random(data,num)

        elif generate_method=='dirichlet':
            prop_0=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            prop_1=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            idx_array_0=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==0].index)
            idx_array_1=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==1].index)
            idx_array_0=idx_array_0.values
            idx_array_1=idx_array_1.values
            prop_0 = prop_0*len(idx_array_0)
            prop_1 = prop_1*len(idx_array_1)
            prop_0 = prop_0.astype(int)
            prop_1 = prop_1.astype(int)
            prop_0 = prop_0.cumsum()
            prop_1 = prop_1.cumsum()
            prop_0[-1]=len(idx_array_0)
            prop_1[-1]=len(idx_array_1)
            np.random.shuffle(idx_array_0)
            np.random.shuffle(idx_array_1)
            index_list=[]
            for i in range(num):
                if i==0:
                    index_list.append(np.concatenate((idx_array_0[:prop_0[i]],idx_array_1[:prop_1[i]])))
                else:
                    index_list.append(np.concatenate((idx_array_0[prop_0[i-1]:prop_0[i]],idx_array_1[prop_1[i-1]:prop_1[i]])))
                
        
        #X=X.drop([label],axis=1)
        column_cate=['workclass','marital-status','sex','occupation','relationship','race','native-country']
        #column_cate.remove(label)
        X=pd.get_dummies(X,columns=column_cate)
        train_idx_list=[]
        test_idx_list=[]
        dict_users={}
        test_dict_users={}
        count=0
        test_count=0


    elif data_name == 'compas':
        data=pd.read_csv(data_path+'compas-scores-two-years.csv')
        features_to_keep = ['sex', 'age', 'age_cat', 'race', 'juv_fel_count', 'juv_misd_count', 'juv_other_count',
                'priors_count', 'c_charge_degree', 'c_charge_desc','two_year_recid']
        categorical_attributes = ['age_cat', 'c_charge_degree', 'c_charge_desc','race']
        continuous_attributes = ['age', 'juv_fel_count', 'juv_misd_count', 'juv_other_count', 'priors_count']
        sensitive = 'sex'

        data=data[features_to_keep]
        data=pd.get_dummies(data, columns = categorical_attributes)
        
        Y=data['two_year_recid'].copy()
        data.drop(['two_year_recid'],axis=1,inplace=True)

        Sensitive_attribute=data[sensitive].copy()
        data.drop([sensitive],axis=1,inplace=True)

        for S_i in Sensitive_attribute.index:
            if Sensitive_attribute[S_i]=='Female':
                Sensitive_attribute[S_i]=1
            else:
                Sensitive_attribute[S_i]=0

        np.random.seed(random_seed)
        if generate_method=='random':
            def dataSplit_by_random(data, n):
                #divide data index into n parts randomly
                index_list = []
                index = data.index
                index = np.array(index)
                np.random.shuffle(index)
                index_list = np.array_split(index, n)
                return index_list

            index_list=dataSplit_by_random(data,num)
            
        elif generate_method=='dirichlet':
            prop_0=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            prop_1=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            idx_array_0=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==0].index)
            idx_array_1=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==1].index)
            idx_array_0=idx_array_0.values
            idx_array_1=idx_array_1.values

            prop_0 = prop_0*len(idx_array_0)
            prop_1 = prop_1*len(idx_array_1)
            prop_0 = prop_0.astype(int)
            prop_1 = prop_1.astype(int)
            prop_0 = prop_0.cumsum()
            prop_1 = prop_1.cumsum()
            prop_0[-1]=len(idx_array_0)
            prop_1[-1]=len(idx_array_1)
            np.random.shuffle(idx_array_0)
            np.random.shuffle(idx_array_1)

            index_list=[]
            for i in range(num):
                if i==0:
                    index_list.append(np.concatenate((idx_array_0[:prop_0[i]],idx_array_1[:prop_1[i]])))
                else:
                    index_list.append(np.concatenate((idx_array_0[prop_0[i-1]:prop_0[i]],idx_array_1[prop_1[i-1]:prop_1[i]])))
        train_idx_list=[]
        test_idx_list=[]
        dict_users={}
        test_dict_users={}
        count=0
        test_count=0
        X=data

    for idx in range(len(index_list)):
        #divide idx into train and test
        if generate_method=='dirichlet':
            idx_train, idx_test = train_test_split(index_list[idx], test_size = 1-train_rate, random_state = random_seed, stratify=Y.loc[index_list[idx]])
        else:
            idx_train, idx_test = train_test_split(index_list[idx], test_size = 1-train_rate, random_state = random_seed)
        train_idx_list+=list(idx_train)
        test_idx_list+=list(idx_test)
        dict_users[idx]=np.arange(count,count+len(idx_train))
        count+=len(idx_train)
        test_dict_users[idx]=np.arange(test_count,test_count+len(idx_test))
        test_count+=len(idx_test)
        #print(len(idx_train),len(idx_test))

    X_test=X.loc[test_idx_list].values
    Y_test=Y.loc[test_idx_list].values
    Sensitive_attribute_test=Sensitive_attribute.loc[test_idx_list].values
    X_train=X.loc[train_idx_list].values
    Y_train=Y.loc[train_idx_list].values
    Sensitive_attribute_train=Sensitive_attribute.loc[train_idx_list].values
    scaler=StandardScaler()
    X_train=scaler.fit_transform(X_train)
    X_test=scaler.transform(X_test)
    #tensor
    X_train=torch.tensor(X_train,dtype=torch.float32)
    #Y_train=torch.tensor(Y_train.values,dtype=torch.long)
    X_test=torch.tensor(X_test,dtype=torch.float32)
    #Y_test=torch.tensor(Y_test.values,dtype=torch.long)

    dataset_train=[]
    for idx in range(len(X_train)):
        dataset_train.append(tuple([X_train[idx],Y_train[idx]]))
    dataset_test=[]
    for idx in range(len(X_test)):
        dataset_test.append(tuple([X_test[idx],Y_test[idx]]))

    return dataset_train,dataset_test,dict_users,test_dict_users,Sensitive_attribute_train,Sensitive_attribute_test
    
def celeba_data_processing(data,label,sensitive,random_seed,split_size=508):
    np.random.seed(random_seed)
    total_celeba=np.arange(1,1+split_size*20)
    np.random.shuffle(total_celeba)
    celeba_split=np.array_split(total_celeba,split_size)
    
    data_split={}
    label_split={}
    sensitive_split={}
    for i in range(split_size):
        data_split[i]=[]
        label_split[i]=[]
        sensitive_split[i]=[]
        for j in celeba_split[i]:
            data_split[i].append(data[j])
            label_split[i].append((1-label[j])/2)
            sensitive_split[i].append((1-sensitive[j])/2)
        data_split[i]=np.concatenate(data_split[i],axis=0)
        label_split[i]=np.concatenate(label_split[i],axis=0).astype(int)
        sensitive_split[i]=np.concatenate(sensitive_split[i],axis=0)
    X_train=[]
    X_test=[]
    Y_train=[]
    Y_test=[]
    Sen_train=[]
    Sen_test=[]
    dict_users={}
    dict_users_test={}
    count=0
    count_test=0

    for i in range(split_size):
        train, test = train_test_split(np.arange(len(label_split[i])), test_size=0.2, random_state=random_seed)
        X_train+=list(data_split[i][train])
        X_test+=list(data_split[i][test])
        Y_train+=list(label_split[i][train])
        Y_test+=list(label_split[i][test])
        Sen_train+=list(sensitive_split[i][train])
        Sen_test+=list(sensitive_split[i][test])
        dict_users[i]=np.arange(count,count+len(train))
        dict_users_test[i]=np.arange(count_test,count_test+len(test))
        count+=len(train)
        count_test+=len(test)

    X_train=np.array(X_train)
    X_test=np.array(X_test)
    Y_train=np.array(Y_train)
    Y_test=np.array(Y_test)
    Sen_train=np.array(Sen_train)
    Sen_test=np.array(Sen_test)
    X_train=X_train/255
    X_test=X_test/255

    X_train=torch.tensor(X_train,dtype=torch.float32)
    X_test=torch.tensor(X_test,dtype=torch.float32)
    X_train=X_train.permute(0, 3, 1, 2)
    X_test=X_test.permute(0, 3, 1, 2)
    Y_train=torch.tensor(Y_train,dtype=torch.long)
    Y_test=torch.tensor(Y_test,dtype=torch.long)
    
    dataset_train=[]
    for idx in range(len(X_train)):
        dataset_train.append(tuple([X_train[idx],Y_train[idx]]))
    dataset_test=[]
    for idx in range(len(X_test)):
        dataset_test.append(tuple([X_test[idx],Y_test[idx]]))
    return dataset_train,dataset_test,dict_users,dict_users_test,Sen_train,Sen_test

class LoadData(Dataset):
    def __init__(self, y, x, sen):
        self.y = y
        self.x = x
        self.sen = sen
    
    def __getitem__(self, index):
        return torch.tensor(self.x[index]), torch.tensor(self.y[index]), torch.tensor(self.sen[index])
    
    def __len__(self):
        return self.y.shape[0]
    
def FedFB_processing(data_name,data_path,num=100,train_rate=0.8,random_seed=0,generate_method='random',dir_alpha=1):
    if data_name == 'adult':
        if data_path is None:
            data=pd.read_csv('data/adult/adult.data',header=None,sep=', ',engine='python')
            test=pd.read_csv('data/adult/adult.test',header=None,sep=', ',engine='python')
        else:
            data=pd.read_csv(data_path+'adult.data',header=None,sep=', ',engine='python')
            test=pd.read_csv(data_path+'adult.test',header=None,sep=', ',engine='python')
        for col in data.columns:
            data=data[data[col]!='?']
        print(data.shape)
        for col in test.columns:
            test=test[test[col]!='?']
        print(test.shape)
        data=pd.concat([data,test])
        data.index=range(len(data))
        print(data.shape)
        column_list=['age','workclass','fnlwgt','education','education-num','marital-status','occupation','relationship','race','sex','capital-gain','capital-loss','hours-per-week','native-country','income']
        column_cate=['workclass','marital-status','occupation','relationship','race','native-country']
        data.columns=column_list
        data.drop(['education'],axis=1,inplace=True)
        Y=data['income'].copy()
        Sensitive_attribute=data['sex'].copy()
        for Y_i in Y.index:
            if Y[Y_i][:5]=='<=50K':
                Y[Y_i]=0
            else:
                Y[Y_i]=1

        for S_i in Sensitive_attribute.index:
            if Sensitive_attribute[S_i]=='Male':
                Sensitive_attribute[S_i]=0
            else:
                Sensitive_attribute[S_i]=1
        np.random.seed(random_seed)
        X=data.drop(['income'],axis=1)

        if generate_method=='random':
            def dataSplit_by_random(data, n):
                #divide data index into n parts randomly
                index_list = []
                index = data.index
                index = np.array(index)
                np.random.shuffle(index)
                index_list = np.array_split(index, n)
                return index_list
            #divide data into n parts by label
            label='workclass'
            #data_list, index_list = dataSplit(data, label)
            index_list=dataSplit_by_random(data,num)
        elif generate_method=='dirichlet':
            prop_0=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            prop_1=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            idx_array_0=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==0].index)
            idx_array_1=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==1].index)
            idx_array_0=idx_array_0.values
            idx_array_1=idx_array_1.values
            prop_0 = prop_0*len(idx_array_0)
            prop_1 = prop_1*len(idx_array_1)
            prop_0 = prop_0.astype(int)
            prop_1 = prop_1.astype(int)
            prop_0 = prop_0.cumsum()
            prop_1 = prop_1.cumsum()
            prop_0[-1]=len(idx_array_0)
            prop_1[-1]=len(idx_array_1)
            np.random.shuffle(idx_array_0)
            np.random.shuffle(idx_array_1)
            index_list=[]
            for i in range(num):
                if i==0:
                    index_list.append(np.concatenate((idx_array_0[:prop_0[i]],idx_array_1[:prop_1[i]])))
                else:
                    index_list.append(np.concatenate((idx_array_0[prop_0[i-1]:prop_0[i]],idx_array_1[prop_1[i-1]:prop_1[i]])))

        column_cate=['workclass','marital-status','sex','occupation','relationship','race','native-country']
        #column_cate.remove(label)
        X=pd.get_dummies(X,columns=column_cate)
        client_dataset=[]
        client_Sensitive_attribute=[]
        train_idx_list=[]
        test_idx_list=[]
        dict_users={}
        test_dict_users={}
        count=0
        test_count=0
        np.random.seed(random_seed)
        
    if data_name == 'compas':
        data=pd.read_csv(data_path+'compas-scores-two-years.csv')
        features_to_keep = ['sex', 'age', 'age_cat', 'race', 'juv_fel_count', 'juv_misd_count', 'juv_other_count',
                'priors_count', 'c_charge_degree', 'c_charge_desc','two_year_recid']
        categorical_attributes = ['age_cat', 'c_charge_degree', 'c_charge_desc','race']
        continuous_attributes = ['age', 'juv_fel_count', 'juv_misd_count', 'juv_other_count', 'priors_count']
        sensitive = 'sex'

        data=data[features_to_keep]
        data=pd.get_dummies(data, columns = categorical_attributes)
        
        Y=data['two_year_recid'].copy()
        data.drop(['two_year_recid'],axis=1,inplace=True)

        Sensitive_attribute=data[sensitive].copy()
        data.drop([sensitive],axis=1,inplace=True)

        for S_i in Sensitive_attribute.index:
            if Sensitive_attribute[S_i]=='Female':
                Sensitive_attribute[S_i]=1
            else:
                Sensitive_attribute[S_i]=0

        np.random.seed(random_seed)
        if generate_method=='random':
            def dataSplit_by_random(data, n):
                #divide data index into n parts randomly
                index_list = []
                index = data.index
                index = np.array(index)
                np.random.shuffle(index)
                index_list = np.array_split(index, n)
                return index_list

            index_list=dataSplit_by_random(data,num)
            
        elif generate_method=='dirichlet':
            prop_0=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            prop_1=np.random.dirichlet(np.ones(num)*dir_alpha,size=1)
            idx_array_0=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==0].index)
            idx_array_1=copy.deepcopy(Sensitive_attribute[Sensitive_attribute==1].index)
            idx_array_0=idx_array_0.values
            idx_array_1=idx_array_1.values

            prop_0 = prop_0*len(idx_array_0)
            prop_1 = prop_1*len(idx_array_1)
            prop_0 = prop_0.astype(int)
            prop_1 = prop_1.astype(int)
            prop_0 = prop_0.cumsum()
            prop_1 = prop_1.cumsum()
            prop_0[-1]=len(idx_array_0)
            prop_1[-1]=len(idx_array_1)
            np.random.shuffle(idx_array_0)
            np.random.shuffle(idx_array_1)
            
            index_list=[]
            for i in range(num):
                if i==0:
                    index_list.append(np.concatenate((idx_array_0[:prop_0[i]],idx_array_1[:prop_1[i]])))
                else:
                    index_list.append(np.concatenate((idx_array_0[prop_0[i-1]:prop_0[i]],idx_array_1[prop_1[i-1]:prop_1[i]])))
                    
        train_idx_list=[]
        test_idx_list=[]
        dict_users={}
        test_dict_users={}
        count=0
        test_count=0
        X=data
        
    for idx in range(len(index_list)):
        #divide idx into train and test
        if generate_method=='dirichlet':
            idx_train, idx_test = train_test_split(index_list[idx], test_size = 1-train_rate, random_state = random_seed, stratify=Y.loc[index_list[idx]])
        else:
            idx_train, idx_test = train_test_split(index_list[idx], test_size = 1-train_rate, random_state = random_seed)
        train_idx_list+=list(idx_train)
        test_idx_list+=list(idx_test)
        dict_users[idx]=np.arange(count,count+len(idx_train))
        count+=len(idx_train)
        test_dict_users[idx]=np.arange(test_count,test_count+len(idx_test))
        test_count+=len(idx_test)
        #print(len(idx_train),len(idx_test))

    X_test=X.loc[test_idx_list].values
    Y_test=Y.loc[test_idx_list].values
    Sensitive_attribute_test=Sensitive_attribute.loc[test_idx_list].values
    X_train=X.loc[train_idx_list].values
    Y_train=Y.loc[train_idx_list].values
    Sensitive_attribute_train=Sensitive_attribute.loc[train_idx_list].values
    scaler=StandardScaler()
    X_train=scaler.fit_transform(X_train)
    X_test=scaler.transform(X_test)
    #tensor
    
    adult_test = LoadData(Y_test, X_test, Sensitive_attribute_test)
    adult_train = LoadData(Y_train, X_train, Sensitive_attribute_train)

    X_train=torch.tensor(X_train,dtype=torch.float32)
    X_test=torch.tensor(X_test,dtype=torch.float32)
    dataset_train=[]
    for idx in range(len(X_train)):
        dataset_train.append(tuple([X_train[idx],Y_train[idx]]))
    dataset_test=[]
    for idx in range(len(X_test)):
        dataset_test.append(tuple([X_test[idx],Y_test[idx]]))


    return [adult_train, adult_test,dict_users], [Sensitive_attribute_train,Sensitive_attribute_test], X.shape[1],dataset_train,dataset_test, Sensitive_attribute_train,Sensitive_attribute_test,dict_users,test_dict_users   
#---------------------------------------------------------qdigest---------------------------------------------------------

class qdigestNode():
    def __init__(self, weight,min_w,back=None,left=None, right=None):
        self.weight = weight
        self.min=min_w
        self.left = left
        self.right = right
        self.back=back
        self.count=0
    
    def __repr__(self):
        return str(self.weight)
    
    def __str__(self):
        return str(self.weight)
        
class qdigest_simu():
    def __init__(self, depth, data_list):
        self.data = {}
        self.depth = depth
        self.range = 2**(depth-1)
        self.accuracy = 1/self.range
        self.count=len(data_list)
        self.root=qdigestNode(1,0)
        self.data[0]=[self.root]
        for i in range(1,depth):
            self.data[i]=[]
            for j in range(2**i):
                self.data[i].append(qdigestNode((j+1)/(2**(i)),j/(2**(i)),back=self.data[i-1][int(j/2)]))
                if j%2==0:
                    self.data[i-1][int(j/2)].left=self.data[i][j]
                else:
                    self.data[i-1][int(j/2)].right=self.data[i][j]
        for i in data_list:
            num=int(i/self.accuracy-1e-10)

            self.data[depth-1][num].count+=1
        #self.node_list=self.get_node(list([]),self.root)
        self.check_list=None #self.construct_check_list()

    

    def get_node(self,current_list,node):
        if node.left is None and node.right is None:
            current_list.append(node)
            return current_list
        elif  node.left.count is None and node.right.count is None:
            current_list.append(node)
            return current_list
        else:
            output_list=current_list
            if node.left is not None and node.left.count is not None:
                output_list=self.get_node(output_list,node.left)
                #print('left')
            output_list.append(node)
            #print('back')
            if node.right is not None and node.right.count is not None:
                output_list=self.get_node(output_list,node.right)  
                #print('right')       
            return output_list
        
    def construct_check_list(self):
        node_list=self.get_node(list([]),self.root)
        #by weight-min_w
        for i in range(len(node_list)):
            for j in range(i+1,len(node_list)):
                if node_list[i].weight>node_list[j].weight:
                    node_list[i],node_list[j]=node_list[j],node_list[i]
                elif node_list[i].weight==node_list[j].weight:
                    if node_list[i].min>node_list[j].min:
                        node_list[i],node_list[j]=node_list[j],node_list[i]
        return node_list

    def merge(self, qd):
        for dep in range(qd.depth):
            for j in range(len(qd.data[dep])):
                if self.data[dep][j].count is None:
                    self.data[dep][j].count=qd.data[dep][j].count
                else:
                    self.data[dep][j].count+=qd.data[dep][j].count
        self.count+=qd.count
        #self.check_list=self.construct_check_list()

    def get_quantile(self,q,type='quantile'):
        
        if type=='quantile':
            if q<0 or q>1:
                print("error")
                return None
            else:
                rank=q*self.count
        else:
            rank=q
            
        for node in self.check_list:
            rank-=node.count
            if rank<=0:
                return node.weight
        return rank
                
    def get_rank(self,weight):
        if weight<0 or weight>1:
            print("error")
            return None
        else:
            rank=0
            for node in self.check_list:
                if node.weight<=weight:
                    rank+=node.count
            return rank

    def compress(self,compress_factor=100):
        
        compress_bound=int(self.count/compress_factor)
        queue_list=copy.deepcopy(self.data[self.depth-2])
        while len(queue_list)>0:
            node=queue_list.pop(0)
            if node.count+node.left.count+node.right.count<compress_bound:
                node.count+=node.left.count+node.right.count
                #here we only do simulation, in real implementation, we should delete the node
                node.left.count=None
                node.right.count=None
                if node.back is not None:
                    if node.back not in queue_list:
                        queue_list.append(node.back)

        self.check_list=self.construct_check_list()

def Fed_Fair(net,dataset_train,dataset_test,client_dict,Sensitive_attribute_train,Sensitive_attribute_test,zeta,alpha,random_seed,args,depth=7,k=100,metric='DEOO',model='Select_1'):
    np.random.seed( random_seed )
    client_score_00={}
    client_score_01={}
    client_score_10={}
    client_score_11={}
    
    total_score_00=[]
    total_score_01=[]
    total_score_10=[]
    total_score_11=[]
    total_q_00=qdigest_simu(depth,[])
    total_q_01=qdigest_simu(depth,[])
    total_q_10=qdigest_simu(depth,[])
    total_q_11=qdigest_simu(depth,[])
    data_num_00=[]
    data_num_01=[]
    data_num_10=[]
    data_num_11=[]

    net.eval()

    for idx in range(args.num_users):
        local_score_00=[]
        local_score_01=[]
        local_score_10=[]
        local_score_11=[]
        

        for sample_idx in client_dict[idx]:
            data, label = dataset_train[sample_idx]
            attribute = Sensitive_attribute_train[sample_idx]
            if label==0 and attribute==0:
                local_score_00.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
            elif label==0 and attribute==1:
                local_score_01.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
            elif label==1 and attribute==0:
                local_score_10.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
            elif label==1 and attribute==1:
                local_score_11.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
                
        total_score_00+=local_score_00
        total_score_01+=local_score_01
        total_score_10+=local_score_10
        total_score_11+=local_score_11

        data_num_00.append(len(local_score_00))
        data_num_01.append(len(local_score_01))
        data_num_10.append(len(local_score_10))
        data_num_11.append(len(local_score_11))
        
        client_score_00[idx]=qdigest_simu(depth,np.array(local_score_00))
        client_score_01[idx]=qdigest_simu(depth,np.array(local_score_01))
        client_score_10[idx]=qdigest_simu(depth,np.array(local_score_10))
        client_score_11[idx]=qdigest_simu(depth,np.array(local_score_11))
        client_score_00[idx].compress(k)
        client_score_01[idx].compress(k)
        client_score_10[idx].compress(k)
        client_score_11[idx].compress(k)

        total_q_00.merge(client_score_00[idx])
        total_q_01.merge(client_score_01[idx])
        total_q_10.merge(client_score_10[idx])
        total_q_11.merge(client_score_11[idx])

    total_q_00.check_list=total_q_00.construct_check_list()
    total_q_01.check_list=total_q_01.construct_check_list()
    total_q_10.check_list=total_q_10.construct_check_list()
    total_q_11.check_list=total_q_11.construct_check_list()
    total_score_00=np.array(total_score_00)
    total_score_01=np.array(total_score_01)
    total_score_10=np.array(total_score_10)
    total_score_11=np.array(total_score_11)
    total_score_00.sort()
    total_score_01.sort()
    total_score_10.sort()
    total_score_11.sort()

    num_00=total_q_00.count
    num_01=total_q_01.count
    num_10=total_q_10.count
    num_11=total_q_11.count

    #data_num=np.array(data_num_00)+np.array(data_num_01)+np.array(data_num_10)+np.array(data_num_11)
    #N=num_00+num_01+num_10+num_11

    def func(r0,r1,alpha=0.04,n=1000,random_seed=0,metric='DEOO'):
        np.random.seed(random_seed)
        if metric=='DEOO':
            th_0=total_q_10.get_quantile(r0,type='rank')
            th_1=total_q_11.get_quantile(r1,type='rank')
            local_rank_10=[]
            local_rank_11=[]
            sum_above=np.zeros(n)
            sum_below=np.zeros(n)
            for idx in range(args.num_users):
                local_r0=client_score_10[idx].get_rank(th_0)
                local_r1=client_score_11[idx].get_rank(th_1)
                local_rank_10.append(local_r0)
                local_rank_11.append(local_r1)
                epsi_11=int(data_num_11[idx]*(depth-1)/k)
                epsi_10=int(data_num_10[idx]*(depth-1)/k)
                if local_r1+epsi_11>=data_num_11[idx]:
                    sum_above+=np.ones(n)*data_num_11[idx]/num_11
                else:
                    sum_above+=np.random.beta(local_r1+1+epsi_11,data_num_11[idx]-local_r1-epsi_11,n)*data_num_11[idx]/num_11
                if local_r0!=0:
                    sum_above-=np.random.beta(local_r0,data_num_10[idx]-local_r0+1,n)*data_num_10[idx]/num_10
                if local_r0+epsi_10>=data_num_10[idx]:
                    sum_below+=np.ones(n)*data_num_10[idx]/num_10
                else:
                    sum_below+=np.random.beta(local_r0+1+epsi_10,data_num_10[idx]-local_r0-epsi_10,n)*data_num_10[idx]/num_10
                if local_r1!=0:
                    sum_below-=np.random.beta(local_r1,data_num_11[idx]-local_r1+1,n)*data_num_11[idx]/num_11

            rate_above=(sum_above>=alpha).astype(int).sum()/n
            rate_below=(sum_below>=alpha).astype(int).sum()/n
            
            return rate_above+rate_below 

        elif metric=='DEO':
            th_0=total_q_10.get_quantile(r0,type='rank')
            th_1=total_q_11.get_quantile(r1,type='rank')

            sum_above=np.zeros(n)
            sum_below=np.zeros(n)
            sum_above_0=np.zeros(n)
            sum_below_0=np.zeros(n)
            for idx in range(args.num_users):
                local_r0=client_score_10[idx].get_rank(th_0)
                local_r1=client_score_11[idx].get_rank(th_1)
                local_r00=client_score_00[idx].get_rank(th_0)
                local_r01=client_score_01[idx].get_rank(th_1)

                epsi_11=int(data_num_11[idx]*(depth-1)/k)
                epsi_10=int(data_num_10[idx]*(depth-1)/k)
                epsi_01=int(data_num_01[idx]*(depth-1)/k)
                epsi_00=int(data_num_00[idx]*(depth-1)/k)
                
                if local_r1+epsi_11>=data_num_11[idx]:
                    sum_above+=np.ones(n)*data_num_11[idx]/num_11
                else:
                    sum_above+=np.random.beta(local_r1+1+epsi_11,data_num_11[idx]-local_r1-epsi_11,n)*data_num_11[idx]/num_11
                if local_r0!=0:
                    sum_above-=np.random.beta(local_r0,data_num_10[idx]-local_r0+1,n)*data_num_10[idx]/num_10
                
                if local_r0+epsi_10>=data_num_10[idx]:
                    sum_below+=np.ones(n)*data_num_10[idx]/num_10
                else:
                    sum_below+=np.random.beta(local_r0+1+epsi_10,data_num_10[idx]-local_r0-epsi_10,n)*data_num_10[idx]/num_10
                if local_r1!=0:
                    sum_below-=np.random.beta(local_r1,data_num_11[idx]-local_r1+1,n)*data_num_11[idx]/num_11

                if local_r01+epsi_01>=data_num_01[idx]:
                    sum_above_0+=np.ones(n)*data_num_01[idx]/num_01
                else:
                    sum_above_0+=np.random.beta(local_r01+1+epsi_01,data_num_01[idx]-local_r01-epsi_01,n)*data_num_01[idx]/num_01
                if local_r00!=0:
                    sum_above_0-=np.random.beta(local_r00,data_num_00[idx]-local_r00+1,n)*data_num_00[idx]/num_00
                
                if local_r00+epsi_00>=data_num_00[idx]:
                    sum_below_0+=np.ones(n)*data_num_00[idx]/num_00
                else:
                    sum_below_0+=np.random.beta(local_r00+1+epsi_00,data_num_00[idx]-local_r00-epsi_00,n)*data_num_00[idx]/num_00
                if local_r01!=0:
                    sum_below_0-=np.random.beta(local_r01,data_num_01[idx]-local_r01+1,n)*data_num_01[idx]/num_01
                

            rate_above=(sum_above>=alpha).astype(int).sum()/n
            rate_below=(sum_below>=alpha).astype(int).sum()/n
            rate_above_0=(sum_above_0>=alpha).astype(int).sum()/n
            rate_below_0=(sum_below_0>=alpha).astype(int).sum()/n
            
            return rate_above+rate_below+rate_above_0+rate_below_0

            

    def trans(r0):
        th_0= total_q_10.get_quantile(r0,type='rank')
        threshold=num_11/(2*num_11-(1/th_0-2)*num_10)
        if threshold<0:
            threshold=0
        elif threshold>1:
            threshold=1
        rank= total_q_11.get_rank(threshold)
        if rank==0:
            return 1
        if rank==num_11:
            return rank
        else:
            return rank
        
    def trans_0(r1):
        th_1= total_q_11.get_quantile(r1,type='rank')
        threshold=num_10/(2*num_10-(1/th_1-2)*num_11)
        if threshold<0:
            threshold=0
        elif threshold>1:
            threshold=1
        rank= total_q_10.get_rank(threshold)
        if rank==0:
            return 1
        if rank==num_10:
            return rank
        else:
            return rank
        
    def cal_error(r0,r1):
        th_0=total_q_10.get_quantile(r0,type='rank')
        th_1=total_q_11.get_quantile(r1,type='rank')
        error=0
        for idx in range(args.num_users):
            local_r0=client_score_10[idx].get_rank(th_0)
            local_r1=client_score_11[idx].get_rank(th_1)
            local_r00=client_score_00[idx].get_rank(th_0)
            local_r01=client_score_01[idx].get_rank(th_1)
            
            error+=((local_r0+0.5)/(data_num_10[idx]+1)*data_num_10[idx] \
                    +(local_r1+0.5)/(data_num_11[idx]+1)*data_num_11[idx] \
                    +(data_num_00[idx]+0.5-local_r00)/(data_num_00[idx]+1)*data_num_00[idx] \
                    +(data_num_01[idx]+0.5-local_r01)/(data_num_01[idx]+1)*data_num_01[idx])

        return error

    K=[]
    min_error=100000
    if model=='Select_0':
        for rank_0 in range(1,num_10+1):
            rank_1 = trans(rank_0)
            if func(rank_0,rank_1,alpha=alpha,random_seed=random_seed,metric=metric)<=zeta:
                K.append([rank_0,rank_1])
                error=cal_error(rank_0,rank_1)
                if error<min_error:
                    min_error=error
                    best_K=[rank_0,rank_1]
    elif model=='Select_1':
        for rank_1 in range(1,num_11+1):
            rank_0 = trans_0(rank_1)
            if func(rank_0,rank_1,alpha=alpha,random_seed=random_seed,metric=metric)<=zeta:
                K.append([rank_0,rank_1])
                error=cal_error(rank_0,rank_1)
                if error<min_error:
                    min_error=error
                    best_K=[rank_0,rank_1]
    if len(K)==0:
        print('no solution')
        if model=='Select_1':
            for rank_0 in range(1,num_10+1):
                rank_1 = trans(rank_0)
                if func(rank_0,rank_1,alpha=alpha,random_seed=random_seed,metric=metric)<=zeta:
                    K.append([rank_0,rank_1])
                    error=cal_error(rank_0,rank_1)
                    if error<min_error:
                        min_error=error
                        best_K=[rank_0,rank_1]
        elif model=='Select_0':
            for rank_1 in range(1,num_11+1):
                rank_0 = trans_0(rank_1)
                if func(rank_0,rank_1,alpha=alpha,random_seed=random_seed,metric=metric)<=zeta:
                    K.append([rank_0,rank_1])
                    error=cal_error(rank_0,rank_1)
                    if error<min_error:
                        min_error=error
                        best_K=[rank_0,rank_1]
        
        if len(K)==0:
            print('No solution for both')
            return None,None,None,None

    test_score_10=[]
    test_score_11=[]
    test_score_01=[]
    test_score_00=[]
    for sample_idx in range(len(dataset_test)):
        data, label = dataset_test[sample_idx]
        attribute = Sensitive_attribute_test[sample_idx]
        if label==1 and attribute==0:
            test_score_10.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        elif label==1 and attribute==1:
            test_score_11.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        elif label==0 and attribute==0:
            test_score_00.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        elif label==0 and attribute==1:
            test_score_01.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        
    test_score_10=np.array(test_score_10)
    test_score_11=np.array(test_score_11)
    test_score_00=np.array(test_score_00)
    test_score_01=np.array(test_score_01)

    print('len test_score_10',len(test_score_10))
    print('len test_score_11',len(test_score_11))
    print('pred 10',np.sum(test_score_10>=total_score_10[best_K[0]-1]))
    print('pred 11',np.sum(test_score_11>=total_score_11[best_K[1]-1]))
    print('DEOO',np.sum(test_score_11>=total_score_11[best_K[1]-1])/len(test_score_11)-np.sum(test_score_10>=total_score_10[best_K[0]-1])/len(test_score_10))
    print('K',best_K)
    
    score_train={'00':total_score_00,'01':total_score_01,'10':total_score_10,'11':total_score_11}
    score_test={'00':test_score_00,'01':test_score_01,'10':test_score_10,'11':test_score_11}
    indicator_df=pd.DataFrame(index=['train','test'],columns=['accuracy','DEOO','DPE','after accuracy','after DEOO','after DPE'])
    fair,indicator_dict=cal_DEOO(net, score_train, args,Sensitive_attribute_train,type='score')
    indicator_df.loc['train','DEOO']=fair
    indicator_df.loc['train','accuracy']=indicator_dict['accuracy']
    indicator_df.loc['train','DPE']=indicator_dict['DPE']
    #print('finish 1')
    fair,indicator_dict=cal_DEOO(net, score_test, args,Sensitive_attribute_test,type='score')
    indicator_df.loc['test','DEOO']=fair
    indicator_df.loc['test','accuracy']=indicator_dict['accuracy']
    indicator_df.loc['test','DPE']=indicator_dict['DPE']
    #print('finish 2')
    fair,indicator_dict=cal_DEOO(net, score_train, args,Sensitive_attribute_train,th_0=total_score_10[best_K[0]-1],th_1=total_score_11[best_K[1]-1],type='score')
    indicator_df.loc['train','after DEOO']=fair
    indicator_df.loc['train','after accuracy']=indicator_dict['accuracy']
    indicator_df.loc['train','after DPE']=indicator_dict['DPE']
    fair,indicator_dict=cal_DEOO(net, score_test, args,Sensitive_attribute_test,th_0=total_score_10[best_K[0]-1],th_1=total_score_11[best_K[1]-1],type='score')
    indicator_df.loc['test','after DEOO']=fair
    indicator_df.loc['test','after accuracy']=indicator_dict['accuracy']
    indicator_df.loc['test','after DPE']=indicator_dict['DPE']
    
    # DEOO_list=[]
    # for set in K:
    #     fair,indicator_dict=cal_DEOO(net, dataset_train, args,Sensitive_attribute_train,th_0=total_score_10[set[0]-1],th_1=total_score_11[set[1]-1])
    #     DEOO_list.append(fair)

    return indicator_df,best_K,min_error,K#,DEOO_list

def cal_DEOO(net_g,dataset,args,sensitive,th_0=None,th_1=None,type='data'):
    if th_0 is None:
        th_0 = 0.5
    if th_1 is None:
        th_1 = 0.5
    
    if type=='data':

        net_g.eval()

        n_00=0
        n_01=0
        n_10=0
        n_11=0

        acc_00=0
        acc_01=0
        acc_10=0
        acc_11=0

        for idx in range(len(dataset)):
            data, target = dataset[idx]
            if target==1:
                if args.gpu != -1:
                    data= data.cuda()
                log_probs = net_g.pred_prob(data.cuda())[1]
                if sensitive[idx]==1:
                    n_11+=1
                    if log_probs[1]>=th_1:
                        acc_11+=1        
                else:
                    n_10+=1
                    if log_probs[1]>=th_0:
                        acc_10+=1
            else:
                if args.gpu != -1:
                    data= data.cuda()
                log_probs = net_g.pred_prob(data.cuda())[1]
                if sensitive[idx]==1:
                    n_01+=1
                    if log_probs[1]<th_1:
                        acc_01+=1        
                else:
                    n_00+=1
                    if log_probs[1]<th_0:
                        acc_00+=1
        fairness=abs(acc_11/n_11-acc_10/n_10)
        indicator_dict={'acc_00':acc_00/n_00,'acc_01':acc_01/n_01,'acc_10':acc_10/n_10,'acc_11':acc_11/n_11,'n_00':n_00,'n_01':n_01,'n_10':n_10,'n_11':n_11 \
                        ,'fairness':fairness,'accuracy':(acc_00+acc_01+acc_10+acc_11)/(n_00+n_01+n_10+n_11),'DPE':abs(acc_00/n_00-acc_01/n_01)}
    else:
        n_00=len(dataset['00'])
        n_01=len(dataset['01'])
        n_10=len(dataset['10'])
        n_11=len(dataset['11'])

        acc_00=0
        acc_01=0
        acc_10=0
        acc_11=0

        for score in dataset['00']:
            if score<th_0:
                acc_00+=1
        for score in dataset['01']:
            if score<th_1:
                acc_01+=1
        for score in dataset['10']:
            if score>=th_0:
                acc_10+=1
        for score in dataset['11']:
            if score>=th_1:
                acc_11+=1
        fairness=abs(acc_11/n_11-acc_10/n_10)
        indicator_dict={'acc_00':acc_00/n_00,'acc_01':acc_01/n_01,'acc_10':acc_10/n_10,'acc_11':acc_11/n_11,'n_00':n_00,'n_01':n_01,'n_10':n_10,'n_11':n_11 \
                        ,'fairness':fairness,'accuracy':(acc_00+acc_01+acc_10+acc_11)/(n_00+n_01+n_10+n_11),'DPE':abs(acc_00/n_00-acc_01/n_01)}

    return fairness,indicator_dict

