import io
import pickle
import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
import re

class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

def circle_plot():
    datasets = ["cifar100"]
    alphas = {"cifar100": [ "hetero0.1"]}
    methods = ["fedspa","SubAVG","fedavg","ditto","local"]
    methods_name =  ["FedSpa","Sub-FedAvg","FedAvg","Ditto","Local"]
    index=0
    colors = ["#51C1C8", "#E96279", "#44A2D6", "#536D84",
             "#FA84F5", "b", "y", "#536D84"]
    # fig = figure(num=None, figsize=(8, 8), dpi=120, facecolor='w', edgecolor='k')
    fig = figure(num=None, figsize=(7, 5), dpi=300, facecolor='w', edgecolor='k')
    for dataset in datasets:
        for alpha in alphas[dataset]:

            identity_array = []
            dense_ratio = ["0.2", "0.4", "0.5","0.6", "0.8","1.0"]
            sparse_ratio = ["0.8", "0.6", "0.5","0.4", "0.2","0"]
            marker = ["*",".","^", "v", "D"]
            for method in methods:

                x_array, y_array, z_array = [], [], []
                if method == "fedspa":
                    for dr in dense_ratio:
                        # identity_array.append( dataset+"/"+method +  "-dr"  + dr+ "riglTrue-" +alpha+ "-staticFalse-shared0-strict_avgTrue")
                        identity=dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed0"
                        # identity_array.append(
                        #     dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgFalse")
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        y_array.append(results["test_acc"][-1])
                        z_array.append(results["sum_training_flops"].cpu()/5e14)
                        x_array.append(results["sum_comm_params"]/1e9)
                        print(results["test_acc"][-2])
                    # plt.scatter(x_array, y_array,s=z_array)
                    plt.plot(x_array, y_array, marker="*",label="Ours",linewidth=4,markersize=15,color=colors[index])
                    for i in range(len(x_array)-1):
                        plt.annotate(sparse_ratio[i], xy=(x_array[i], y_array[i]), xytext=(x_array[i]-0.05*1e2 , y_array[i] + 0.01),fontsize=12)
                    index = index + 1
                elif method == "SubAVG":
                    for dr in ["0.2", "0.4", "0.5","0.6", "0.8","1.0"]:
                        identity = dataset + "/" + method + "-dr" + dr + "-" + alpha + "-seed0"
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        x_array.append(results["sum_comm_params"]/1e9)
                        y_array.append(results["test_acc"][-1])
                        z_array.append(results["sum_training_flops"].cpu()/5e14)
                    # plt.scatter(x_array, y_array,s=z_array )
                    plt.plot(x_array, y_array,marker=".",label="Sub-FedAvg",linewidth=4,markersize=15,color=colors[index],)
                    for i in range(len(x_array)-1):
                        plt.annotate(sparse_ratio[i], xy=(x_array[i], y_array[i]), xytext=(x_array[i]-0.05*1e2 , y_array[i] + 0.01),fontsize=12)
                    plt.annotate(sparse_ratio[-1], xy=(x_array[-1], y_array[-1]),
                                 xytext=(x_array[-1], y_array[-1] + 0.01), fontsize=12)
                    index=index+1
                else:
                    identity = dataset + "/" + method + "-" + alpha + "-seed0"
                    # identity_array.append(
                    #     dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgFalse")
                    f = open(identity, 'rb')
                    results = pickle.load(f)
                    y_array.append(results["test_acc"][-1])
                    x_array.append(results["sum_comm_params"]/1e9)
                    z_array.append(results["sum_training_flops"].cpu()/5e14)
                    plt.plot(x_array, y_array,marker=marker[index],color=colors[index],markersize=12)
                    for i in range(len(x_array)):
                        if index ==2:
                            plt.annotate(text=methods_name[index], xy=(x_array[i], y_array[i]),
                                         xytext=(x_array[i] - 0.2 * 1e2, y_array[i] - 0.03), fontsize=15)
                        else:
                            plt.annotate( text=methods_name[index], xy= (x_array[i],y_array[i]), xytext= ( x_array[i]-0.11*1e2,y_array[i]-0.03),fontsize=15)
                    index = index + 1

            # max_x_element= np.max(x_array)
            # max_z_element = np.max(z_array)

            # plt.xlim(-0.85, 0.85)
            plt.ylim(0.3, 0.70)
        plt.xlabel("Communication Overhead (GB)",fontsize=14)
        plt.ylabel("Top-1 Accuracy", fontsize=14)
        plt.legend(loc='lower right',fontsize=12,ncol=2)
        plt.grid(True, linestyle='-', linewidth=1, )
        plt.arrow(x=12,y=0.42,dy=0.20,dx=0,head_length=0.015, head_width = 6,width=1.8,color="red",zorder=20)
        plt.arrow(x=160, y=0.38, dy=0, dx=-80, head_length=7, head_width=0.01, width=0.004, color="red",zorder=20)
        plt.annotate("higher the better",xy= (2.5,0.44),rotation=90,fontsize=15)
        plt.annotate("smaller the better", xy=(80, 0.36),fontsize=15)
        plt.savefig("fig/thumb_pic")
        plt.tight_layout()
        plt.show()

def acc_round_report():
    datasets = ["cifar10","cifar100"]
    alphas = {"emnist": ["homo", "hetero0.2", "hetero0.1"], "cifar10": ["homo", "hetero0.5", "hetero0.3", ],
              "cifar100": ["homo", "hetero0.2", "hetero0.1"]}
    methods = ["fedspa-0.5", "fedspa-RST-0.5", "ditto", "fedavg", "SubAVG", "local","subsampling"]
    task_name = ["EMNIST-L", "CIFAR10", "CIFAR100"]
    methods_name = ["FedSpa (DST)", "FedSpa (RSM)", "Ditto", "FedAvg", "SubAVG", "Local","Subsampling"]
    data_setting_name = ["IID", "Non-IID (setting A)", "Non-IID (setting B)"]
    thresholds = {"cifar10": [0.7,0.75,0.8],
              "cifar100": [0.4,0.5,0.625]}
    dense_ratio = ["0.5"]
    seeds = ["0","1","2"]
    string = ""
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]
        index=0
        for method in methods:
            string +=  methods_name[index]
            index+=1
            for alpha_index in range(len(alphas[dataset])):
                alpha = alphas[dataset][alpha_index]
                if "fedspa" in method and "RST" not in method:
                    if "0.5" in method:
                        identity = dataset + "/" + "fedspa" + "-dr" + "0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed"
                    elif "0.2" in method:
                        identity = dataset + "/" + "fedspa" + "-dr" + "0.2" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed"
                elif method == "fedspa-RST-0.5":
                    if "0.5" in method:
                        identity = dataset + "/" + "fedspa" + "-dr" + "0.5" + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed"
                    elif "0.2" in method:
                        identity = dataset + "/" + "fedspa" + "-dr" + "0.2" + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed"
                elif method == "SubAVG":
                    for dr in dense_ratio:
                        identity = dataset + "/" + method + "-dr" + dr + "-" + alpha + "-seed"
                elif method == "subsampling":
                    identity = dataset + "/" + method + "-" + dr + "-" + alpha + "-seed"
                else:
                    identity = dataset + "/" + method + "-" + alpha + "-seed"

                for threshold in thresholds[dataset]:
                    reach=True
                    rounds = []
                    for seed in seeds:
                        real_identity= identity+seed
                        f = open(real_identity, 'rb')
                        results = CPU_Unpickler(f).load()
                        for round_idx in range(len(results["test_acc"])):
                            if results["test_acc"][round_idx]>threshold:
                                rounds.append(round_idx)
                                break
                            if round_idx==len(results["test_acc"])-1:
                                reach=False
                    if reach:
                        string += " & " + str(round(np.average(rounds),1))
                        string += "$\pm$"+str(round(np.std(rounds),1))
                    else:
                        string += " & " + ">1000"
            string += "\\\ \n"
    print(string)

def main_report():
    datasets = ["emnist","cifar10","cifar100"]
    alphas = {"emnist": ["homo",  "hetero0.2",  "hetero0.1"], "cifar10": ["homo","hetero0.5", "hetero0.3",],"cifar100": ["homo","hetero0.2","hetero0.1"]}
    methods = ["fedspa-0.5","fedspa-RST-0.5", "ditto", "fedavg", "SubAVG", "local","subsampling"]
    task_name = ["EMNIST-L", "CIFAR10", "CIFAR100"]
    methods_name = ["FedSpa (DST)", "FedSpa (RSM)", "Ditto", "FedAvg", "Sub-FedAVG", "Local","Subsampling"]
    data_setting_name = ["IID", "Non-IID (setting A)", "Non-IID (setting B)"]
    dense_ratio = ["0.5"]
    string=""
    seeds=["0","1","2"]
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]
        for method_index in range(len(methods)):
            method =methods[method_index]
            string += "&"+methods_name[method_index]
            for alpha_index in range(len(alphas[dataset])):
                value1 = []
                value2 =[]
                value3 = []
                for seed in seeds:
                    alpha = alphas[dataset][alpha_index]
                    if "fedspa" in method and "RST" not in method :
                        if "0.5" in method:
                            identity = dataset + "/" + "fedspa" + "-dr" + "0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed"+seed
                        elif "0.2" in method:
                            identity = dataset + "/" + "fedspa" + "-dr" + "0.2" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed"+seed
                    elif method =="fedspa-RST-0.5":
                        if "0.5" in method:
                            identity = dataset + "/" + "fedspa" + "-dr" + "0.5" + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed"+seed
                        elif "0.2" in method:
                            identity = dataset + "/" + "fedspa" + "-dr" + "0.2" + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed"+seed
                    elif method == "SubAVG":
                        for dr in dense_ratio:
                            identity =dataset + "/" + method + "-dr" + dr + "-" + alpha + "-seed"+seed
                    elif method == "local":
                        identity = dataset + "/" + method + "-" + alpha + "-seed"+seed
                    elif method == "subsampling":
                        identity= dataset + "/" + method + "-" + dr + "-" + alpha + "-seed"+seed
                        # f = open(identity, 'rb')
                        # results = pickle.load(f)
                        # results["sum_comm_params"] = 0
                        # f = open(identity, 'wb')
                        # pickle.dump(results,f)
                    else:
                        identity = dataset + "/" + method + "-" + alpha + "-seed" + seed
                    f = open(identity, 'rb')
                    results = CPU_Unpickler(f).load()
                    value1.append(float(results["test_acc"][-1]) * 100)
                    value2.append(float(results["sum_comm_params"]) / 1e9 * 32 / 8)
                    value3.append(float(results["sum_training_flops"].cpu() / 1e14))
                string += " & " + str(round(np.average(value1),1)) + "$\pm$"+str(round(np.std(value1),1))
                string += " & " + str(round(np.average(value2), 1))
                string += " & " + str(round(np.average(value3), 1))
            string+="\\\ \n"
    print(string)

def sparse_report():
    datasets = ["cifar100"]
    alphas = {"cifar100": ["homo", "hetero0.2", "hetero0.1"]}
    methods = ["fedspa"]
    sparse = [ "0.2", "0.4", "0.5", "0.6","0.8"]
    # dense_ratio = ["0.2", "0.4", "0.5", "0.6", "0.8", "0.9"]
    dense_ratio = ["0.8" , "0.6","0.5", "0.4","0.2"]
    seeds = ["0","1","2"]
    string=""
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]
        for dr_index in range(len(dense_ratio)):
            dr= dense_ratio[dr_index]
            string += sparse[dr_index]
            for alpha_index in range(len(alphas[dataset])):
                alpha = alphas[dataset][alpha_index]
                for method in methods:
                    seed_results= []
                    seed_sum_comm_params = []
                    seed_sum_training_flops = []
                    for seed in seeds:
                        identity = dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed"+seed
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        seed_results.append(results["test_acc"][-1]*100)
                        seed_sum_comm_params.append(results["sum_comm_params"]/1e9*32/8)
                        seed_sum_training_flops.append(results["sum_training_flops"].cpu()/1e16)
                    string+= " & " + str(round(np.average(seed_results),1))+"$\pm$"+str(round(np.std(seed_results),1))
                    string += " & "+str(round(np.average(seed_sum_comm_params),1))
                    string += " & " + str(round(np.average(seed_sum_training_flops),1))
            string+="\\\ \n"
    print(string)

def ablation_sparse():
    datasets = ["cifar100"]
    alphas = {"cifar100": ["homo", "hetero0.2", "hetero0.1"]}
    methods = ["fedspa"]
    task_name = ["CIFAR100 (ResNet18)"]
    methods_name = ["FedSpa (DST)-0.8", "FedSpa (DST)-0.6","FedSpa (DST)-0.5","FedSpa (DST)-0.4","FedSpa (DST)-0.2"]
    seeds = ["0","1","2"]
    min_y = {"cifar100": 0.2}
    max_y =  {"cifar100": 0.7}
    # methods = ["fedavg","fedspa","ditto"]
    data_setting_name = ["IID", "Non-IID (" + r'$\gamma=0.2$)', "Non-IID (" + r'$\gamma=0.1$)']
    fontsize = 15
    for dataset_index in range(len(datasets)):
        fig = figure(num=None, figsize=(18, 5), dpi=300, facecolor='w', edgecolor='k')
        fig.subplots_adjust(left=0.08, bottom=0.21, right=0.97, top=0.95, wspace=0.15, hspace=0.4)
        dataset = datasets[dataset_index]
        for alpha_index in range(len(alphas[dataset])):
            alpha = alphas[dataset][alpha_index]
            identity_array = []
            dense_ratio = ["0.2","0.4","0.5","0.6","0.8"]
            subplob = fig.add_subplot(1, 3, alpha_index + 1)
            for dr in dense_ratio:
                for method in methods:
                        identity_array.append(
                            dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed")

                if alpha_index % 3 == 0:
                    subplob.set_ylabel('Accuracy', fontsize=fontsize)

            for i in range(len(identity_array)):
                if alpha_index % 3 == 0:
                    label = methods_name[i]
                else:
                    label = None
                seed_result=[]
                for seed in seeds:
                    identity = identity_array[i]+seed
                    f = open(identity, 'rb')
                    results = pickle.load(f)
                    seed_result.append(results["test_acc"])
                rounds = range(len(results["test_acc"]))
                mean_acc = np.average(seed_result,0)
                std = np.std(seed_result,0)
                subplob.plot(rounds, np.average(seed_result,0), linewidth=2,
                             markersize=3, label=label)
                subplob.fill_between(rounds, mean_acc + std, mean_acc - std, alpha=0.3)
                # subplob.tick_params(axis='both', which='major', labelsize=fontsize)
                if alpha_index!=0:
                    subplob.set_ylim(bottom=min_y[dataset], top = max_y[dataset])

                # subplob.legend(fontsize=1)
                subplob.set_title(data_setting_name[alpha_index], fontsize=fontsize)
                subplob.set_xlabel('Communication Rounds', fontsize=fontsize)
                subplob.grid(True, linestyle='-', linewidth=1, )
        fig.legend(loc='lower center', bbox_to_anchor=(0.0, 0.0, 1, 1), fancybox=False, shadow=False, ncol=6,
                   fontsize=fontsize, frameon=False)
        fig.suptitle(task_name[dataset_index], fontweight='bold', rotation='vertical', fontsize=15, x=0.025, y=0.8)
        plt.savefig("fig/sparsity" + dataset + ".png")
        plt.show()
        plt.close(0)


# def ablation_static():
#     datasets= ["cifar100"]
#     alphas = {  "cifar100": ["homo","hetero0.1","hetero0.2"]   }
#     methods = [ "fedspa","fedavg","ditto"]
#     methods_name = ["FedSlim (w/ ms)", "FedSlim (w/o ms)", "FedAvg", "Ditto"]
#     min_y = {"emnist": 0.7, "cifar10": 0.5, "cifar100": 0.2}
#     # methods = ["fedavg","fedspa","ditto"]
#     for dataset in datasets:
#         for alpha in alphas[dataset]:
#             dense_ratio = ["0.5"]
#             for dr in dense_ratio:
#                 fig = plt.figure(0)
#                 identity_array = []
#                 for method in methods:
#                     if method == "fedspa":
#                             identity_array.append( dataset+"/"+method +  "-dr"  + dr+ "riglTrue-" +alpha+ "-staticFalse-shared0-strict_avgTrue-seed0")
#                             identity_array.append(
#                                 dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed0")
#                     elif method == "SubAVG":
#                         for dr in dense_ratio:
#                             identity_array.append( dataset+"/"+method + "-dr" + dr + "-" +alpha+"-seed0")
#                     else:
#                         identity_array.append( dataset+"/"+ method+ "-"+alpha+"-seed0" )
#
#                 for i in range(len(identity_array)):
#                     identity = identity_array[i]
#                     f=open(identity,'rb')
#                     results = pickle.load(f)
#                     rounds = range(len(results["test_acc"]) )
#                     plt.plot(rounds, results["test_acc"], label=methods_name[i])
#                     plt.legend()
#                     print(results["test_acc"])
#                 plt.legend(fontsize=15)
#                 plt.grid(c='#d9d9d9')
#                 plt.ylim(bottom=min_y[dataset])
#                 plt.ylabel('Test Accuracy', fontsize=15)
#                 plt.xlabel('Communication Rounds', fontsize=15)
#                 plt.savefig("fig/static_"+alpha+".png")
#                 plt.show()
#                 plt.close(0)
def ablation_erk():
    datasets = ["cifar100"]
    alphas = {"cifar100": ["homo", "hetero0.2", "hetero0.1"]}
    methods = ["fedspa"]
    task_name = ["CIFAR100 (ResNet18)"]
    methods_name = ["FedSpa-DST (ERK)", "FedSpa-DST (Uniform)"]
    min_y = {"emnist": 0.7, "cifar10": 0.5, "cifar100": 0.2}
    # methods = ["fedavg","fedspa","ditto"]
    data_setting_name = ["IID", "Non-IID (" + r'$\gamma=0.2$)', "Non-IID (" + r'$\gamma=0.1$)']
    colors = ["#51C1C8", "#E96279", "#44A2D6", "#536D84",
             "#FA84F5", "b", "y", "#536D84"]
    seeds = ["0", "1", "2"]
    string = ""
    fontsize = 15
    for dataset_index in range(len(datasets)):
        fig = figure(num=None, figsize=(18, 5), dpi=300, facecolor='w', edgecolor='k')
        fig.subplots_adjust(left=0.08, bottom=0.21, right=0.97, top=0.95, wspace=0.15, hspace=0.4)
        dataset = datasets[dataset_index]
        for method in methods:
            for alpha_index in range(len(alphas[dataset])):
                alpha = alphas[dataset][alpha_index]
                dense_ratio = ["0.5"]
                subplob = fig.add_subplot(1, 3, alpha_index + 1)
                # for dr in dense_ratio:
                #
                #     identity_array = []
                #     subplob = fig.add_subplot(1, 3, alpha_index + 1)
                #     if method == "fedspa":
                #         identity_array.append(
                #             dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed0")
                #         identity_array.append(
                #             dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed0-u")
                #     elif method == "SubAVG":
                #         for dr in dense_ratio:
                #             identity_array.append(dataset + "/" + method + "-dr" + dr + "-" + alpha)
                #     else:
                #         identity_array.append(dataset + "/" + method + "-" + alpha + "-seed")
                if alpha_index % 3 == 0:
                    subplob.set_ylabel('Accuracy', fontsize=fontsize)
                for i in range(2):
                    seed_acc = []
                    seed_final = []
                    seed_sum_comm_params = []
                    seed_sum_training_flops = []
                    if i==0:
                        string += "ERK"
                    else:
                        string += "Uniform"
                    if alpha_index % 3 == 0:
                        label = methods_name[i]
                    else:
                        label = None

                    for seed in seeds:
                        identity = dataset + "/" + method + "-dr0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed" + seed
                        if i == 1:
                            identity += "-u"
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        seed_acc.append(np.array(results["test_acc"]))
                        seed_final.append(results["test_acc"][-1] * 100)
                        seed_sum_comm_params.append(results["sum_comm_params"] / 1e9 * 32 / 8)
                        seed_sum_training_flops.append(results["sum_training_flops"].cpu() / 1e16)
                    mean_acc = np.average(seed_acc, axis=0)
                    std = np.std(seed_acc, axis=0)
                    rounds = range(len(results["test_acc"]))
                    subplob.fill_between(rounds, mean_acc + std, mean_acc - std,color=colors[i],alpha=0.3)
                    subplob.plot(rounds, mean_acc, linewidth=2,
                                 markersize=3, label=label,color=colors[i])
                    # subplob.tick_params(axis='both', which='major', labelsize=fontsize)
                    # subplob.set_ylim(bottom=min_y[dataset])
                    # sum_training_flops
                    # subplob.legend(fontsize=1)
                    subplob.set_title(data_setting_name[alpha_index], fontsize=fontsize)
                    subplob.grid(True, linestyle='-', linewidth=1, )
                    subplob.set_xlabel('Communication Rounds', fontsize=fontsize)
        fig.legend(loc='lower center', bbox_to_anchor=(0.0, 0.0, 1, 1), fancybox=False, shadow=False, ncol=6,
                   fontsize=fontsize, frameon=False)
        fig.suptitle(task_name[dataset_index], fontweight='bold', rotation='vertical', fontsize=15, x=0.025, y=0.80)
        plt.savefig("fig/erk" + dataset + ".png")
        plt.show()
        plt.close(0)
        print(string)

def ablation_rigl():
    datasets = ["cifar100"]
    alphas = {"cifar100": ["homo", "hetero0.2", "hetero0.1"]}
    methods = ["fedspa"]
    task_name = ["CIFAR100 (ResNet18)"]
    methods_name = ["FedSpa (DST with gradient information)", "FedSpa (DST with random recovery)"]
    min_y = {"emnist": 0.7, "cifar10": 0.5, "cifar100": 0.2}
    # methods = ["fedavg","fedspa","ditto"]
    data_setting_name = ["IID", "Non-IID (" + r'$\gamma=0.2$)', "Non-IID (" + r'$\gamma=0.1$)']
    colors = ["#51C1C8", "#E96279", "#44A2D6", "#536D84",
              "#FA84F5", "b", "y", "#536D84"]
    seeds = ["0", "1", "2"]
    fontsize = 16
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]
        fig = figure(num=None, figsize=(18, 5), dpi=300, facecolor='w', edgecolor='k')
        fig.subplots_adjust(left=0.08, bottom=0.21, right=0.97, top=0.94, wspace=0.15, hspace=0.4)
        for alpha_index in range(len(alphas[dataset])):
            alpha = alphas[dataset][alpha_index]
            dense_ratio = ["0.5"]
            subplob = fig.add_subplot(1, 3, alpha_index + 1)
            if alpha_index % 3 == 0:
                subplob.set_ylabel('Accuracy', fontsize=fontsize)
            for i in range(2):
                seed_acc = []
                seed_final = []
                seed_sum_comm_params = []
                seed_sum_training_flops = []
                if alpha_index % 3 == 0:
                    label = methods_name[i]
                else:
                    label = None
                for seed in seeds:
                    if i == 0:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed" + seed
                    if i == 1:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglFalse-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed" + seed

                    f = open(identity, 'rb')
                    results = pickle.load(f)
                    seed_acc.append(np.array(results["test_acc"]))
                    seed_final.append(results["test_acc"][-1] * 100)
                    seed_sum_comm_params.append(results["sum_comm_params"] / 1e9 * 32 / 8)
                    seed_sum_training_flops.append(results["sum_training_flops"].cpu() / 1e16)
                mean_acc = np.average(seed_acc, axis=0)
                std = np.std(seed_acc, axis=0)
                rounds = range(len(results["test_acc"]))
                subplob.fill_between(rounds, mean_acc + std, mean_acc - std, color=colors[i], alpha=0.2)
                subplob.plot(rounds, mean_acc, linewidth=2,
                             markersize=3, label=label, color=colors[i])
                subplob.set_xlabel('Communication Rounds', fontsize=fontsize)
                # subplob.tick_params(axis='both', which='major', labelsize=fontsize)
                # subplob.set_ylim(bottom=min_y[dataset])
                # sum_training_flops
                # subplob.legend(fontsize=1)
                subplob.set_title(data_setting_name[alpha_index], fontsize=fontsize)
                subplob.grid(True, linestyle='-', linewidth=1, )

        fig.legend(loc='lower center', bbox_to_anchor=(0.0, 0.0, 1, 1), fancybox=False, shadow=False, ncol=6,
                   fontsize=fontsize, frameon=False)
        fig.suptitle(task_name[dataset_index], fontweight='bold', rotation='vertical', fontsize=15, x=0.025, y=0.80)
        plt.savefig("fig/rigl" + dataset + ".png")
        plt.show()
        plt.close(0)

def ablation_initialization():
    datasets= ["cifar100"]
    alphas = {  "cifar100": ["homo","hetero0.2","hetero0.1"]   }
    methods = [ "fedspa"]
    task_name = [ "CIFAR100 (ResNet18)"]
    methods_name = ["FedSpa (DST, same intialization)", "FedSpa (RSM, same intialization)","FedSpa (DST, different initialization)", "FedSpa (RSM, different initialization)"]
    min_y = {"emnist": 0.7, "cifar10": 0.5, "cifar100": 0.2}
    # methods = ["fedavg","fedspa","ditto"]
    data_setting_name =["IID", "Non-IID (" +r'$\gamma=0.2$)',"Non-IID ("+r'$\gamma=0.1$)']
    colors = ["#51C1C8", "#E96279", "#44A2D6", "#536D84",
              "#FA84F5", "b", "y", "#536D84"]
    seeds = ["0", "1", "2"]
    fontsize = 15
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]
        fig = figure(num=None, figsize=(18, 5), dpi=300, facecolor='w', edgecolor='k')
        fig.subplots_adjust(left=0.08, bottom=0.21, right=0.97, top=0.95, wspace=0.15, hspace=0.4)
        for alpha_index in range(len(alphas[dataset])):
            alpha = alphas[dataset][alpha_index]
            dense_ratio = ["0.5"]
            subplob = fig.add_subplot(1, 3, alpha_index + 1)
            if alpha_index % 3 == 0:
                subplob.set_ylabel('Accuracy', fontsize=fontsize)
            for i in range(4):
                seed_acc = []
                seed_final = []
                seed_sum_comm_params = []
                seed_sum_training_flops = []
                if alpha_index % 3 == 0:
                    label = methods_name[i]
                else:
                    label = None
                for seed in seeds:
                    if i==0:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed" + seed
                    if i==1:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed" + seed
                    if i==2:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed" + seed +"-d"
                    if i==3:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed" + seed +"-d"

                    f = open(identity, 'rb')
                    results = pickle.load(f)
                    seed_acc.append(np.array(results["test_acc"]))
                    seed_final.append(results["test_acc"][-1] * 100)
                    seed_sum_comm_params.append(results["sum_comm_params"] / 1e9 * 32 / 8)
                    seed_sum_training_flops.append(results["sum_training_flops"].cpu() / 1e16)
                mean_acc = np.average(seed_acc, axis=0)
                std = np.std(seed_acc, axis=0)
                rounds = range(len(results["test_acc"]))
                subplob.fill_between(rounds, mean_acc + std, mean_acc - std, color=colors[i], alpha=0.2)
                subplob.plot(rounds, mean_acc, linewidth=2,
                             markersize=3, label=label, color=colors[i])
                # subplob.tick_params(axis='both', which='major', labelsize=fontsize)
                # subplob.set_ylim(bottom=min_y[dataset])
                # sum_training_flops
                # subplob.legend(fontsize=1)
                subplob.set_title(data_setting_name[alpha_index], fontsize=fontsize)
                subplob.set_xlabel('Communication Rounds', fontsize=fontsize)
                subplob.grid(True, linestyle='-', linewidth=1, )

        fig.legend(loc='lower center', bbox_to_anchor=(0.0, 0.0, 1, 1), fancybox=False, shadow=False, ncol=6,
                   fontsize=fontsize, frameon=False)
        fig.suptitle(task_name[dataset_index], fontweight='bold', rotation='vertical', fontsize=15, x=0.025, y=0.8)
        plt.savefig("fig/intialization" + dataset + ".png")
        plt.show()
        plt.close(0)

# def acc_report():
#     fig = figure(num=None, figsize=(15, 5), dpi=120, facecolor='w', edgecolor='k')
#     fontsize = 15
#     resnset20 = fig.add_subplot(2, 3, 1)
#     resnset20.plot(x_axis, gradient_ERK_RN20_mean[:length], '-', color='orange', linewidth=linewidth,
#                    markersize=markersize, label='ERK, Top-1 Acc = 91.57%')
#     resnset20.fill_between(x_axis, gradient_ERK_RN20_mean[:length] + gradient_ERK_RN56_std[:length],
#                            gradient_ERK_RN20_mean[:length] - gradient_ERK_RN56_std[:length], color='orange',
#                            alpha=alpha, linewidth=0)
#     resnset20.plot(x_axis, gradient_snip_RN20_mean[:length], '-', color='green', linewidth=linewidth,
#                    markersize=markersize, label='SNIP, Top-1 Acc = 91.86%')
#     resnset20.fill_between(x_axis, gradient_snip_RN20_mean[:length] + gradient_snip_RN20_std[:length],
#                            gradient_snip_RN20_mean[:length] - gradient_snip_RN20_std[:length], color='green',
#                            alpha=alpha, linewidth=0)
#     resnset20.plot(x_axis, gradient_rn20_dense_mean[:length], '-', color='blue', linewidth=linewidth,
#                    markersize=markersize, label='Dense,Top-1 Acc = $\mathbf{92.37\%}$')
#     resnset20.fill_between(x_axis, gradient_rn20_dense_mean[:length] + gradient_rn20_dense_std[:length],
#                            gradient_rn20_dense_mean[:length] - gradient_rn20_dense_std[:length], color='cornflowerblue',
#                            alpha=alpha, linewidth=0)
#     resnset20.set_title('ResNet-20, Sparsity=50%', fontsize=fontsize)
#     # resnset20.set_xticks(x_axis)
#     resnset20.legend(fontsize=legend_fontsize)
#     resnset20.set_ylabel('Gradient Norm', fontsize=fontsize)
#     # resnset20.set_xlabel('Training Iterations [#]', fontsize=fontsize)
#     # resnset20.set_xticklabels([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1], fontsize=fontsize )
#     resnset20.tick_params(axis='both', which='major', labelsize=fontsize)
#     resnset20.grid(True, linestyle='-', linewidth=0.5, )

def wall_time():
    datasets = ["emnist","cifar10","cifar100"]
    devices = ["cpu", "gpu"]
    table= ""
    for device in devices:
        for dataset in datasets:

            file_name = "wall_time/wall_time_" + device + "_" + dataset+".out"
            # regex = '\[(.*?)\]'
            regex1 = r'(?:train time elapse: )\d+\.?\d*'
            train_time=[]
            mask_search_time = []
            with open(file_name, 'r') as file:
                all_log_file = file.read()
                regex1 = r'(?:train time elapse: )\d+\.?\d*'
                a = re.findall(regex1, all_log_file)
                for string in a:
                    pattern = re.compile(r'\d+\.?\d*')
                    b= pattern.findall(string)
                    train_time.append(float(b[0]))
                regex2 = r'(?:mask searching time elapse: )\d+\.?\d*'
                a = re.findall(regex2, all_log_file)
                for string in a:
                    pattern = re.compile(r'\d+\.?\d*')
                    b = pattern.findall(string)
                    mask_search_time.append(float(b[0]))
            # print(train_time)
            # print(mask_search_time)
            mask_search_time= np.array(mask_search_time)
            train_time = np.array(train_time)
            table += dataset + "-" +"({})".format(device)
            table += " & " + str(round(np.average(train_time), 2)) + "$\pm$" + str(round(np.std(train_time), 2)) + " & " + str(round(np.average(mask_search_time), 2)) + "$\pm$" + str(round(np.std(mask_search_time), 2)) + \
                     " & " +  str(round(np.average(mask_search_time/train_time*100),2))+"\%"+ "$\pm$" +  str(round(np.std(mask_search_time/train_time*100),2))+ "\\\ \n"
            file.close()
    print(table)
    # x = list(map(float, a.split()))

def acc_fig():
    # datasets = ["emnist","cifar10","cifar100"]
    datasets = ["emnist","cifar10","cifar100"]
    # datasets = ["cifar10"]
    alphas = {"emnist": ["homo",  "hetero0.2",  "hetero0.1"], "cifar10": ["homo","hetero0.5", "hetero0.3",],"cifar100": ["homo","hetero0.2","hetero0.1"]}
    methods = ["fedspa",  "ditto", "fedavg", "SubAVG", "local","subsampling"]
    task_name =["EMNIST-L","CIFAR10","CIFAR100"]
    methods_name = ["FedSpa (DST)", "FedSpa (RSM)","Ditto", "FedAvg", "Sub-FedAvg", "Local","Subsampling"]
    data_setting_name=["IID", "Non-IID (setting A)","Non-IID (setting B)"]
    min_y = {"emnist": 0.7, "cifar10": 0.5, "cifar100": 0  }
    seeds = ["0","1","2"]
    # methods = [ "fedavg", "fedspa"]
    # methods = ["fedavg","fedspa","ditto"]
    fontsize = 15
    plob_num=0
    fig = figure(num=None, figsize=(12, 6), dpi=300, facecolor='w', edgecolor='k')
    fig.subplots_adjust(left=0.08, bottom=0.17, right=0.99, top=0.95, wspace=0.15, hspace=0.3)
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]

        for alpha_index in range(len(alphas[dataset])):
            alpha = alphas[dataset][alpha_index]
            identity_array = []
            dense_ratio = ["0.5"]
            for method in methods:
                if method == "fedspa":
                    for dr in dense_ratio:
                        # identity_array.append(
                        #     dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgFalse")
                        identity_array.append(
                            dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed")
                        identity_array.append(
                            dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticTrue-shared0-strict_avgTrue-seed")
                elif method == "SubAVG":
                    for dr in dense_ratio:
                        identity_array.append(dataset + "/" + method + "-dr" + dr + "-" + alpha+"-seed")
                elif method == "subsampling":
                    identity_array.append(dataset + "/" + method +"-"+ dr + "-" + alpha+"-seed")
                else:
                    identity_array.append(dataset + "/" + method + "-" + alpha+"-seed")
            # cmap = matplotlib.cm.get_cmap("rainbow", len(identity_array))
            # # 从调色板中获取颜色列表
            # colors = cmap(np.linspace(0, 1, len(identity_array)))
            # plob_num += 1
            subplob = fig.add_subplot(3, 3, dataset_index*3+alpha_index+1)
            if alpha_index % 3 == 0:
                y_label = subplob.set_ylabel(task_name[dataset_index]+' \n Accuracy', fontsize=fontsize)

            for i in range(len(identity_array)):
                if alpha_index== 0 and dataset_index == 0:
                    label = methods_name[i]
                else:
                    label = None
                acc=[]
                for seed in seeds:
                    identity = identity_array[i]+seed
                    f = open(identity, 'rb')
                    results = CPU_Unpickler(f).load()
                    if acc is None:
                        acc.append(np.array(results["test_acc"]))
                    else:
                        acc.append(np.array(results["test_acc"]))
                mean_acc=np.average(acc,axis=0)
                std = np.std(acc, axis=0)
                rounds = range(len(results["test_acc"]))
                subplob.plot(rounds, mean_acc, linewidth=2,
                markersize=3, label=label)
                subplob.fill_between(rounds, mean_acc + std, mean_acc - std, alpha=0.3)
            # subplob.tick_params(axis='both', which='major', labelsize=fontsize)
            subplob.set_ylim(bottom=min_y[dataset])
            if dataset_index==2:
                subplob.set_xlabel("Communication Rounds",fontsize=fontsize)
            # subplob.legend(fontsize=1)
            if dataset_index==0:
                subplob.set_title(data_setting_name[alpha_index], fontsize=fontsize)
            subplob.grid(True, linestyle='-', linewidth=1, )

    fig.legend(loc='lower center', bbox_to_anchor=(0.0, 0.0, 1, 1), fancybox=False, shadow=False, ncol=7,
               fontsize=12, frameon=False)
    # fig.suptitle(task_name[dataset_index],fontweight='bold', rotation='vertical',fontsize=13,x=0.025,y=0.75)
    plt.savefig("fig/" + "main_result" + ".png")
    plt.show()
    plt.close(0)

from sklearn.decomposition import PCA

# def flop_report():
#     datasets = ["emnist","cifar10","cifar100"]
#     alphas = {"emnist": ["homo", "hetero0.1", "hetero0.2"], "cifar10": ["homo", "hetero0.3", "hetero0.5"],
#               "cifar100": ["homo", "hetero0.2", "hetero0.1"]}
#     methods = ["fedspa", "ditto", "fedavg", "SubAVG", "local"]
#     methods_name = ["FedSlim", "Ditto", "FedAvg", "SubAVG", "Local"]
#     # methods = [ "fedavg", "fedspa"]
#     # methods = ["fedavg","fedspa","ditto"]
#     for dataset in datasets:
#         for alpha in alphas[dataset]:
#             fig = plt.figure(0)
#             identity_array = []
#             dense_ratio = ["0.5"]
#             for method in methods:
#                 if method == "fedspa":
#                     for dr in dense_ratio:
#                         # identity_array.append(
#                         #     dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgFalse")
#                         identity_array.append(
#                             dataset + "/" + method + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed0")
#                 elif method == "SubAVG":
#                     for dr in dense_ratio:
#                         identity_array.append(dataset + "/" + method + "-dr" + dr + "-" + alpha + "-seed0" )
#                 else:
#                     identity_array.append(dataset + "/" + method + "-" + alpha + "-seed0")
#             # cmap = matplotlib.cm.get_cmap("rainbow", len(identity_array))
#             # # 从调色板中获取颜色列表
#             # colors = cmap(np.linspace(0, 1, len(identity_array)))
#             for i in range(len(identity_array)):
#                 identity = identity_array[i]
#                 f = open(identity, 'rb')
#                 results = pickle.load(f)
#                 plt.plot( results["sum_training_flops"].cpu() , label=methods_name[i])
#                 plt.legend()
#                 print( identity+str(results["sum_training_flops"].cpu()))
#                 print(identity + str(results["sum_comm_params"]))
#             plt.legend(fontsize=15)
#             plt.grid(c='#d9d9d9')
#             plt.ylabel('Test Accuracy', fontsize=15)
#             plt.xlabel('Communication Rounds', fontsize=15)
#             plt.savefig("fig/" + dataset + alpha + ".png")
#             plt.show()
#             plt.close(0)

def erk_report():
    datasets = ["cifar100"]
    alphas = {"cifar100": ["homo", "hetero0.2", "hetero0.1"]}
    methods = ["ERK","Uniform"]
    seeds = ["0", "1", "2"]
    string = ""
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]
        for method in methods:
            string += method
            for alpha_index in range(len(alphas[dataset])):
                alpha = alphas[dataset][alpha_index]
                dense_ratio = ["0.5"]
                for dr in dense_ratio:
                    seed_results = []
                    seed_sum_comm_params = []
                    seed_sum_training_flops = []
                    for seed in seeds:
                        if method == "ERK":
                            identity = dataset + "/fedspa"  + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed"+seed
                        else:
                            identity = dataset + "/fedspa"  + "-dr" + dr + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed"+seed+"-u"
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        seed_results.append(results["test_acc"][-1] * 100)
                        seed_sum_comm_params.append(results["sum_comm_params"] / 1e9 * 32 / 8)
                        seed_sum_training_flops.append(results["sum_training_flops"].cpu() / 1e16)

                    string += " & " + str(round(np.average(seed_results), 1)) + "$\pm$" + str(
                        round(np.std(seed_results), 1))
                    string += " & " + str(round(np.average(seed_sum_comm_params), 1))
                    string += " & " + str(round(np.average(seed_sum_training_flops), 1))
            string += "\\\ \n"

        print(string)



# def ablation_global():
#     def cluster():
#         path = "cifar100/fedspa-dr0.5riglTrue-homo-staticFalse-shared0-strict_avgTrue-masks"
#         f = open(path, 'rb')
#         stat_info = pickle.load(f)
#         for i in range(len(stat_info["label_num"])):
#             stat_info["label_num"][i] = np.array(stat_info["label_num"][i]) / np.sum(stat_info["label_num"][i])
#
#         kmeans = KMeans(n_clusters=10, random_state=0).fit(stat_info["label_num"])
#         labels = [[] for i in range(len(stat_info["label_num"]))]
#         for i in range(len(stat_info["label_num"])):
#             labels[i] = np.argmax(stat_info["label_num"][i])
#         print(labels)
#         print(kmeans.labels_)
#         print(stat_info["label_num"])
#         return kmeans.labels_



# def pca_decompose(labels):
#     colors = ['c', 'b', 'g', 'r', 'm', 'y', 'k', 'deepskyblue','darkviolet','deeppink','deepskyblue','moccasin', 'orchid','royalblue','seagreen']
#     path = "cifar100/fedspa-dr0.5riglTrue-homo-staticFalse-shared0-strict_avgTrue-masks"
#     f = open(path, 'rb')
#     results = pickle.load(f)
#     final_masks = results["final_masks"]
#     flatten_mask = []
#     for mask in final_masks:
#             flatten_mask.append(torch.cat([ mask[name].flatten() for name in mask]).numpy())
#     pca = PCA(n_components=2)
#     newX = pca.fit_transform(flatten_mask)
#     for i in range(len(newX)):
#         plt.scatter(newX[i,0],newX[i,1], marker="o",  c=colors[labels[i]])
#     plt.savefig("fig/cluster"   + ".png")
#     plt.show()

# labels= cluster()
# pca_decompose(labels)
def ablation_global():
    datasets = ["cifar100"]
    alphas = {"cifar100": ["homo", "hetero0.2", "hetero0.1"]}
    methods = ["fedspa","fedavg"]
    task_name = ["CIFAR100 (ResNet18)"]
    methods_name = ["FedSpa (DST with personalized model)", "FedSpa (DST with global model)","FedAvg (global model)"]
    min_y = {"emnist": 0.7, "cifar10": 0.5, "cifar100": 0.3}
    # methods = ["fedavg","fedspa","ditto"]
    data_setting_name = ["IID", "Non-IID (" + r'$\gamma=0.2$)', "Non-IID (" + r'$\gamma=0.1$)']
    colors = ["#51C1C8", "#E96279", "#44A2D6", "#536D84",
              "#FA84F5", "b", "y", "#536D84"]
    seeds = ["0", "1", "2"]
    fontsize = 15
    for dataset_index in range(len(datasets)):
        dataset = datasets[dataset_index]
        fig = figure(num=None, figsize=(18, 5), dpi=300, facecolor='w', edgecolor='k')
        fig.subplots_adjust(left=0.08, bottom=0.21, right=0.97, top=0.95, wspace=0.15, hspace=0.4)
        for alpha_index in range(len(alphas[dataset])):
            alpha = alphas[dataset][alpha_index]
            dense_ratio = ["0.5"]
            subplob = fig.add_subplot(1, 3, alpha_index + 1)
            if alpha_index % 3 == 0:
                subplob.set_ylabel('Accuracy', fontsize=fontsize)
            for i in range(3):
                seed_acc = []
                seed_final = []
                seed_sum_comm_params = []
                seed_sum_training_flops = []
                if alpha_index % 3 == 0:
                    label = methods_name[i]
                else:
                    label = None
                for seed in seeds:
                    if i == 0:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed" + seed+"-g"
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        seed_acc.append(np.array(results["test_acc"]))
                    if i == 1:
                        identity = dataset + "/" + "fedspa-dr0.5" + "riglTrue-" + alpha + "-staticFalse-shared0-strict_avgTrue-seed" + seed+"-g"
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        seed_acc.append(np.array(results["global_model_acc"]))
                    if i == 2:
                        identity = dataset + "/" + "fedavg" + "-" + alpha+"-seed"+seed
                        f = open(identity, 'rb')
                        results = pickle.load(f)
                        seed_acc.append(np.array(results["test_acc"]))

                mean_acc = np.average(seed_acc, axis=0)
                std = np.std(seed_acc, axis=0)
                rounds = range(len(results["test_acc"]))
                subplob.fill_between(rounds, mean_acc + std, mean_acc - std, color=colors[i], alpha=0.2)
                subplob.plot(rounds, mean_acc, linewidth=2,
                             markersize=3, label=label, color=colors[i])
                # subplob.tick_params(axis='both', which='major', labelsize=fontsize)
                # subplob.set_ylim(bottom=min_y[dataset])
                # sum_training_flops
                # subplob.legend(fontsize=1)
                subplob.set_title(data_setting_name[alpha_index], fontsize=fontsize)
                subplob.grid(True, linestyle='-', linewidth=1, )
                subplob.set_xlabel('Communication Rounds', fontsize=fontsize)

        fig.legend(loc='lower center', bbox_to_anchor=(0.0, 0.0, 1, 1), fancybox=False, shadow=False, ncol=6,
                   fontsize=fontsize, frameon=False)
        fig.suptitle(task_name[dataset_index], fontweight='bold', rotation='vertical', fontsize=15, x=0.025, y=0.8)
        plt.savefig("fig/global" + dataset + ".png")
        plt.show()
        plt.close(0)

# acc_fig()
# wall_time()
# main_report()
# ablation_rigl()
# ablation_erk()
# erk_report()
# ablation_global()
# sparse_report()
# ablation_sparse()
# flop_report()
# ablation_initialization()
acc_round_report()
# circle_plot()