import pandas as pd
import numpy as np
import torch
from FedFB.DP_server import *
from FairFed.DP_server import *
from util import FedFB_processing,Fed_Fair

def FedFB_model(num_features, info,args, prn = True, trial = False, select_round = False,seed = 0):
    #set seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    dataset=args.dataset
    model=args.model
    method=args.method

    if dataset == 'adult':
        Z, num_features, info = 2, num_features, info
    if dataset == 'compas':
        Z, num_features, info = 2, num_features, info
    
    if model == 'mlp':
        arc = mlp(num_features=num_features, num_hidden=args.hidden, num_classes=args.num_classes, seed = seed)
    else:
        Warning('Does not support this model!')
        exit(1)

    # set up the server
    server = Server(arc, info, args,train_prn = False, seed = seed, Z = Z, ret = True, prn = prn, trial = trial, select_round = select_round)

    # execute
    if method == 'fedfb':
        acc, dpdisp, classifier = server.FedFB(args)
    else:
        Warning('Does not support this method!')
        exit(1)

    #if not trial:
    return {'accuracy': acc}, classifier

    #return classifier

def FairFed_model(num_features, info,args, prn = True, trial = False, select_round = False,seed = 0):
    #set seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    dataset=args.dataset
    model=args.model
    method=args.method

    if dataset == 'adult':
        Z, num_features, info = 2, num_features, info
    if dataset == 'compas':
        Z, num_features, info = 2, num_features, info

    if model == 'mlp':
        arc = mlp(num_features=num_features, num_hidden=args.hidden, num_classes=args.num_classes, seed = seed)
    else:
        Warning('Does not support this model!')
        exit(1)

    # set up the server
    server = FairFed_Server(arc, info,args, train_prn = False, seed = seed, Z = Z, ret = True, prn = prn, trial = trial, select_round = select_round)

    # execute
    if method == 'fairfed':
        acc, dpdisp, classifier = server.FairFed(args)
    else:
        Warning('Does not support this method!')
        exit(1)

    #if not trial:
    return {'accuracy': acc}, classifier,FairFed_Server

    #return classifier

def FedFB_FedFair(times,args,zeta=0.05,alpha=0.09,depth=8,k=300,generate_method='random',generate_param=1):
    """"
    AFL+Fair
    Input:
        times: number of experiments
        args: hyperparameters and 
        zeta: fairness constraint confidence level
        alpha: fairness constraint
        depth: depth of Q digest
        k: compression parameter of Q digest
    """
    for random_seed in range(times):
        try:
            adult_info,sensitive_attribute,adult_num_features,dataset_train, \
            dataset_test, Sensitive_attribute_train,Sensitive_attribute_test, \
            dict_users,test_dict_users=FedFB_processing(data_name=args.dataset,data_path=args.data_path,num=args.num_users, train_rate=0.8,random_seed=random_seed, \
                                                        generate_method=generate_method,dir_alpha=generate_param)
        except:
            print('Random seed:',random_seed,'failed')
            continue
        net_glob=FedFB_model(adult_num_features, adult_info,args, seed = random_seed)[1]
        
        #save global model
        torch.save(net_glob.state_dict(), args.out_path+'FedFB_{}_alpha_{}.pth'.format(random_seed,alpha))
        
        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEOO',model='Select_1')
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEOO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEOO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)

        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEO',model='Select_1')
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)
            print('Fair_{} finished'.format(random_seed))

def FairFed_FedFair(times,args,zeta=0.05,alpha=0.09,depth=8,k=300,generate_method='random',generate_param=1):
    """"
    FairFed+Fair
    Input:
        times: number of experiments
        args: hyperparameters and 
        zeta: fairness constraint confidence level
        alpha: fairness constraint
        depth: depth of Q digest
        k: compression parameter of Q digest
    """
    for random_seed in range(35,times):
        try:
            adult_info,sensitive_attribute,adult_num_features,dataset_train, \
            dataset_test, Sensitive_attribute_train,Sensitive_attribute_test, \
            dict_users,test_dict_users=FedFB_processing(data_name=args.dataset,data_path=args.data_path,num=args.num_users, \
                                                        train_rate=0.8,random_seed=random_seed,generate_method=generate_method,dir_alpha=generate_param)
        except:
            print('Random seed:',random_seed,'failed')
            continue
        net_glob=FairFed_model(adult_num_features, adult_info,args, seed = random_seed)[1]
        
        #save global model
        torch.save(net_glob.state_dict(), args.out_path+'FedFB_{}_alpha_{}.pth'.format(random_seed,alpha))
        
        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEOO',model='Select_1')
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEOO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEOO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)

        indicator_df,best_K,min_error,K=Fed_Fair(net_glob,dataset_train,dataset_test,dict_users,
                                                Sensitive_attribute_train,Sensitive_attribute_test,
                                                zeta=zeta,alpha=alpha,random_seed=random_seed,args=args,depth=depth,k=k,metric='DEO',model='Select_1')
        if indicator_df is not None:
            indicator_df.to_csv(args.out_path+'DEO_{}_alpha_{}.csv'.format(random_seed,alpha))
            with open(args.out_path+'DEO_{}_alpha_{}.pkl'.format(random_seed,alpha), 'wb') as f:
                pickle.dump([best_K,min_error,K], f)
            print(indicator_df)
            print('Fair_{} finished'.format(random_seed))