import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sys
import random

import pandas as pd
import torch
import torchvision

from dataset import Data_Preprocess
from fed_bl import Bilevel_Federation
from local_train import Benchmark_Local_Train
from fed_avg import Benchmark_Fed_Avg
from fed_ditto import Benchmark_Ditto
from fed_pme import Benchmark_Fed_Pme
from model import Net

train_transform = torchvision.transforms.Compose(
                        [ torchvision.transforms.ToTensor(),
                          torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = torchvision.transforms.Compose(
                        [ torchvision.transforms.ToTensor(),
                          torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


if __name__ == "__main__":

    print("Start preparing datasets.")
    print("-" * 30)

    # Cifar-10 part
    train_set = torchvision.datasets.CIFAR10(root="$HOME/project_bilevel/bilevel/Datasets/",
                                             train=True, 
                                             transform=train_transform, 
                                             target_transform=None, 
                                             download=True)

    # for testing set after all stages
    test_set = torchvision.datasets.CIFAR10(root="$HOME/project_bilevel/bilevel/Datasets/", 
                                            train=False, 
                                            transform=test_transform, 
                                            target_transform=None, 
                                            download=True)
    # torch.Size([50000, 32, 32, 3])
    # change the dimension of image data from N*H*W*C to N*C*H*W where C = 3
    train_set_data = torch.FloatTensor(train_set.data).permute(0,3,1,2)
    print("All training data have size: ", train_set_data.size())
    # change labels to long type
    train_set_targets = torch.LongTensor(train_set.targets)
    print("All training targets have size: ", train_set_targets.size())

    # change the dimension of image data from N*H*W*C to N*C*H*W where C = 3
    test_set_data = torch.FloatTensor(test_set.data).permute(0,3,1,2)
    print("All testing data have size: ", test_set_data.size())
    # change labels to long type
    test_set_targets = torch.LongTensor(test_set.targets)
    print("All testing targets have size: ", test_set_targets.size())
    
    SIZE = train_set_data.size(dim=2) # get H (i.e. W) for neural network architecture
    
    SETTING = int(sys.argv[1])
    SEED = int(sys.argv[2])

    output_path = "$HOME/project_bilevel/bilevel/Experiment_results/Cifar10_Setting{}_Seed{}/".format(str(SETTING),str(SEED))
    
    NUM_STAGES = int(sys.argv[3])
    NUM_NODES = 15 # K = number of nodes
    NUM_SIMILAR_BOUND = 3  # b in the probability simplex
    NUM_SIMILAR = 5 # J = number of nodes with similar data distribution
    
    random.seed(a=SEED)
    near_node_list = sorted(random.sample(range(NUM_NODES), NUM_SIMILAR))
    print("The near ones among {} nodes in the simulation are:".format(NUM_NODES), near_node_list)
    print("-" * 30)
    
    far_node_list = sorted(list(set(range(NUM_NODES)) - set(near_node_list)))
    print("The near ones among {} nodes in the simulation are:".format(NUM_NODES), near_node_list)
    print("-" * 30)
    
    near_center_idx = random.choice(seq=near_node_list)
    print("For bilevel experiment near version, the center node in the simulation is node {}.".format(near_center_idx))
    print("-" * 30)
    
    far_center_idx = random.choice(seq=far_node_list)
    print("For bilevel experiment far version, the center node in the simulation is node {}.".format(far_center_idx))
    print("-" * 30)
      
    sys.stdout.flush()

    TRAIN_SIZE = 4000
    VAL_SIZE = 500
    TEST_SIZE = 5000

    near_p_dict = { "label_list_1": 0.36, "label_list_2": 0.04, "label_list_3": 0.54, "label_list_4": 0.06 }
    far_p_dict = { "label_list_1": 0.04, "label_list_2": 0.36, "label_list_3": 0.06, "label_list_4": 0.54 }

    label_list_1 = [1, 9]
    label_list_2 = [0, 8]
    label_list_3 = [2, 3, 4]
    label_list_4 = [5, 6, 7]

    label_permute = { 1: 0, 0: 2, 2: 5, 5: 1 }
    
    NUM_BATCH_LB = 50 # to decide how many batches there are in a data loader
    
    
    # preprocessing cifar10 data
    cifar10 = Data_Preprocess(train_set_data=train_set_data, train_set_targets=train_set_targets,
                              test_set_data=test_set_data, test_set_targets=test_set_targets,
                              label_list_1=label_list_1, label_list_2=label_list_2,
                              label_list_3=label_list_3, label_list_4=label_list_4,
                              near_center_idx=near_center_idx, far_center_idx=far_center_idx)
    
    cifar10.Decide_Dataset_Sizes(num_nodes=NUM_NODES, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    cifar10.Resample_Data(near_node_list=near_node_list, seed=SEED, near_p_dict=near_p_dict, far_p_dict=far_p_dict, margin=4.0)
    
    sys.stdout.flush()
    
    if SETTING == 1:
        # Setting 1
        cifar10.Dataset_Generating(seed=SEED)
    elif SETTING == 2:
        # Setting 2
        cifar10.Dataset_Generating_and_Label_Permute(seed=SEED, label_permute=label_permute)
    elif SETTING == 3:
        # Setting 3
        cifar10.Dataset_Generating_and_Rotation(seed=SEED)
    elif SETTING == 4:
        # Setting 4
        cifar10.Dataset_Generating_Label_Permute_and_Rotation(seed=SEED, label_permute=label_permute)
    
    # obtain three data loaders for four algorithms
    train_loaders, val_loaders, test_loaders = cifar10.Prepare_Dataloaders(num_batch_lb=NUM_BATCH_LB)
    
    sys.stdout.flush()
    
    loss_function = torch.nn.CrossEntropyLoss() # loss function for binary classification problem
    # loss_function = torch.nn.BCELoss()
    # parameters for model architecture addressing cifar10 data in the dimension of N*3*32*32
    in_channels = 3
    out_channels = 5
    out_features = 10 # 10 for torch.nn.CrossEntropyLoss()
    
    # pivot model as initialization of theta
    model_pivot = Net(in_channels=in_channels, 
                      out_channels=out_channels, 
                      out_features=out_features, 
                      size=SIZE)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_pivot = model_pivot.to(device)
    print(torch.nn.utils.parameters_to_vector(model_pivot.parameters()).size())
    
    sys.stdout.flush()

    NUM_EPOCH_IN = 5 # this times NUM_BATCH_LB is T1 in four algorithms
    NUM_EPOCH_HV = 5  # this times NUM_BATCH_LB is T2 in Bi-level algorithm
    CP = 10   # tau in Loopless-Local-SVRG of Bi-level algorithm
    
    LR_INNER = 5e-2 # learning rate in Loopless-Local-SVRG for updating theta of Bi-level algorithm
    LR_HV = 5e-4 # learning rate in Loopless-Local-SVRG for computing hv product of Bi-level algorithm
    # L = 5000 # constant in HIV for computing hv product of Bi-level algorithm
    
    LR_OUTER = 1.5e-2 # learning rate in projected gradient descent for updating w of Bi-level algorithm
    # LR_OUTER_HIV = 0.75 # learning rate in HIV for computing hv product of Bi-level algorithm

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    cifar10_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                         loss_function=loss_function, bl_type="near",
                                         train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fl_model_dict[node_idx] = node_model.to(cifar10_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    cifar10_fl.Initialize_Results()
    cifar10_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                         val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        cifar10_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                           num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                           communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = cifar10_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        cifar10_fl.current_num_syn += 1
        cifar10_fl.current_num_points += cifar10_fl.val_size
    
        hv_g = cifar10_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                       num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        cifar10_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                       stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = cifar10_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_bl_near_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_bl_near_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    cifar10_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                    loss_function=loss_function, bl_type="far",
                                    train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fl_model_dict[node_idx] = node_model.to(cifar10_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    cifar10_fl.Initialize_Results()
    cifar10_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                    val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        cifar10_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                      num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                      communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = cifar10_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        cifar10_fl.current_num_syn += 1
        cifar10_fl.current_num_points += cifar10_fl.val_size
    
        hv_g = cifar10_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                  num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        cifar10_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                  stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = cifar10_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_bl_far_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_bl_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # local-train algorithm:
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    cifar10_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                   near_center_idx=near_center_idx, 
                                                   far_center_idx=far_center_idx,
                                                   bm_type="near",
                                                   loss_function=loss_function,
                                                   val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
    
    local_bm_model = local_bm_model.to(cifar10_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    cifar10_local_train_bm.Initialize_Results()
    cifar10_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                                train_loaders=train_loaders,
                                                val_loaders=val_loaders, 
                                                test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        cifar10_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                                 num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = cifar10_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_local_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    cifar10_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                   near_center_idx=near_center_idx, 
                                                   far_center_idx=far_center_idx,
                                                   bm_type="far",
                                                   loss_function=loss_function,
                                                   val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
    
    local_bm_model = local_bm_model.to(cifar10_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    cifar10_local_train_bm.Initialize_Results()
    cifar10_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                                train_loaders=train_loaders,
                                                val_loaders=val_loaders, 
                                                test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        cifar10_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                                 num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                                 communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = cifar10_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_local_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Fed-avg algorithm:
    NUM_STAGE_FA = int(NUM_STAGES*2) # S in Fed-avg Algorithm
    
    cifar10_fed_avg_bm = Benchmark_Fed_Avg(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                           loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fed_avg_bm_model_dict = {}
    for node_idx in range(NUM_NODES):

        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fed_avg_bm_model_dict[node_idx] = node_model.to(cifar10_fed_avg_bm.device)
        fed_avg_bm_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    cifar10_fed_avg_bm.Initialize_Results()
    cifar10_fed_avg_bm.Initialize_Variables(model_dict=fed_avg_bm_model_dict, train_loaders=train_loaders,
                                            val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)

    for s in range(NUM_STAGE_FA):

        cifar10_fed_avg_bm.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders,
                                              test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN,
                                              num_batches=NUM_BATCH_LB, communicate_period=CP, 
                                              lr_inner=LR_INNER)

    favg_metric_dict = cifar10_fed_avg_bm.Output_Results()
    pd.DataFrame(favg_metric_dict).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_favg_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Ditto algorithm:
    NUM_STAGE_FDITTO = int(NUM_STAGES*2) # S in Ditto Algorithm
    
    cifar10_ditto_bm = Benchmark_Ditto(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                       loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    ditto_bm_model_shared_dict = {}
    ditto_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_channels=in_channels, 
                                out_channels=out_channels, 
                                out_features=out_features, 
                                size=SIZE)
        ditto_bm_model_shared_dict[node_idx] = node_shared_model.to(cifar10_ditto_bm.device)
        ditto_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
        ditto_bm_model_shared_old = Net(in_channels=in_channels, 
                                        out_channels=out_channels, 
                                        out_features=out_features, 
                                        size=SIZE)
        ditto_bm_model_shared_old = ditto_bm_model_shared_old.to(cifar10_ditto_bm.device)
        ditto_bm_model_shared_old.load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_channels=in_channels, 
                                  out_channels=out_channels, 
                                  out_features=out_features, 
                                  size=SIZE)
    
        ditto_bm_model_personal_dict[node_idx] = node_personal_model.to(cifar10_ditto_bm.device)
        ditto_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    cifar10_ditto_bm.Initialize_Results()
    cifar10_ditto_bm.Initialize_Variables(model_shared_dict=ditto_bm_model_shared_dict, model_shared_old=ditto_bm_model_shared_old,
                                          model_personal_dict=ditto_bm_model_personal_dict, train_loaders=train_loaders,
                                          val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    for s in range(NUM_STAGE_FDITTO):
    
        cifar10_ditto_bm.Update_Shared_and_Personal_Models_SGD(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders, 
                                                               num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER, 
                                                               num_batches_rand=25, lambda_val=0.1)    
    
    fditto_metric_dict = cifar10_ditto_bm.Output_Results()
    pd.DataFrame(fditto_metric_dict["near"]).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_fditto_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fditto_metric_dict["far"]).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_fditto_far_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    # Fed-personalized with ME algorithm:
    NUM_STAGE_FPME = int(NUM_STAGES*2) # S in Fed-personalized with ME algorithm
    
    cifar10_fpme_bm = Benchmark_Fed_Pme(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                        loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fpme_bm_model_shared_dict = {}
    fpme_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_channels=in_channels, 
                                out_channels=out_channels, 
                                out_features=out_features, 
                                size=SIZE)
        fpme_bm_model_shared_dict[node_idx] = node_shared_model.to(cifar10_fpme_bm.device)
        fpme_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_channels=in_channels, 
                                  out_channels=out_channels, 
                                  out_features=out_features, 
                                  size=SIZE)
    
        fpme_bm_model_personal_dict[node_idx] = node_personal_model.to(cifar10_fpme_bm.device)
        fpme_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    cifar10_fpme_bm.Initialize_Results()
    cifar10_fpme_bm.Initialize_Variables(model_shared_dict=fpme_bm_model_shared_dict, model_personal_dict=fpme_bm_model_personal_dict,
                                         train_loaders=train_loaders,val_loaders=val_loaders, test_loaders=test_loaders,num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_FPME):
    
        cifar10_fpme_bm.Update_Shared_and_Personal_Models_SGD(val_loaders=val_loaders, test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                                              num_steps_ub=25, delta=0.005, communicate_period=CP, lr_inner=LR_INNER, lambda_val=15)     
        
    fpme_metric_dict = cifar10_fpme_bm.Output_Results()
    pd.DataFrame(fpme_metric_dict["near"]).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_fpme_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fpme_metric_dict["far"]).to_csv(os.path.join(output_path, "Cifar10_Setting{}_Seed{}_Result_fpme_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    sys.stdout.close()
    sys.exit(0)