import seaborn as sns
import os
import tqdm
import torch
import pickle
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import copy
import time
from tools import *
from server_utils import *
from models import *
from clean import *
from tools import *
from server_utils import *
from cluster_alg import *
from apcluster import *
from new_server_utils import *
from options import args_parser
from utils import exp_details, get_datasets, get_pub_datasets,get_public_datasets
from update import LocalUpdate

from test import test_inference
from tools_client import *
from sample import *
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
import warnings
warnings.filterwarnings('ignore')



if __name__ == '__main__':
    path_project = os.path.abspath('..')
     
    exp_details(args)
    device = args.device
    round_times = []  
    ap_cluster_times =[]
    reassembly_times = []
    global_times = []
   
    Personal_times  = []
    c2prkd_distll_times  = []
    ####
    cleanup_memory()
    # server_train_data = get_server_train(dataset, args)

    train_dataset, test_dataset, dict_users_train, dict_users_test, server_train_data = get_datasets(args)

    if args.public_dataset == args.dataset:
        data_at_server = server_train_data
    else:
        data_at_server = get_public_datasets(args)
    if args.public_dataset == 'cifar10':
        input_size = 32 * 32 * 3
    elif args.public_dataset == 'mnist':
        input_size = 32 * 32 * 3
    elif args.public_dataset == 'svhn':
        input_size = 32 * 32 * 3
    else:
        print('wrong public dataset name')

    if args.model_same == 1:
        if args.model == 'CNN1':
            model_indicator = 'A'
        elif args.model == 'CNN2':
            model_indicator = 'B'
        elif args.model == 'CNN3':
            model_indicator = 'C'
        elif args.model == 'CNN4':
            model_indicator = 'D'
        else:
            print('wrong model name')
        client_model_dict = {}
        for client_idx in range(args.num_users):
            client_model_dict[client_idx] = model_indicator
    # model_dict, server_train_loss = Server.train_server(args = args, dataset = server_train_data)
    global_layer_order  = build_global_layer_order_detailed()
    #global_layer_order: {'conv1': 0, 'conv2': 1, 'conv3': 2, 'conv4': 3, 'conv5': 4, 'conv6': 5, 'fc1': 6, 'fc2': 7, 'fc3': 8, 'fc4': 9, 'linear': 10}
    client_model_dict = model_generationhom(args.num_users) # input is number_users, the output key is the index, value is the model name
    # comment: client_model_dict = {0:'A', 1:'A', 2:'B', 3:'B', .....8:'E',9:'E'}
    local_avg_train_losses_list, local_avg_train_accuracy_list = [],[]
    local_avg_test_losses_list, local_avg_test_acc_list = [], []
    local_avg_test_accuracy_list = []
    print_every = 1
    
    #generate local model dict: key is the idx, value is CNN1 CNN2 CNN3 CNN4 CNN5
    model_assign_dict, client_id_model_name = new_model_generatorhom(args.num_users)
    # comment: model_assign_dict's key is 0,1,2,3,4,5,6,7,8,9 value is a real model, like a CNN model
    # client_id_model_name is the same as client_model_dict, but the value is the model name, not the model itself
    print("start server and client communication:")
    previous_user_list = []
    current_user_id_modelweights_dict = {}
    final_model_dict = {}  
    for epoch in tqdm(range(args.epochs)):
        round_start_time = time.time() 
        local_weights, local_losses = [], []

        
        local_test_losses, local_test_accuracy = [],[]
        print('communication round: {} \n'.format(epoch))

        
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False).tolist()
        
        
        for idx in idxs_users:
            test_loader_for_each_client = torch.utils.data.DataLoader(
                dataset=DatasetSplit(train_dataset, dict_users_test[idx]),
                shuffle=True,
            )
            local_model = LocalUpdate(args = args, dataset = train_dataset, idxs = dict_users_train[idx])

            
            if epoch == 0:
                w, loss = local_model.update_weights(model = copy.deepcopy(model_assign_dict[idx]), global_round = epoch)
                trained_local_model = copy.deepcopy(model_assign_dict[idx])
                trained_local_model.load_state_dict(w)
                trained_local_model.to(device)
                test_acc, test_loss = test_inference(args, trained_local_model, test_loader_for_each_client)
            
            
            else:
                
                if idx in previous_user_list:
                    
                    teacher_model = final_model_dict[idx]
                    currrent_student_model = current_user_id_modelweights_dict[idx]
                    c2prkd_start_time = time.time() 
                    w, loss,c2prkd_distll_times = local_model.k_distll(student_model = currrent_student_model, teacher_model = teacher_model)
                    
                    trained_local_model = copy.deepcopy(current_user_id_modelweights_dict[idx])
                    trained_local_model.load_state_dict(w)
                    test_acc, test_loss = test_inference(args, trained_local_model, test_loader_for_each_client)

                else:
                    w, loss = local_model.update_weights(model = copy.deepcopy(model_assign_dict[idx]), global_round = epoch)
                    trained_local_model = copy.deepcopy(model_assign_dict[idx])
                    trained_local_model.load_state_dict(w)
                    test_acc, test_loss = test_inference(args, trained_local_model, test_loader_for_each_client)

            local_losses.append(copy.deepcopy(loss.detach().cpu().item()))
            local_test_losses.append(test_loss) 
            local_test_accuracy.append(test_acc)
            temp_temp_model = copy.deepcopy(model_assign_dict[idx])
            temp_temp_model.load_state_dict(w)
            
            current_user_id_modelweights_dict[idx] = temp_temp_model

        loss_avg = sum(local_losses) / len(local_losses)
        loss_avg_test_loss = sum(local_test_losses) / len(local_test_losses)
        local_avg_test_losses_list.append(loss_avg_test_loss)
        loss_avg_test_accuracy = sum(local_test_accuracy) / len(local_test_accuracy)
        local_avg_test_accuracy_list.append(loss_avg_test_accuracy)
     
        

        if (epoch+1) % print_every == 0:
            print(f' \nAvg Stats after {epoch+1} global rounds:')
            print(f'Local Avg Training Loss : {loss_avg}')
            print(f'Local Avg Test Loss : {loss_avg_test_loss}')
            print(f'Local Avg Test Accuracy : {loss_avg_test_accuracy}')
        del local_losses, local_test_losses, local_test_accuracy
        
        cleanup_memory()
       
        previous_user_list = list(map(int, previous_user_list))
        
        
        
        local_model_dict = {}
        for idx in idxs_users:
            if idx in current_user_id_modelweights_dict:
                local_model_dict[idx] = current_user_id_modelweights_dict[idx]
        client_model_info = client_model_dict


        server_output_dict = layer_feature(idxs_users, local_model_dict, server_train_data, client_model_info, args)
        # in server_output_dict, key is a string '[client id, layer index]', value is [total_feat_in, total_feat_out, layer_name, total_feat_in_size, total_feat_out_size]
        model_layer_index_to_model_layer_name = process_format(server_output_dict)
        #{'[1, 0]': ['1', 'conv1'], '[1, 1]': ['1', 'conv2'], '[1, 2]': ['1', 'fc1'], '[1, 3]': ['1', 'fc2'], 
        # '[5, 0]': ['5', 'conv1'], '[5, 1]': ['5', 'conv2'], '[5, 2]': ['5', 'fc1'], '[5, 3]': ['5', 'fc2'], 
        # '[13, 0]': ['13', 'conv1'], '[13, 1]': ['13', 'conv2'], '[13, 2]': ['13', 'fc1'], '[13, 3]': ['13', 'fc2'], 
        # '[31, 0]': ['31', 'conv1'], '[31, 1]': ['31', 'conv2'], '[31, 2]': ['31', 'conv3'], '[31, 3]': ['31', 'conv4'], '[31, 4]': ['31', 'conv5'],
        #  '[31, 5]': ['31', 'fc1'], '[31, 6]': ['31', 'fc2'], '[31, 7]': ['31', 'fc3'], '[31, 8]': ['31', 'fc4'], '[44, 0]
        # ': ['44', 'conv1'], '[44, 1]': ['44', 'conv2'], '[44, 2]': ['44', 'conv3'], '[44, 3]': ['44', 'conv4'], '
        # [44, 4]': ['44', 'conv5'], '[44, 5]': ['44', 'conv6'], '[44, 6]': ['44', 'fc1'], '[44, 7]': ['44', 'fc2'], '[44, 8]': ['44', 'linear']}
        # in model_layer_index_to_model_layer_name,  key is a string '[client id, layer index]', value is [client id, layer name]
        # client_id_model_name is the same as client_model_dict, but the value is the model name, not the model itself
        prepare_layer_size_dict = {}
        for key, value in server_output_dict.items():
            prepare_layer_size_dict[str(model_layer_index_to_model_layer_name[key])] = value[-2:]
        # prepare_layer_size_dict key is '[client id, layer name]', value is [total_feat_in_size, total_feat_out_size]

        server_output_dict_only_embedding = extract_embedding(server_output_dict)
        server_output_dict_only_size = extrac_size(server_output_dict)
        # in server_output_dict_only_embedding, key is a string '[client id, layer index]', value is [total_feat_in, total_feat_out]
        server_output_dict_same_embedding = embedding_process(server_output_dict)
        #print(server_output_dict_same_embedding)
        # in server_output_dict_same_embedding, key is a string '[client id, layer index]', value is [total_feat_in, total_feat_out]

       # cluster_results = k_cluster(server_output_dict_same_embedding, args.cluster_num, -5, 10)
        #cluster_results, center_layer_info,cka_matrix = ap(server_output_dict_same_embedding,preference_quantile=0.7, damping=0.9, max_iter=1000) 
        
        
        
       
        #cluster_results = k_cluster(server_output_dict_same_embedding, args.cluster_num, -5, 10)
        cluster_start_time = time.time() 
        cluster_dict, center_layer_info, similarity_matrix = cross_model_layer_clustering(server_output_dict_same_embedding, damping=0.8, max_iter=500, bs=2048, plot_matrix=False)
        ap_cluster_time = time.time() - cluster_start_time
        ap_cluster_times.append(ap_cluster_time)

        #print(cluster_dict)
        #print(cka_matrix)
        


        # for value in cluster_results.values():
        #     if len(value) == 0:
        #         cluster_results = k_cluster(server_output_dict_same_embedding, args.cluster_num, -5, 10)
        #         break
                # if len(value) == 0:
                #     cluster_results = k_cluster(server_output_dict_same_embedding, args.cluster_num, -5, 2)
        # in cluster_results, key is a cluster id, value is a list of [CNN, layer index]

        ##### Step 2: assemble layers into a candidate pool  #########
        # layer_cluster_result = copy.deepcopy(cluster_results)
        # layer_cluster_result is a dictionary, key is a cluster id, value is a string
        # e.g. layer_cluster_result = {0: [[client id, 0],[client id, 1],[client id, 4]]
        #                              1: [[client id, 2],[client id, 2]]
        #                              2: [[client id, 1],[client id, 3]]}
        
        cleanup_memory()
        center_layer_name = {}
        for key, value in center_layer_info.items():
            news_value = []
            center_layer_name[key] = []
            #item_converted = ast.literal_eval(value)
            temp_temp_temp = model_layer_index_to_model_layer_name[value]
            news_value.append(temp_temp_temp)
            center_layer_name[key] = news_value
        #print(center_layer_name)
        #{0: [['1', 'conv2']], 1: [['5', 'conv2']], 2: [['13', 'conv2']], 3: [['31', 'conv2']], 4: [['31', 'fc1']], 5: [['44', 'conv2']], 6: [['44', 'fc1']]}


        for_find_comb_input = {}
        for key, value in cluster_dict.items():
            new_value = []
            for item in value:
                temp_temp_temp = model_layer_index_to_model_layer_name[item]
                new_value.append(temp_temp_temp)
            for_find_comb_input[key] = new_value
        # ['[5, 0]', '[5, 1]', '[5, 2]', '[5, 3]']--->[['1', 'conv1'], ['1', 'conv2'], ['1', 'fc1'], ['1', 'fc2']]

        
        rea_start_time = time.time()  
        candidate_model_combine = newsample_models(for_find_comb_input, 4,6, args.expected_num_models,global_layer_order, center_layer_name)
        reassembly_time = time.time() - rea_start_time
        reassembly_times.append(reassembly_time)
        print(candidate_model_combine)
        #[['1', 'conv2'], ['5', 'fc1'], ['13', 'fc2'], ['31', 'fc3'], ['44', 'linear']]
        
        #formatted_models = [model['full'] for model in candidate_model_combine]
        #print(candidate_model_combine)
        #print(formatted_models)




        #model_sampler = ModelSampler(min_layers=4, max_layers=6)
        #candidate_model_combine = model_sampler.newsample_models(
        #    input_cluster=for_find_comb_input,
        #    expected_num_models=args.expected_num_models,
        #    center_layer_info=center_layer_info
        #)

        # candidate_model_combine is a list [[['client index','layer name'],...]]
        # print(prepare_layer_size_dict)
        # print(getattr(local_model_dict[7], 'conv1')())
        # for v in prepare_layer_size_dict.values():
        #     print(v[0])
        #client_id_model_name is the same as client_model_dict, but the value is the model name, not the model itself
        
        
        
        #[['1', 'conv2'], ['5', 'fc2'], ['13', 'fc2'], ['31', 'fc4'], ['44', 'linear']]--->[[1, 'conv2'], [5, 'fc2'], [13, 'fc2'], [31, 'fc4'], [44, 'linear']]
        #candidate_model_combine_client_id_to_model_name = process_type_list(formatted_models)
        candidate_model_combine_client_id_to_model_name = process_type_list(candidate_model_combine)
        #print(candidate_model_combine_client_id_to_model_name)
        
        candidate_model_combine_show_model = candidate_model_combine_client_id_to_model_name
        #print(candidate_model_combine_show_model)
        # candidate_model_combine_show_model is a list of model name, like [[1, 'conv2'], [5, 'fc1'], [13, 'fc2'], [31, 'fc3'], [44, 'linear']]
        #candidate_model_combine_show_model[0] =[[client id, layer name],]

        ###### Step 2-1: for each output models, we training-free only one  global model   #########
        #Combination model with randomly initialized weights


        global_start_time = time.time() 
        model_pool_with_mlp = newcomb_with_mlp(candidate_model_combine_show_model, model_assign_dict, prepare_layer_size_dict, input_size = input_size)
        #model_pool_with_mlp:word
        ##global model selection-traningfree
        best_global_model, selection_info = global_model_selection(model_pool_with_mlp, server_train_data, args.zero_proxy,args.zero_bs)
        print(selection_info)

        ###
        cleanup_memory()
        ##selection_info:{'best_model_index': 3, 'best_score': 10.740560388492888, 'all_scores': [{'index': 0, 'score': 10.191102940755545}, 
        # {'index': 1, 'score': 9.566577069011025}, {'index': 2, 'score': 10.443815268931012}, {'index': 3, 'score': 10.740560388492888}, 
        # {'index': 4, 'score': 10.720993367324711}], 'proxy_metric': 'naswot', 'total_candidates': 5, 'valid_candidates': 5}


        ###global model mathcing nerual network
        ##从selection_inretrieve indexes from selection_info, find original combination models corresponding to the indexes,
        #  and then concatenate them
        ####([[...], [...], [...], [...], [...], [...], [...], [...]])
        best_global_model_index_to_best_candidate_model,best_model_index = get_best_combined_model(selection_info, candidate_model_combine_show_model)
        global_time = time.time() - global_start_time
        global_times.append(global_time)
        ###best_global_model is Mixed weight status
        #best_global_model = global_comb_with_mlp(best_global_model_index_to_best_candidate_model, local_model_dict, prepare_layer_size_dict, input_size = input_size)
        ###best_global_model_index_to_best_candidate_model=[[1, 'conv1'], [5, 'conv1'], [13, 'conv1'], [31, 'conv1'], [44, 'conv1'], [1, 'conv2'], [5, 'fc2'], [13, 'fc2'], [31, 'fc4'], [44, 'linear']]



        ##local model index and client name:[[[...], [...], [...], [...]], [[...], [...], [...], [...]], [[...], [...], [...], [...]], [[...], [...], [...], [...], [...], [...], [...], [...], [...]], [[...], [...], [...], [...], [...], [...], [...], [...], [...]]]
        local_model_index_to_client_name = get_local_model_index_to_client_name(prepare_layer_size_dict)
        personal_start_time = time.time() 
        personalized_combin_models = personalized_model_reassembly(best_global_model_index_to_best_candidate_model, local_model_index_to_client_name)
        ###personal models comb_with_mlp
        final_personalized_models_list_ = comb_with_mlp(personalized_combin_models, local_model_dict, prepare_layer_size_dict, input_size = input_size)
        
        
        
        ###final_personalized_models_list=five combined personalized models
        if args.supervised: 
            final_model_dict = personal_model_fine_tuning(final_personalized_models_list_,local_model_dict,server_train_data)
        else:
            final_model_dict = personal_model_fine_tuning2(final_personalized_models_list_,local_model_dict,server_train_data)
        personal_time = time.time() - personal_start_time
        Personal_times.append(personal_time)
        for idx in idxs_users:
            if idx in final_model_dict:
                final_model_dict[idx] = final_model_dict[idx]
        for idx in list(final_model_dict.keys()):
            if idx not in idxs_users:
               
                final_model_dict[idx] = final_model_dict[idx].cpu()
        previous_user_list = idxs_users
        # final_model_dict is a dictionary, key is client index, value is the best match model from the candidate pool
        #
        del (
            local_model_dict,
            server_output_dict, model_layer_index_to_model_layer_name,
            prepare_layer_size_dict, cluster_dict, center_layer_info, 
            similarity_matrix, best_global_model, selection_info, 
            best_global_model_index_to_best_candidate_model,
            local_model_index_to_client_name, personalized_combin_models,
            final_personalized_models_list_
        )
        
        cleanup_memory()
        round_end_time = time.time()
        round_time = round_end_time - round_start_time
        round_times.append(round_time)
        #if args.supervised:
        #    best_model_dict = match_best_model(model_pool_with_mlp,local_model_dict,server_train_data)
        #else:
        #    best_model_dict = match_best_model_2(model_pool_with_mlp,local_model_dict,server_train_data) 


        #model_pool_with_mlp = comb_with_mlp(candidate_model_combine_show_model, local_model_dict, prepare_layer_size_dict, input_size = input_size)
        # prepare_layer_size_dict key is '[client id,layer name]', value is a [input size, output size]
        # local_model_dict: KEY is client index, VALUE is a model
        # candidate_model_combine: [[client index,'layer name'],...]
        # print(model_pool_with_mlp[8](torch.randn(64, 3, 32, 32).cuda()))

        # for each model in model_pool, we train it and we train the local model as well.
        # we compare the output logits and do matching
        ##### Step 3: for each client, we assign a model to that  #########
        # local_model_dict, key is client index, value is a model
              # list is the candidate model pool, local_model_dict is previous client model dict,
        # print(best_model_dict)
        # output length is the number of clients
        
        

    save_path = './exp_result/{}_{}_com{}_iid_{}_E_{}_sfine_{}_teacherweight_{}_cluster_{}_frac_{}_user_num_{}/'.format(args.dataset, args.public_dataset, args.epochs,
                       args.iid, args.local_ep, args.sfine, args.alpha, args.cluster_num, args.frac, args.num_users)
# Check whether the specified path exists or not
    isExist = os.path.exists(save_path)

    if not isExist:
        os.makedirs(save_path)
        print("The new directory is created!")
    
        
    import matplotlib
    import matplotlib.pyplot as plt
    matplotlib.use('Agg')

    
    with open(save_path + 'local_avg_train_loss.txt', 'w') as filehandle:
        for listitem in local_avg_train_losses_list:
            filehandle.write('%s\n' % listitem)

    with open(save_path + 'local_avg_test_losses_list.txt', 'w') as filehandle:
        for listitem in local_avg_test_losses_list:
            filehandle.write('%s\n' % listitem) 

    with open(save_path + 'local_avg_test_accuracy_list.txt', 'w') as filehandle:
        for listitem in local_avg_test_accuracy_list:
            filehandle.write('%s\n' % listitem) 
   
    with open(save_path + 'every round_times.txt', 'w') as filehandle:
        filehandle.write('Round Times (seconds):\n')
        for i, round_time in enumerate(round_times):
            filehandle.write(f'Round {i+1}: {round_time:.2f}\n')
    with open(save_path + 'AP_cluster_times.txt', 'w') as filehandle:
        filehandle.write('AP-Clustering Times (seconds):\n')
        for i, ap_cluster_time in enumerate(ap_cluster_times):
            filehandle.write(f'Round {i+1}: {ap_cluster_time:.2f}\n')
    with open(save_path + 'c2prKD_times.txt', 'w') as filehandle:
        filehandle.write('c2pr-KD Times (seconds):\n')
        for i, c2prkd_distll_time in enumerate(c2prkd_distll_times):
            filehandle.write(f'Round {i+1}: {c2prkd_distll_time:.2f}\n')
    with open(save_path +'reassembly_times.txt', 'w') as filehandle:
        filehandle.write('Reassembly Times (seconds):\n')
        for i, reassembly_time in enumerate(reassembly_times):
            filehandle.write(f'Round {i+1}: {reassembly_time:.2f}\n')
    with open(save_path +'matching_times.txt', 'w') as filehandle:
        filehandle.write('Matching Times (seconds):\n')
        for i, personal_time in enumerate(Personal_times):
            filehandle.write(f'Round {i+1}: {personal_time:.2f}\n') 
    with open(save_path +'global_times.txt', 'w') as filehandle:
        filehandle.write('global Times (seconds):\n')
        for i,global_time in enumerate(global_times):
            filehandle.write(f'Round {i+1}: {global_time:.2f}\n')


    print("*******last 3 avg local acc****************")
    print("{}".format(sum(local_avg_test_accuracy_list[-3:])/3))
    print("********************************************")

    # Plot Loss curve
    plt.figure()
    plt.title('Local Average Training Loss vs Communication rounds')
    plt.plot(range(len(local_avg_train_losses_list)), local_avg_train_losses_list, color='r')
    plt.ylabel('Training loss')
    plt.xlabel('Communication Rounds')
    plt.savefig(save_path + 'fed_train_loss.png')
    
    # # Plot Average Accuracy vs Communication rounds
    plt.figure()
    plt.title('Local Average Test Loss vs Communication rounds')
    plt.plot(range(len(local_avg_test_losses_list)), local_avg_test_losses_list, color='k')
    plt.ylabel('Test Loss')
    plt.xlabel('Communication Rounds')
    plt.savefig(save_path + 'fed_test_loss.png')

    plt.figure()
    plt.title('Local Average Test Accuracy vs Communication rounds')
    plt.plot(range(len(local_avg_test_accuracy_list)), local_avg_test_accuracy_list, color='r')
    plt.ylabel('Test accuracy')
    plt.xlabel('Communication Rounds')
    plt.savefig(save_path + 'fed_test_accuracy.png')
    

    
    plt.figure()
    plt.title('Round Time vs Communication rounds')
    plt.plot(range(1, len(round_times)+1), round_times, color='b', marker='o')
    plt.xlabel('Communication Rounds')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.savefig(save_path + 'round_times.png')
    ##K-Clustering时间
    plt.figure()
    plt.title('ap-Clustering Time vs Communication rounds')
    plt.plot(range(1, len(ap_cluster_times)+1), ap_cluster_times, color='g', marker='o')    
    plt.xlabel('Communication Rounds')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.savefig(save_path + 'ap_cluster_times.png')
   
    plt.figure()
    plt.title('c2pr-KD Time vs Communication rounds')
    plt.plot(range(1, len(c2prkd_distll_times)+1), c2prkd_distll_times, color='y', marker='o')
    plt.xlabel('Communication Rounds')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.savefig(save_path + 'c2prKD_times.png')
    
    plt.figure()
    plt.title('Reassembly Time vs Communication rounds')
    plt.plot(range(1, len(reassembly_times)+1), reassembly_times, color='m', marker='o')
    plt.xlabel('Communication Rounds')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.savefig(save_path +'reassembly_times.png')
    
    plt.figure()
    plt.title('Matching Time vs Communication rounds')
    plt.plot(range(1, len(Personal_times)+1), Personal_times, color='c', marker='o')
    plt.xlabel('Communication Rounds')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.savefig(save_path +'personal_times.png')
  
    plt.figure()
    plt.title('global model Time vs Communication rounds')
    plt.plot(range(1, len(global_times)+1), global_times, color='c', marker='o')
    plt.xlabel('Communication Rounds')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.savefig(save_path +'globalmodel_times.png')





        
        












