import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import copy





def data_processing(data_name, num=100, train_rate=0.8,random_seed=0):
    if data_name == 'adult':
        path='C:/Users/yinqi/Desktop/FedFaiREE/'
        data=pd.read_csv(path+'data/adult/adult.data',header=None,sep=', ')
        test=pd.read_csv(path+'data/adult/adult.test',header=None,sep=', ')
        for col in data.columns:
            data=data[data[col]!='?']
        print(data.shape)
        for col in test.columns:
            test=test[test[col]!='?']
        print(test.shape)
        data=pd.concat([data,test])
        data.index=range(len(data))
        print(data.shape)
        column_list=['age','workclass','fnlwgt','education','education-num','marital-status','occupation','relationship','race','sex','capital-gain','capital-loss','hours-per-week','native-country','income']
        column_cate=['workclass','marital-status','occupation','relationship','race','native-country']
        data.columns=column_list
        data.drop(['education'],axis=1,inplace=True)
        Y=data['income'].copy()
        Sensitive_attribute=data['sex'].copy()
        for Y_i in Y.index:
            if Y[Y_i][:5]=='<=50K':
                Y[Y_i]=0
            else:
                Y[Y_i]=1

        for S_i in Sensitive_attribute.index:
            if Sensitive_attribute[S_i]=='Male':
                Sensitive_attribute[S_i]=1
            else:
                Sensitive_attribute[S_i]=0
        np.random.seed(random_seed)
        X=data.drop(['income'],axis=1)

        def dataSplit_by_random(data, n):
            #divide data index into n parts randomly
            index_list = []
            index = data.index
            index = np.array(index)
            np.random.shuffle(index)
            index_list = np.array_split(index, n)
            return index_list
        #divide data into n parts by label
        label='workclass'
        #data_list, index_list = dataSplit(data, label)
        index_list=dataSplit_by_random(data,num)
        #X=X.drop([label],axis=1)
        column_cate=['workclass','marital-status','sex','occupation','relationship','race','native-country']
        #column_cate.remove(label)
        X=pd.get_dummies(X,columns=column_cate)
        client_dataset=[]
        client_Sensitive_attribute=[]
        train_idx_list=[]
        test_idx_list=[]
        dict_users={}
        test_dict_users={}
        count=0
        test_count=0
        np.random.seed(random_seed)
        for idx in range(len(index_list)):
            #divide idx into train and test
            idx_train, idx_test = train_test_split(index_list[idx], test_size = 1-train_rate, random_state = random_seed)
            train_idx_list+=list(idx_train)
            test_idx_list+=list(idx_test)
            dict_users[idx]=np.arange(count,count+len(idx_train))
            count+=len(idx_train)
            test_dict_users[idx]=np.arange(test_count,test_count+len(idx_test))
            test_count+=len(idx_test)
            #print(len(idx_train),len(idx_test))

        X_test=X.loc[test_idx_list].values
        Y_test=Y.loc[test_idx_list].values
        Sensitive_attribute_test=Sensitive_attribute.loc[test_idx_list].values
        X_train=X.loc[train_idx_list].values
        Y_train=Y.loc[train_idx_list].values
        Sensitive_attribute_train=Sensitive_attribute.loc[train_idx_list].values
        scaler=StandardScaler()
        X_train=scaler.fit_transform(X_train)
        X_test=scaler.transform(X_test)
        #tensor
        X_train=torch.tensor(X_train,dtype=torch.float32)
        #Y_train=torch.tensor(Y_train.values,dtype=torch.long)
        X_test=torch.tensor(X_test,dtype=torch.float32)
        #Y_test=torch.tensor(Y_test.values,dtype=torch.long)

        dataset_train=[]
        for idx in range(len(X_train)):
            dataset_train.append(tuple([X_train[idx],Y_train[idx]]))
        dataset_test=[]
        for idx in range(len(X_test)):
            dataset_test.append(tuple([X_test[idx],Y_test[idx]]))

        return dataset_train,dataset_test,dict_users,test_dict_users,Sensitive_attribute_train,Sensitive_attribute_test
    


class qdigestNode():
    def __init__(self, weight,min_w,back=None,left=None, right=None):
        self.weight = weight
        self.min=min_w
        self.left = left
        self.right = right
        self.back=back
        self.count=0
    
    def __repr__(self):
        return str(self.weight)
    
    def __str__(self):
        return str(self.weight)
        
class qdigest_simu():
    def __init__(self, depth, data_list):
        self.data = {}
        self.depth = depth
        self.range = 2**(depth-1)
        self.accuracy = 1/self.range
        self.count=len(data_list)
        self.root=qdigestNode(1,0)
        self.data[0]=[self.root]
        for i in range(1,depth):
            self.data[i]=[]
            for j in range(2**i):
                self.data[i].append(qdigestNode((j+1)/(2**(i)),j/(2**(i)),back=self.data[i-1][int(j/2)]))
                if j%2==0:
                    self.data[i-1][int(j/2)].left=self.data[i][j]
                else:
                    self.data[i-1][int(j/2)].right=self.data[i][j]
        for i in data_list:
            num=int(i/self.accuracy-1e-10)

            self.data[depth-1][num].count+=1
        #self.node_list=self.get_node(list([]),self.root)
        self.check_list=None #self.construct_check_list()

    

    def get_node(self,current_list,node):
        if node.left is None and node.right is None:
            current_list.append(node)
            return current_list
        elif  node.left.count is None and node.right.count is None:
            current_list.append(node)
            return current_list
        else:
            output_list=current_list
            if node.left is not None and node.left.count is not None:
                output_list=self.get_node(output_list,node.left)
                #print('left')
            output_list.append(node)
            #print('back')
            if node.right is not None and node.right.count is not None:
                output_list=self.get_node(output_list,node.right)  
                #print('right')       
            return output_list
        
    def construct_check_list(self):
        node_list=self.get_node(list([]),self.root)
        #by weight-min_w
        for i in range(len(node_list)):
            for j in range(i+1,len(node_list)):
                if node_list[i].weight>node_list[j].weight:
                    node_list[i],node_list[j]=node_list[j],node_list[i]
                elif node_list[i].weight==node_list[j].weight:
                    if node_list[i].min>node_list[j].min:
                        node_list[i],node_list[j]=node_list[j],node_list[i]
        return node_list

    def merge(self, qd):
        for dep in range(qd.depth):
            for j in range(len(qd.data[dep])):
                if self.data[dep][j].count is None:
                    self.data[dep][j].count=qd.data[dep][j].count
                else:
                    self.data[dep][j].count+=qd.data[dep][j].count
        self.count+=qd.count
        #self.check_list=self.construct_check_list()

    def get_quantile(self,q,type='quantile'):
        
        if type=='quantile':
            if q<0 or q>1:
                print("error")
                return None
            else:
                rank=q*self.count
        else:
            rank=q
            
        for node in self.check_list:
            rank-=node.count
            if rank<=0:
                return node.weight
        return rank
                
    def get_rank(self,weight):
        if weight<0 or weight>1:
            print("error")
            return None
        else:
            rank=0
            for node in self.check_list:
                if node.weight<=weight:
                    rank+=node.count
            return rank

    def compress(self,compress_factor=100):
        
        compress_bound=int(self.count/compress_factor)
        queue_list=copy.deepcopy(self.data[self.depth-2])
        while len(queue_list)>0:
            node=queue_list.pop(0)
            if node.count+node.left.count+node.right.count<compress_bound:
                node.count+=node.left.count+node.right.count
                #here we only do simulation, in real implementation, we should delete the node
                node.left.count=None
                node.right.count=None
                if node.back is not None:
                    if node.back not in queue_list:
                        queue_list.append(node.back)

        self.check_list=self.construct_check_list()

def Fed_FaiREE(net,dataset_train,dataset_test,client_dict,Sensitive_attribute_train,Sensitive_attribute_test,zeta,alpha,random_seed,args,depth=7,k=100):
    np.random.seed( random_seed )
    client_score_00={}
    client_score_01={}
    client_score_10={}
    client_score_11={}
    
    total_score_00=[]
    total_score_01=[]
    total_score_10=[]
    total_score_11=[]
    total_q_00=qdigest_simu(depth,[])
    total_q_01=qdigest_simu(depth,[])
    total_q_10=qdigest_simu(depth,[])
    total_q_11=qdigest_simu(depth,[])
    data_num_00=[]
    data_num_01=[]
    data_num_10=[]
    data_num_11=[]

    net.eval()

    for idx in range(args.num_users):
        local_score_00=[]
        local_score_01=[]
        local_score_10=[]
        local_score_11=[]
        

        for sample_idx in client_dict[idx]:
            data, label = dataset_train[sample_idx]
            attribute = Sensitive_attribute_train[sample_idx]
            if label==0 and attribute==0:
                local_score_00.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
            elif label==0 and attribute==1:
                local_score_01.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
            elif label==1 and attribute==0:
                local_score_10.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
            elif label==1 and attribute==1:
                local_score_11.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
                
        total_score_00+=local_score_00
        total_score_01+=local_score_01
        total_score_10+=local_score_10
        total_score_11+=local_score_11

        data_num_00.append(len(local_score_00))
        data_num_01.append(len(local_score_01))
        data_num_10.append(len(local_score_10))
        data_num_11.append(len(local_score_11))
        
        client_score_00[idx]=qdigest_simu(depth,np.array(local_score_00))
        client_score_01[idx]=qdigest_simu(depth,np.array(local_score_01))
        client_score_10[idx]=qdigest_simu(depth,np.array(local_score_10))
        client_score_11[idx]=qdigest_simu(depth,np.array(local_score_11))
        client_score_00[idx].compress(k)
        client_score_01[idx].compress(k)
        client_score_10[idx].compress(k)
        client_score_11[idx].compress(k)

        total_q_00.merge(client_score_00[idx])
        total_q_01.merge(client_score_01[idx])
        total_q_10.merge(client_score_10[idx])
        total_q_11.merge(client_score_11[idx])

    total_q_00.check_list=total_q_00.construct_check_list()
    total_q_01.check_list=total_q_01.construct_check_list()
    total_q_10.check_list=total_q_10.construct_check_list()
    total_q_11.check_list=total_q_11.construct_check_list()
    total_score_00=np.array(total_score_00)
    total_score_01=np.array(total_score_01)
    total_score_10=np.array(total_score_10)
    total_score_11=np.array(total_score_11)
    total_score_00.sort()
    total_score_01.sort()
    total_score_10.sort()
    total_score_11.sort()

    num_00=total_q_00.count
    num_01=total_q_01.count
    num_10=total_q_10.count
    num_11=total_q_11.count

    #data_num=np.array(data_num_00)+np.array(data_num_01)+np.array(data_num_10)+np.array(data_num_11)
    #N=num_00+num_01+num_10+num_11

    def func(r0,r1,alpha=0.04,n=1000,random_seed=0):
        np.random.seed(random_seed)
        th_0=total_q_10.get_quantile(r0,type='rank')
        th_1=total_q_11.get_quantile(r1,type='rank')
        local_rank_10=[]
        local_rank_11=[]
        sum_above=np.zeros(n)
        sum_below=np.zeros(n)
        for idx in range(args.num_users):
            local_r0=client_score_10[idx].get_rank(th_0)
            local_r1=client_score_11[idx].get_rank(th_1)
            local_rank_10.append(local_r0)
            local_rank_11.append(local_r1)
            epsi_11=int(data_num_11[idx]*(depth-1)/k)
            epsi_10=int(data_num_10[idx]*(depth-1)/k)
            if local_r1+epsi_11>=data_num_11[idx]:
                sum_above+=np.ones(n)*data_num_11[idx]/num_11
            else:
                sum_above+=np.random.beta(local_r1+1+epsi_11,data_num_11[idx]-local_r1-epsi_11,n)*data_num_11[idx]/num_11
            if local_r0!=0:
                sum_above-=np.random.beta(local_r0,data_num_10[idx]-local_r0+1,n)*data_num_10[idx]/num_10
            if local_r0+epsi_10>=data_num_10[idx]:
                sum_below+=np.ones(n)*data_num_10[idx]/num_10
            else:
                sum_below+=np.random.beta(local_r0+1+epsi_10,data_num_10[idx]-local_r0-epsi_10,n)*data_num_10[idx]/num_10
            if local_r1!=0:
                sum_below-=np.random.beta(local_r1,data_num_11[idx]-local_r1+1,n)*data_num_11[idx]/num_11

        rate_above=(sum_above>=alpha).astype(int).sum()/n
        rate_below=(sum_below>=alpha).astype(int).sum()/n

        return rate_above+rate_below 

            

    def trans(r0):
        th_0= total_q_10.get_quantile(r0,type='rank')
        threshold=num_11/(2*num_11-(1/th_0-2)*num_10)
        if threshold<0:
            threshold=0
        elif threshold>1:
            threshold=1
        rank= total_q_11.get_rank(threshold)
        if rank==0:
            return 1
        if rank==num_11:
            return rank
        else:
            return rank
        
    def cal_error(r0,r1):
        th_0=total_q_10.get_quantile(r0,type='rank')
        th_1=total_q_11.get_quantile(r1,type='rank')
        error=0
        for idx in range(args.num_users):
            local_r0=client_score_10[idx].get_rank(th_0)
            local_r1=client_score_11[idx].get_rank(th_1)
            local_r00=client_score_00[idx].get_rank(th_0)
            local_r01=client_score_01[idx].get_rank(th_1)
            
            error+=((local_r0+0.5)/(data_num_10[idx]+1)*data_num_10[idx] \
                    +(local_r1+0.5)/(data_num_11[idx]+1)*data_num_11[idx] \
                    +(data_num_00[idx]+0.5-local_r00)/(data_num_00[idx]+1)*data_num_00[idx] \
                    +(data_num_01[idx]+0.5-local_r01)/(data_num_01[idx]+1)*data_num_01[idx])

        return error

    K=[]
    min_error=100000
    for rank_0 in range(1,num_10+1):
        rank_1 = trans(rank_0)
        if func(rank_0,rank_1,alpha)<=zeta:
            K.append([rank_0,rank_1])
            error=cal_error(rank_0,rank_1)
            if error<min_error:
                min_error=error
                best_K=[rank_0,rank_1]

    test_score_10=[]
    test_score_11=[]
    test_score_01=[]
    test_score_00=[]
    for sample_idx in range(len(dataset_test)):
        data, label = dataset_test[sample_idx]
        attribute = Sensitive_attribute_test[sample_idx]
        if label==1 and attribute==0:
            test_score_10.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        elif label==1 and attribute==1:
            test_score_11.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        elif label==0 and attribute==0:
            test_score_00.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        elif label==0 and attribute==1:
            test_score_01.append(net.pred_prob(data.cuda())[1].cpu().detach().numpy())
        
    test_score_10=np.array(test_score_10)
    test_score_11=np.array(test_score_11)
    test_score_00=np.array(test_score_00)
    test_score_01=np.array(test_score_01)

    print('len test_score_10',len(test_score_10))
    print('len test_score_11',len(test_score_11))
    print('pred 10',np.sum(test_score_10>=total_score_10[best_K[0]-1]))
    print('pred 11',np.sum(test_score_11>=total_score_11[best_K[1]-1]))
    print('DEOO',np.sum(test_score_11>=total_score_11[best_K[1]-1])/len(test_score_11)-np.sum(test_score_10>=total_score_10[best_K[0]-1])/len(test_score_10))
    print('K',best_K)
    
    score_train={'00':total_score_00,'01':total_score_01,'10':total_score_10,'11':total_score_11}
    score_test={'00':test_score_00,'01':test_score_01,'10':test_score_10,'11':test_score_11}
    indicator_df=pd.DataFrame(index=['train','test'],columns=['accuracy','DEOO','DPE','after accuracy','after DEOO','after DPE'])
    fair,indicator_dict=cal_DEOO(net, score_train, args,Sensitive_attribute_train,type='score')
    indicator_df.loc['train','DEOO']=fair
    indicator_df.loc['train','accuracy']=indicator_dict['accuracy']
    indicator_df.loc['train','DPE']=indicator_dict['DPE']
    #print('finish 1')
    fair,indicator_dict=cal_DEOO(net, score_test, args,Sensitive_attribute_test,type='score')
    indicator_df.loc['test','DEOO']=fair
    indicator_df.loc['test','accuracy']=indicator_dict['accuracy']
    indicator_df.loc['test','DPE']=indicator_dict['DPE']
    #print('finish 2')
    fair,indicator_dict=cal_DEOO(net, score_train, args,Sensitive_attribute_train,th_0=total_score_10[best_K[0]-1],th_1=total_score_11[best_K[1]-1],type='score')
    indicator_df.loc['train','after DEOO']=fair
    indicator_df.loc['train','after accuracy']=indicator_dict['accuracy']
    indicator_df.loc['train','after DPE']=indicator_dict['DPE']
    fair,indicator_dict=cal_DEOO(net, score_test, args,Sensitive_attribute_test,th_0=total_score_10[best_K[0]-1],th_1=total_score_11[best_K[1]-1],type='score')
    indicator_df.loc['test','after DEOO']=fair
    indicator_df.loc['test','after accuracy']=indicator_dict['accuracy']
    indicator_df.loc['test','after DPE']=indicator_dict['DPE']
    
    # DEOO_list=[]
    # for set in K:
    #     fair,indicator_dict=cal_DEOO(net, dataset_train, args,Sensitive_attribute_train,th_0=total_score_10[set[0]-1],th_1=total_score_11[set[1]-1])
    #     DEOO_list.append(fair)

    return indicator_df,best_K,min_error,K#,DEOO_list

def cal_DEOO(net_g,dataset,args,sensitive,th_0=None,th_1=None,type='data'):
    if th_0 is None:
        th_0 = 0.5
    if th_1 is None:
        th_1 = 0.5
    
    if type=='data':

        net_g.eval()

        n_00=0
        n_01=0
        n_10=0
        n_11=0

        acc_00=0
        acc_01=0
        acc_10=0
        acc_11=0

        for idx in range(len(dataset)):
            data, target = dataset[idx]
            if target==1:
                if args.gpu != -1:
                    data= data.cuda()
                log_probs = net_g.forward(data)[0]
                if sensitive[idx]==1:
                    n_11+=1
                    if log_probs[1]>=th_1:
                        acc_11+=1        
                else:
                    n_10+=1
                    if log_probs[1]>=th_0:
                        acc_10+=1
            else:
                if args.gpu != -1:
                    data= data.cuda()
                log_probs = net_g.forward(data)[0]
                if sensitive[idx]==1:
                    n_01+=1
                    if log_probs[1]<th_1:
                        acc_01+=1        
                else:
                    n_00+=1
                    if log_probs[1]<th_0:
                        acc_00+=1
        fairness=abs(acc_11/n_11-acc_10/n_10)
        indicator_dict={'acc_00':acc_00/n_00,'acc_01':acc_01/n_01,'acc_10':acc_10/n_10,'acc_11':acc_11/n_11,'n_00':n_00,'n_01':n_01,'n_10':n_10,'n_11':n_11 \
                        ,'fairness':fairness,'accuracy':(acc_00+acc_01+acc_10+acc_11)/(n_00+n_01+n_10+n_11),'DPE':abs(acc_00/n_00-acc_01/n_01)}
    else:
        n_00=len(dataset['00'])
        n_01=len(dataset['01'])
        n_10=len(dataset['10'])
        n_11=len(dataset['11'])

        acc_00=0
        acc_01=0
        acc_10=0
        acc_11=0

        for score in dataset['00']:
            if score<th_0:
                acc_00+=1
        for score in dataset['01']:
            if score<th_1:
                acc_01+=1
        for score in dataset['10']:
            if score>=th_0:
                acc_10+=1
        for score in dataset['11']:
            if score>=th_1:
                acc_11+=1
        fairness=abs(acc_11/n_11-acc_10/n_10)
        indicator_dict={'acc_00':acc_00/n_00,'acc_01':acc_01/n_01,'acc_10':acc_10/n_10,'acc_11':acc_11/n_11,'n_00':n_00,'n_01':n_01,'n_10':n_10,'n_11':n_11 \
                        ,'fairness':fairness,'accuracy':(acc_00+acc_01+acc_10+acc_11)/(n_00+n_01+n_10+n_11),'DPE':abs(acc_00/n_00-acc_01/n_01)}

    return fairness,indicator_dict

