import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sys
import random
import numpy as np
import pandas as pd
import torch
import torchvision
import pickle
import zipfile

from dataset import Data_Preprocess
from fed_bl import Bilevel_Federation
from local_train import Benchmark_Local_Train
from fed_avg import Benchmark_Fed_Avg
from fed_ditto import Benchmark_Ditto
from fed_pme import Benchmark_Fed_Pme
from model import Net

# Downsampled-ImageNet part
root_dictory = "$HOME/project_bilevel/bilevel/Datasets/"
train_data_path = "Imagenet32_train"
train_data_path = os.path.join(root_dictory, train_data_path)

# os.mkdir(train_data_path)

# create a file folder with the same name
# file_name = "Imagenet32_train.zip"
# with zipfile.ZipFile(os.path.join(root_dictory, file_name), 'r') as zip_ref:
#     zip_ref.extractall(train_data_path)

train_list = ["train_data_batch_{}".format(i + 1) for i in range(10)]
train_data = []
train_targets = []

for step, filename in enumerate(train_list):
    filename = os.path.join(train_data_path, filename)
    with open(filename, "rb") as f:
        entry = pickle.load(f)
    if step == 0:
        train_data = entry["data"].reshape(-1, 3, 32, 32)
        train_targets = entry["labels"]
    else:
        train_data = np.append(arr=train_data, values=entry["data"].reshape(-1, 3, 32, 32), axis=0)
        train_targets = np.append(arr=train_targets, values=entry["labels"], axis=0)

train_data = torch.FloatTensor(train_data)
print("All training data have size: ", train_data.size())
train_targets = torch.LongTensor(train_targets)
print("All training targets have size: ", train_targets.size())
print("-" * 30)

val_data_path = "Imagenet32_val"
val_data_path = os.path.join(root_dictory, val_data_path)

# os.mkdir(val_data_path)

# create a file folder with the same name
# file_name = "Imagenet32_val.zip"
# with zipfile.ZipFile(os.path.join(root_dictory, file_name), "r") as zip_ref:
#     zip_ref.extractall(val_data_path)

val_list = ["val_data"]
val_data = []
val_targets = []

for step, filename in enumerate(val_list):
    filename = os.path.join(val_data_path, filename)
    with open(filename, "rb") as f:
        entry = pickle.load(f)
    if step == 0:
        val_data = entry["data"].reshape(-1, 3, 32, 32)
        val_targets = entry["labels"]
    else:
        val_data = np.append(arr=val_data, values=entry["data"].reshape(-1, 3, 32, 32), axis=0)
        val_targets = np.append(arr=val_targets, values=entry["labels"], axis=0)

val_data = torch.FloatTensor(val_data)
print("All testing data have size: ", val_data.size())
val_targets = torch.LongTensor(val_targets)
print("All testing targets have size: ", val_targets.size())
print("-" * 30)

if __name__ == "__main__":
    
    print("Start preparing datasets.")
    print("-" * 30)

    SETTING = int(sys.argv[1])
    SEED = int(sys.argv[2])
    
    output_path = "$HOME/project_bilevel/bilevel/Experiment_results/ImageNet_Setting{}_Seed{}/".format(str(SETTING),str(SEED))

    NUM_STAGES = int(sys.argv[3])
    NUM_NODES = 15 # K = number of nodes
    NUM_SIMILAR_BOUND = 3  # b in the probability simplex
    NUM_SIMILAR = 5 # J = number of nodes with similar data distribution
    
    random.seed(a=SEED)
    near_node_list = sorted(random.sample(range(NUM_NODES), NUM_SIMILAR))
    print("The near ones among {} nodes in the simulation are:".format(NUM_NODES), near_node_list)
    print("-" * 30)
    
    far_node_list = sorted(list(set(range(NUM_NODES)) - set(near_node_list)))
    print("The near ones among {} nodes in the simulation are:".format(NUM_NODES), near_node_list)
    print("-" * 30)
    
    near_center_idx = random.choice(seq=near_node_list)
    print("For bilevel experiment near version, the center node in the simulation is node {}.".format(near_center_idx))
    print("-" * 30)
    
    far_center_idx = random.choice(seq=far_node_list)
    print("For bilevel experiment far version, the center node in the simulation is node {}.".format(far_center_idx))
    print("-" * 30)
      
    sys.stdout.flush()

    TRAIN_SIZE = 4000
    VAL_SIZE = 1500
    TEST_SIZE = 1000
    
    near_p_dict = { "label_list_1": 0.36, "label_list_2": 0.04, "label_list_3": 0.54, "label_list_4": 0.06 }
    far_p_dict = { "label_list_1": 0.04, "label_list_2": 0.36, "label_list_3": 0.06, "label_list_4": 0.54 }
    
    raw_label_list_1 = [7, 9, 10, 29, 54, 75, 84, 189]
    raw_label_list_2 = [61, 66, 68, 101, 114, 124, 131, 148]
    raw_label_list_3 = [383, 397, 403, 404, 405, 406, 412, 414, 420, 426, 433, 434]
    raw_label_list_4 = [224, 441, 442, 443, 444, 445, 449, 453, 454, 498, 499, 500]
    raw_label_list = raw_label_list_1 + raw_label_list_2 + raw_label_list_3 + raw_label_list_4

    label_permute = { raw_label_list.index(7): raw_label_list.index(61), raw_label_list.index(61): raw_label_list.index(383),
                      raw_label_list.index(383): raw_label_list.index(224), raw_label_list.index(224): raw_label_list.index(9),
                      raw_label_list.index(9): raw_label_list.index(66), raw_label_list.index(66): raw_label_list.index(397),
                      raw_label_list.index(397): raw_label_list.index(441), raw_label_list.index(441): raw_label_list.index(10),
                      raw_label_list.index(10): raw_label_list.index(68), raw_label_list.index(68): raw_label_list.index(403),
                      raw_label_list.index(403): raw_label_list.index(442), raw_label_list.index(442): raw_label_list.index(7) }

    label_list_1 = [0, 1, 2, 3, 4, 5, 6, 7]
    label_list_2 = [8, 9, 10, 11, 12, 13, 14, 15]
    label_list_3 = [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
    label_list_4 = [28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]

    NUM_BATCH_LB = 50 # to decide how many batches there are in a data loader
    
    train_set_index_list = []
    for index, label in enumerate(train_targets.tolist()):
            if label in raw_label_list:
                train_set_index_list.append(index)

    train_set_index_list = torch.LongTensor(train_set_index_list)
    train_set_data = train_data.index_select(dim=0, index=train_set_index_list)
    train_set_raw_targets = train_targets.index_select(dim=0, index=train_set_index_list)
    train_set_targets = torch.LongTensor([raw_label_list.index(raw_label) for raw_label in train_set_raw_targets.tolist()])
    print("All selected training data have size: ", train_set_data.size())
    print("All selected training targets have size: ", train_set_targets.size())
    print("-" * 30)
    # print(ct.Counter(train_set_targets.tolist()))

    test_set_index_list = []
    for index, label in enumerate(val_targets.tolist()):
            if label in raw_label_list:
                test_set_index_list.append(index)

    test_set_index_list = torch.LongTensor(test_set_index_list)
    test_set_data = val_data.index_select(dim=0, index=test_set_index_list)
    test_set_raw_targets = val_targets.index_select(dim=0, index=test_set_index_list)
    test_set_targets = torch.LongTensor([raw_label_list.index(raw_label) for raw_label in test_set_raw_targets.tolist()])
    print("All selected testing data have size: ", test_set_data.size())
    print("All selected testing targets have size: ", test_set_targets.size())
    print("-" * 30)
    # print(ct.Counter(test_set_targets.tolist()))
    
    SIZE = train_set_data.size(dim=2) # get H (i.e. W) for neural network architecture

    # preprocessing imagenet data
    imagenet = Data_Preprocess(train_set_data=train_set_data, train_set_targets=train_set_targets,
                               test_set_data=test_set_data, test_set_targets=test_set_targets,
                               label_list_1=label_list_1, label_list_2=label_list_2,
                               label_list_3=label_list_3, label_list_4=label_list_4,
                               near_center_idx=near_center_idx, far_center_idx=far_center_idx)
    
    imagenet.Decide_Dataset_Sizes(num_nodes=NUM_NODES, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    imagenet.Resample_Data(near_node_list=near_node_list, seed=SEED, near_p_dict=near_p_dict, far_p_dict=far_p_dict, margin=0.1)
    
    sys.stdout.flush()
    
    if SETTING == 1:
        # Setting 1
        imagenet.Dataset_Generating(seed=SEED)
    elif SETTING == 2:
        # Setting 2
        imagenet.Dataset_Generating_and_Label_Permute(seed=SEED, label_permute=label_permute)
    elif SETTING == 3:
        # Setting 3
        imagenet.Dataset_Generating_and_Rotation(seed=SEED)
    elif SETTING == 4:
        # Setting 4
        imagenet.Dataset_Generating_Label_Permute_and_Rotation(seed=SEED, label_permute=label_permute)
    
    # obtain three data loaders for four algorithms
    train_loaders, val_loaders, test_loaders = imagenet.Prepare_Dataloaders(num_batch_lb=NUM_BATCH_LB)
    
    sys.stdout.flush()
    
    loss_function = torch.nn.CrossEntropyLoss() # loss function for binary classification problem
    # loss_function = torch.nn.BCELoss()
    # parameters for model architecture addressing downsampled-imagenet data in the dimension of N*3*32*32
    in_channels = 3
    out_channels = 5
    out_features = 40 # 40 for torch.nn.CrossEntropyLoss()
    
    # pivot model as initialization of theta
    model_pivot = Net(in_channels=in_channels, 
                      out_channels=out_channels, 
                      out_features=out_features, 
                      size=SIZE)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_pivot = model_pivot.to(device)
    print(torch.nn.utils.parameters_to_vector(model_pivot.parameters()).size())
    
    sys.stdout.flush()

    NUM_EPOCH_IN = 5 # this times NUM_BATCH_LB is T1 in four algorithms
    NUM_EPOCH_HV = 5  # this times NUM_BATCH_LB is T2 in Bi-level algorithm
    CP = 10   # tau in Loopless-Local-SVRG of Bi-level algorithm
    
    LR_INNER = 5e-2 # learning rate in Loopless-Local-SVRG for updating theta of Bi-level algorithm
    LR_HV = 5e-4 # learning rate in Loopless-Local-SVRG for computing hv product of Bi-level algorithm
    # L = 5000 # constant in HIV for computing hv product of Bi-level algorithm
    
    LR_OUTER = 1.5e-2 # learning rate in projected gradient descent for updating w of Bi-level algorithm
    # LR_OUTER_HIV = 0.75 # learning rate in HIV for computing hv product of Bi-level algorithm

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    imagenet_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                     loss_function=loss_function, bl_type="near",
                                     train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fl_model_dict[node_idx] = node_model.to(imagenet_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    imagenet_fl.Initialize_Results()
    imagenet_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                     val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        imagenet_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                       num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                       communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = imagenet_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        imagenet_fl.current_num_syn += 1
        imagenet_fl.current_num_points += imagenet_fl.val_size
    
        hv_g = imagenet_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                   num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        imagenet_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                   stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = imagenet_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_bl_near_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_bl_near_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    imagenet_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                     loss_function=loss_function, bl_type="far",
                                     train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fl_model_dict[node_idx] = node_model.to(imagenet_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    imagenet_fl.Initialize_Results()
    imagenet_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                     val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        imagenet_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                       num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                       communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = imagenet_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        imagenet_fl.current_num_syn += 1
        imagenet_fl.current_num_points += imagenet_fl.val_size
    
        hv_g = imagenet_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                   num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        imagenet_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                   stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = imagenet_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_bl_far_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_bl_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # local-train algorithm:
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    imagenet_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                    near_center_idx=near_center_idx, 
                                                    far_center_idx=far_center_idx,
                                                    bm_type="near",
                                                    loss_function=loss_function,
                                                    val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
    
    local_bm_model = local_bm_model.to(imagenet_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    imagenet_local_train_bm.Initialize_Results()
    imagenet_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                                 train_loaders=train_loaders,
                                                 val_loaders=val_loaders, 
                                                 test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        imagenet_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                                  num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = imagenet_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_local_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    imagenet_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                    near_center_idx=near_center_idx, 
                                                    far_center_idx=far_center_idx,
                                                    bm_type="far",
                                                    loss_function=loss_function,
                                                    val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
    
    local_bm_model = local_bm_model.to(imagenet_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    imagenet_local_train_bm.Initialize_Results()
    imagenet_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                                 train_loaders=train_loaders,
                                                 val_loaders=val_loaders, 
                                                 test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        imagenet_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                                  num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                                  communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = imagenet_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_local_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Fed-avg algorithm:
    NUM_STAGE_FA = int(NUM_STAGES*2) # S in Fed-avg Algorithm
    
    imagenet_fed_avg_bm = Benchmark_Fed_Avg(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                            loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fed_avg_bm_model_dict = {}
    for node_idx in range(NUM_NODES):

        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fed_avg_bm_model_dict[node_idx] = node_model.to(imagenet_fed_avg_bm.device)
        fed_avg_bm_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    imagenet_fed_avg_bm.Initialize_Results()
    imagenet_fed_avg_bm.Initialize_Variables(model_dict=fed_avg_bm_model_dict, train_loaders=train_loaders,
                                             val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)

    for s in range(NUM_STAGE_FA):

        imagenet_fed_avg_bm.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders,
                                               test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN,
                                               num_batches=NUM_BATCH_LB, communicate_period=CP, 
                                               lr_inner=LR_INNER)

    favg_metric_dict = imagenet_fed_avg_bm.Output_Results()
    pd.DataFrame(favg_metric_dict).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_favg_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Ditto algorithm:
    NUM_STAGE_FDITTO = int(NUM_STAGES*2) # S in Ditto Algorithm
    
    imagenet_ditto_bm = Benchmark_Ditto(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                        loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    ditto_bm_model_shared_dict = {}
    ditto_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_channels=in_channels, 
                                out_channels=out_channels, 
                                out_features=out_features, 
                                size=SIZE)
        ditto_bm_model_shared_dict[node_idx] = node_shared_model.to(imagenet_ditto_bm.device)
        ditto_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
        ditto_bm_model_shared_old = Net(in_channels=in_channels, 
                                        out_channels=out_channels, 
                                        out_features=out_features, 
                                        size=SIZE)
        ditto_bm_model_shared_old = ditto_bm_model_shared_old.to(imagenet_ditto_bm.device)
        ditto_bm_model_shared_old.load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_channels=in_channels, 
                                  out_channels=out_channels, 
                                  out_features=out_features, 
                                  size=SIZE)
    
        ditto_bm_model_personal_dict[node_idx] = node_personal_model.to(imagenet_ditto_bm.device)
        ditto_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    imagenet_ditto_bm.Initialize_Results()
    imagenet_ditto_bm.Initialize_Variables(model_shared_dict=ditto_bm_model_shared_dict, model_shared_old=ditto_bm_model_shared_old,
                                           model_personal_dict=ditto_bm_model_personal_dict, train_loaders=train_loaders,
                                           val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    for s in range(NUM_STAGE_FDITTO):
    
        imagenet_ditto_bm.Update_Shared_and_Personal_Models_SGD(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders, 
                                                                num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER, 
                                                                num_batches_rand=25, lambda_val=0.1)    
    
    fditto_metric_dict = imagenet_ditto_bm.Output_Results()
    pd.DataFrame(fditto_metric_dict["near"]).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_fditto_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fditto_metric_dict["far"]).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_fditto_far_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    # Fed-personalized with ME algorithm:
    NUM_STAGE_FPME = int(NUM_STAGES*2) # S in Fed-personalized with ME algorithm
    
    imagenet_fpme_bm = Benchmark_Fed_Pme(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                         loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fpme_bm_model_shared_dict = {}
    fpme_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_channels=in_channels, 
                                out_channels=out_channels, 
                                out_features=out_features, 
                                size=SIZE)
        fpme_bm_model_shared_dict[node_idx] = node_shared_model.to(imagenet_fpme_bm.device)
        fpme_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_channels=in_channels, 
                                  out_channels=out_channels, 
                                  out_features=out_features, 
                                  size=SIZE)
    
        fpme_bm_model_personal_dict[node_idx] = node_personal_model.to(imagenet_fpme_bm.device)
        fpme_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    imagenet_fpme_bm.Initialize_Results()
    imagenet_fpme_bm.Initialize_Variables(model_shared_dict=fpme_bm_model_shared_dict, model_personal_dict=fpme_bm_model_personal_dict,
                                          train_loaders=train_loaders,val_loaders=val_loaders, test_loaders=test_loaders,num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_FPME):
    
        imagenet_fpme_bm.Update_Shared_and_Personal_Models_SGD(val_loaders=val_loaders, test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                                               num_steps_ub=25, delta=0.005, communicate_period=CP, lr_inner=LR_INNER, lambda_val=15)     
        
    fpme_metric_dict = imagenet_fpme_bm.Output_Results()
    pd.DataFrame(fpme_metric_dict["near"]).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_fpme_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fpme_metric_dict["far"]).to_csv(os.path.join(output_path, "ImageNet_Setting{}_Seed{}_Result_fpme_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    sys.stdout.close()
    sys.exit(0)