import os
import itertools
import json
import numpy as np
import pandas as pd
import argparse
from utils.api import save, load, makedir_exist_ok
import matplotlib.pyplot as plt
from collections import defaultdict
from jenkspy import JenksNaturalBreaks

os.environ['KMP_DUPLICATE_LIB_OK']='True'
parser = argparse.ArgumentParser(description='analyze_data')
parser.add_argument('--type', default='dp', type=str)
args = vars(parser.parse_args())

result_path = './output/result'
save_format = 'png'
vis_path = './output/vis/{}'.format(save_format)
num_experiments = 1
exp = [str(x) for x in list(range(num_experiments))]

global_figure_indicator = 'test_acc'
# global_figure_indicator = 'average_participation_costs'

def make_controls(control_name):
    control_names = []
    for i in range(len(control_name)):
        control_names.extend(list('_'.join(x) for x in itertools.product(*control_name[i])))
    # controls = [exp] + data_names + model_names + [control_names]
    controls = [exp] + [control_names]
    controls = list(itertools.product(*controls))
    return controls


def make_control_list(file):
    controls = []
    if file == 'incentivize' or file == 'incentivize_25' or file == 'incentivize_noln' or file == 'incentivize_oldln' or file == 'change_label' or file == 'latest':
        # control_name = [[['CINIC10', 'CIFAR10', 'FashionMNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                 ['0.05', '0.1', '0.2'], ['1'], ['iid-equal'], 
        #                     ['simpfedincen'], ['epoch-5'], ['0.1'], 
        #                 ['0.005'], ['0.005'], ['0.3'], ['0.5'], ['3'], ['1']]]
        # CIFAR10_controls_6 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_6)

        # control_name = [[['CINIC10', 'CIFAR10', 'FashionMNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                 ['0.05', '0.1', '0.2'], ['1'], ['iid-equal'], 
        #                     ['fedavg'], ['epoch-5'], ['0.1'], 
        #                 ['0.005'], ['0.005'], ['0.3'], ['0.5'], ['3'], ['1']]]
        # CIFAR10_controls_7 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_7)

        # control_name = [[['CINIC10', 'CIFAR10', 'FashionMNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                 ['0'], ['0'], ['iid-equal'], 
        #                     ['fedavg'], ['epoch-5'], ['0.1'], 
        #                 ['0.005'], ['0.005'], ['0.3'], ['0.5'], ['3'], ['1']]]
        # CIFAR10_controls_8 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_8)


        # control_name = [[['CINIC10', 'CIFAR10', 'FashionMNIST'], ['cnn'], ['0.2'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                 ['0.1', '0.2', '0.3'], ['1'], ['iid-equal'], 
        #                     ['fedavg'], ['epoch-5'], ['0.1'], 
        #                 ['0.005'], ['0.005'], ['0.3'], ['0.5'], ['3'], ['1']]]
        # CIFAR10_controls_6 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_6)

        # test_fedavg_trial_rest
        control_name = [[['CINIC10', 'CIFAR10', 'FashionMNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
                        ['0.2', '0.3', '0.4'], ['1'], ['iid-equal'], 
                            ['fedavg'], ['epoch-5'], ['0.1'], 
                        ['0.005'], ['0.005'], ['0.3'], ['0.5'], ['3'], ['1']]]
        CIFAR10_controls_6 = make_controls(control_name)
        controls.extend(CIFAR10_controls_6)


        control_name = [[['FashionMNIST', 'CIFAR10', 'CINIC10'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
                        ['0.2', '0.3', '0.4'], ['1'], ['iid-equal'], 
                         ['simpfedincen'], ['epoch-5'], ['0.1'], 
                        ['0.005'], ['0.0001'], ['-0.1'], ['0.3'], ['3'], ['1'], ['1', '2', '3']]]
        CIFAR10_controls_6 = make_controls( control_name)
        controls.extend(CIFAR10_controls_6)



        # test_method
        # control_name = [[['CINIC10', 'CIFAR10', 'FashionMNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                 ['0.3', '0.4', '0.5'], ['1'], ['iid-equal'], 
        #                     ['simpfedincen'], ['epoch-5'], ['0.1'], 
        #                 ['0.005'], ['0.005'], ['0.3'], ['0.5'], ['3'], ['1']]]
        # CIFAR10_controls_6 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_6)
        
        # control_name = [[['CINIC10', 'CIFAR10', 'FashionMNIST'], ['cnn'], ['0.2'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                 ['0.1', '0.2', '0.3'], ['1'], ['iid-equal'], 
        #                     ['simpfedincen'], ['epoch-5'], ['0.1'], 
        #                 ['0.005'], ['0.005'], ['0.3'], ['0.5'], ['3'], ['1']]]
        # CIFAR10_controls_6 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_6)








        # control_name = [[['CIFAR10', 'FashionMNIST', 'CINIC10'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], ['0.05', '0.1', '0.2'], 
        #                     ['1'], ['iid-equal'], ['simpfedincen'], ['epoch-5'], ['0.1'], ['0.005'], ['0.005'], 
        #                     ['0.3'], ['0.5'], ['3'], ['1']]]
        # CIFAR10_controls_3 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_3)
    # if file == 'CIFAR10':
    #     control_name = [[['CIFAR10'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], ['0.05', '0.1', '0.2'], 
    #                         ['1'], ['iid-equal'], ['simpfedincen'], ['epoch-5'], ['0.1'], ['0.005'], ['0.005'], 
    #                         ['0.3'], ['0.5'], ['3'], ['1']]]

        
    #     # ['CIFAR10'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
    #     #                     ['0.05', '0.1'], ['1'], ['iid-equal'], 
    #     #                      ['fedincen'], ['epoch-5'], ['0.1', '0.3'], 
    #     #                     ['0.005'], ['0.005'], ['0.3'], ['0.5']

    #     # control_name = [[['CIFAR10', 'CIFAR100', 'FEMNIST'], ['cnn'], ['0.1', '0.3', '0.5'], ['100'], ['non-iid-l-1', 'non-iid-l-2','non-iid-d-0.1', 'non-iid-d-0.3'], 
    #     #                      ['dynamicfl'], ['5'], ['nonpre']]]
    #     CIFAR10_controls_3 = make_controls(control_name)
    #     controls.extend(CIFAR10_controls_3)

    #     # control_name = [[['CIFAR10'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
    #     #                     ['0.05', '0.1'], ['1'], ['iid-equal'], 
    #     #                      ['fedavg'], ['epoch-5'], ['0.1'], 
    #     #                     ['0.005'], ['0.005'], ['0.3'], ['0.5']]]
    #     # CIFAR10_controls_4 = make_controls(control_name)
    #     # controls.extend(CIFAR10_controls_4)                    
    # elif file == 'FashionMNIST':
    #     control_name = [[['FashionMNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], ['0.05', '0.1', '0.2'], 
    #                         ['1'], ['iid-equal'], ['simpfedincen'], ['epoch-5'], ['0.1'], ['0.005'], ['0.005'], 
    #                         ['0.3'], ['0.5'], ['3'], ['1']]]
    #     CIFAR10_controls_3 = make_controls(control_name)
    #     controls.extend(CIFAR10_controls_3)

    #     # control_name = [[['FashionMNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
    #     #                     ['0.05', '0.1'], ['1'], ['iid-equal'], 
    #     #                      ['fedavg'], ['epoch-5'], ['0.1'], 
    #     #                     ['0.005'], ['0.005'], ['0.3'], ['0.5']]]
    #     # CIFAR10_controls_4 = make_controls(control_name)
    #     # controls.extend(CIFAR10_controls_4) 
    # elif file == 'MNIST':
    #     control_name = [[['MNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], ['0.05', '0.1', '0.2'], 
    #                         ['1'], ['iid-equal'], ['simpfedincen'], ['epoch-5'], ['0.1'], ['0.005'], ['0.005'], 
    #                         ['0.3'], ['0.5'], ['3'], ['1']]]
    #     CIFAR10_controls_3 = make_controls(control_name)
    #     controls.extend(CIFAR10_controls_3)
    # elif file == 'CINIC10':
    #     control_name = [[['CINIC10'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], ['0.05', '0.1', '0.2'], 
    #                         ['1'], ['iid-equal'], ['simpfedincen'], ['epoch-5'], ['0.1'], ['0.005'], ['0.005'], 
    #                         ['0.3'], ['0.5'], ['3'], ['1']]]
    #     CIFAR10_controls_3 = make_controls(control_name)
    #     controls.extend(CIFAR10_controls_3)
        # control_name = [[['MNIST'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                     ['0.05', '0.1'], ['1'], ['iid-equal'], 
        #                      ['fedavg'], ['epoch-5'], ['0.1'], 
        #                     ['0.005'], ['0.005'], ['0.3'], ['0.5']]]
        # CIFAR10_controls_4 = make_controls(control_name)
        # controls.extend(CIFAR10_controls_4) 

    # if file == 'fs':
    #     control_name = [[['fs']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     data_names = [['SVHN']]
    #     model_names = [['wresnet28x2']]
    #     svhn_controls = make_controls(data_names, model_names, control_name)
    #     data_names = [['CIFAR100']]
    #     model_names = [['wresnet28x8']]
    #     cifar100_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls + svhn_controls + cifar100_controls
    # elif file == 'ps':
    #     control_name = [[['250', '4000']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['250', '1000']]]
    #     data_names = [['SVHN']]
    #     model_names = [['wresnet28x2']]
    #     svhn_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['2500', '10000']]]
    #     data_names = [['CIFAR100']]
    #     model_names = [['wresnet28x8']]
    #     cifar100_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls + svhn_controls + cifar100_controls
    # elif file == 'cd':
    #     control_name = [[['250', '4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['250', '1000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['1']]]
    #     data_names = [['SVHN']]
    #     model_names = [['wresnet28x2']]
    #     svhn_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['2500', '10000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'],
    #                      ['1']]]
    #     data_names = [['CIFAR100']]
    #     model_names = [['wresnet28x8']]
    #     cifar100_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls + svhn_controls + cifar100_controls
    # elif file == 'ub':
    #     control_name = [
    #         [['250', '4000'], ['fix-mix'], ['100'], ['0.1'], ['non-iid-d-0.1', 'non-iid-d-0.3'], ['5'], ['0.5'], ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [
    #         [['250', '1000'], ['fix-mix'], ['100'], ['0.1'], ['non-iid-d-0.1', 'non-iid-d-0.3'], ['5'], ['0.5'], ['1']]]
    #     data_names = [['SVHN']]
    #     model_names = [['wresnet28x2']]
    #     svhn_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['2500', '10000'], ['fix-mix'], ['100'], ['0.1'], ['non-iid-d-0.1', 'non-iid-d-0.3'], ['5'],
    #                      ['0.5'], ['1']]]
    #     data_names = [['CIFAR100']]
    #     model_names = [['wresnet28x8']]
    #     cifar100_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls + svhn_controls + cifar100_controls
    # elif file == 'loss':
    #     control_name = [[['4000'], ['fix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls
    # elif file == 'local-epoch':
    #     control_name = [[['4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['1'], ['0.5'], ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls
    # elif file == 'gm':
    #     control_name = [[['4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0'], ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls
    # elif file == 'sbn':
    #     control_name = [[['250', '4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['0']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls
    # elif file == 'alternate':
    #     control_name = [[['4000'], ['fix-batch'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'],
    #                      ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls_1 = make_controls(data_names, model_names, control_name)
    #     control_name = [[['4000'], ['fix', 'fix-batch'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'],
    #                      ['1'], ['0']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls_2 = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls_1 + cifar10_controls_2
    # elif file == 'fl':
    #     control_name = [
    #         [['fs'], ['sup'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'],
    #          ['0.5'], ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [
    #         [['fs'], ['sup'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'],
    #          ['0.5'], ['1']]]
    #     data_names = [['SVHN']]
    #     model_names = [['wresnet28x2']]
    #     svhn_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [
    #         [['fs'], ['sup'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'],
    #          ['0.5'], ['1']]]
    #     data_names = [['CIFAR100']]
    #     model_names = [['wresnet28x8']]
    #     cifar100_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls + svhn_controls + cifar100_controls
    # elif file == 'fsgd':
    #     control_name = [[['4000'], ['fix-fsgd'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['0'], ['0'], ['1']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls
    # elif file == 'frgd':
    #     control_name = [
    #         [['250', '4000'], ['fix-frgd'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'],
    #          ['5'], ['0.5'], ['1'], ['0']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [
    #         [['250', '1000'], ['fix-frgd'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'],
    #          ['5'], ['0.5'], ['1'], ['0']]]
    #     data_names = [['SVHN']]
    #     model_names = [['wresnet28x2']]
    #     svhn_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['2500', '10000'], ['fix-frgd'], ['100'], ['0.1'],
    #                      ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
    #     data_names = [['CIFAR100']]
    #     model_names = [['wresnet28x8']]
    #     cifar100_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls + svhn_controls + cifar100_controls
    # elif file == 'fmatch':
    #     control_name = [[['250', '4000'], ['fix-fmatch'], ['100'], ['0.1'],
    #                      ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
    #     data_names = [['CIFAR10']]
    #     model_names = [['wresnet28x2']]
    #     cifar10_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['250', '1000'], ['fix-fmatch'], ['100'], ['0.1'],
    #                      ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
    #     data_names = [['SVHN']]
    #     model_names = [['wresnet28x2']]
    #     svhn_controls = make_controls(data_names, model_names, control_name)
    #     control_name = [[['2500', '10000'], ['fix-fmatch'], ['100'], ['0.1'],
    #                      ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
    #     data_names = [['CIFAR100']]
    #     model_names = [['wresnet28x8']]
    #     cifar100_controls = make_controls(data_names, model_names, control_name)
    #     controls = cifar10_controls + svhn_controls + cifar100_controls
    # else:
        # raise ValueError('Not valid file')
    return controls


def main():
    # files = ['fs', 'ps', 'cd', 'ub', 'loss', 'local-epoch', 'gm', 'sbn', 'alternate', 'fl', 'fsgd', 'frgd', 'fmatch']
    global result_path, vis_path, num_experiments, exp

    result_path = './output/result/{}'.format(args['type'])
    vis_path = './output/vis/{}'.format(args['type'])
    files = [args['type']]

    if args['type'] == 'incentivize' or args['type'] == 'incentivize_25' or args['type'] == 'incentivize_noln' or args['type'] == 'incentivize_oldln' or args['type'] == 'change_label' or args['type'] == 'latest':
        num_experiments = 1
    else:
        raise ValueError('Not valid type')
    exp = [str(x) for x in list(range(num_experiments))]
    # files = ['CIFAR10']
    # files = ['CIFAR10', 'FashionMNIST', 'CINIC10']
    controls = []
    for file in files:
        controls += make_control_list(file)
    processed_result_exp, processed_result_history = process_result(controls)
    with open('{}/processed_result_exp.json'.format(result_path), 'w') as fp:
        json.dump(processed_result_exp, fp, indent=2)
    save(processed_result_exp, os.path.join(result_path, 'processed_result_exp.pt'))
    save(processed_result_history, os.path.join(result_path, 'processed_result_history.pt'))
    extracted_processed_result_exp = {}
    extracted_processed_result_history = {}
    # if processed_result_exp:
    # extract_processed_result(extracted_processed_result_exp, processed_result_exp, [])
    # if processed_result_history:
    extract_processed_result(extracted_processed_result_history, processed_result_history, [])
    # if extracted_processed_result_exp:
    #     df_exp = make_df_exp(extracted_processed_result_exp)
    if extracted_processed_result_history:
        df_history = make_df_history(extracted_processed_result_history)
    df_exp = {}
    # make_vis(df_exp, df_history)
    make_vis(extracted_processed_result_history)
    return


def process_result(controls):
    processed_result_exp, processed_result_history = {}, {}
    for control in controls:
        model_tag = '_'.join(control)
        extract_result(list(control), model_tag, processed_result_exp, processed_result_history)
    if processed_result_exp:
        summarize_result(processed_result_exp)
    if processed_result_history:
        summarize_result(processed_result_history)
    return processed_result_exp, processed_result_history


def extract_result(control, model_tag, processed_result_exp, processed_result_history):
    if len(control) == 1:
        exp_idx = exp.index(control[0])
        base_result_path_i = os.path.join(result_path, '{}.pt'.format(model_tag))
        if os.path.exists(base_result_path_i):
            base_result = load(base_result_path_i)
            
            for k in base_result['logger']['train'].history:
                # print(f'k: {k}')
                # metric_name = k.split('/')[1]
                metric_name = k
                if metric_name not in processed_result_history:
                    processed_result_history[metric_name] = {'history': [None for _ in range(num_experiments)]}
                # processed_result_exp[metric_name]['exp'][exp_idx] = base_result['logger']['test'].mean[k]
                processed_result_history[metric_name]['history'][exp_idx] = base_result['logger']['train'].history[k]
            # else:
            #     for k in base_result['logger']['test'].mean:
            #         metric_name = k.split('/')[1]
            #         if metric_name not in processed_result_exp:
            #             processed_result_exp[metric_name] = {'exp': [None for _ in range(num_experiments)]}
            #             processed_result_history[metric_name] = {'history': [None for _ in range(num_experiments)]}
            #         processed_result_exp[metric_name]['exp'][exp_idx] = base_result['logger']['test'].mean[k]
            #         processed_result_history[metric_name]['history'][exp_idx] = base_result['logger']['train'].history[k]
        else:
            print('Missing {}'.format(base_result_path_i))
    else:
        if control[1] not in processed_result_exp:
            processed_result_exp[control[1]] = {}
            processed_result_history[control[1]] = {}
        extract_result([control[0]] + control[2:], model_tag, processed_result_exp[control[1]],
                       processed_result_history[control[1]])
    return


def summarize_result(processed_result):
    if 'exp' in processed_result:
        pivot = 'exp'
        processed_result[pivot] = np.stack(processed_result[pivot], axis=0)
        processed_result['mean'] = np.mean(processed_result[pivot], axis=0).item()
        processed_result['std'] = np.std(processed_result[pivot], axis=0).item()
        processed_result['max'] = np.max(processed_result[pivot], axis=0).item()
        processed_result['min'] = np.min(processed_result[pivot], axis=0).item()
        processed_result['argmax'] = np.argmax(processed_result[pivot], axis=0).item()
        processed_result['argmin'] = np.argmin(processed_result[pivot], axis=0).item()
        processed_result[pivot] = processed_result[pivot].tolist()
    elif 'history' in processed_result:
        pivot = 'history'
        filter_length = []
        for i in range(len(processed_result[pivot])):
            x = processed_result[pivot][i]
            # if len(processed_result[pivot][i]) in [400, 800]:
            #     filter_length.append(x)
            # elif len(processed_result[pivot][i]) == 801:
            #     filter_length.append(x[:800])
            # else:
            #     filter_length.append(x + [x[-1]] * (800 - len(x)))
            filter_length.append(x)
        processed_result[pivot] = filter_length
        processed_result[pivot] = np.stack(processed_result[pivot], axis=0)
        # processed_result['mean'] = np.mean(processed_result[pivot], axis=0)
        # processed_result['std'] = np.std(processed_result[pivot], axis=0)
        # processed_result['max'] = np.max(processed_result[pivot], axis=0)
        # processed_result['min'] = np.min(processed_result[pivot], axis=0)
        # processed_result['argmax'] = np.argmax(processed_result[pivot], axis=0)
        # processed_result['argmin'] = np.argmin(processed_result[pivot], axis=0)
        processed_result[pivot] = processed_result[pivot].tolist()
    else:
        for k, v in processed_result.items():
            summarize_result(v)
        return
    return


def extract_processed_result(extracted_processed_result, processed_result, control):
    if 'exp' in processed_result or 'history' in processed_result:
        exp_name = '_'.join(control[:-1])
        metric_name = control[-1]
        if exp_name not in extracted_processed_result:
            extracted_processed_result[exp_name] = defaultdict()
        # extracted_processed_result[exp_name]['{}_mean'.format(metric_name)] = processed_result['mean']
        # extracted_processed_result[exp_name]['{}_std'.format(metric_name)] = processed_result['std']
        extracted_processed_result[exp_name]['{}_history'.format(metric_name)] = processed_result['history']

    else:
        for k, v in processed_result.items():
            extract_processed_result(extracted_processed_result, v, control + [k])
    return


def write_xlsx(path, df, startrow=0):
    writer = pd.ExcelWriter(path, engine='xlsxwriter')
    for df_name in df:
        df[df_name] = pd.concat(df[df_name])
        df[df_name].to_excel(writer, sheet_name='Sheet1', startrow=startrow + 1)
        writer.sheets['Sheet1'].write_string(startrow, 0, df_name)
        startrow = startrow + len(df[df_name].index) + 3
    writer.save()
    return


def make_df_exp(extracted_processed_result_exp):
    df = defaultdict(list)
    for exp_name in extracted_processed_result_exp:
        control = exp_name.split('_')
        if len(control) == 3:
            data_name, model_name, num_supervised = control
            index_name = ['1']
            df_name = '_'.join([data_name, model_name, num_supervised])
        elif len(control) == 10:
            data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
            local_epoch, gm, sbn = control
            index_name = ['_'.join([local_epoch, gm])]
            df_name = '_'.join(
                [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn])
        elif len(control) == 11:
            data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
            local_epoch, gm, sbn, ft = control
            index_name = ['_'.join([local_epoch, gm])]
            df_name = '_'.join(
                [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn,
                 ft])
        else:
            raise ValueError('Not valid control')
        df[df_name].append(pd.DataFrame(data=extracted_processed_result_exp[exp_name], index=index_name))
    write_xlsx('{}/result_exp.xlsx'.format(result_path), df)
    return df


def make_df_history(extracted_processed_result_history):
    return
    # df = defaultdict(list)
    # for exp_name in extracted_processed_result_history:
    #     control = exp_name.split('_')
        
    #     if len(control) == 16:
    #         data_name, model_name, active_rate, num_clients, lr, malicious_way, \
    #         malicious_ratio, _, data_split, algo_name, _, lamda, objective_sigmoid_s,\
    #             objective_func_lr, _, _ = control
            
    #         df_name = '_'.join(
    #             [data_name, model_name, active_rate, num_clients, malicious_way, malicious_ratio, \
    #                 algo_name, lamda])
    #         for k in extracted_processed_result_history[exp_name]:
    #             # data_name, malicious_way, malicious_ratio, algo_name, lamda
    #             index_name = ['_'.join([k])]
    #             # a = extracted_processed_result_history[exp_name][k]
    #             # b = np.array(extracted_processed_result_history[exp_name][k]).reshape(1, -1)
    #             # print(k)
    #             df[df_name].append(
    #                 pd.DataFrame(data=np.array(extracted_processed_result_history[exp_name][k], dtype=object).reshape(1, -1), index=index_name))
    #             # df[df_name].append(
    #             #     pd.DataFrame(data=np.array(extracted_processed_result_history[exp_name][k]), index=index_name))
        
        
        # group_high_freq_trend
        # if len(control) == 9:
        #     data_name, model_name, active_rate, num_clients, data_split_mode, algo_mode, \
        #     max_gradient_update_num, max_combination_size, selection_method = control
        #     df_name = '_'.join(
        #         [data_name, model_name, active_rate, num_clients, data_split_mode, algo_mode, \
        #             max_gradient_update_num, max_combination_size, selection_method])
        #     for k in extracted_processed_result_history[exp_name]:
        #         index_name = ['_'.join([active_rate, data_split_mode, k])]
        #         # print(k)
        #         df[df_name].append(
        #             pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))
        # # different portion freq
        # elif len(control) == 12:
        #     data_name, model_name, active_rate, num_clients, data_split_mode, algo_mode, \
        #     max_gradient_update_num, max_combination_size, selection_method = control
        #     df_name = '_'.join(
        #         [data_name, model_name, active_rate, num_clients, data_split_mode, algo_mode, \
        #             max_gradient_update_num, max_combination_size, selection_method])
        #     for k in extracted_processed_result_history[exp_name]:
        #         index_name = ['_'.join([active_rate, data_split_mode, k])]
        #         # print(k)
        #         df[df_name].append(
        #             pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))

        # if len(control) == 3:
        #     data_name, model_name, num_supervised = control
        #     index_name = ['1']
        #     for k in extracted_processed_result_history[exp_name]:
        #         df_name = '_'.join([data_name, model_name, num_supervised, k])
        #         df[df_name].append(
        #             pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))
        # elif len(control) == 10:
        #     data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
        #     local_epoch, gm, sbn = control
        #     index_name = ['_'.join([local_epoch, gm])]
        #     for k in extracted_processed_result_history[exp_name]:
        #         df_name = '_'.join(
        #             [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode,
        #              sbn, k])
        #         df[df_name].append(
        #             pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))
        # elif len(control) == 11:
        #     data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
        #     local_epoch, gm, sbn, ft = control
        #     index_name = ['_'.join([local_epoch, gm])]
        #     for k in extracted_processed_result_history[exp_name]:
        #         df_name = '_'.join(
        #             [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode,
        #              sbn, ft, k])
        #         df[df_name].append(
        #             pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))
    #     else:
    #         raise ValueError('Not valid control')
    # write_xlsx('{}/result_history.xlsx'.format(result_path), df)

    # return df

def make_vis(extracted_processed_result_history):
    data_split_mode_dict = {'iid': 'IID', 'non-iid-l-2': 'Non-IID, $K=2$',
                            'non-iid-d-0.1': 'Non-IID, $\operatorname{Dir}(0.1)$',
                            'non-iid-d-0.3': 'Non-IID, $\operatorname{Dir}(0.3)$', 'fix-fsgd': 'DynamicSgd + FixMatch',
                            'fix-batch': 'FedAvg + FixMatch', 'fs': 'Fully Supervised', 'ps': 'Partially Supervised'}
    

    color = {'5_0.5': 'red', '1_0.5': 'orange', '5_0': 'dodgerblue', '5_0.9': 'blue', '5_0.5_nomixup': 'green',
             '5_0_nomixup': 'green', 'iid': 'red', 'non-iid-l-2': 'orange', 'non-iid-d-0.1': 'dodgerblue',
             'non-iid-d-0.3': 'green', 'fix-fsgd': 'red', 'fix-batch': 'blue',
             'fs': 'black', 'ps': 'orange',
             'active_rate_0.1': 'green',
             'active_rate_0.3': 'red',
             'active_rate_0.5': 'dodgerblue',
             'malicious_clients': 'green',
             'benign_clients': 'red',
             'partcipation_clients': 'dodgerblue',
             'active_clients': 'orange',
             'FedAvg_0.05': 'red',
             'FedAvg_0.1': 'green',
             'FedAvg_0.2': 'orange',
             'FedAvg': 'red',
             'FedAvg': 'green',
             'ICL_0.1': 'black',
             'ICL_0.3': 'dodgerblue',
             'ICL_0.1': 'red',
             'ICL_0.3': 'blue',
             'ICL_3_interval_benign': 'red',
             'ICL_3_interval_malicious': 'blue',
             'ICL_cur_round_benign': 'black',
             'ICL_cur_round_malicious': 'green',
             'ICL_plan_1': 'black',
            'ICL_plan_2': 'red',
            'ICL_plan_3': 'blue',
             'ICL_0.05_0.1': 'blue',
             'ICL_0.05_0.3': 'blue',
             'ICL_0.1_0.1': 'blue',
             'ICL_0.1_0.3': 'blue',
             'ICL_0.05_0.1_malicious': 'red',
             'ICL_0.05_0.3_malicious': 'red',
             'ICL_0.1_0.1_malicious': 'red',
             'ICL_0.1_0.3_malicious': 'red',
             }
    linestyle = {'5_0.5': '-', '1_0.5': '--', '5_0': ':', '5_0.5_nomixup': '-.', '5_0_nomixup': '-.',
                 '5_0.9': (0, (1, 5)), 'iid': '-', 'non-iid-l-2': '--', 'non-iid-d-0.1': '-.', 'non-iid-d-0.3': ':',
                 'fix-fsgd': '--', 'fix-batch': ':', 'fs': '-', 'ps': '-.', ''
                 'active_rate_0.1': ':',
                 'active_rate_0.3': '-',
                 'active_rate_0.5': '-.',
                 'malicious_clients': ':',
                'benign_clients': '-',
                'partcipation_clients': '-.',
                'active_clients': '--',
                'FedAvg_0.05': '-',
                'FedAvg_0.1': '-.',
                'FedAvg_0.2': (0, (3, 1, 1, 1, 1, 1)),
                'FedAvg': '-',
                'FedAvg': '-.',
                'ICL_0.1': ':',
                'ICL_0.3': (0, (1, 5)),
                'ICL_plan_1': ':',
                'ICL_plan_2': (0, (1, 5)),
                'ICL_plan_3': ':',

                'ICL_0.1_benign': ':',
                'ICL_0.1_malicious': (0, (1, 5)),
                'ICL_3_interval_benign': ':',
                'ICL_3_interval_malicious': (0, (1, 5)),
                'ICL_cur_round_benign': '-.',
                'ICL_cur_round_malicious': '--',

                'ICL_0.05_0.1': ':',
                'ICL_0.05_0.3': (0, (1, 5)),
                'ICL_0.1_0.1': '--',
                'ICL_0.1_0.3': '-.'
                 }

    loc_dict = {'Accuracy': 'upper right', 'Loss': 'upper right', 'average': 'center right'}
    fontsize = {'legend': 16, 'label': 16, 'ticks': 16}
    fig = {}
    reorder_fig = []
    for exp_name in extracted_processed_result_history:
        # control = exp_name.split('_')
    # for df_name in df_history:
        control = exp_name.split('_')
        print(len(control))
        # if len(control) == 16:
        #     data_name, model_name, active_rate, num_clients, lr, malicious_way, \
        #     malicious_ratio, _, data_split, algo_name, _, lamda, objective_sigmoid_s,\
        #         objective_func_lr, _, _ = control
            
        #     # participation_clients_participation_costs_history = []
        #     # active_clients_participation_costs_history = []
        #     # malicious_clients_participation_costs = []
        #     # benign_clients_participation_costs = []
        #     # malicious_clients_participation_costs_num = []
        #     # benign_clients_participation_costs_num = []
        #     # malicious_client_ids = []
        #     # for k in extracted_processed_result_history[exp_name]:
        #     #     a = k
        #         # if 'participation_clients_participation_costs_history' in k:
        #         #     participation_clients_participation_costs_history.append(row)
        #         # elif 'active_clients_participation_costs_history' in k:
        #         #     active_clients_participation_costs_history.append(row)
        #         # elif 'malicious_clients_participation_costs' in k:
        #         #     malicious_clients_participation_costs.append(row)
        #         # elif 'benign_clients_participation_costs' in k:
        #         #     benign_clients_participation_costs.append(row)
        #         # elif 'malicious_clients_participation_costs_num' in k:
        #         #     malicious_clients_participation_costs_num.append(row)
        #         # elif 'benign_clients_participation_costs_num' in k:
        #         #     benign_clients_participation_costs_num.append(row)
        #         # elif 'malicious_client_ids' in k:
        #         #     malicious_client_ids.append(row)
        #         # print('row', row)

        #     if global_figure_indicator == 'test_acc':
        #         # x_axis = [i for i in range(1,11)]
        #         # for KL, get odd indices
                
        #     #     data_name, model_name, active_rate, num_clients, lr, malicious_way, \
        #     # malicious_ratio, _, data_split, algo_name, _, lamda, objective_sigmoid_s,\
        #     #     objective_func_lr, _, _ = control

        #         fig_name = '_'.join([data_name, malicious_way, malicious_ratio])
        #         fig[fig_name] = plt.figure(fig_name)
        #         y = np.array(extracted_processed_result_history[exp_name]['test_server/Accuracy_history'][0])
        #         # y = np.array(mean_list[1::2])
        #         # yerr = np.array(std_list[1::2])
        #         x = np.arange(len(y))
        #         if 'fedincen' in exp_name:
        #             algo_name = 'ICL'
        #             key = f'{algo_name}_{lamda}'
        #         else:
        #             algo_name = 'FedAvg'
        #             key = f'{algo_name}'
        #         plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
        #         # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
        #         # label = f'{key}'
        #         # plt.errorbar(x, y, color=color[key], linestyle=linestyle[key],
        #         #     label=label)
        #         plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
        #         plt.ylabel('Accuracy', fontsize=fontsize['label'])
        #         plt.xticks(fontsize=fontsize['ticks'])
        #         plt.yticks(fontsize=fontsize['ticks'])
        #         plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])
        #         # a = 5
        #     if global_figure_indicator == 'average_participation_costs':
                
        #         if 'fedavg' in exp_name:
        #             continue
                
        #         malicious_client_ids = extracted_processed_result_history[exp_name]['train/malicious_client_ids_history'][0][0]
        #         malicious_client_ids.sort()
        #         print(type(malicious_client_ids[0]))
        #         total_clients = set([i for i in range(int(num_clients))])
        #         benign_client_ids = list(total_clients - set(malicious_client_ids))
        #         benign_client_ids.sort()

        #         malicious_average_participation_costs = []
        #         benign_average_participation_costs = []
        #         for i in range(int(num_clients)):
        #             cur_key = f'train_{i}_participation_cost/participation_cost_history'
        #             participation_cost_list = extracted_processed_result_history[exp_name][cur_key][0]
        #             if i in malicious_client_ids:
        #                 print('zheli', i)
        #                 malicious_average_participation_costs.append(sum(participation_cost_list)/len(participation_cost_list))
        #             else:
        #                 benign_average_participation_costs.append(sum(participation_cost_list)/len(participation_cost_list))
                    
        #         # malicious_clients


        #         fig_name = '_'.join(['average_participation_costs', data_name, malicious_way, malicious_ratio, lamda])
        #         fig[fig_name] = plt.figure(fig_name)
        #         y = np.array(malicious_average_participation_costs)
        #         # y = np.array(mean_list[1::2])
        #         # yerr = np.array(std_list[1::2])
        #         # x = np.arange(len(y))
        #         x = np.array(malicious_client_ids)
        #         if 'fedincen' in exp_name:
        #             algo_name = 'ICL'
        #             key = f'{algo_name}_{malicious_ratio}_{lamda}'
        #         # else:
        #         #     key = f'{algo_name}_{malicious_ratio}'
        #         y1 = np.array(benign_average_participation_costs)
        #         # y = np.array(mean_list[1::2])
        #         # yerr = np.array(std_list[1::2])
        #         # x = np.arange(len(y))
        #         x1 = np.array(benign_client_ids)
        #         # if 'fedincen' in exp_name:
        #         #     key = f'{algo_name}_{malicious_ratio}_{lamda}'
        #         # else:
        #         #     key = f'{algo_name}_{malicious_ratio}'
        #         benign_average_participation_costs = [40*i for i in benign_average_participation_costs]
        #         plt.scatter(x1, y1, s=benign_average_participation_costs, color=color[key], linestyle=linestyle[key], label='benign_clients')

        #         malicious_average_participation_costs = [60*i for i in malicious_average_participation_costs]
        #         key = f'{algo_name}_{malicious_ratio}_{lamda}_malicious'
        #         plt.scatter(x, y, s=malicious_average_participation_costs, color=color[key], label='malicious_clients')

              
        #         # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
        #         # label = f'{key}'
        #         # plt.errorbar(x, y, color=color[key], linestyle=linestyle[key],
        #         #     label=label)
        #         plt.xlabel('All Clients', fontsize=fontsize['label'])
        #         plt.ylabel('Average Participation Costs', fontsize=fontsize['label'])
        #         plt.xticks(fontsize=fontsize['ticks'])
        #         plt.yticks(fontsize=fontsize['ticks'])
        #         plt.legend(loc=loc_dict['average'], fontsize=fontsize['legend'])
        if len(control) == 18:
            data_name, model_name, active_rate, num_clients, lr, malicious_way, \
            malicious_ratio, _, data_split, algo_name, _, lamda, objective_sigmoid_s,\
                objective_func_lr, _, _, interval, sample_portion = control
            print('~~~')
            # participation_clients_participation_costs_history = []
            # active_clients_participation_costs_history = []
            # malicious_clients_participation_costs = []
            # benign_clients_participation_costs = []
            # malicious_clients_participation_costs_num = []
            # benign_clients_participation_costs_num = []
            # malicious_client_ids = []
            # for k in extracted_processed_result_history[exp_name]:
            #     a = k
                # if 'participation_clients_participation_costs_history' in k:
                #     participation_clients_participation_costs_history.append(row)
                # elif 'active_clients_participation_costs_history' in k:
                #     active_clients_participation_costs_history.append(row)
                # elif 'malicious_clients_participation_costs' in k:
                #     malicious_clients_participation_costs.append(row)
                # elif 'benign_clients_participation_costs' in k:
                #     benign_clients_participation_costs.append(row)
                # elif 'malicious_clients_participation_costs_num' in k:
                #     malicious_clients_participation_costs_num.append(row)
                # elif 'benign_clients_participation_costs_num' in k:
                #     benign_clients_participation_costs_num.append(row)
                # elif 'malicious_client_ids' in k:
                #     malicious_client_ids.append(row)
                # print('row', row)

            # for performance
            if global_figure_indicator == 'test_acc':
                print('!!!!!')

                fig_name = '_'.join([data_name, malicious_way, malicious_ratio, sample_portion, 'z_diff'])
                fig[fig_name] = plt.figure(fig_name)
                
                # use all local datasets
                if sample_portion == '1':
                    if algo_name != 'fedavg':
                        y = np.array(extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_mean_history'][0])
                        yerr = np.array(extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_std_history'][0])
                        x = np.arange(len(y))
                        if 'simpfedincen' in exp_name:
                            algo_name = 'ICL'
                            key = f'{algo_name}_{interval}_interval_malicious'
                        else:
                            algo_name = 'FedAvg'
                            key = f'{algo_name}'
                        plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                        plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)

                        if 'simpfedincen' in exp_name:
                            algo_name = 'ICL'
                            key = f'{algo_name}_{interval}_interval_benign'
                        else:
                            algo_name = 'FedAvg'
                            key = f'{algo_name}'
                        y1 = np.array(extracted_processed_result_history[exp_name]['train/z_diff_mean_history'][0])
                        yerr1 = np.array(extracted_processed_result_history[exp_name]['train/z_diff_std_history'][0])
                        plt.plot(x, y1, color=color[key], linestyle=linestyle[key], label=key)
                        plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                        plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                        plt.ylabel('z_diff', fontsize=fontsize['label'])
                        plt.xticks(fontsize=fontsize['ticks'])
                        plt.yticks(fontsize=fontsize['ticks'])
                        plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                        fig_name = '_'.join([data_name, malicious_way, malicious_ratio, sample_portion, 'grouping'])
                        fig[fig_name] = plt.figure(fig_name)
                        z_diff_list = extracted_processed_result_history[exp_name]['train/z_diff_history'][0]
                        malicious_client_z_diff = extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_history'][0]

                        # print('z_diff_list', z_diff_list)
                        # print("\n")
                        print('malicious_client_z_diff', malicious_client_z_diff)
                        print("\n")

                        benign_client_num = len(z_diff_list[0])
                        malicious_client_num = len(malicious_client_z_diff[0])

                        if malicious_client_num != 0:
                            benign_identify_ratio = []
                            malicious_identify_ratio = []
                            # for i in range(3, len(z_diff_list)):
                            #     cur_combine_list = z_diff_list[i] + malicious_client_z_diff[i]
                            #     jnb = JenksNaturalBreaks(2)
                            #     jnb.fit(cur_combine_list)
                            #     print(jnb.breaks_)
                            #     for group in jnb.groups_:
                            #         if len(group) > 50:
                            #             intersect_num = len(np.intersect1d(np.array(group), np.array(z_diff_list[i])))
                            #             # print('benign', intersect_num / benign_client_num, "\n")
                            #             benign_identify_ratio.append(intersect_num / benign_client_num)
                            #         else:
                            #             intersect_num = len(np.intersect1d(np.array(group), np.array(malicious_client_z_diff[i])))
                            #             # print('malicious', intersect_num / malicious_client_num, "\n")
                            #             malicious_identify_ratio.append(intersect_num / malicious_client_num)

                            benign_identify_ratio = extracted_processed_result_history[exp_name]['train/benign_identify_ratio_history'][0]
                            malicious_identify_ratio = extracted_processed_result_history[exp_name]['train/malicious_identify_ratio_history'][0]
                                # jnb.fit(all_z_diff)
                                # for i in range(cfg['num_clients']):
                                #     # benign_client
                                #     if all_z_diff[i] <= jnb.breaks_[1]:
                                #         cur_round_participation_client_ids.append(i)
                                
                            y = np.array(benign_identify_ratio)
                            yerr = np.array(y)
                            x = np.arange(len(y))
                            if 'simpfedincen' in exp_name:
                                algo_name = 'ICL'
                                key = f'{algo_name}_cur_round_benign'
                            else:
                                algo_name = 'FedAvg'
                                key = f'{algo_name}'
                            points_to_kept = y!=0
                            new_x = x[points_to_kept]
                            new_y = y[points_to_kept]
                            new_yerr = yerr[points_to_kept]
                            plt.plot(new_x, new_y, color=color[key], linestyle=linestyle[key], label=key)
                            # plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                            plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                            plt.ylabel('Grouping Accuracy', fontsize=fontsize['label'])
                            plt.xticks(fontsize=fontsize['ticks'])
                            plt.yticks(fontsize=fontsize['ticks'])
                            plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                        # plt.fill_between(new_x, (new_y - new_yerr), (new_y + new_yerr), color=color[key], alpha=.1)

                            y = np.array(malicious_identify_ratio)
                            yerr = np.array(y)
                            x = np.arange(len(y))
                            if 'simpfedincen' in exp_name:
                                algo_name = 'ICL'
                                key = f'{algo_name}_cur_round_malicious'
                            else:
                                algo_name = 'FedAvg'
                                key = f'{algo_name}'
                            points_to_kept = y!=0
                            new_x = x[points_to_kept]
                            new_y = y[points_to_kept]
                            new_yerr = yerr[points_to_kept]
                            plt.plot(new_x, new_y, color=color[key], linestyle=linestyle[key], label=key)
                            plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                            plt.ylabel('Grouping Accuracy', fontsize=fontsize['label'])
                            plt.xticks(fontsize=fontsize['ticks'])
                            plt.yticks(fontsize=fontsize['ticks'])
                            plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                            # cur_round_z_diff
                            fig_name = '_'.join([data_name, malicious_way, malicious_ratio, sample_portion, 'cur_round_z_diff'])
                            fig[fig_name] = plt.figure(fig_name)


                            y = np.array(extracted_processed_result_history[exp_name]['train/cur_round_malicious_client_z_diff_mean_history'][0])
                            yerr = np.array(extracted_processed_result_history[exp_name]['train/cur_round_malicious_client_z_diff_std_history'][0])
                            x = np.arange(len(y))
                            if 'simpfedincen' in exp_name:
                                algo_name = 'ICL'
                                key = f'{algo_name}_cur_round_malicious'
                            else:
                                algo_name = 'FedAvg'
                                key = f'{algo_name}'
                            points_to_kept = y!=0
                            new_x = x[points_to_kept]
                            new_y = y[points_to_kept]
                            new_yerr = yerr[points_to_kept]
                            plt.plot(new_x, new_y, color=color[key], linestyle=linestyle[key], label=key)
                            plt.fill_between(new_x, (new_y - new_yerr), (new_y + new_yerr), color=color[key], alpha=.1)
                            # plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                            # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)

                            if 'simpfedincen' in exp_name:
                                algo_name = 'ICL'
                                key = f'{algo_name}_cur_round_benign'
                            else:
                                algo_name = 'FedAvg'
                                key = f'{algo_name}'
                            y1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_mean_history'][0])
                            yerr1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_std_history'][0])
                            plt.plot(x, y1, color=color[key], linestyle=linestyle[key], label=key)
                            plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                            plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                            plt.ylabel('Accuracy', fontsize=fontsize['label'])
                            plt.xticks(fontsize=fontsize['ticks'])
                            plt.yticks(fontsize=fontsize['ticks'])
                            plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                    # plt.fill_between(new_x, (new_y - new_yerr), (new_y + new_yerr), color=color[key], alpha=.1)

                    fig_name = '_'.join([data_name, malicious_way, malicious_ratio, 'acc'])

                    # fig_name = '_'.join([data_name, 'acc'])
                    fig[fig_name] = plt.figure(fig_name)
                    y = np.array(extracted_processed_result_history[exp_name]['test_server/Accuracy_history'][0])
                    # y = np.array(mean_list[1::2])
                    # yerr = np.array(std_list[1::2])
                    x = np.arange(len(y))
                    if 'fedincen' in exp_name:
                        algo_name = 'ICL'
                        key = f'{algo_name}_{lamda}'
                    else:
                        algo_name = 'FedAvg'
                        key = f'{algo_name}_{active_rate}'
                    plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                    # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
                    # label = f'{key}'
                    # plt.errorbar(x, y, color=color[key], linestyle=linestyle[key],
                    #     label=label)
                    plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                    plt.ylabel('Accuracy', fontsize=fontsize['label'])
                    plt.xticks(fontsize=fontsize['ticks'])
                    plt.yticks(fontsize=fontsize['ticks'])
                    plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])





                    

                    # if 'simpfedincen' in exp_name:
                    #     algo_name = 'ICL'
                    #     key = f'{algo_name}_cur_round_benign'
                    # else:
                    #     algo_name = 'FedAvg'
                    #     key = f'{algo_name}'
                    # y1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_mean_history'][0])
                    # yerr1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_std_history'][0])
                    # plt.plot(x, y1, color=color[key], linestyle=linestyle[key], label=key)
                    # plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                    # plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                    # plt.ylabel('Accuracy', fontsize=fontsize['label'])
                    # plt.xticks(fontsize=fontsize['ticks'])
                    # plt.yticks(fontsize=fontsize['ticks'])
                    # plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])
                elif sample_portion == '0.1':
                    y = np.array(extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_mean_history'][0])
                    yerr = np.array(extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_std_history'][0])
                    x = np.arange(len(y))
                    if 'simpfedincen' in exp_name:
                        algo_name = 'ICL'
                        key = f'{algo_name}_{interval}_interval_malicious'
                    else:
                        algo_name = 'FedAvg'
                        key = f'{algo_name}'
                    plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                    plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)

                    if 'simpfedincen' in exp_name:
                        algo_name = 'ICL'
                        key = f'{algo_name}_{interval}_interval_benign'
                    else:
                        algo_name = 'FedAvg'
                        key = f'{algo_name}'
                    y1 = np.array(extracted_processed_result_history[exp_name]['train/z_diff_mean_history'][0])
                    yerr1 = np.array(extracted_processed_result_history[exp_name]['train/z_diff_std_history'][0])
                    plt.plot(x, y1, color=color[key], linestyle=linestyle[key], label=key)
                    plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                    plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                    plt.ylabel('Accuracy', fontsize=fontsize['label'])
                    plt.xticks(fontsize=fontsize['ticks'])
                    plt.yticks(fontsize=fontsize['ticks'])
                    plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])

                    y = np.array(extracted_processed_result_history[exp_name]['train/cur_round_malicious_client_z_diff_mean_history'][0])
                    yerr = np.array(extracted_processed_result_history[exp_name]['train/cur_round_malicious_client_z_diff_std_history'][0])
                    x = np.arange(len(y))
                    if 'simpfedincen' in exp_name:
                        algo_name = 'ICL'
                        key = f'{algo_name}_cur_round_malicious'
                    else:
                        algo_name = 'FedAvg'
                        key = f'{algo_name}'
                    points_to_kept = y!=0
                    new_x = x[points_to_kept]
                    new_y = y[points_to_kept]
                    new_yerr = yerr[points_to_kept]
                    plt.plot(new_x, new_y, color=color[key], linestyle=linestyle[key], label=key)
                    plt.fill_between(new_x, (new_y - new_yerr), (new_y + new_yerr), color=color[key], alpha=.1)
                    # plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                    # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)

                    if 'simpfedincen' in exp_name:
                        algo_name = 'ICL'
                        key = f'{algo_name}_cur_round_benign'
                    else:
                        algo_name = 'FedAvg'
                        key = f'{algo_name}'
                    y1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_mean_history'][0])
                    yerr1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_std_history'][0])
                    plt.plot(x, y1, color=color[key], linestyle=linestyle[key], label=key)
                    plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                    plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                    plt.ylabel('Accuracy', fontsize=fontsize['label'])
                    plt.xticks(fontsize=fontsize['ticks'])
                    plt.yticks(fontsize=fontsize['ticks'])
                    plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


            if global_figure_indicator == 'average_participation_costs':
                
                if 'fedavg' in exp_name:
                    continue
                
                malicious_client_ids = extracted_processed_result_history[exp_name]['train/malicious_client_ids_history'][0][0]
                malicious_client_ids.sort()
                print(type(malicious_client_ids[0]))
                total_clients = set([i for i in range(int(num_clients))])
                benign_client_ids = list(total_clients - set(malicious_client_ids))
                benign_client_ids.sort()

                malicious_average_participation_costs = []
                benign_average_participation_costs = []
                for i in range(int(num_clients)):
                    cur_key = f'train_{i}_participation_cost/participation_cost_history'
                    participation_cost_list = extracted_processed_result_history[exp_name][cur_key][0]
                    if i in malicious_client_ids:
                        print('zheli', i)
                        malicious_average_participation_costs.append(sum(participation_cost_list)/len(participation_cost_list))
                    else:
                        benign_average_participation_costs.append(sum(participation_cost_list)/len(participation_cost_list))
                    
                # malicious_clients


                fig_name = '_'.join(['average_participation_costs', data_name, malicious_way, malicious_ratio, lamda])
                fig[fig_name] = plt.figure(fig_name)
                y = np.array(malicious_average_participation_costs)
                # y = np.array(mean_list[1::2])
                # yerr = np.array(std_list[1::2])
                # x = np.arange(len(y))
                x = np.array(malicious_client_ids)
                if 'fedincen' in exp_name:
                    algo_name = 'ICL'
                    key = f'{algo_name}_{malicious_ratio}_{lamda}'
                # else:
                #     key = f'{algo_name}_{malicious_ratio}'
                y1 = np.array(benign_average_participation_costs)
                # y = np.array(mean_list[1::2])
                # yerr = np.array(std_list[1::2])
                # x = np.arange(len(y))
                x1 = np.array(benign_client_ids)
                # if 'fedincen' in exp_name:
                #     key = f'{algo_name}_{malicious_ratio}_{lamda}'
                # else:
                #     key = f'{algo_name}_{malicious_ratio}'
                benign_average_participation_costs = [40*i for i in benign_average_participation_costs]
                plt.scatter(x1, y1, s=benign_average_participation_costs, color=color[key], linestyle=linestyle[key], label='benign_clients')

                malicious_average_participation_costs = [60*i for i in malicious_average_participation_costs]
                key = f'{algo_name}_{malicious_ratio}_{lamda}_malicious'
                plt.scatter(x, y, s=malicious_average_participation_costs, color=color[key], label='malicious_clients')

            
                # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
                # label = f'{key}'
                # plt.errorbar(x, y, color=color[key], linestyle=linestyle[key],
                #     label=label)
                plt.xlabel('All Clients', fontsize=fontsize['label'])
                plt.ylabel('Average Participation Costs', fontsize=fontsize['label'])
                plt.xticks(fontsize=fontsize['ticks'])
                plt.yticks(fontsize=fontsize['ticks'])
                plt.legend(loc=loc_dict['average'], fontsize=fontsize['legend'])



                # for quadratic, get even indeices
            #     fig_name = '_'.join([data_split_mode, 'Quadratic Loss'])
            #     fig[fig_name] = plt.figure(fig_name)
            #     y = np.array(mean_list[::2])
            #     yerr = np.array(std_list[::2])
            #     x = np.arange(len(y))
            #     key = f'active_rate_{active_rate}'
            #     plt.plot(x, y, color=color[key], linestyle=linestyle[key])
            #     # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
            #     label = f'{key}'
            #     plt.errorbar(x, y, yerr=yerr, color=color[key], linestyle=linestyle[key],
            #         label=label)
            #     plt.xlabel('Group Client Size', fontsize=fontsize['label'])
            #     plt.ylabel('Quadratic Loss', fontsize=fontsize['label'])
            #     plt.xticks(x, x_axis, fontsize=fontsize['ticks'])
            #     plt.yticks(fontsize=fontsize['ticks'])
            #     plt.legend(loc=loc_dict['Loss'], fontsize=fontsize['legend'])
            # a = 5

        elif len(control) == 19:
            data_name, model_name, active_rate, num_clients, lr, malicious_way, \
            malicious_ratio, _, data_split, algo_name, _, lamda, objective_sigmoid_s,\
                objective_func_lr, _, _, interval, sample_portion, pricing_plan = control
            print('~~~')
            # participation_clients_participation_costs_history = []
            # active_clients_participation_costs_history = []
            # malicious_clients_participation_costs = []
            # benign_clients_participation_costs = []
            # malicious_clients_participation_costs_num = []
            # benign_clients_participation_costs_num = []
            # malicious_client_ids = []
            # for k in extracted_processed_result_history[exp_name]:
            #     a = k
                # if 'participation_clients_participation_costs_history' in k:
                #     participation_clients_participation_costs_history.append(row)
                # elif 'active_clients_participation_costs_history' in k:
                #     active_clients_participation_costs_history.append(row)
                # elif 'malicious_clients_participation_costs' in k:
                #     malicious_clients_participation_costs.append(row)
                # elif 'benign_clients_participation_costs' in k:
                #     benign_clients_participation_costs.append(row)
                # elif 'malicious_clients_participation_costs_num' in k:
                #     malicious_clients_participation_costs_num.append(row)
                # elif 'benign_clients_participation_costs_num' in k:
                #     benign_clients_participation_costs_num.append(row)
                # elif 'malicious_client_ids' in k:
                #     malicious_client_ids.append(row)
                # print('row', row)

            # for performance
            if global_figure_indicator == 'test_acc':
                print('!!!!!')

                # fig_name = '_'.join([data_name, malicious_way, malicious_ratio, sample_portion, 'z_diff'])
                # fig[fig_name] = plt.figure(fig_name)
                
                # # use all local datasets
                # if sample_portion == '1':
                #     if algo_name != 'fedavg':
                #         y = np.array(extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_mean_history'][0])
                #         yerr = np.array(extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_std_history'][0])
                #         x = np.arange(len(y))
                #         if 'simpfedincen' in exp_name:
                #             algo_name = 'ICL'
                #             key = f'{algo_name}_{interval}_interval_malicious'
                #         else:
                #             algo_name = 'FedAvg'
                #             key = f'{algo_name}'
                #         plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                #         plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)

                #         if 'simpfedincen' in exp_name:
                #             algo_name = 'ICL'
                #             key = f'{algo_name}_{interval}_interval_benign'
                #         else:
                #             algo_name = 'FedAvg'
                #             key = f'{algo_name}'
                #         y1 = np.array(extracted_processed_result_history[exp_name]['train/z_diff_mean_history'][0])
                #         yerr1 = np.array(extracted_processed_result_history[exp_name]['train/z_diff_std_history'][0])
                #         plt.plot(x, y1, color=color[key], linestyle=linestyle[key], label=key)
                #         plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                #         plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                #         plt.ylabel('z_diff', fontsize=fontsize['label'])
                #         plt.xticks(fontsize=fontsize['ticks'])
                #         plt.yticks(fontsize=fontsize['ticks'])
                #         plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                #         fig_name = '_'.join([data_name, malicious_way, malicious_ratio, sample_portion, 'grouping'])
                #         fig[fig_name] = plt.figure(fig_name)
                #         z_diff_list = extracted_processed_result_history[exp_name]['train/z_diff_history'][0]
                #         malicious_client_z_diff = extracted_processed_result_history[exp_name]['train/malicious_client_z_diff_history'][0]

                #         # print('z_diff_list', z_diff_list)
                #         # print("\n")
                #         print('malicious_client_z_diff', malicious_client_z_diff)
                #         print("\n")

                #         benign_client_num = len(z_diff_list[0])
                #         malicious_client_num = len(malicious_client_z_diff[0])

                #         if malicious_client_num != 0:
                #             benign_identify_ratio = []
                #             malicious_identify_ratio = []
                #             # for i in range(3, len(z_diff_list)):
                #             #     cur_combine_list = z_diff_list[i] + malicious_client_z_diff[i]
                #             #     jnb = JenksNaturalBreaks(2)
                #             #     jnb.fit(cur_combine_list)
                #             #     print(jnb.breaks_)
                #             #     for group in jnb.groups_:
                #             #         if len(group) > 50:
                #             #             intersect_num = len(np.intersect1d(np.array(group), np.array(z_diff_list[i])))
                #             #             # print('benign', intersect_num / benign_client_num, "\n")
                #             #             benign_identify_ratio.append(intersect_num / benign_client_num)
                #             #         else:
                #             #             intersect_num = len(np.intersect1d(np.array(group), np.array(malicious_client_z_diff[i])))
                #             #             # print('malicious', intersect_num / malicious_client_num, "\n")
                #             #             malicious_identify_ratio.append(intersect_num / malicious_client_num)

                #             benign_identify_ratio = extracted_processed_result_history[exp_name]['train/benign_identify_ratio_history'][0]
                #             malicious_identify_ratio = extracted_processed_result_history[exp_name]['train/malicious_identify_ratio_history'][0]
                #                 # jnb.fit(all_z_diff)
                #                 # for i in range(cfg['num_clients']):
                #                 #     # benign_client
                #                 #     if all_z_diff[i] <= jnb.breaks_[1]:
                #                 #         cur_round_participation_client_ids.append(i)
                                
                #             y = np.array(benign_identify_ratio)
                #             yerr = np.array(y)
                #             x = np.arange(len(y))
                #             if 'simpfedincen' in exp_name:
                #                 algo_name = 'ICL'
                #                 key = f'{algo_name}_cur_round_benign'
                #             else:
                #                 algo_name = 'FedAvg'
                #                 key = f'{algo_name}'
                #             points_to_kept = y!=0
                #             new_x = x[points_to_kept]
                #             new_y = y[points_to_kept]
                #             new_yerr = yerr[points_to_kept]
                #             plt.plot(new_x, new_y, color=color[key], linestyle=linestyle[key], label=key)
                #             # plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                #             plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                #             plt.ylabel('Grouping Accuracy', fontsize=fontsize['label'])
                #             plt.xticks(fontsize=fontsize['ticks'])
                #             plt.yticks(fontsize=fontsize['ticks'])
                #             plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                #         # plt.fill_between(new_x, (new_y - new_yerr), (new_y + new_yerr), color=color[key], alpha=.1)

                #             y = np.array(malicious_identify_ratio)
                #             yerr = np.array(y)
                #             x = np.arange(len(y))
                #             if 'simpfedincen' in exp_name:
                #                 algo_name = 'ICL'
                #                 key = f'{algo_name}_cur_round_malicious'
                #             else:
                #                 algo_name = 'FedAvg'
                #                 key = f'{algo_name}'
                #             points_to_kept = y!=0
                #             new_x = x[points_to_kept]
                #             new_y = y[points_to_kept]
                #             new_yerr = yerr[points_to_kept]
                #             plt.plot(new_x, new_y, color=color[key], linestyle=linestyle[key], label=key)
                #             plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                #             plt.ylabel('Grouping Accuracy', fontsize=fontsize['label'])
                #             plt.xticks(fontsize=fontsize['ticks'])
                #             plt.yticks(fontsize=fontsize['ticks'])
                #             plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                #             # cur_round_z_diff
                #             fig_name = '_'.join([data_name, malicious_way, malicious_ratio, sample_portion, 'cur_round_z_diff'])
                #             fig[fig_name] = plt.figure(fig_name)


                #             y = np.array(extracted_processed_result_history[exp_name]['train/cur_round_malicious_client_z_diff_mean_history'][0])
                #             yerr = np.array(extracted_processed_result_history[exp_name]['train/cur_round_malicious_client_z_diff_std_history'][0])
                #             x = np.arange(len(y))
                #             if 'simpfedincen' in exp_name:
                #                 algo_name = 'ICL'
                #                 key = f'{algo_name}_cur_round_malicious'
                #             else:
                #                 algo_name = 'FedAvg'
                #                 key = f'{algo_name}'
                #             points_to_kept = y!=0
                #             new_x = x[points_to_kept]
                #             new_y = y[points_to_kept]
                #             new_yerr = yerr[points_to_kept]
                #             plt.plot(new_x, new_y, color=color[key], linestyle=linestyle[key], label=key)
                #             plt.fill_between(new_x, (new_y - new_yerr), (new_y + new_yerr), color=color[key], alpha=.1)
                #             # plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                #             # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)

                #             if 'simpfedincen' in exp_name:
                #                 algo_name = 'ICL'
                #                 key = f'{algo_name}_cur_round_benign'
                #             else:
                #                 algo_name = 'FedAvg'
                #                 key = f'{algo_name}'
                #             y1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_mean_history'][0])
                #             yerr1 = np.array(extracted_processed_result_history[exp_name]['train/cur_round_z_diff_std_history'][0])
                #             plt.plot(x, y1, color=color[key], linestyle=linestyle[key], label=key)
                #             plt.fill_between(x, (y1 - yerr1), (y1 + yerr1), color=color[key], alpha=.1)
                #             plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                #             plt.ylabel('Accuracy', fontsize=fontsize['label'])
                #             plt.xticks(fontsize=fontsize['ticks'])
                #             plt.yticks(fontsize=fontsize['ticks'])
                #             plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


                    # plt.fill_between(new_x, (new_y - new_yerr), (new_y + new_yerr), color=color[key], alpha=.1)

                fig_name = '_'.join([data_name, malicious_way, malicious_ratio, 'acc'])

                # fig_name = '_'.join([data_name, 'acc'])
                fig[fig_name] = plt.figure(fig_name)
                y = np.array(extracted_processed_result_history[exp_name]['test_server/Accuracy_history'][0])
                # y = np.array(mean_list[1::2])
                # yerr = np.array(std_list[1::2])
                x = np.arange(len(y))
                if 'fedincen' in exp_name:
                    algo_name = 'ICL'
                    key = f'{algo_name}_plan_{pricing_plan}'
                else:
                    algo_name = 'FedAvg'
                    key = f'{algo_name}_{active_rate}'
                plt.plot(x, y, color=color[key], linestyle=linestyle[key], label=key)
                # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
                # label = f'{key}'
                # plt.errorbar(x, y, color=color[key], linestyle=linestyle[key],
                #     label=label)
                plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
                plt.ylabel('Accuracy', fontsize=fontsize['label'])
                plt.xticks(fontsize=fontsize['ticks'])
                plt.yticks(fontsize=fontsize['ticks'])
                plt.legend(loc=loc_dict['Accuracy'], fontsize=fontsize['legend'])


        # control_name = [[['FashionMNIST', 'CIFAR10', 'CINIC10'], ['cnn'], ['0.1'], ['100'], ['0.03'], ['labelflipping', 'signflipping'], 
        #                 ['0.2', '0.3', '0.4'], ['1'], ['iid-equal'], 
        #                  ['simpfedincen'], ['epoch-5'], ['0.1'], 
        #                 ['0.005'], ['0.0001'], ['-0.1'], ['0.3'], ['3'], ['1'], ['1', '2', '3']]]
        #     a = 5
        # if len(df_name_list) == 9:
        #     data_name, model_name, active_rate, num_clients, data_split_mode, algo_mode, \
        #     max_gradient_update_num, max_combination_size, selection_method = df_name_list
        #     # _, _, _, _, data_split_mode, _, \
        #     # _, _, _ = df_name_list
        #     # df_name_std = '_'.join([data_name, model_name, num_supervised, metric_name, 'std'])
        #     # fig_name = '_'.join([data_name, model_name, num_supervised, metric_name])
        #     # fig[fig_name] = plt.figure(fig_name)
            
        #     # data_split_mode_dict = {
        #     #     'non-iid-d-0.1': []
        #     #     'non-iid-d-0.3': []
        #     #     'non-iid-l-1': []
        #     #     'non-iid-l-2': []
        #     # }

        #     index_list = []
        #     mean_list = []
        #     std_list = []
        #     for index, row in df_history[df_name].iterrows():
        #         if 'mean' in index:
        #             index_list.append(index)
        #             mean_list.append(np.mean(row))
        #         elif 'std' in index:
        #             std_list.append(np.mean(row))
        #         else:
        #             raise ValueError('wrong index')
            
        #     x_axis = [i for i in range(1,11)]
        #     # for KL, get odd indices
        #     fig_name = '_'.join([data_split_mode, 'KL Divergence'])
        #     fig[fig_name] = plt.figure(fig_name)
        #     y = np.array(mean_list[1::2])
        #     yerr = np.array(std_list[1::2])
        #     x = np.arange(len(y))
        #     key = f'active_rate_{active_rate}'
        #     plt.plot(x, y, color=color[key], linestyle=linestyle[key])
        #     # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
        #     label = f'{key}'
        #     plt.errorbar(x, y, yerr=yerr, color=color[key], linestyle=linestyle[key],
        #         label=label)
        #     plt.xlabel('Group Client Size', fontsize=fontsize['label'])
        #     plt.ylabel('KL Divergence', fontsize=fontsize['label'])
        #     plt.xticks(x, x_axis, fontsize=fontsize['ticks'])
        #     plt.yticks(fontsize=fontsize['ticks'])
        #     plt.legend(loc=loc_dict['Loss'], fontsize=fontsize['legend'])
        #     a = 5

        #     # for quadratic, get even indeices
        #     fig_name = '_'.join([data_split_mode, 'Quadratic Loss'])
        #     fig[fig_name] = plt.figure(fig_name)
        #     y = np.array(mean_list[::2])
        #     yerr = np.array(std_list[::2])
        #     x = np.arange(len(y))
        #     key = f'active_rate_{active_rate}'
        #     plt.plot(x, y, color=color[key], linestyle=linestyle[key])
        #     # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
        #     label = f'{key}'
        #     plt.errorbar(x, y, yerr=yerr, color=color[key], linestyle=linestyle[key],
        #         label=label)
        #     plt.xlabel('Group Client Size', fontsize=fontsize['label'])
        #     plt.ylabel('Quadratic Loss', fontsize=fontsize['label'])
        #     plt.xticks(x, x_axis, fontsize=fontsize['ticks'])
        #     plt.yticks(fontsize=fontsize['ticks'])
        #     plt.legend(loc=loc_dict['Loss'], fontsize=fontsize['legend'])
        #     a = 5
   
    for fig_name in fig:
        fig[fig_name] = plt.figure(fig_name)
        plt.grid()
        fig_path = '{}/{}.{}'.format(vis_path, fig_name, save_format)
        makedir_exist_ok(vis_path)
        plt.savefig(fig_path, dpi=500, bbox_inches='tight', pad_inches=0)
        plt.close(fig_name)
    return

# def make_vis(df_exp, df_history):
#     data_split_mode_dict = {'iid': 'IID', 'non-iid-l-2': 'Non-IID, $K=2$',
#                             'non-iid-d-0.1': 'Non-IID, $\operatorname{Dir}(0.1)$',
#                             'non-iid-d-0.3': 'Non-IID, $\operatorname{Dir}(0.3)$', 'fix-fsgd': 'DynamicSgd + FixMatch',
#                             'fix-batch': 'FedAvg + FixMatch', 'fs': 'Fully Supervised', 'ps': 'Partially Supervised'}
    

#     color = {'5_0.5': 'red', '1_0.5': 'orange', '5_0': 'dodgerblue', '5_0.9': 'blue', '5_0.5_nomixup': 'green',
#              '5_0_nomixup': 'green', 'iid': 'red', 'non-iid-l-2': 'orange', 'non-iid-d-0.1': 'dodgerblue',
#              'non-iid-d-0.3': 'green', 'fix-fsgd': 'red', 'fix-batch': 'blue',
#              'fs': 'black', 'ps': 'orange',
#              'active_rate_0.1': 'green',
#              'active_rate_0.3': 'red',
#              'active_rate_0.5': 'dodgerblue',
#              'malicious_clients': 'green',
#              'benign_clients': 'red',
#              'partcipation_clients': 'dodgerblue',
#              'active_clients': 'orange'
#              }
#     linestyle = {'5_0.5': '-', '1_0.5': '--', '5_0': ':', '5_0.5_nomixup': '-.', '5_0_nomixup': '-.',
#                  '5_0.9': (0, (1, 5)), 'iid': '-', 'non-iid-l-2': '--', 'non-iid-d-0.1': '-.', 'non-iid-d-0.3': ':',
#                  'fix-fsgd': '--', 'fix-batch': ':', 'fs': '-', 'ps': '-.', ''
#                  'active_rate_0.1': ':',
#                  'active_rate_0.3': '-',
#                  'active_rate_0.5': '-.',
#                  'malicious_clients': ':',
#                 'benign_clients': '-',
#                 'partcipation_clients': '-.',
#                 'active_clients': '--'
#                  }

#     loc_dict = {'Accuracy': 'lower right', 'Loss': 'upper right'}
#     fontsize = {'legend': 16, 'label': 16, 'ticks': 16}
#     fig = {}
#     reorder_fig = []
#     for df_name in df_history:
#         df_name_list = df_name.split('_')
#         if len(df_name_list) == 8:
#             data_name, model_name, active_rate, num_clients, malicious_way, malicious_ratio, \
#                 algo_name, lamda = df_name_list
            
#             participation_clients_participation_costs_history = []
#             active_clients_participation_costs_history = []
#             malicious_clients_participation_costs = []
#             benign_clients_participation_costs = []
#             malicious_clients_participation_costs_num = []
#             benign_clients_participation_costs_num = []
#             malicious_client_ids = []
#             for index, row in df_history[df_name].iterrows():
#                 a = index 
#                 b = row[index]
#                 print('index', index)
#                 if 'participation_clients_participation_costs_history' in index:
#                     participation_clients_participation_costs_history.append(row)
#                 elif 'active_clients_participation_costs_history' in index:
#                     active_clients_participation_costs_history.append(row)
#                 elif 'malicious_clients_participation_costs' in index:
#                     malicious_clients_participation_costs.append(row)
#                 elif 'benign_clients_participation_costs' in index:
#                     benign_clients_participation_costs.append(row)
#                 elif 'malicious_clients_participation_costs_num' in index:
#                     malicious_clients_participation_costs_num.append(row)
#                 elif 'benign_clients_participation_costs_num' in index:
#                     benign_clients_participation_costs_num.append(row)
#                 elif 'malicious_client_ids' in index:
#                     malicious_client_ids.append(row)
#                 # print('row', row)

#             x_axis = [i for i in range(1,11)]
#             # for KL, get odd indices
#             fig_name = '_'.join([data_split_mode, 'KL Divergence'])
#             fig[fig_name] = plt.figure(fig_name)
#             y = np.array(mean_list[1::2])
#             yerr = np.array(std_list[1::2])
#             x = np.arange(len(y))
#             key = f'active_rate_{active_rate}'
#             plt.plot(x, y, color=color[key], linestyle=linestyle[key])
#             # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
#             label = f'{key}'
#             plt.errorbar(x, y, yerr=yerr, color=color[key], linestyle=linestyle[key],
#                 label=label)
#             plt.xlabel('Group Client Size', fontsize=fontsize['label'])
#             plt.ylabel('KL Divergence', fontsize=fontsize['label'])
#             plt.xticks(x, x_axis, fontsize=fontsize['ticks'])
#             plt.yticks(fontsize=fontsize['ticks'])
#             plt.legend(loc=loc_dict['Loss'], fontsize=fontsize['legend'])
#             a = 5

#             # for quadratic, get even indeices
#             fig_name = '_'.join([data_split_mode, 'Quadratic Loss'])
#             fig[fig_name] = plt.figure(fig_name)
#             y = np.array(mean_list[::2])
#             yerr = np.array(std_list[::2])
#             x = np.arange(len(y))
#             key = f'active_rate_{active_rate}'
#             plt.plot(x, y, color=color[key], linestyle=linestyle[key])
#             # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
#             label = f'{key}'
#             plt.errorbar(x, y, yerr=yerr, color=color[key], linestyle=linestyle[key],
#                 label=label)
#             plt.xlabel('Group Client Size', fontsize=fontsize['label'])
#             plt.ylabel('Quadratic Loss', fontsize=fontsize['label'])
#             plt.xticks(x, x_axis, fontsize=fontsize['ticks'])
#             plt.yticks(fontsize=fontsize['ticks'])
#             plt.legend(loc=loc_dict['Loss'], fontsize=fontsize['legend'])
#             a = 5



#             a = 5
#         if len(df_name_list) == 9:
#             data_name, model_name, active_rate, num_clients, data_split_mode, algo_mode, \
#             max_gradient_update_num, max_combination_size, selection_method = df_name_list
#             # _, _, _, _, data_split_mode, _, \
#             # _, _, _ = df_name_list
#             # df_name_std = '_'.join([data_name, model_name, num_supervised, metric_name, 'std'])
#             # fig_name = '_'.join([data_name, model_name, num_supervised, metric_name])
#             # fig[fig_name] = plt.figure(fig_name)
            
#             # data_split_mode_dict = {
#             #     'non-iid-d-0.1': []
#             #     'non-iid-d-0.3': []
#             #     'non-iid-l-1': []
#             #     'non-iid-l-2': []
#             # }

#             index_list = []
#             mean_list = []
#             std_list = []
#             for index, row in df_history[df_name].iterrows():
#                 if 'mean' in index:
#                     index_list.append(index)
#                     mean_list.append(np.mean(row))
#                 elif 'std' in index:
#                     std_list.append(np.mean(row))
#                 else:
#                     raise ValueError('wrong index')
            
#             x_axis = [i for i in range(1,11)]
#             # for KL, get odd indices
#             fig_name = '_'.join([data_split_mode, 'KL Divergence'])
#             fig[fig_name] = plt.figure(fig_name)
#             y = np.array(mean_list[1::2])
#             yerr = np.array(std_list[1::2])
#             x = np.arange(len(y))
#             key = f'active_rate_{active_rate}'
#             plt.plot(x, y, color=color[key], linestyle=linestyle[key])
#             # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
#             label = f'{key}'
#             plt.errorbar(x, y, yerr=yerr, color=color[key], linestyle=linestyle[key],
#                 label=label)
#             plt.xlabel('Group Client Size', fontsize=fontsize['label'])
#             plt.ylabel('KL Divergence', fontsize=fontsize['label'])
#             plt.xticks(x, x_axis, fontsize=fontsize['ticks'])
#             plt.yticks(fontsize=fontsize['ticks'])
#             plt.legend(loc=loc_dict['Loss'], fontsize=fontsize['legend'])
#             a = 5

#             # for quadratic, get even indeices
#             fig_name = '_'.join([data_split_mode, 'Quadratic Loss'])
#             fig[fig_name] = plt.figure(fig_name)
#             y = np.array(mean_list[::2])
#             yerr = np.array(std_list[::2])
#             x = np.arange(len(y))
#             key = f'active_rate_{active_rate}'
#             plt.plot(x, y, color=color[key], linestyle=linestyle[key])
#             # plt.fill_between(x, (y - yerr), (y + yerr), color=color[key], alpha=.1)
#             label = f'{key}'
#             plt.errorbar(x, y, yerr=yerr, color=color[key], linestyle=linestyle[key],
#                 label=label)
#             plt.xlabel('Group Client Size', fontsize=fontsize['label'])
#             plt.ylabel('Quadratic Loss', fontsize=fontsize['label'])
#             plt.xticks(x, x_axis, fontsize=fontsize['ticks'])
#             plt.yticks(fontsize=fontsize['ticks'])
#             plt.legend(loc=loc_dict['Loss'], fontsize=fontsize['legend'])
#             a = 5





#             # for ((index, row), (_, row_std)) in zip(df_history[df_name].iterrows(), df_history[df_name_std].iterrows()):
#             #     y = row.to_numpy()
#             #     yerr = row_std.to_numpy()
#             #     x = np.arange(len(y))
#             #     plt.plot(x, y, color='r', linestyle='-')
#             #     plt.fill_between(x, (y - yerr), (y + yerr), color='r', alpha=.1)
#             #     plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#             #     plt.ylabel(metric_name, fontsize=fontsize['label'])
#             #     plt.xticks(fontsize=fontsize['ticks'])
#             #     plt.yticks(fontsize=fontsize['ticks'])

#         # if len(df_name_list) == 5:
#         #     data_name, model_name, num_supervised, metric_name, stat = df_name.split('_')
#         #     if stat == 'std':
#         #         continue
#         #     df_name_std = '_'.join([data_name, model_name, num_supervised, metric_name, 'std'])
#         #     fig_name = '_'.join([data_name, model_name, num_supervised, metric_name])
#         #     fig[fig_name] = plt.figure(fig_name)
#         #     for ((index, row), (_, row_std)) in zip(df_history[df_name].iterrows(), df_history[df_name_std].iterrows()):
#         #         y = row.to_numpy()
#         #         yerr = row_std.to_numpy()
#         #         x = np.arange(len(y))
#         #         plt.plot(x, y, color='r', linestyle='-')
#         #         plt.fill_between(x, (y - yerr), (y + yerr), color='r', alpha=.1)
#         #         plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#         #         plt.ylabel(metric_name, fontsize=fontsize['label'])
#         #         plt.xticks(fontsize=fontsize['ticks'])
#         #         plt.yticks(fontsize=fontsize['ticks'])
#         # elif len(df_name_list) == 10:
#         #     data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn, \
#         #     metric_name, stat = df_name.split('_')
#         #     if stat == 'std':
#         #         continue
#         #     df_name_std = '_'.join(
#         #         [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn,
#         #          metric_name, 'std'])
#         #     for ((index, row), (_, row_std)) in zip(df_history[df_name].iterrows(), df_history[df_name_std].iterrows()):
#         #         y = row.to_numpy()
#         #         yerr = row_std.to_numpy()
#         #         x = np.arange(len(y))
#         #         if index == '5_0.5' and loss_mode == 'fix-mix':
#         #             fig_name = '_'.join(
#         #                 [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, sbn,
#         #                  metric_name])
#         #             reorder_fig.append(fig_name)
#         #             label_name = '{}'.format(data_split_mode_dict[data_split_mode])
#         #             style = data_split_mode
#         #             fig[fig_name] = plt.figure(fig_name)
#         #             plt.plot(x, y, color=color[style], linestyle=linestyle[style], label=label_name)
#         #             plt.fill_between(x, (y - yerr), (y + yerr), color=color[style], alpha=.1)
#         #             plt.legend(loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#         #             plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#         #             plt.ylabel(metric_name, fontsize=fontsize['label'])
#         #             plt.xticks(fontsize=fontsize['ticks'])
#         #             plt.yticks(fontsize=fontsize['ticks'])
#         #         if data_split_mode in ['iid', 'non-iid-l-2'] and loss_mode not in ['fix-batch', 'fix-fsgd', 'fix-frgd']:
#         #             fig_name = '_'.join(
#         #                 [data_name, model_name, num_supervised, num_clients, active_rate, data_split_mode, sbn,
#         #                  metric_name])
#         #             reorder_fig.append(fig_name)
#         #             fig[fig_name] = plt.figure(fig_name)
#         #             local_epoch, gm = index.split('_')
#         #             if loss_mode == 'fix':
#         #                 label_name = '$E={}$, $\\beta_g={}$, No mixup'.format(local_epoch, gm)
#         #                 style = '{}_nomixup'.format(index)
#         #             else:
#         #                 label_name = '$E={}$, $\\beta_g={}$'.format(local_epoch, gm)
#         #                 style = index
#         #             plt.plot(x, y, color=color[style], linestyle=linestyle[style], label=label_name)
#         #             plt.fill_between(x, (y - yerr), (y + yerr), color=color[style], alpha=.1)
#         #             plt.legend(loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#         #             plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#         #             plt.ylabel(metric_name, fontsize=fontsize['label'])
#         #             plt.xticks(fontsize=fontsize['ticks'])
#         #             plt.yticks(fontsize=fontsize['ticks'])
#         #         if data_split_mode in ['iid', 'non-iid-l-2'] and loss_mode == 'fix-fsgd':
#         #             fix_batch_df_name = '_'.join(
#         #                 [data_name, model_name, num_supervised, 'fix-batch', num_clients, active_rate, data_split_mode,
#         #                  sbn, '0', metric_name, 'mean'])
#         #             fix_batch_df_name_std = '_'.join(
#         #                 [data_name, model_name, num_supervised, 'fix-batch', num_clients, active_rate, data_split_mode,
#         #                  sbn, '0', metric_name, 'std'])
#         #             fix_batch_y = list(df_history[fix_batch_df_name].iterrows())[0][1]
#         #             fix_batch_y_yerr = list(df_history[fix_batch_df_name_std].iterrows())[0][1]
#         #             fs_df_name = '_'.join([data_name, model_name, 'fs'])
#         #             fs_df_name_std = '_'.join([data_name, model_name, 'fs'])
#         #             fs_y = list(df_exp[fs_df_name].iterrows())[0][1]['{}_mean'.format(metric_name)]
#         #             fs_y_yerr = list(df_exp[fs_df_name_std].iterrows())[0][1]['{}_std'.format(metric_name)]
#         #             ps_df_name = '_'.join([data_name, model_name, num_supervised])
#         #             ps_df_name_std = '_'.join([data_name, model_name, num_supervised])
#         #             ps_y = list(df_exp[ps_df_name].iterrows())[0][1]['{}_mean'.format(metric_name)]
#         #             ps_y_yerr = list(df_exp[ps_df_name_std].iterrows())[0][1]['{}_std'.format(metric_name)]
#         #             fig_name = '_'.join(
#         #                 [data_name, model_name, num_supervised, num_clients, active_rate, data_split_mode, sbn,
#         #                  metric_name, 'fsgd'])
#         #             reorder_fig.append(fig_name)
#         #             fig[fig_name] = plt.figure(fig_name)
#         #             label_name = '{}'.format(data_split_mode_dict['fix-fsgd'])
#         #             style = 'fix-fsgd'
#         #             plt.plot(x, y, color=color[style], linestyle=linestyle[style], label=label_name)
#         #             plt.fill_between(x, (y - yerr), (y + yerr), color=color[style], alpha=.1)
#         #             label_name = '{}'.format(data_split_mode_dict['fix-batch'])
#         #             style = 'fix-batch'
#         #             plt.plot(x, fix_batch_y, color=color[style], linestyle=linestyle[style], label=label_name)
#         #             plt.fill_between(x, (fix_batch_y - fix_batch_y_yerr), (fix_batch_y + fix_batch_y_yerr),
#         #                              color=color[style], alpha=.1)
#         #             label_name = '{}'.format(data_split_mode_dict['fs'])
#         #             style = 'fs'
#         #             plt.plot(x, np.repeat(fs_y, len(x)), color=color[style], linestyle=linestyle[style],
#         #                      label=label_name)
#         #             plt.fill_between(x, np.repeat(fs_y - fs_y_yerr, len(x)), np.repeat(fs_y + fs_y_yerr, len(x)),
#         #                              color=color[style], alpha=.1)
#         #             label_name = '{}'.format(data_split_mode_dict['ps'])
#         #             style = 'ps'
#         #             plt.plot(x, np.repeat(ps_y, len(x)), color=color[style], linestyle=linestyle[style],
#         #                      label=label_name)
#         #             plt.fill_between(x, np.repeat(ps_y - ps_y_yerr, len(x)), np.repeat(ps_y + ps_y_yerr, len(x)),
#         #                              color=color[style], alpha=.1)
#         #             plt.legend(loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#         #             plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#         #             plt.ylabel(metric_name, fontsize=fontsize['label'])
#         #             plt.xticks(fontsize=fontsize['ticks'])
#         #             plt.yticks(fontsize=fontsize['ticks'])
#     # for fig_name in reorder_fig:
#     #     fig_name_list = fig_name.split('_')
#     #     data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, sbn, metric_name = fig_name_list[:8]
#     #     plt.figure(fig_name)
#     #     handles, labels = plt.gca().get_legend_handles_labels()
#     #     if len(fig_name_list) == 9:
#     #         if len(handles) == 4:
#     #             handles = [handles[2], handles[3], handles[0], handles[1]]
#     #             labels = [labels[2], labels[3], labels[0], labels[1]]
#     #             plt.legend(handles, labels, loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#     #     else:
#     #         if len(handles) == 4:
#     #             handles = [handles[0], handles[3], handles[2], handles[1]]
#     #             labels = [labels[0], labels[3], labels[2], labels[1]]
#     #             plt.legend(handles, labels, loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#     #         if len(handles) == 5:
#     #             handles = [handles[0], handles[4], handles[2], handles[3], handles[1]]
#     #             labels = [labels[0], labels[4], labels[2], labels[3], labels[1]]
#     #             plt.legend(handles, labels, loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#     for fig_name in fig:
#         fig[fig_name] = plt.figure(fig_name)
#         plt.grid()
#         fig_path = '{}/{}.{}'.format(vis_path, fig_name, save_format)
#         makedir_exist_ok(vis_path)
#         plt.savefig(fig_path, dpi=500, bbox_inches='tight', pad_inches=0)
#         plt.close(fig_name)
#     return


if __name__ == '__main__':
    main()



# import os
# import itertools
# import json
# import numpy as np
# import pandas as pd
# from utils import save, load, makedir_exist_ok
# import matplotlib.pyplot as plt
# from collections import defaultdict

# result_path = './output/result'
# save_format = 'png'
# vis_path = './output/vis/{}'.format(save_format)
# num_experiments = 4
# exp = [str(x) for x in list(range(num_experiments))]


# def make_controls(data_names, model_names, control_name):
#     control_names = []
#     for i in range(len(control_name)):
#         control_names.extend(list('_'.join(x) for x in itertools.product(*control_name[i])))
#     controls = [exp] + data_names + model_names + [control_names]
#     controls = list(itertools.product(*controls))
#     return controls


# def make_control_list(file):
#     if file == 'fs':
#         control_name = [[['fs']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         data_names = [['SVHN']]
#         model_names = [['wresnet28x2']]
#         svhn_controls = make_controls(data_names, model_names, control_name)
#         data_names = [['CIFAR100']]
#         model_names = [['wresnet28x8']]
#         cifar100_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls + svhn_controls + cifar100_controls
#     elif file == 'ps':
#         control_name = [[['250', '4000']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['250', '1000']]]
#         data_names = [['SVHN']]
#         model_names = [['wresnet28x2']]
#         svhn_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['2500', '10000']]]
#         data_names = [['CIFAR100']]
#         model_names = [['wresnet28x8']]
#         cifar100_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls + svhn_controls + cifar100_controls
#     elif file == 'cd':
#         control_name = [[['250', '4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['250', '1000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['1']]]
#         data_names = [['SVHN']]
#         model_names = [['wresnet28x2']]
#         svhn_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['2500', '10000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'],
#                          ['1']]]
#         data_names = [['CIFAR100']]
#         model_names = [['wresnet28x8']]
#         cifar100_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls + svhn_controls + cifar100_controls
#     elif file == 'ub':
#         control_name = [
#             [['250', '4000'], ['fix-mix'], ['100'], ['0.1'], ['non-iid-d-0.1', 'non-iid-d-0.3'], ['5'], ['0.5'], ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         control_name = [
#             [['250', '1000'], ['fix-mix'], ['100'], ['0.1'], ['non-iid-d-0.1', 'non-iid-d-0.3'], ['5'], ['0.5'], ['1']]]
#         data_names = [['SVHN']]
#         model_names = [['wresnet28x2']]
#         svhn_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['2500', '10000'], ['fix-mix'], ['100'], ['0.1'], ['non-iid-d-0.1', 'non-iid-d-0.3'], ['5'],
#                          ['0.5'], ['1']]]
#         data_names = [['CIFAR100']]
#         model_names = [['wresnet28x8']]
#         cifar100_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls + svhn_controls + cifar100_controls
#     elif file == 'loss':
#         control_name = [[['4000'], ['fix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls
#     elif file == 'local-epoch':
#         control_name = [[['4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['1'], ['0.5'], ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls
#     elif file == 'gm':
#         control_name = [[['4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0'], ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls
#     elif file == 'sbn':
#         control_name = [[['250', '4000'], ['fix-mix'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'], ['0']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls
#     elif file == 'alternate':
#         control_name = [[['4000'], ['fix-batch'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'],
#                          ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls_1 = make_controls(data_names, model_names, control_name)
#         control_name = [[['4000'], ['fix', 'fix-batch'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['5'], ['0.5'],
#                          ['1'], ['0']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls_2 = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls_1 + cifar10_controls_2
#     elif file == 'fl':
#         control_name = [
#             [['fs'], ['sup'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'],
#              ['0.5'], ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         control_name = [
#             [['fs'], ['sup'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'],
#              ['0.5'], ['1']]]
#         data_names = [['SVHN']]
#         model_names = [['wresnet28x2']]
#         svhn_controls = make_controls(data_names, model_names, control_name)
#         control_name = [
#             [['fs'], ['sup'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'],
#              ['0.5'], ['1']]]
#         data_names = [['CIFAR100']]
#         model_names = [['wresnet28x8']]
#         cifar100_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls + svhn_controls + cifar100_controls
#     elif file == 'fsgd':
#         control_name = [[['4000'], ['fix-fsgd'], ['100'], ['0.1'], ['iid', 'non-iid-l-2'], ['0'], ['0'], ['1']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls
#     elif file == 'frgd':
#         control_name = [
#             [['250', '4000'], ['fix-frgd'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'],
#              ['5'], ['0.5'], ['1'], ['0']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         control_name = [
#             [['250', '1000'], ['fix-frgd'], ['100'], ['0.1'], ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'],
#              ['5'], ['0.5'], ['1'], ['0']]]
#         data_names = [['SVHN']]
#         model_names = [['wresnet28x2']]
#         svhn_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['2500', '10000'], ['fix-frgd'], ['100'], ['0.1'],
#                          ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
#         data_names = [['CIFAR100']]
#         model_names = [['wresnet28x8']]
#         cifar100_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls + svhn_controls + cifar100_controls
#     elif file == 'fmatch':
#         control_name = [[['250', '4000'], ['fix-fmatch'], ['100'], ['0.1'],
#                          ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
#         data_names = [['CIFAR10']]
#         model_names = [['wresnet28x2']]
#         cifar10_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['250', '1000'], ['fix-fmatch'], ['100'], ['0.1'],
#                          ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
#         data_names = [['SVHN']]
#         model_names = [['wresnet28x2']]
#         svhn_controls = make_controls(data_names, model_names, control_name)
#         control_name = [[['2500', '10000'], ['fix-fmatch'], ['100'], ['0.1'],
#                          ['iid', 'non-iid-d-0.1', 'non-iid-d-0.3', 'non-iid-l-2'], ['5'], ['0.5'], ['1'], ['0']]]
#         data_names = [['CIFAR100']]
#         model_names = [['wresnet28x8']]
#         cifar100_controls = make_controls(data_names, model_names, control_name)
#         controls = cifar10_controls + svhn_controls + cifar100_controls
#     else:
#         raise ValueError('Not valid file')
#     return controls


# def main():
#     files = ['fs', 'ps', 'cd', 'ub', 'loss', 'local-epoch', 'gm', 'sbn', 'alternate', 'fl', 'fsgd', 'frgd', 'fmatch']
#     controls = []
#     for file in files:
#         controls += make_control_list(file)
#     processed_result_exp, processed_result_history = process_result(controls)
#     with open('{}/processed_result_exp.json'.format(result_path), 'w') as fp:
#         json.dump(processed_result_exp, fp, indent=2)
#     save(processed_result_exp, os.path.join(result_path, 'processed_result_exp.pt'))
#     save(processed_result_history, os.path.join(result_path, 'processed_result_history.pt'))
#     extracted_processed_result_exp = {}
#     extracted_processed_result_history = {}
#     extract_processed_result(extracted_processed_result_exp, processed_result_exp, [])
#     extract_processed_result(extracted_processed_result_history, processed_result_history, [])
#     df_exp = make_df_exp(extracted_processed_result_exp)
#     df_history = make_df_history(extracted_processed_result_history)
#     make_vis(df_exp, df_history)
#     return


# def process_result(controls):
#     processed_result_exp, processed_result_history = {}, {}
#     for control in controls:
#         model_tag = '_'.join(control)
#         extract_result(list(control), model_tag, processed_result_exp, processed_result_history)
#     summarize_result(processed_result_exp)
#     summarize_result(processed_result_history)
#     return processed_result_exp, processed_result_history


# def extract_result(control, model_tag, processed_result_exp, processed_result_history):
#     if len(control) == 1:
#         exp_idx = exp.index(control[0])
#         base_result_path_i = os.path.join(result_path, '{}.pt'.format(model_tag))
#         if os.path.exists(base_result_path_i):
#             base_result = load(base_result_path_i)
#             for k in base_result['logger']['test'].mean:
#                 metric_name = k.split('/')[1]
#                 if metric_name not in processed_result_exp:
#                     processed_result_exp[metric_name] = {'exp': [None for _ in range(num_experiments)]}
#                     processed_result_history[metric_name] = {'history': [None for _ in range(num_experiments)]}
#                 processed_result_exp[metric_name]['exp'][exp_idx] = base_result['logger']['test'].mean[k]
#                 processed_result_history[metric_name]['history'][exp_idx] = base_result['logger']['train'].history[k]
#         else:
#             print('Missing {}'.format(base_result_path_i))
#     else:
#         if control[1] not in processed_result_exp:
#             processed_result_exp[control[1]] = {}
#             processed_result_history[control[1]] = {}
#         extract_result([control[0]] + control[2:], model_tag, processed_result_exp[control[1]],
#                        processed_result_history[control[1]])
#     return


# def summarize_result(processed_result):
#     if 'exp' in processed_result:
#         pivot = 'exp'
#         processed_result[pivot] = np.stack(processed_result[pivot], axis=0)
#         processed_result['mean'] = np.mean(processed_result[pivot], axis=0).item()
#         processed_result['std'] = np.std(processed_result[pivot], axis=0).item()
#         processed_result['max'] = np.max(processed_result[pivot], axis=0).item()
#         processed_result['min'] = np.min(processed_result[pivot], axis=0).item()
#         processed_result['argmax'] = np.argmax(processed_result[pivot], axis=0).item()
#         processed_result['argmin'] = np.argmin(processed_result[pivot], axis=0).item()
#         processed_result[pivot] = processed_result[pivot].tolist()
#     elif 'history' in processed_result:
#         pivot = 'history'
#         filter_length = []
#         for i in range(len(processed_result[pivot])):
#             x = processed_result[pivot][i]
#             if len(processed_result[pivot][i]) in [400, 800]:
#                 filter_length.append(x)
#             elif len(processed_result[pivot][i]) == 801:
#                 filter_length.append(x[:800])
#             else:
#                 filter_length.append(x + [x[-1]] * (800 - len(x)))
#         processed_result[pivot] = filter_length
#         processed_result[pivot] = np.stack(processed_result[pivot], axis=0)
#         processed_result['mean'] = np.mean(processed_result[pivot], axis=0)
#         processed_result['std'] = np.std(processed_result[pivot], axis=0)
#         processed_result['max'] = np.max(processed_result[pivot], axis=0)
#         processed_result['min'] = np.min(processed_result[pivot], axis=0)
#         processed_result['argmax'] = np.argmax(processed_result[pivot], axis=0)
#         processed_result['argmin'] = np.argmin(processed_result[pivot], axis=0)
#         processed_result[pivot] = processed_result[pivot].tolist()
#     else:
#         for k, v in processed_result.items():
#             summarize_result(v)
#         return
#     return


# def extract_processed_result(extracted_processed_result, processed_result, control):
#     if 'exp' in processed_result or 'history' in processed_result:
#         exp_name = '_'.join(control[:-1])
#         metric_name = control[-1]
#         if exp_name not in extracted_processed_result:
#             extracted_processed_result[exp_name] = defaultdict()
#         extracted_processed_result[exp_name]['{}_mean'.format(metric_name)] = processed_result['mean']
#         extracted_processed_result[exp_name]['{}_std'.format(metric_name)] = processed_result['std']
#     else:
#         for k, v in processed_result.items():
#             extract_processed_result(extracted_processed_result, v, control + [k])
#     return


# def write_xlsx(path, df, startrow=0):
#     writer = pd.ExcelWriter(path, engine='xlsxwriter')
#     for df_name in df:
#         df[df_name] = pd.concat(df[df_name])
#         df[df_name].to_excel(writer, sheet_name='Sheet1', startrow=startrow + 1)
#         writer.sheets['Sheet1'].write_string(startrow, 0, df_name)
#         startrow = startrow + len(df[df_name].index) + 3
#     writer.save()
#     return


# def make_df_exp(extracted_processed_result_exp):
#     df = defaultdict(list)
#     for exp_name in extracted_processed_result_exp:
#         control = exp_name.split('_')
#         if len(control) == 3:
#             data_name, model_name, num_supervised = control
#             index_name = ['1']
#             df_name = '_'.join([data_name, model_name, num_supervised])
#         elif len(control) == 10:
#             data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
#             local_epoch, gm, sbn = control
#             index_name = ['_'.join([local_epoch, gm])]
#             df_name = '_'.join(
#                 [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn])
#         elif len(control) == 11:
#             data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
#             local_epoch, gm, sbn, ft = control
#             index_name = ['_'.join([local_epoch, gm])]
#             df_name = '_'.join(
#                 [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn,
#                  ft])
#         else:
#             raise ValueError('Not valid control')
#         df[df_name].append(pd.DataFrame(data=extracted_processed_result_exp[exp_name], index=index_name))
#     write_xlsx('{}/result_exp.xlsx'.format(result_path), df)
#     return df


# def make_df_history(extracted_processed_result_history):
#     df = defaultdict(list)
#     for exp_name in extracted_processed_result_history:
#         control = exp_name.split('_')
#         if len(control) == 3:
#             data_name, model_name, num_supervised = control
#             index_name = ['1']
#             for k in extracted_processed_result_history[exp_name]:
#                 df_name = '_'.join([data_name, model_name, num_supervised, k])
#                 df[df_name].append(
#                     pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))
#         elif len(control) == 10:
#             data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
#             local_epoch, gm, sbn = control
#             index_name = ['_'.join([local_epoch, gm])]
#             for k in extracted_processed_result_history[exp_name]:
#                 df_name = '_'.join(
#                     [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode,
#                      sbn, k])
#                 df[df_name].append(
#                     pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))
#         elif len(control) == 11:
#             data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, \
#             local_epoch, gm, sbn, ft = control
#             index_name = ['_'.join([local_epoch, gm])]
#             for k in extracted_processed_result_history[exp_name]:
#                 df_name = '_'.join(
#                     [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode,
#                      sbn, ft, k])
#                 df[df_name].append(
#                     pd.DataFrame(data=extracted_processed_result_history[exp_name][k].reshape(1, -1), index=index_name))
#         else:
#             raise ValueError('Not valid control')
#     write_xlsx('{}/result_history.xlsx'.format(result_path), df)
#     return df


# def make_vis(df_exp, df_history):
#     data_split_mode_dict = {'iid': 'IID', 'non-iid-l-2': 'Non-IID, $K=2$',
#                             'non-iid-d-0.1': 'Non-IID, $\operatorname{Dir}(0.1)$',
#                             'non-iid-d-0.3': 'Non-IID, $\operatorname{Dir}(0.3)$', 'fix-fsgd': 'DynamicSgd + FixMatch',
#                             'fix-batch': 'FedAvg + FixMatch', 'fs': 'Fully Supervised', 'ps': 'Partially Supervised'}
#     color = {'5_0.5': 'red', '1_0.5': 'orange', '5_0': 'dodgerblue', '5_0.9': 'blue', '5_0.5_nomixup': 'green',
#              '5_0_nomixup': 'green', 'iid': 'red', 'non-iid-l-2': 'orange', 'non-iid-d-0.1': 'dodgerblue',
#              'non-iid-d-0.3': 'green', 'fix-fsgd': 'red', 'fix-batch': 'blue',
#              'fs': 'black', 'ps': 'orange'}
#     linestyle = {'5_0.5': '-', '1_0.5': '--', '5_0': ':', '5_0.5_nomixup': '-.', '5_0_nomixup': '-.',
#                  '5_0.9': (0, (1, 5)), 'iid': '-', 'non-iid-l-2': '--', 'non-iid-d-0.1': '-.', 'non-iid-d-0.3': ':',
#                  'fix-fsgd': '--', 'fix-batch': ':', 'fs': '-', 'ps': '-.'}
#     loc_dict = {'Accuracy': 'lower right', 'Loss': 'upper right'}
#     fontsize = {'legend': 16, 'label': 16, 'ticks': 16}
#     fig = {}
#     reorder_fig = []
#     for df_name in df_history:
#         df_name_list = df_name.split('_')
#         if len(df_name_list) == 5:
#             data_name, model_name, num_supervised, metric_name, stat = df_name.split('_')
#             if stat == 'std':
#                 continue
#             df_name_std = '_'.join([data_name, model_name, num_supervised, metric_name, 'std'])
#             fig_name = '_'.join([data_name, model_name, num_supervised, metric_name])
#             fig[fig_name] = plt.figure(fig_name)
#             for ((index, row), (_, row_std)) in zip(df_history[df_name].iterrows(), df_history[df_name_std].iterrows()):
#                 y = row.to_numpy()
#                 yerr = row_std.to_numpy()
#                 x = np.arange(len(y))
#                 plt.plot(x, y, color='r', linestyle='-')
#                 plt.fill_between(x, (y - yerr), (y + yerr), color='r', alpha=.1)
#                 plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#                 plt.ylabel(metric_name, fontsize=fontsize['label'])
#                 plt.xticks(fontsize=fontsize['ticks'])
#                 plt.yticks(fontsize=fontsize['ticks'])
#         elif len(df_name_list) == 10:
#             data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn, \
#             metric_name, stat = df_name.split('_')
#             if stat == 'std':
#                 continue
#             df_name_std = '_'.join(
#                 [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, data_split_mode, sbn,
#                  metric_name, 'std'])
#             for ((index, row), (_, row_std)) in zip(df_history[df_name].iterrows(), df_history[df_name_std].iterrows()):
#                 y = row.to_numpy()
#                 yerr = row_std.to_numpy()
#                 x = np.arange(len(y))
#                 if index == '5_0.5' and loss_mode == 'fix-mix':
#                     fig_name = '_'.join(
#                         [data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, sbn,
#                          metric_name])
#                     reorder_fig.append(fig_name)
#                     label_name = '{}'.format(data_split_mode_dict[data_split_mode])
#                     style = data_split_mode
#                     fig[fig_name] = plt.figure(fig_name)
#                     plt.plot(x, y, color=color[style], linestyle=linestyle[style], label=label_name)
#                     plt.fill_between(x, (y - yerr), (y + yerr), color=color[style], alpha=.1)
#                     plt.legend(loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#                     plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#                     plt.ylabel(metric_name, fontsize=fontsize['label'])
#                     plt.xticks(fontsize=fontsize['ticks'])
#                     plt.yticks(fontsize=fontsize['ticks'])
#                 if data_split_mode in ['iid', 'non-iid-l-2'] and loss_mode not in ['fix-batch', 'fix-fsgd', 'fix-frgd']:
#                     fig_name = '_'.join(
#                         [data_name, model_name, num_supervised, num_clients, active_rate, data_split_mode, sbn,
#                          metric_name])
#                     reorder_fig.append(fig_name)
#                     fig[fig_name] = plt.figure(fig_name)
#                     local_epoch, gm = index.split('_')
#                     if loss_mode == 'fix':
#                         label_name = '$E={}$, $\\beta_g={}$, No mixup'.format(local_epoch, gm)
#                         style = '{}_nomixup'.format(index)
#                     else:
#                         label_name = '$E={}$, $\\beta_g={}$'.format(local_epoch, gm)
#                         style = index
#                     plt.plot(x, y, color=color[style], linestyle=linestyle[style], label=label_name)
#                     plt.fill_between(x, (y - yerr), (y + yerr), color=color[style], alpha=.1)
#                     plt.legend(loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#                     plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#                     plt.ylabel(metric_name, fontsize=fontsize['label'])
#                     plt.xticks(fontsize=fontsize['ticks'])
#                     plt.yticks(fontsize=fontsize['ticks'])
#                 if data_split_mode in ['iid', 'non-iid-l-2'] and loss_mode == 'fix-fsgd':
#                     fix_batch_df_name = '_'.join(
#                         [data_name, model_name, num_supervised, 'fix-batch', num_clients, active_rate, data_split_mode,
#                          sbn, '0', metric_name, 'mean'])
#                     fix_batch_df_name_std = '_'.join(
#                         [data_name, model_name, num_supervised, 'fix-batch', num_clients, active_rate, data_split_mode,
#                          sbn, '0', metric_name, 'std'])
#                     fix_batch_y = list(df_history[fix_batch_df_name].iterrows())[0][1]
#                     fix_batch_y_yerr = list(df_history[fix_batch_df_name_std].iterrows())[0][1]
#                     fs_df_name = '_'.join([data_name, model_name, 'fs'])
#                     fs_df_name_std = '_'.join([data_name, model_name, 'fs'])
#                     fs_y = list(df_exp[fs_df_name].iterrows())[0][1]['{}_mean'.format(metric_name)]
#                     fs_y_yerr = list(df_exp[fs_df_name_std].iterrows())[0][1]['{}_std'.format(metric_name)]
#                     ps_df_name = '_'.join([data_name, model_name, num_supervised])
#                     ps_df_name_std = '_'.join([data_name, model_name, num_supervised])
#                     ps_y = list(df_exp[ps_df_name].iterrows())[0][1]['{}_mean'.format(metric_name)]
#                     ps_y_yerr = list(df_exp[ps_df_name_std].iterrows())[0][1]['{}_std'.format(metric_name)]
#                     fig_name = '_'.join(
#                         [data_name, model_name, num_supervised, num_clients, active_rate, data_split_mode, sbn,
#                          metric_name, 'fsgd'])
#                     reorder_fig.append(fig_name)
#                     fig[fig_name] = plt.figure(fig_name)
#                     label_name = '{}'.format(data_split_mode_dict['fix-fsgd'])
#                     style = 'fix-fsgd'
#                     plt.plot(x, y, color=color[style], linestyle=linestyle[style], label=label_name)
#                     plt.fill_between(x, (y - yerr), (y + yerr), color=color[style], alpha=.1)
#                     label_name = '{}'.format(data_split_mode_dict['fix-batch'])
#                     style = 'fix-batch'
#                     plt.plot(x, fix_batch_y, color=color[style], linestyle=linestyle[style], label=label_name)
#                     plt.fill_between(x, (fix_batch_y - fix_batch_y_yerr), (fix_batch_y + fix_batch_y_yerr),
#                                      color=color[style], alpha=.1)
#                     label_name = '{}'.format(data_split_mode_dict['fs'])
#                     style = 'fs'
#                     plt.plot(x, np.repeat(fs_y, len(x)), color=color[style], linestyle=linestyle[style],
#                              label=label_name)
#                     plt.fill_between(x, np.repeat(fs_y - fs_y_yerr, len(x)), np.repeat(fs_y + fs_y_yerr, len(x)),
#                                      color=color[style], alpha=.1)
#                     label_name = '{}'.format(data_split_mode_dict['ps'])
#                     style = 'ps'
#                     plt.plot(x, np.repeat(ps_y, len(x)), color=color[style], linestyle=linestyle[style],
#                              label=label_name)
#                     plt.fill_between(x, np.repeat(ps_y - ps_y_yerr, len(x)), np.repeat(ps_y + ps_y_yerr, len(x)),
#                                      color=color[style], alpha=.1)
#                     plt.legend(loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#                     plt.xlabel('Communication Rounds', fontsize=fontsize['label'])
#                     plt.ylabel(metric_name, fontsize=fontsize['label'])
#                     plt.xticks(fontsize=fontsize['ticks'])
#                     plt.yticks(fontsize=fontsize['ticks'])
#     for fig_name in reorder_fig:
#         fig_name_list = fig_name.split('_')
#         data_name, model_name, num_supervised, loss_mode, num_clients, active_rate, sbn, metric_name = fig_name_list[:8]
#         plt.figure(fig_name)
#         handles, labels = plt.gca().get_legend_handles_labels()
#         if len(fig_name_list) == 9:
#             if len(handles) == 4:
#                 handles = [handles[2], handles[3], handles[0], handles[1]]
#                 labels = [labels[2], labels[3], labels[0], labels[1]]
#                 plt.legend(handles, labels, loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#         else:
#             if len(handles) == 4:
#                 handles = [handles[0], handles[3], handles[2], handles[1]]
#                 labels = [labels[0], labels[3], labels[2], labels[1]]
#                 plt.legend(handles, labels, loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#             if len(handles) == 5:
#                 handles = [handles[0], handles[4], handles[2], handles[3], handles[1]]
#                 labels = [labels[0], labels[4], labels[2], labels[3], labels[1]]
#                 plt.legend(handles, labels, loc=loc_dict[metric_name], fontsize=fontsize['legend'])
#     for fig_name in fig:
#         fig[fig_name] = plt.figure(fig_name)
#         plt.grid()
#         fig_path = '{}/{}.{}'.format(vis_path, fig_name, save_format)
#         makedir_exist_ok(vis_path)
#         plt.savefig(fig_path, dpi=500, bbox_inches='tight', pad_inches=0)
#         plt.close(fig_name)
#     return


# if __name__ == '__main__':
#     main()

