import json
import random
from open_fl_net import net, CNN_SVHN, SimpleVGG, ResNet10
from open_fl_dataloader import get_dataset_dict
from open_fl_utils import Fedavg_local_weight, test_inference, client, create_distinct_half_SVHN, create_distinct_half_mnist_fmnist, create_distinct_labels_for_10_clients_mnist_fmnist, Dirichlet_disbuted_classes, distinct_class_each_device, filter_dataset_by_classes, distinct_half, generate_device_lists, distribute_labels_in_batches,  distribute_labels_slight_overlap_10_clients
from open_fl_parser import parse_arguments
import time
from matplotlib import pyplot as plt
import torch
import numpy as np
import math
import os
import copy
import matplotlib.ticker as mtick
from datetime import datetime
import pytz
from torch.utils.data import DataLoader, Subset
from torchvision.models import resnet18, densenet121, squeezenet1_1, resnet34

def synchronous_FL_training_mnist_noniid(args, dataset_name, client_list, type_testing_dataset_dict, training_order_epoch_task_list, datetime_string):

    seed = 0
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
    regularized_or_not = args.regularized_or_not

    trained_pilot_grad_or_not = args.trained_pilot_grad_or_not

    fixed_pilot_grad_or_not = args.fixed_pilot_grad_or_not

    trained_pilot_diff_as_grad_or_not = args.trained_pilot_diff_as_grad_or_not

    if trained_pilot_grad_or_not  + trained_pilot_diff_as_grad_or_not + fixed_pilot_grad_or_not > 1 or trained_pilot_grad_or_not  + trained_pilot_diff_as_grad_or_not + fixed_pilot_grad_or_not <= 0:
        raise ValueError("Give only 1 to one of trained_pilot_grad_or_not, trained_pilot_diff_as_grad_or_not and fixed_pilot_grad_or_not")

    gpu_index = args.gpu_index

    num_of_set_of_device_to_average_warm_model = args.num_of_set_of_device_to_average_warm_model

    sync_num_training = args.sync_num_training

    num_of_iterations_client_fixed_warm = args.num_of_iterations_client_fixed_warm

    num_of_iterations_client_fixed_apply = args.num_of_iterations_client_fixed_apply

    if num_of_iterations_client_fixed_warm < 1:
        raise ValueError("num_of_iterations_client_fixed_warm should be at least 1")
    
    if num_of_iterations_client_fixed_apply < 1:
        raise ValueError("num_of_iterations_client_fixed should be at least 1")
    
    learning_rate_map = {
        "SVHN": args.SVHN_lr,
        "mnist": args.mnist_lr,
        "fmnist": args.fmnist_lr,
        "cifar10": args.cifar10_lr,
        "cifar100": args.cifar100_lr
    }

    lr = learning_rate_map[dataset_name]

    momentum = args.momentum

    gpu_cpu_device = torch.device(f'cuda:{gpu_index}' if torch.cuda.is_available() else 'cpu')

    if dataset_name == "SVHN":
        global_model = CNN_SVHN()
        global_model.to(gpu_cpu_device)

    elif dataset_name == "cifar10":
        if args.cifar_model == "resnet18":
            
            global_model = resnet18(weights='DEFAULT') 
            global_model.fc = torch.nn.Linear(global_model.fc.in_features, 10) 

        elif args.cifar_model == "simpleVGG":
            global_model = SimpleVGG(num_classes=10)

        else:
            raise ValueError("cifar_10 only accepts resnet18 or simpleVGG")

        global_model.to(gpu_cpu_device)

    elif dataset_name == "cifar100":

        if args.cifar_model == "resnet18":

            global_model = resnet18(weights='DEFAULT') 
            global_model.fc = torch.nn.Linear(global_model.fc.in_features, 100) 

        elif args.cifar_model == "resnet34":

            global_model = resnet34(weights='DEFAULT') 
            global_model.fc = torch.nn.Linear(global_model.fc.in_features, 100) 

        elif args.cifar_model == "densenet121":

            global_model = densenet121(weights='DEFAULT') 
            global_model.classifier = torch.nn.Linear(1024, 100) 

        else:
            raise ValueError("cifar100 only accepts resnet18 or resnet34")

        global_model.to(gpu_cpu_device)
        
    else:
        global_model = net(dataset_name)
        global_model.to(gpu_cpu_device)

    initial_random_global_model = global_model.state_dict()

    pilot_parameter_name_to_grad_dict = {}
    for name, param in global_model.named_parameters():
        pilot_parameter_name_to_grad_dict[name] = torch.full(param.shape, 0).to(gpu_cpu_device).float()

    pilot_iteration_to_model_grad_l_d = [copy.deepcopy(pilot_parameter_name_to_grad_dict) for _ in range(sync_num_training - num_of_set_of_device_to_average_warm_model)]

    init_warm_global_model_weight_list = [dict() for _ in range(sync_num_training)]
    no_init_warm_global_model_weight_list = [dict() for _ in range(sync_num_training)]

    sync_aggre_acc_list_no_init_select = [0] * (len(training_order_epoch_task_list)+1)

    sync_aggre_acc_list_init_select = [0] * (len(training_order_epoch_task_list)+1)

    initial_test_dataloder = DataLoader(type_testing_dataset_dict[dataset_name], batch_size= args.batch_size_training, shuffle=True)

    init_global_acc, _ = test_inference(args, gpu_cpu_device, global_model, initial_test_dataloder)

    sync_aggre_acc_list_no_init_select[0] = init_global_acc
    sync_aggre_acc_list_init_select[0] = init_global_acc

    similarity_normalized_value_check_dict = {}
    similarity_original_value_check_dict = {}

    average_warm_model_as_pilot_model_dict = {k: torch.full(v.shape, 0).to(gpu_cpu_device).float() for k, v in global_model.state_dict().items()}

    pilot_stage = training_order_epoch_task_list[:num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model]

    application_stage = training_order_epoch_task_list[num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:]

    
    for sync_idx, current_client_idx_list in enumerate(pilot_stage):

        print("")
        print("sync_idx", sync_idx)

        # if dataset_name == "SVHN":
        #     if sync_idx %700 == 0:
        #         lr = lr * 0.9
        # elif dataset_name in {"cifar10", "cifar100"}:
        #     if sync_idx %30 == 0:
        #         lr = lr * 0.9
        # else:
        #     pass
        scaled_local_model_parameter = [0 for _ in range(args.max_num_client)] 
        local_weight_dict_type = Fedavg_local_weight(client_list, current_client_idx_list)

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            num_changes_in_set_of_device = sync_idx//num_of_iterations_client_fixed_warm
        else:
            num_changes_in_set_of_device = (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)//num_of_iterations_client_fixed_apply + num_of_set_of_device_to_average_warm_model

        change_point_indicator = False

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            if sync_idx%num_of_iterations_client_fixed_warm == 0:
                change_point_indicator = True
        else:
            if (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)%num_of_iterations_client_fixed_apply == 0:
                change_point_indicator = True

        if change_point_indicator:
            lr = learning_rate_map[dataset_name]
        else:
            if dataset_name in {"cifar10", "cifar100"}:
                if sync_idx %100 == 0:
                    lr = lr * 0.5
            # elif dataset_name == "SVHN":
            #     if sync_idx % num_of_iterations_client_fixed_warm >= 70 :
            #         lr = learning_rate_map[dataset_name] * 100 / (100 + sync_idx % num_of_iterations_client_fixed_warm)
            else:
                pass


        unique_classes = set()

        for client_idx in current_client_idx_list:
            
            print(f'Client Index {client_idx}')
            local_model_parameter_dict = client_list[client_idx].local_training_sync(gpu_cpu_device, global_model.state_dict(), dataset_name, lr, momentum, regularized_or_not)
            
            scaled_local_model_parameter_dict = {}
            for k, v in local_model_parameter_dict.items():
                scaled_local_model_parameter_dict[k] = local_weight_dict_type[client_idx] * v
            scaled_local_model_parameter[client_idx] = scaled_local_model_parameter_dict

            unique_classes.update(client_list[client_idx].set_of_classes)

        global_state_dict_weighted = global_model.state_dict()
        for k, v in global_state_dict_weighted.items():
            initialized_zero = torch.full(v.shape, 0).to(gpu_cpu_device).float()
            for i in scaled_local_model_parameter:
                if i != 0:
                    initialized_zero += i[k]
            global_state_dict_weighted[k] = initialized_zero

        save_warm_model_indicator = False

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            if sync_idx%num_of_iterations_client_fixed_warm == num_of_iterations_client_fixed_warm - 1:
                save_warm_model_indicator = True
        else:
            if (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)%num_of_iterations_client_fixed_apply == num_of_iterations_client_fixed_apply - 1:
                save_warm_model_indicator = True
        
        if save_warm_model_indicator:
            init_warm_global_model_weight_list[num_changes_in_set_of_device] = global_state_dict_weighted
            no_init_warm_global_model_weight_list[num_changes_in_set_of_device] = global_state_dict_weighted

        global_model.load_state_dict(global_state_dict_weighted)

        test_dataloader = filter_dataset_by_classes(type_testing_dataset_dict[dataset_name], unique_classes, args.batch_size_training)

        acc, global_loss = test_inference(args, gpu_cpu_device, global_model, test_dataloader)

        sync_aggre_acc_list_no_init_select[sync_idx+1] = acc
        sync_aggre_acc_list_init_select[sync_idx+1] = acc


    for sync_idx, current_client_idx_list in enumerate(application_stage):

        sync_idx = sync_idx + num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model
        print("")
        print("sync_idx", sync_idx)

        # if dataset_name == "SVHN":
        #     if sync_idx %700 == 0:
        #         lr = lr * 0.9
        # elif dataset_name in {"cifar10", "cifar100"}:
        #     if sync_idx %100 == 0:
        #         lr = lr * 0.5
        # else:
        #     pass

        # if sync_idx >= 70 and dataset_name == "SVHN":
        #     lr = learning_rate_map[dataset_name] * 100 / (100 + sync_idx)

        scaled_local_model_parameter = [0 for _ in range(args.max_num_client)] 
        local_weight_dict_type = Fedavg_local_weight(client_list, current_client_idx_list)

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            num_changes_in_set_of_device = sync_idx//num_of_iterations_client_fixed_warm
        else:
            num_changes_in_set_of_device = (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)//num_of_iterations_client_fixed_apply + num_of_set_of_device_to_average_warm_model

        change_point_indicator = False

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            if sync_idx%num_of_iterations_client_fixed_warm == 0:
                change_point_indicator = True
        else:
            if (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)%num_of_iterations_client_fixed_apply == 0:
                change_point_indicator = True

        if change_point_indicator:
            lr = learning_rate_map[dataset_name]
        else:
            if dataset_name in {"cifar10", "cifar100"}:
                if sync_idx %100 == 0:
                    lr = lr * 0.5
            # elif dataset_name == "SVHN":
            #     if sync_idx % num_of_iterations_client_fixed_apply >= 70:
            #         lr = learning_rate_map[dataset_name] * 100 / (100 + sync_idx % num_of_iterations_client_fixed_apply)
            else:
                pass

        if sync_idx == num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            global_model.load_state_dict(initial_random_global_model) 

        unique_classes = set()

        for client_idx in current_client_idx_list:
            
            local_model_parameter_dict = client_list[client_idx].local_training_sync(gpu_cpu_device, global_model.state_dict(), dataset_name, lr, momentum, regularized_or_not)
            
            scaled_local_model_parameter_dict = {}
            for k, v in local_model_parameter_dict.items():
                scaled_local_model_parameter_dict[k] = local_weight_dict_type[client_idx] * v
            scaled_local_model_parameter[client_idx] = scaled_local_model_parameter_dict

            unique_classes.update(client_list[client_idx].set_of_classes)

        global_state_dict_weighted = global_model.state_dict()
        for k, v in global_state_dict_weighted.items():
            initialized_zero = torch.full(v.shape, 0).to(gpu_cpu_device).float()
            for i in scaled_local_model_parameter:
                if i != 0:
                    initialized_zero += i[k]
            global_state_dict_weighted[k] = initialized_zero

        save_warm_model_indicator = False

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            if sync_idx%num_of_iterations_client_fixed_warm == num_of_iterations_client_fixed_warm - 1:
                save_warm_model_indicator = True
        else:
            if (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)%num_of_iterations_client_fixed_apply == num_of_iterations_client_fixed_apply - 1:
                save_warm_model_indicator = True

        if save_warm_model_indicator:
            if num_changes_in_set_of_device > len(no_init_warm_global_model_weight_list)-1:
                pass
            else:
                no_init_warm_global_model_weight_list[num_changes_in_set_of_device] = global_state_dict_weighted

        global_model.load_state_dict(global_state_dict_weighted) 

        test_dataloader = filter_dataset_by_classes(type_testing_dataset_dict[dataset_name], unique_classes, args.batch_size_training)

        acc, global_loss = test_inference(args, gpu_cpu_device, global_model, test_dataloader)
        sync_aggre_acc_list_no_init_select[sync_idx+1] = acc
            

    for sync_idx, current_client_idx_list in enumerate(application_stage):

        sync_idx = sync_idx + num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model
        print("")
        print("sync_idx", sync_idx)
        
        # if dataset_name == "SVHN":
        #     if sync_idx %700 == 0:
        #         lr = lr * 0.9
        # elif dataset_name in {"cifar10", "cifar100"}:
        #     if sync_idx %100 == 0:
        #         lr = lr * 0.5
        # else:
        #     pass

        scaled_local_model_parameter = [0 for _ in range(args.max_num_client)] 
        local_weight_dict_type = Fedavg_local_weight(client_list, current_client_idx_list)

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            num_changes_in_set_of_device = sync_idx//num_of_iterations_client_fixed_warm
        else:
            num_changes_in_set_of_device = (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)//num_of_iterations_client_fixed_apply + num_of_set_of_device_to_average_warm_model

        change_point_indicator = False

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            if sync_idx%num_of_iterations_client_fixed_warm == 0:
                change_point_indicator = True
        else:
            if (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)%num_of_iterations_client_fixed_apply == 0:
                change_point_indicator = True

        if change_point_indicator:
            lr = learning_rate_map[dataset_name]
        else:
            if dataset_name in {"cifar10", "cifar100"}:
                if sync_idx %100 == 0:
                    lr = lr * 0.5
            # elif dataset_name == "SVHN":
            #     if sync_idx % num_of_iterations_client_fixed_apply >= 70 :
            #         lr = learning_rate_map[dataset_name] * 100 / (100 + sync_idx % num_of_iterations_client_fixed_apply)
            else:
                pass

        if num_changes_in_set_of_device == num_of_set_of_device_to_average_warm_model and change_point_indicator:

            # Take average of the warm model to be the pilot model
            for warm_model_till_now in init_warm_global_model_weight_list[:num_of_set_of_device_to_average_warm_model]:
                for w_k, w_v in warm_model_till_now.items():
                    average_warm_model_as_pilot_model_dict[w_k] += w_v/num_of_set_of_device_to_average_warm_model

        if change_point_indicator and num_changes_in_set_of_device >= num_of_set_of_device_to_average_warm_model:
            
            pilot_idx = num_changes_in_set_of_device - num_of_set_of_device_to_average_warm_model

            if trained_pilot_grad_or_not:

                for client_idx in current_client_idx_list:
                    pilot_iteration_to_model_grad_l_d[pilot_idx] = client_list[client_idx].local_training_sync_train_pilot_grad(sync_idx, client_idx, gpu_cpu_device,  average_warm_model_as_pilot_model_dict, pilot_iteration_to_model_grad_l_d[pilot_idx], dataset_name, lr, momentum, regularized_or_not, local_weight_dict_type)
                
            if fixed_pilot_grad_or_not: 
                
                for client_idx in current_client_idx_list:
                    pilot_iteration_to_model_grad_l_d[pilot_idx] = client_list[client_idx].local_training_sync_fixed_pilot_grad(args, client_idx, gpu_cpu_device,  average_warm_model_as_pilot_model_dict, pilot_iteration_to_model_grad_l_d[pilot_idx], dataset_name, lr, momentum, regularized_or_not, local_weight_dict_type)

            if trained_pilot_diff_as_grad_or_not:
                
                if not args.test_cold_model:
                    footprint_dict = copy.deepcopy(average_warm_model_as_pilot_model_dict)
                else:
                    footprint_dict = copy.deepcopy(initial_random_global_model)

                if args.footprint_fl:
                    for _ in range(args.footprint_num_iteration):

                        for client_idx in current_client_idx_list:
                            
                            print(f'Client Index {client_idx}')
                            
                            local_model_based_on_warm_model_dict = client_list[client_idx].local_training_sync_model_diff(gpu_cpu_device, footprint_dict, dataset_name, lr, momentum, regularized_or_not)
                            
                            scaled_local_model_parameter_dict = {}
                            for k, v in local_model_based_on_warm_model_dict.items():
                                scaled_local_model_parameter_dict[k] = local_weight_dict_type[client_idx] * v
                            scaled_local_model_parameter[client_idx] = scaled_local_model_parameter_dict
                    
                        for k, v in footprint_dict.items():
                            initialized_zero = torch.full(v.shape, 0).to(gpu_cpu_device).float()
                            for i in scaled_local_model_parameter:
                                if i != 0:
                                    initialized_zero += i[k]
                            footprint_dict[k] = initialized_zero
                else:

                    footprint_dict = {k: torch.full(v.shape, 0).to(gpu_cpu_device).float() for k, v in footprint_dict.items()}

                    for client_idx in current_client_idx_list:
                        
                        if not args.test_cold_model:
                            local_model_based_on_warm_model_dict = client_list[client_idx].local_training_sync_model_diff(gpu_cpu_device, average_warm_model_as_pilot_model_dict, dataset_name, lr, momentum, regularized_or_not)
                        else:
                            local_model_based_on_warm_model_dict = client_list[client_idx].local_training_sync_model_diff(gpu_cpu_device, initial_random_global_model, dataset_name, lr, momentum, regularized_or_not)

                        for k, v in local_model_based_on_warm_model_dict.items():
                            footprint_dict[k] += local_weight_dict_type[client_idx] * v

                difference_model_dict = {}
                
                for keys, values in footprint_dict.items():
                    difference_model_dict[keys] = (values - average_warm_model_as_pilot_model_dict[keys])/lr
                
                if pilot_idx > len(pilot_iteration_to_model_grad_l_d) -1:
                    pass
                else:
                    pilot_iteration_to_model_grad_l_d[pilot_idx] = difference_model_dict
                
        if num_changes_in_set_of_device >= (1 + num_of_set_of_device_to_average_warm_model) and change_point_indicator:

            pilot_idx = num_changes_in_set_of_device - num_of_set_of_device_to_average_warm_model

            similarity_value_list = [0 for _ in range(pilot_idx)]
            for p in range(len(similarity_value_list)):
                similarity_across_parameter_name = 0
                how_many_parameters_counted = 0

                for k, current_parameter_grad in pilot_iteration_to_model_grad_l_d[pilot_idx].items():
                    
                    if "bias" not in k:
                        
                        how_many_parameters_counted += 1
                        first_norm = torch.norm(pilot_iteration_to_model_grad_l_d[p][k])
                        normalized_first_tensor = pilot_iteration_to_model_grad_l_d[p][k]/first_norm

                        second_norm = torch.norm(current_parameter_grad)
                        normalized_second_tensor = current_parameter_grad/second_norm

                        if args.similarity == "inner_product":
                            similarity_across_parameter_name += torch.sum(normalized_first_tensor * normalized_second_tensor)
                        elif args.similarity == "two_norm":
                            similarity_across_parameter_name += torch.norm(normalized_first_tensor - normalized_second_tensor, p=2)
                        else:
                            raise ValueError("give wrong values to similarity. Only accepts inner_product or two_norm")

                similarity_value_list[p] = similarity_across_parameter_name/how_many_parameters_counted

            similarity_value_tensor =  torch.stack(similarity_value_list, dim=0)

            if args.similarity == "inner_product":
                similarity_value_normalized_tensor = torch.nn.Softmax(dim=0)(similarity_value_tensor * args.similarity_scale)
            elif args.similarity == "two_norm":
                similarity_value_normalized_tensor = torch.nn.Softmin(dim=0)(similarity_value_tensor * args.similarity_scale)
            else:
                raise ValueError("give wrong values to similarity. Only accepts inner_product or two_norm")
            
            
            similarity_normalized_value_check_dict[sync_idx] = similarity_value_normalized_tensor.cpu().detach().numpy().tolist()

            similarity_original_value_check_dict[sync_idx] = torch.stack(similarity_value_list, dim=0).cpu().detach().numpy().tolist()
            
            store_global_model_dict = global_model.state_dict()
            for k in store_global_model_dict.keys():
                initialized_pilot_zero = torch.full(store_global_model_dict[k].shape, 0).to(gpu_cpu_device).float()
                for px in range(similarity_value_normalized_tensor.shape[0]):
                    initialized_pilot_zero += init_warm_global_model_weight_list[px + num_of_set_of_device_to_average_warm_model][k] * similarity_value_normalized_tensor[px]
                    
                store_global_model_dict[k] = initialized_pilot_zero
            global_model.load_state_dict(store_global_model_dict)
        
        if sync_idx == num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            global_model.load_state_dict(initial_random_global_model) 

        unique_classes = set()

        for client_idx in current_client_idx_list:
            
            print(f'Client Index {client_idx}')
            local_model_parameter_dict = client_list[client_idx].local_training_sync(gpu_cpu_device, global_model.state_dict(), dataset_name, lr, momentum, regularized_or_not)
            
            scaled_local_model_parameter_dict = {}
            for k, v in local_model_parameter_dict.items():
                scaled_local_model_parameter_dict[k] = local_weight_dict_type[client_idx] * v
            scaled_local_model_parameter[client_idx] = scaled_local_model_parameter_dict

            unique_classes.update(client_list[client_idx].set_of_classes)

        global_state_dict_weighted = global_model.state_dict()
        for k, v in global_state_dict_weighted.items():
            initialized_zero = torch.full(v.shape, 0).to(gpu_cpu_device).float()
            for i in scaled_local_model_parameter:
                if i != 0:
                    initialized_zero += i[k]
            global_state_dict_weighted[k] = initialized_zero

        save_warm_model_indicator = False

        if sync_idx < num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model:
            if sync_idx%num_of_iterations_client_fixed_warm == num_of_iterations_client_fixed_warm - 1:
                save_warm_model_indicator = True
        else:
            if (sync_idx - num_of_iterations_client_fixed_warm * num_of_set_of_device_to_average_warm_model)%num_of_iterations_client_fixed_apply == num_of_iterations_client_fixed_apply - 1:
                save_warm_model_indicator = True
        
        if save_warm_model_indicator:
            if num_changes_in_set_of_device >= len(init_warm_global_model_weight_list):
                pass
            else:
                if num_changes_in_set_of_device == num_of_set_of_device_to_average_warm_model or num_changes_in_set_of_device == num_of_set_of_device_to_average_warm_model + 1:
                    init_warm_global_model_weight_list[num_changes_in_set_of_device] = no_init_warm_global_model_weight_list[num_changes_in_set_of_device]
                else:
                    init_warm_global_model_weight_list[num_changes_in_set_of_device] = global_state_dict_weighted

        if num_changes_in_set_of_device == num_of_set_of_device_to_average_warm_model or num_changes_in_set_of_device == num_of_set_of_device_to_average_warm_model + 1:
            global_model.load_state_dict(no_init_warm_global_model_weight_list[num_changes_in_set_of_device]) 
        else:
            global_model.load_state_dict(global_state_dict_weighted)   

        global_model.load_state_dict(global_state_dict_weighted) 

        test_dataloader = filter_dataset_by_classes(type_testing_dataset_dict[dataset_name], unique_classes, args.batch_size_training)

        acc, global_loss = test_inference(args, gpu_cpu_device, global_model, test_dataloader)

        sync_aggre_acc_list_init_select[sync_idx+1] = acc


    similarity_values_dir_path = os.path.join(os.getcwd(), args.similarity + "_value", dataset_name)
    if not os.path.exists(similarity_values_dir_path):
         os.makedirs(similarity_values_dir_path)

    similarity_values_dt_dir_path = os.path.join(similarity_values_dir_path, f"{datetime_string}")
    if not os.path.exists(similarity_values_dt_dir_path):
         os.makedirs(similarity_values_dt_dir_path)

    with open(os.path.join(similarity_values_dt_dir_path,f'original.json'), 'w') as f:
        json.dump(similarity_original_value_check_dict,f, indent= 4)

    with open(os.path.join(similarity_values_dt_dir_path,f'normalized.json'), 'w') as f:
        json.dump(similarity_normalized_value_check_dict,f, indent= 4)

    sync_aggre_acc_list_init_select[num_of_set_of_device_to_average_warm_model * num_of_iterations_client_fixed_warm +1 : num_of_set_of_device_to_average_warm_model* num_of_iterations_client_fixed_warm + 2 * num_of_iterations_client_fixed_apply +1] = sync_aggre_acc_list_no_init_select[num_of_set_of_device_to_average_warm_model * num_of_iterations_client_fixed_warm +1 : num_of_set_of_device_to_average_warm_model* num_of_iterations_client_fixed_warm + 2 * num_of_iterations_client_fixed_apply +1]
    return sync_aggre_acc_list_init_select, sync_aggre_acc_list_no_init_select


if __name__ == "__main__":

    start_time = time.time()

    args = parse_arguments()

    # Format the datetime object to the desired format
    dt = datetime.now(pytz.timezone('America/New_York')).strftime('%Y-%m-%d_%H:%M:%S')

    program_start_time = time.time()

    sync_num_training = args.sync_num_training
    max_num_client = args.max_num_client
    num_of_set_of_device_to_average_warm_model = args.num_of_set_of_device_to_average_warm_model

    num_of_iterations_client_fixed_warm = args.num_of_iterations_client_fixed_warm
    num_of_iterations_client_fixed_apply = args.num_of_iterations_client_fixed_apply

    if num_of_set_of_device_to_average_warm_model > sync_num_training:
        raise ValueError("num_of_set_of_device_to_average_warm_model should no larger than sync_num_training!")

    

    # Parse dataset_name from JSON string
    current_dataset_collection_list = json.loads(args.dataset_name)

    accuracy_w_initial_selection_dict = {}
    accuracy_wo_initial_selection_dict = {}

    for dataset_name in current_dataset_collection_list:

        type_training_dataset_dict = get_dataset_dict(args, dataset_name, train=True)
        type_testing_dataset_dict = get_dataset_dict(args, dataset_name, train=False)
        
        if args.class_distribution == "half":
            clients_data_ids, client_classes = distinct_half(type_training_dataset_dict[dataset_name])

        elif args.class_distribution == "distinct":
            clients_data_ids, client_classes = distinct_class_each_device(type_training_dataset_dict[dataset_name])
        
        elif args.class_distribution == "2-shard":

            clients_data_ids, client_classes = distribute_labels_in_batches(type_training_dataset_dict[dataset_name], args.max_num_client)
        
        elif args.class_distribution == "slight_overlap":
            
            clients_data_ids, client_classes = distribute_labels_slight_overlap_10_clients(type_training_dataset_dict[dataset_name])

        elif args.class_distribution == "Dirichlet":
            clients_data_ids, client_classes = Dirichlet_disbuted_classes(type_training_dataset_dict[dataset_name], max_num_client, args.Dirichlet_alpha)
        else:
            raise ValueError("Wrong values to class distribution!")
        
        client_list = [0] * max_num_client

        for i in range(max_num_client):
            training_dataloader = DataLoader(Subset(type_training_dataset_dict[dataset_name], clients_data_ids[i]), batch_size=args.batch_size_training, shuffle=True)
            grad_cal_dataloader = DataLoader(Subset(type_training_dataset_dict[dataset_name], clients_data_ids[i]), batch_size=args.batch_size_grad_cal, shuffle=True)
            client_list[i] = client(training_dataloader, grad_cal_dataloader, args.num_SGD_training, args.num_SGD_grad_cal, client_classes[i], args.cifar_model)

        sync_training_order = []
        set_of_device_every_round_dict = {}

        # alternating 
        if args.alternating_order:
            for i in range(sync_num_training):
                if i % 2 == 0:
                    selected_client_idx = range(max_num_client // 2)
                else:
                    selected_client_idx = range(max_num_client // 2, max_num_client)

                idx_list = list(selected_client_idx)
                idx_list.sort()
                set_of_device_every_round_dict[i] = idx_list

                if i < num_of_set_of_device_to_average_warm_model:
                    sync_training_order += ([idx_list] * num_of_iterations_client_fixed_warm)
                else:
                    sync_training_order += ([idx_list] * num_of_iterations_client_fixed_apply)

        # Distinct client, not repeating number of clients for two consecutive iterations
        elif args.randomly_select_order:
            selected_client_idx_previous_iteration = []
            for i in range(sync_num_training):
                num_of_client_selected = args.num_client_per_round
                selected_client_idx = random.sample(range(10), num_of_client_selected)
                if i == 0:
                    selected_client_idx_previous_iteration = selected_client_idx
                else:
                    while selected_client_idx_previous_iteration == selected_client_idx:
                        selected_client_idx = random.sample(range(10), num_of_client_selected)

                selected_client_idx.sort()
                set_of_device_every_round_dict[i] = selected_client_idx

                if i < num_of_set_of_device_to_average_warm_model:
                    sync_training_order += ([selected_client_idx] * num_of_iterations_client_fixed_warm)
                else:
                    sync_training_order += ([selected_client_idx] * num_of_iterations_client_fixed_apply)

        elif args.handcrafted_order:
            generated_list = generate_device_lists(client_classes, sync_num_training, args.min_num_client_per_round, args.max_num_client_per_round, 10**6, 0.05)

            for i in range(sync_num_training):
                selected_client_idx = generated_list[i]
                selected_client_idx.sort()
                set_of_device_every_round_dict[i] = selected_client_idx

                if i < num_of_set_of_device_to_average_warm_model:
                    sync_training_order += ([selected_client_idx] * num_of_iterations_client_fixed_warm)
                else:
                    sync_training_order += ([selected_client_idx] * num_of_iterations_client_fixed_apply)
                
        else:
            raise ValueError("Wrong values to randomly_select_order, handcrafted_order or alternating_order!")
            

        accuracy_w_initial_selection_dict[dataset_name], accuracy_wo_initial_selection_dict[dataset_name] = synchronous_FL_training_mnist_noniid(args, dataset_name, client_list, type_testing_dataset_dict, sync_training_order, dt)
        
        plt.figure()
        ax = plt.gca()
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
        plt.step(range(len(accuracy_w_initial_selection_dict[dataset_name])), accuracy_w_initial_selection_dict[dataset_name], linestyle='--', linewidth = 3,  marker= "p", markersize = 3, label = "With Initial Point Selection")
        plt.step(range(len(accuracy_wo_initial_selection_dict[dataset_name])), accuracy_wo_initial_selection_dict[dataset_name], linestyle='--', linewidth = 3,  marker= "o", markersize = 3, label = "Without Initial Point Selection")
        xticks_np_range = np.array([num_of_iterations_client_fixed_warm]*num_of_set_of_device_to_average_warm_model + [num_of_iterations_client_fixed_apply]*(sync_num_training - num_of_set_of_device_to_average_warm_model))
        plt.xticks(np.cumsum(xticks_np_range)) 
        plt.xticks(fontsize=8)
        ax.legend()
        ax.set_title(dataset_name, fontsize = 16)

        argument_dir_path = os.path.join(os.getcwd(), "args_namespace")
        if not os.path.exists(argument_dir_path):
            os.makedirs(argument_dir_path)
        
        dataset_argument_dir_path = os.path.join(argument_dir_path, dataset_name)
        if not os.path.exists(dataset_argument_dir_path):
            os.makedirs(dataset_argument_dir_path)

        end_time = time.time()

        # Calculate elapsed time
        elapsed_time = end_time - start_time

        # Convert elapsed time to days, hours, minutes, and seconds
        days = int(elapsed_time // 86400)
        hours = int((elapsed_time % 86400) // 3600)
        minutes = int((elapsed_time % 3600) // 60)
        seconds = elapsed_time % 60

        elapsed_time_string = f"{days}d {hours}h {minutes}m {seconds:.2f}s"
        args.elapsed_time = elapsed_time_string

        with open(os.path.join(dataset_argument_dir_path, f'{dt}.json'), 'w') as f:
            json.dump(vars(args), f, indent=4)

        training_order_dir_path = os.path.join(os.getcwd(), "training_order")
        if not os.path.exists(training_order_dir_path):
            os.makedirs(training_order_dir_path)

        with open(os.path.join(training_order_dir_path, f"{dt}.json"), 'w') as f:
            json.dump(set_of_device_every_round_dict, f, indent=4)

        training_value_dir_path = os.path.join(os.getcwd(), "training_values")
        if not os.path.exists(training_value_dir_path):
            os.makedirs(training_value_dir_path)

        training_value_time_dir_path = os.path.join(training_value_dir_path, dataset_name,  f"{dt}")
        if not os.path.exists(training_value_time_dir_path):
            os.makedirs(training_value_time_dir_path)

        with open(os.path.join(training_value_time_dir_path, f'init_select.json'), 'w') as f:
            json.dump(accuracy_w_initial_selection_dict[dataset_name],f, indent= 4)

        with open(os.path.join(training_value_time_dir_path, f'no_init_select.json'), 'w') as f:
            json.dump(accuracy_wo_initial_selection_dict[dataset_name],f, indent= 4)
        
        training_plot_dir_path = os.path.join(os.getcwd(), "training_plot", dataset_name)
        if not os.path.exists(training_plot_dir_path):
            os.makedirs(training_plot_dir_path)

        plt.savefig(os.path.join(os.getcwd(), training_plot_dir_path, f"{dt}.png"))




