import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import sys
import random

import pandas as pd
import torch
import torchvision

from dataset import Data_Preprocess
from fed_bl import Bilevel_Federation
from local_train import Benchmark_Local_Train
from fed_avg import Benchmark_Fed_Avg
from fed_ditto import Benchmark_Ditto
from fed_pme import Benchmark_Fed_Pme
from model import Net

train_transform = torchvision.transforms.Compose(
                        [ # torchvision.transforms.RandomRotation(degrees=180),
                          # torchvision.transforms.RandomGrayscale(),
                          torchvision.transforms.ToTensor() ])

test_transform = torchvision.transforms.Compose(
                        [ # torchvision.transforms.RandomRotation(degrees=180),
                          torchvision.transforms.ToTensor() ])


if __name__ == "__main__":

    print("Start preparing datasets.")
    print("-" * 30)

    # MNIST part
    train_set = torchvision.datasets.MNIST(root="$HOME/project_bilevel/bilevel/Datasets/",
                                           train=True,
                                           transform=train_transform,
                                           target_transform=None,
                                           download=True)

    test_set = torchvision.datasets.MNIST(root="$HOME/project_bilevel/bilevel/Datasets/",
                                          train=False,
                                          transform=test_transform,
                                          target_transform=None,
                                          download=True)

    # change the dimension of image data from N*H*W to N*C*H*W where C = 1
    train_set_data = train_set.data.float().unsqueeze(dim=1)
    print("-" * 30)
    print("All training data have size: ", train_set_data.size())
    print("-" * 30)
    # change labels to long type
    train_set_targets = train_set.targets.long()
    print("All training targets have size: ", train_set_targets.size())
    print("-" * 30)

    # change the dimension of image dat from N*H*W to N*C*H*W where C = 1
    test_set_data = test_set.data.float().unsqueeze(dim=1)
    print("All testing data have size: ", test_set_data.size())
    print("-" * 30)
    # change labels to long type
    test_set_targets = test_set.targets.long()
    print("All testing targets have size: ", test_set_targets.size())
    print("-" * 30)
    
    SIZE = train_set_data.size(dim=2) # get H (i.e. W) for neural network architecture
    
    SETTING = int(sys.argv[1])
    SEED = int(sys.argv[2])
    
    output_path = "$HOME/project_bilevel/bilevel/Experiment_results/MNIST_Setting{}_Seed{}/".format(str(SETTING),str(SEED))

    NUM_STAGES = int(sys.argv[3])
    NUM_NODES = 15 # K = number of nodes
    NUM_SIMILAR_BOUND = 3  # b in the probability simplex
    NUM_SIMILAR = 5 # J = number of nodes with similar data distribution
    
    random.seed(a=SEED)
    near_node_list = sorted(random.sample(range(NUM_NODES), NUM_SIMILAR))
    print("The near ones among {} nodes in the simulation are:".format(NUM_NODES), near_node_list)
    print("-" * 30)
    
    far_node_list = sorted(list(set(range(NUM_NODES)) - set(near_node_list)))
    print("The near ones among {} nodes in the simulation are:".format(NUM_NODES), near_node_list)
    print("-" * 30)
    
    near_center_idx = random.choice(seq=near_node_list)
    print("For bilevel experiment near version, the center node in the simulation is node {}.".format(near_center_idx))
    print("-" * 30)
    
    far_center_idx = random.choice(seq=far_node_list)
    print("For bilevel experiment far version, the center node in the simulation is node {}.".format(far_center_idx))
    print("-" * 30)
      
    sys.stdout.flush()

    TRAIN_SIZE = 4000
    VAL_SIZE = 500
    TEST_SIZE = 5000
    
    near_p_dict = { "label_list_1": 0.42, "label_list_2": 0.08, "label_list_3": 0.38, "label_list_4": 0.12 }
    far_p_dict = { "label_list_1": 0.12, "label_list_2": 0.38, "label_list_3": 0.08, "label_list_4": 0.42 }
    
    label_list_1 = [2, 4, 6]
    label_list_2 = [0, 3]
    label_list_3 = [1, 8]
    label_list_4 = [5, 7, 9]
    
    label_permute = { 2: 0, 0: 1, 1: 5, 5: 2 }
    
    NUM_BATCH_LB = 50 # to decide how many batches there are in a data loader
    
    
    # preprocessing mnist data
    mnist = Data_Preprocess(train_set_data=train_set_data, train_set_targets=train_set_targets,
                            test_set_data=test_set_data, test_set_targets=test_set_targets,
                            label_list_1=label_list_1, label_list_2=label_list_2,
                            label_list_3=label_list_3, label_list_4=label_list_4,
                            near_center_idx=near_center_idx, far_center_idx=far_center_idx)
    
    mnist.Decide_Dataset_Sizes(num_nodes=NUM_NODES, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    mnist.Resample_Data(near_node_list=near_node_list, seed=SEED, near_p_dict=near_p_dict, far_p_dict=far_p_dict, margin=4.0)
    
    sys.stdout.flush()
    
    if SETTING == 1:
        # Setting 1
        mnist.Dataset_Generating(seed=SEED)
    elif SETTING == 2:
        # Setting 2
        mnist.Dataset_Generating_and_Label_Permute(seed=SEED, label_permute=label_permute)
    elif SETTING == 3:
        # Setting 3
        mnist.Dataset_Generating_and_Rotation(seed=SEED)
    elif SETTING == 4:
        # Setting 4
        mnist.Dataset_Generating_Label_Permute_and_Rotation(seed=SEED, label_permute=label_permute)
    
    # obtain three data loaders for four algorithms
    train_loaders, val_loaders, test_loaders = mnist.Prepare_Dataloaders(num_batch_lb=NUM_BATCH_LB)
    
    sys.stdout.flush()
    
    loss_function = torch.nn.CrossEntropyLoss() # loss function for binary classification problem
    # loss_function = torch.nn.BCELoss()
    # parameters for model architecture addressing mnist data in the dimension of N*1*28*28
    in_channels = 1
    out_channels = 2
    out_features = 10 # 10 for torch.nn.CrossEntropyLoss()
    
    # pivot model as initialization of theta
    model_pivot = Net(in_channels=in_channels, 
                      out_channels=out_channels, 
                      out_features=out_features, 
                      size=SIZE)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_pivot = model_pivot.to(device)
    print(torch.nn.utils.parameters_to_vector(model_pivot.parameters()).size())
    
    sys.stdout.flush()

    NUM_EPOCH_IN = 5 # this times NUM_BATCH_LB is T1 in four algorithms
    NUM_EPOCH_HV = 5  # this times NUM_BATCH_LB is T2 in Bi-level algorithm
    CP = 10   # tau in Loopless-Local-SVRG of Bi-level algorithm
    
    LR_INNER = 5e-2 # learning rate in Loopless-Local-SVRG for updating theta of Bi-level algorithm
    LR_HV = 5e-4 # learning rate in Loopless-Local-SVRG for computing hv product of Bi-level algorithm
    # L = 5000 # constant in HIV for computing hv product of Bi-level algorithm
    
    LR_OUTER = 1.5e-2 # learning rate in projected gradient descent for updating w of Bi-level algorithm
    # LR_OUTER_HIV = 0.75 # learning rate in HIV for computing hv product of Bi-level algorithm

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    mnist_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                  loss_function=loss_function, bl_type="near",
                                  train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fl_model_dict[node_idx] = node_model.to(mnist_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    mnist_fl.Initialize_Results()
    mnist_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                  val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        mnist_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                    num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                    communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = mnist_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        mnist_fl.current_num_syn += 1
        mnist_fl.current_num_points += mnist_fl.val_size
    
        hv_g = mnist_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        mnist_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = mnist_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_bl_near_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_bl_near_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Bi-Level algorithm:
    NUM_STAGE_BL = NUM_STAGES # S in Bi-level Algorithm
    
    mnist_fl = Bilevel_Federation(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                  loss_function=loss_function, bl_type="far",
                                  train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fl_model_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fl_model_dict[node_idx] = node_model.to(mnist_fl.device)
        fl_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    mnist_fl.Initialize_Results()
    mnist_fl.Initialize_Variables(model_dict=fl_model_dict, train_loaders=train_loaders,
                                  val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_BL):
            
        mnist_fl.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders,
                                    num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                    communicate_period=CP, lr_inner=LR_INNER)
    
        val_grad = mnist_fl.Compute_Full_Gradient_Val(val_loaders=val_loaders)
        mnist_fl.current_num_syn += 1
        mnist_fl.current_num_points += mnist_fl.val_size
    
        hv_g = mnist_fl.Compute_Hv_Product_SVRG(train_loaders=train_loaders, b=val_grad, num_batches=NUM_BATCH_LB,      
                                                num_epoch=NUM_EPOCH_HV, communicate_period=CP, lr_hv=LR_HV)
    
        mnist_fl.Update_Weights(outer_hv_g=hv_g, train_loaders=train_loaders, lr_outer=LR_OUTER,
                                stage=s, w_ub=NUM_SIMILAR_BOUND)
    
    bl_w_dict, bl_metric_dict = mnist_fl.Output_Results()
    
    pd.DataFrame(bl_w_dict).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_bl_far_w_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(bl_metric_dict).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_bl_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # local-train algorithm:
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    mnist_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                 near_center_idx=near_center_idx, 
                                                 far_center_idx=far_center_idx,
                                                 bm_type="near",
                                                 loss_function=loss_function,
                                                 val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
    
    local_bm_model = local_bm_model.to(mnist_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    mnist_local_train_bm.Initialize_Results()
    mnist_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                              train_loaders=train_loaders,
                                              val_loaders=val_loaders, 
                                              test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        mnist_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                               num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = mnist_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_local_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    NUM_STAGE_LC = int(NUM_STAGES*16) # S in local-train Algorithm
    mnist_local_train_bm = Benchmark_Local_Train(num_nodes=NUM_NODES, 
                                                 near_center_idx=near_center_idx, 
                                                 far_center_idx=far_center_idx,
                                                 bm_type="far",
                                                 loss_function=loss_function,
                                                 val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    local_bm_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
    
    local_bm_model = local_bm_model.to(mnist_local_train_bm.device)
    local_bm_model.load_state_dict(model_pivot.state_dict())
    
    mnist_local_train_bm.Initialize_Results()
    mnist_local_train_bm.Initialize_Variables(model=local_bm_model, 
                                              train_loaders=train_loaders,
                                              val_loaders=val_loaders, 
                                              test_loaders=test_loaders)
    for s in range(NUM_STAGE_LC):
        mnist_local_train_bm.Update_Model_SVRG(val_loaders=val_loaders, test_loaders=test_loaders, 
                                               num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                               communicate_period=CP, lr_inner=LR_INNER)
    
    local_metric_dict = mnist_local_train_bm.Output_Results()
    pd.DataFrame(local_metric_dict).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_local_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Fed-avg algorithm:
    NUM_STAGE_FA = int(NUM_STAGES*2) # S in Fed-avg Algorithm
    
    mnist_fed_avg_bm = Benchmark_Fed_Avg(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                         loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fed_avg_bm_model_dict = {}
    for node_idx in range(NUM_NODES):

        node_model = Net(in_channels=in_channels, 
                         out_channels=out_channels, 
                         out_features=out_features, 
                         size=SIZE)
        fed_avg_bm_model_dict[node_idx] = node_model.to(mnist_fed_avg_bm.device)
        fed_avg_bm_model_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    mnist_fed_avg_bm.Initialize_Results()
    mnist_fed_avg_bm.Initialize_Variables(model_dict=fed_avg_bm_model_dict, train_loaders=train_loaders,
                                          val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)

    for s in range(NUM_STAGE_FA):

        mnist_fed_avg_bm.Update_Models_SVRG(train_loaders=train_loaders, val_loaders=val_loaders,
                                            test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN,
                                            num_batches=NUM_BATCH_LB, communicate_period=CP, 
                                            lr_inner=LR_INNER)

    favg_metric_dict = mnist_fed_avg_bm.Output_Results()
    pd.DataFrame(favg_metric_dict).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_favg_metric_dict.csv".format(str(SETTING), str(SEED))))

    # Ditto algorithm:
    NUM_STAGE_FDITTO = int(NUM_STAGES*2) # S in Ditto Algorithm
    
    mnist_ditto_bm = Benchmark_Ditto(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                     loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    ditto_bm_model_shared_dict = {}
    ditto_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_channels=in_channels, 
                                out_channels=out_channels, 
                                out_features=out_features, 
                                size=SIZE)
        ditto_bm_model_shared_dict[node_idx] = node_shared_model.to(mnist_ditto_bm.device)
        ditto_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
        ditto_bm_model_shared_old = Net(in_channels=in_channels, 
                                        out_channels=out_channels, 
                                        out_features=out_features, 
                                        size=SIZE)
        ditto_bm_model_shared_old = ditto_bm_model_shared_old.to(mnist_ditto_bm.device)
        ditto_bm_model_shared_old.load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_channels=in_channels, 
                                  out_channels=out_channels, 
                                  out_features=out_features, 
                                  size=SIZE)
    
        ditto_bm_model_personal_dict[node_idx] = node_personal_model.to(mnist_ditto_bm.device)
        ditto_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    mnist_ditto_bm.Initialize_Results()
    mnist_ditto_bm.Initialize_Variables(model_shared_dict=ditto_bm_model_shared_dict, model_shared_old=ditto_bm_model_shared_old,
                                        model_personal_dict=ditto_bm_model_personal_dict, train_loaders=train_loaders,
                                        val_loaders=val_loaders, test_loaders=test_loaders, num_batches=NUM_BATCH_LB)
    for s in range(NUM_STAGE_FDITTO):
    
        mnist_ditto_bm.Update_Shared_and_Personal_Models_SGD(train_loaders=train_loaders, val_loaders=val_loaders, test_loaders=test_loaders, 
                                                             num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, communicate_period=CP, lr_inner=LR_INNER, 
                                                             num_batches_rand=25, lambda_val=0.1)    
    
    fditto_metric_dict = mnist_ditto_bm.Output_Results()
    pd.DataFrame(fditto_metric_dict["near"]).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_fditto_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fditto_metric_dict["far"]).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_fditto_far_metric_dict.csv".format(str(SETTING), str(SEED))))
    
    # Fed-personalized with ME algorithm:
    NUM_STAGE_FPME = int(NUM_STAGES*2) # S in Fed-personalized with ME algorithm
    
    mnist_fpme_bm = Benchmark_Fed_Pme(num_nodes=NUM_NODES, near_center_idx=near_center_idx, far_center_idx=far_center_idx,
                                      loss_function=loss_function, train_size=TRAIN_SIZE, val_size=VAL_SIZE, test_size=TEST_SIZE)
    
    fpme_bm_model_shared_dict = {}
    fpme_bm_model_personal_dict = {}
    for node_idx in range(NUM_NODES):
    
        node_shared_model = Net(in_channels=in_channels, 
                                out_channels=out_channels, 
                                out_features=out_features, 
                                size=SIZE)
        fpme_bm_model_shared_dict[node_idx] = node_shared_model.to(mnist_fpme_bm.device)
        fpme_bm_model_shared_dict[node_idx].load_state_dict(model_pivot.state_dict())
        
        node_personal_model = Net(in_channels=in_channels, 
                                  out_channels=out_channels, 
                                  out_features=out_features, 
                                  size=SIZE)
    
        fpme_bm_model_personal_dict[node_idx] = node_personal_model.to(mnist_fpme_bm.device)
        fpme_bm_model_personal_dict[node_idx].load_state_dict(model_pivot.state_dict())
    
    mnist_fpme_bm.Initialize_Results()
    mnist_fpme_bm.Initialize_Variables(model_shared_dict=fpme_bm_model_shared_dict, model_personal_dict=fpme_bm_model_personal_dict,
                                       train_loaders=train_loaders,val_loaders=val_loaders, test_loaders=test_loaders,num_batches=NUM_BATCH_LB)
    
    for s in range(NUM_STAGE_FPME):
    
        mnist_fpme_bm.Update_Shared_and_Personal_Models_SGD(val_loaders=val_loaders, test_loaders=test_loaders, num_epoch=NUM_EPOCH_IN, num_batches=NUM_BATCH_LB, 
                                                            num_steps_ub=25, delta=0.005, communicate_period=CP, lr_inner=LR_INNER, lambda_val=15)     
        
    fpme_metric_dict = mnist_fpme_bm.Output_Results()
    pd.DataFrame(fpme_metric_dict["near"]).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_fpme_near_metric_dict.csv".format(str(SETTING), str(SEED))))
    pd.DataFrame(fpme_metric_dict["far"]).to_csv(os.path.join(output_path, "MNIST_Setting{}_Seed{}_Result_fpme_far_metric_dict.csv".format(str(SETTING), str(SEED))))

    sys.stdout.close()
    sys.exit(0)