import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sys
import random
import numpy as np
import pandas as pd
import torch
import torchvision

from sklearn import datasets
from sklearn import model_selection
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()

from dataset2 import Data_Preprocess
from fed_bl import Bilevel_Federation
from local_train import Benchmark_Local_Train
from fed_avg import Benchmark_Fed_Avg
from fed_ditto import Benchmark_Ditto
from fed_pme import Benchmark_Fed_Pme
from model2 import Net

if __name__ == "__main__":
    
    SETTING = int(sys.argv[1])
    SEED = int(sys.argv[2])

    print("Start preparing datasets.")
    print("-" * 30)

    # Covtype part
    X, y = datasets.fetch_covtype(data_home="/Users/project_bilevel/bilevel/Datasets/", return_X_y=True) 
    # X is np.ndarray with data type "float64"
    # y is np.ndarray with data type "int32"
    print("X has the shape ", X.shape)
    print("y has the shape ", y.shape)
    print("-" * 30)
    
    y -= 1 # change labels from 1,2,3,4,5,6,7 to 0,1,2,3,4,5,6

    X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.2, random_state=SEED)

    y_train_unique, y_train_counts = np.unique(y_train, return_counts=True)
    print("The numbers of labels are: ", dict(zip(y_train_unique, y_train_counts)))
    print("-" * 30)
    
    y_test_unique, y_test_counts = np.unique(y_test, return_counts=True)
    print("The numbers of labels are: ", dict(zip(y_test_unique, y_test_counts)))
    print("-" * 30)

    # change instances to float type
    X_train = scaler.fit_transform(X_train)
    X_train = torch.FloatTensor(X_train)
    # change labels to long type
    y_train = torch.LongTensor(y_train)
    # change instances to float type
    X_test = scaler.transform(X_test)
    X_test = torch.FloatTensor(X_test)
    # change labels to long type
    y_test = torch.LongTensor(y_test)
    
    sys.stdout.flush()

    output_path = "/Users/project_bilevel/bilevel/Experiment_results/Covtype_Setting{}_Seed{}/".format(str(SETTING),str(SEED))
    
    NUM_STAGES = int(sys.argv[3])
    NUM_NODES = 15 # K = number of nodes
    NUM_SIMILAR_BOUND = 3  # b in the probability simplex
    NUM_SIMILAR = 5 # J = number of nodes with similar data distribution
    
    random.seed(a=SEED)
    near_node_list = sorted(random.sample(range(NUM_NODES), NUM_SIMILAR))
    print("The near ones among {} nodes in the simulation are:".format(NUM_NODES), near_node_list)
    print("-" * 30)
    
    far_node_list = sorted(list(set(range(NUM_NODES)) - set(near_node_list)))
    print("The far ones among {} nodes in the simulation are:".format(NUM_NODES), far_node_list)
    print("-" * 30)
    
    near_center_idx = random.choice(seq=near_node_list)
    print("For bilevel experiment near version, the center node in the simulation is node {}.".format(near_center_idx))
    print("-" * 30)
    
    far_center_idx = random.choice(seq=far_node_list)
    print("For bilevel experiment far version, the center node in the simulation is node {}.".format(far_center_idx))
    print("-" * 30)
      
    sys.stdout.flush()

    TRAIN_SIZE = 40000
    VAL_SIZE = 5000
    TEST_SIZE = 50000

    near_p_dict = { "label_list_1": 0.72, "label_list_2": 0.14, "label_list_3": 0.12, "label_list_4": 0.02 }
    far_p_dict = { "label_list_1": 0.10, "label_list_2": 0.82, "label_list_3": 0.06, "label_list_4": 0.02 }
   
    label_list_1 = [0]
    label_list_2 = [1]
    label_list_3 = [2]
    label_list_4 = [3, 4, 5, 6]

    label_permute = { 0: 1, 1: 2, 2: 0 }
    
    NUM_BATCH_LB = 100 # to decide how many batches there are in a data loader
    
    
    # preprocessing covtype data
    covtype = Data_Preprocess(X_train=X_train, y_train=y_train,
                              X_test=X_test, y_test=y_test,
                              label_list_1=label_list_1, label_list_2=label_list_2,
                              label_list_3=label_list_3, label_list_4=label_list_4,
                              near_center_idx=near_center_idx, far_center_idx=far_center_idx)
    
    covtype.Decide_Dataset_Sizes(num_nodes=NUM_NODES, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    covtype.Resample_Data(near_node_list=near_node_list, seed=SEED, near_p_dict=near_p_dict, far_p_dict=far_p_dict, margin=10.0)
    
    sys.stdout.flush()
    
    if SETTING == 1:
        # Setting 1
        covtype.Dataset_Generating(seed=SEED)
    elif SETTING == 2:
        # Setting 2
        covtype.Dataset_Generating_and_Label_Permute(seed=SEED, label_permute=label_permute)
    
    # obtain three data loaders for four algorithms
    train_loaders, val_loaders, test_loaders = covtype.Prepare_Dataloaders(num_batch_lb=NUM_BATCH_LB)
    
    sys.stdout.flush()
    
    loss_function = torch.nn.CrossEntropyLoss() # loss function for binary classification problem
    # loss_function = torch.nn.BCELoss()
    # parameters for model architecture addressing covtype data in the dimension of N*3*32*32
    in_features = X_train.size(dim=1) # 54
    out_features = y_train_unique.size # 7 for torch.nn.CrossEntropyLoss()
    
    # pivot model as initialization of theta
    model_pivot = Net(in_features=in_features, out_features=out_features)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_pivot = model_pivot.to(device)
    print(torch.nn.utils.parameters_to_vector(model_pivot.parameters()).size())
    
    sys.stdout.flush()

    NUM_EPOCH_IN = 5 # this times NUM_BATCH_LB is T1 in four algorithms
    NUM_EPOCH_HV = 5  # this times NUM_BATCH_LB is T2 in Bi-level algorithm
    CP = 25   # tau in Loopless-Local-SVRG of Bi-level algorithm
    
    LR_INNER = 5e-2 # learning rate in Loopless-Local-SVRG for updating theta of Bi-level algorithm
    LR_HV = 5e-4 # learning rate in Loopless-Local-SVRG for computing hv product of Bi-level algorithm
    # L = 5000 # constant in HIV for computing hv product of Bi-level algorithm
    
    LR_OUTER = 2.5e-2 # learning rate in projected gradient descent for updating w of Bi-level algorithm
    # LR_OUTER_HIV = 0.75 # learning rate in HIV for computing hv product of Bi-level algorithm

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    covtype_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                    loss_function=loss_function, bl_type="near",
                                    train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_features=in_features, out_features=out_features)
        fl_model_dict[node_idx] = node_model.to(covtype_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    covtype_fl.Initialize_Results()
    covtype_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                    val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        covtype_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                      num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                      communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = covtype_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        covtype_fl.current_num_syn += 1
        covtype_fl.current_num_points += covtype_fl.val_size
    
        hv_g = covtype_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                  num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        covtype_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                  stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = covtype_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_bl_near_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_bl_near_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    covtype_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                    loss_function=loss_function, bl_type="far",
                                    train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_features=in_features, out_features=out_features)
        fl_model_dict[node_idx] = node_model.to(covtype_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    covtype_fl.Initialize_Results()
    covtype_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                    val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        covtype_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                      num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                      communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = covtype_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        covtype_fl.current_num_syn += 1
        covtype_fl.current_num_points += covtype_fl.val_size
    
        hv_g = covtype_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                  num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        covtype_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                  stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = covtype_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_bl_far_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_bl_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # local-train algorithm:
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    covtype_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                   near_center_idx=near_center_idx, 
                                                   far_center_idx=far_center_idx,
                                                   bm_type="near",
                                                   loss_function=loss_function,
                                                   val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_features=in_features, out_features=out_features)
    
    local_bm_model = local_bm_model.to(covtype_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    covtype_local_train_bm.Initialize_Results()
    covtype_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                                train_loaders=train_loaders,
                                                val_loaders=val_loaders, 
                                                test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        covtype_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                                 num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = covtype_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_local_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    covtype_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                   near_center_idx=near_center_idx, 
                                                   far_center_idx=far_center_idx,
                                                   bm_type="far",
                                                   loss_function=loss_function,
                                                   val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_features=in_features, out_features=out_features)
    
    local_bm_model = local_bm_model.to(covtype_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    covtype_local_train_bm.Initialize_Results()
    covtype_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                                train_loaders=train_loaders,
                                                val_loaders=val_loaders, 
                                                test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        covtype_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                                 num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                                 communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = covtype_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_local_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Fed-avg algorithm:
    NUM_STAGE_FA = int(NUM_STAGES*2.5) # S in Fed-avg Algorithm
    
    covtype_fed_avg_bm = Benchmark_Fed_Avg(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                           loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fed_avg_bm_model_dict = {}
    for node_idx in range(NUM_NODES):

        node_model = Net(in_features=in_features, out_features=out_features)
        fed_avg_bm_model_dict[node_idx] = node_model.to(covtype_fed_avg_bm.device)
        fed_avg_bm_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    covtype_fed_avg_bm.Initialize_Results()
    covtype_fed_avg_bm.Initialize_Variables(model_dict=fed_avg_bm_model_dict, train_loaders=train_loaders,
                                            val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)

    for s in range(NUM_STAGE_FA):

        covtype_fed_avg_bm.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders,
                                              test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN,
                                              num_batches=NUM_BATCH_LB, communicate_period=CP, 
                                              lr_inner=LR_INNER)

    favg_metric_dict = covtype_fed_avg_bm.Output_Results()
    pd.DataFrame(favg_metric_dict).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_favg_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Ditto algorithm:
    NUM_STAGE_FDITTO = int(NUM_STAGES*2.5) # S in Ditto Algorithm
    
    covtype_ditto_bm = Benchmark_Ditto(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                       loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    ditto_bm_model_shared_dict = {}
    ditto_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_features=in_features, out_features=out_features)
        ditto_bm_model_shared_dict[node_idx] = node_shared_model.to(covtype_ditto_bm.device)
        ditto_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
        ditto_bm_model_shared_old = Net(in_features=in_features, out_features=out_features)
        ditto_bm_model_shared_old = ditto_bm_model_shared_old.to(covtype_ditto_bm.device)
        ditto_bm_model_shared_old.load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_features=in_features, out_features=out_features)
    
        ditto_bm_model_personal_dict[node_idx] = node_personal_model.to(covtype_ditto_bm.device)
        ditto_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    covtype_ditto_bm.Initialize_Results()
    covtype_ditto_bm.Initialize_Variables(model_shared_dict=ditto_bm_model_shared_dict, model_shared_old=ditto_bm_model_shared_old,
                                          model_personal_dict=ditto_bm_model_personal_dict, train_loaders=train_loaders,
                                          val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    for s in range(NUM_STAGE_FDITTO):
    
        covtype_ditto_bm.Update_Shared_and_Personal_Models_SGD(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders, 
                                                               num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER, 
                                                               num_batches_rand=50, lambda_val=0.1)    
    
    fditto_metric_dict = covtype_ditto_bm.Output_Results()
    pd.DataFrame(fditto_metric_dict["near"]).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_fditto_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fditto_metric_dict["far"]).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_fditto_far_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    # Fed-personalized with ME algorithm:
    NUM_STAGE_FPME = int(NUM_STAGES*2.5) # S in Fed-personalized with ME algorithm
    
    covtype_fpme_bm = Benchmark_Fed_Pme(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                        loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fpme_bm_model_shared_dict = {}
    fpme_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_features=in_features, out_features=out_features)
        fpme_bm_model_shared_dict[node_idx] = node_shared_model.to(covtype_fpme_bm.device)
        fpme_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_features=in_features, out_features=out_features)
    
        fpme_bm_model_personal_dict[node_idx] = node_personal_model.to(covtype_fpme_bm.device)
        fpme_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    covtype_fpme_bm.Initialize_Results()
    covtype_fpme_bm.Initialize_Variables(model_shared_dict=fpme_bm_model_shared_dict, model_personal_dict=fpme_bm_model_personal_dict,
                                         train_loaders=train_loaders,val_loaders=val_loaders, test_loaders=test_loaders,num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_FPME):
    
        covtype_fpme_bm.Update_Shared_and_Personal_Models_SGD(val_loaders=val_loaders, test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                                              num_steps_ub=50, delta=0.005, communicate_period=CP, lr_inner=LR_INNER, lambda_val=15)     
        
    fpme_metric_dict = covtype_fpme_bm.Output_Results()
    pd.DataFrame(fpme_metric_dict["near"]).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_fpme_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fpme_metric_dict["far"]).to_csv(os.path.join(output_path, "Covtype_Setting{}_Seed{}_Result_fpme_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    sys.stdout.close()
    sys.exit(0)