import os
import json

with open("./json/COMMON.json", "r") as f:
    temp_dict = json.load(f)
    os.environ["CUDA_VISIBLE_DEVICES"] = temp_dict['CUDA_VISIBLE_DEVICES']
import torch

from moudle.experiment_setup import Experiment_Create_dataset, Experiment_Create_dataloader, \
    Experiment_Model_construction
from tool.logger import *
from tool.utils import check_and_make_the_path
from experiment.Experiment_FedAvg import Experiment_Federated_Average
from experiment.Experiment_FedFair import Experiment_Federated_Fair
from experiment.Experiment_LCO import Experiment_LCO
from experiment.Experiment_FedRenyi import Experiment_Federated_Renyi
from experiment.Experiment_Separate import Experiment_Separate
from experiment.Experiment_FedProx import Experiment_Federated_Proximal
from experiment.Experiment_Scaffold import Experiment_Scaffold
from experiment.Experiment_FairFed import Experiment_FairFed
from experiment.Experiment_FedBatch import Experiment_Federated_Batch
from experiment.Experiment_FedFB import Experiment_FedFB


# Hyperparameters of FL from json file
def load_fl_hyperparameter_from_json(dataset_name):
    param_dict = {}
    with open("./json/COMMON.json", "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    with open(os.path.join("./json/", dataset_name + ".json"), "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    return param_dict


def Experiment(param_dict):
    # Create dataset
    logger.info("Creating dataset")
    training_dataset, positive_training_dataset, negative_training_dataset, \
    testing_dataset, nn_input_size = Experiment_Create_dataset(param_dict, no_pickle=False)

    # Create dataloader
    logger.info("Creating dataloader")
    training_dataloaders, positive_training_dataloaders, negative_training_dataloaders, \
    validation_dataloaders, client_dataset_list, testing_dataloader = Experiment_Create_dataloader(
        param_dict, training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset,
        param_dict['split_strategy'])

    # Model Construction
    global_model = Experiment_Model_construction(param_dict, nn_input_size)
    logger.info("-----------------------------------------------------------------------------")

    # Federated Average
    if ("FederatedAverage" in param_dict["algorithm"]) or ("Average" in param_dict["algorithm"]):
        logger.info("~~~~~~ Algorithm: Federated Average ~~~~~~")
        Experiment_Federated_Average(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list, testing_dataloader
        )

    # Federated Renyi
    if ("FederatedRenyi" in param_dict["algorithm"]) or ("Renyi" in param_dict["algorithm"]):
        logger.info("~~~~~~ Algorithm: Federated Renyi ~~~~~~")
        Experiment_Federated_Renyi(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list, testing_dataloader
        )

    # Federated Fair
    if ("FederatedFair" in param_dict["algorithm"]) or ("FedFair" in param_dict["algorithm"]):
        logger.info("~~~~~~ Algorithm: Federated Fair ~~~~~~")
        Experiment_Federated_Fair(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list, testing_dataloader
        )

    # LCO
    if "LCO" in param_dict["algorithm"]:
        logger.info("~~~~~~ Algorithm: LCO ~~~~~~")
        Experiment_LCO(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list, testing_dataloader
        )

    # FederatedProximal
    if ("FederatedProximal" in param_dict["algorithm"]) or ("FedProx" in param_dict["algorithm"]):
        logger.info("~~~~~~ Algorithm: Federated Proximal ~~~~~~")
        Experiment_Federated_Proximal(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list, testing_dataloader
        )

    # Scaffold
    if "Scaffold" in param_dict["algorithm"]:
        logger.info("~~~~~~ Algorithm: Scaffold ~~~~~~")
        Experiment_Scaffold(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list, testing_dataloader
        )

    # FairFed AAAI 2023
    if "FairFed" in param_dict["algorithm"]:
        logger.info("~~~~~~ Algorithm: FairFed ~~~~~~")
        Experiment_FairFed(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
            testing_dataloader
        )

    # Separate Training
    if ("Separate" in param_dict["algorithm"]) or ("Sepa" in param_dict["algorithm"]):
        logger.info("~~~~~~ Algorithm: Separate Training ~~~~~~")
        Experiment_Separate(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
            testing_dataloader
        )

    # FairBatch 2021 ICLR
    if "FairBatch" in param_dict["algorithm"]:
        logger.info("~~~~~~ Algorithm: Fair Batch ~~~~~~")
        Experiment_Federated_Batch(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
            testing_dataloader
        )

    # FedFB
    if "FedFB" in param_dict["algorithm"]:
        logger.info("~~~~~~ Algorithm: Federated Batch ~~~~~~")
        Experiment_FedFB(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
            testing_dataloader
        )


def main(dataset_name, algorithm, hypothesis, γ_k_style, device):
    ################################################################################################
    param_dict = load_fl_hyperparameter_from_json(dataset_name)

    os.environ["CUDA_VISIBLE_DEVICES"] = param_dict['CUDA_VISIBLE_DEVICES']

    if "gpu" in device:
        param_dict['device'] = "cuda" if torch.cuda.is_available() else "cpu"  # Get cpu or gpu device for experiment
    else:
        param_dict['device'] = "cpu"

    # step_T_communication_I_list = [(500, 50), (500, 20), (500, 10), (2500, 250)]
    # step_T_communication_I_list = [(100, 10), (100, 4), (100, 2), (500, 50)]
    step_T_communication_I_list = [(100, 4)]

    # split_strategy_list = ["Dirichlet0.1", "Dirichlet0.2", "Dirichlet0.5", "Dirichlet1", "Dirichlet8", "Dirichlet64", "Uniform"]
    # split_strategy_list = ["Uniform", "Dirichlet0.5", "Dirichlet1", "Dirichlet8", "Dirichlet64"]
    # split_strategy_list = ["Uniform", "Dirichlet0.5", "Dirichlet1", "Dirichlet8"]
    split_strategy_list = ["Dirichlet0.5"]

    # sensitive_attribute_skew_list = ["no_attribute_skew", "positive_0.1", "positive_0.5"]
    sensitive_attribute_skew_list = ["no_attribute_skew",]


    # split_strategy_list = ["Dirichlet1"]
    param_dict['γ_k_style'] = γ_k_style
    param_dict['dataset_name'] = dataset_name
    param_dict['algorithm'] = algorithm
    param_dict['hypothesis'] = hypothesis
    param_dict["save_checkpoint_rounds"] = 99999999999

    lamda_list = [0]
    rho_list, tolerance_rate_list = [0], [1]
    if ("FederatedAverage" in algorithm) or ("Average" in algorithm):
        FL_drop_rate_list = [0, 0.3, 0.5]
        # FL_drop_rate_list = [0.5, 0.3, 0]

    # TODO : fine-tune the λ in FedRenyi
    elif "Renyi" in algorithm:
        if "ADULT" in dataset_name:
            if "LR" in hypothesis:
                lamda_list = [1000, 5, 1, 0.5, 0.1]
            # else:
            #     lamda_list = [1000, 100, 10, 1]
        elif "COMPAS" in dataset_name:
            if "LR" in hypothesis:
                lamda_list = [1000, 5, 1, 0.5, 0.1]
            # else:
            #     lamda_list = [1000, 100, 10, 1]
        elif "DRUG" in dataset_name:
            if "LR" in hypothesis:
                lamda_list = [1000, 5, 1, 0.5, 0.1]  # 5是目前本地测试过最好的公平性配置
            # else:
            #     lamda_list = [1000, 100, 10, 1]
        elif "DUTCH" in dataset_name:
            if "LR" in hypothesis:
                lamda_list = [1000, 5, 1, 0.5, 0.1]
            # else:
            #     lamda_list = [1000, 100, 10, 1]
        elif "ARRHYTHMIA" in dataset_name:
            if "LR" in hypothesis:
                lamda_list = [1000, 5, 1, 0.5, 0.1]
            # else:
            #     lamda_list = [1000, 100, 10, 1]
        else:
            lamda_list = [1000, 100, 10, 1]

        FL_drop_rate_list = [0, 0.3, 0.5]
        # rho_list = [0.05, 0.1]  # parameter for the similarity_matrix in asynchronous algorithm
        rho_list = [0.1]
        tolerance_rate_list = [0.5, 0.75, 1]

    else:
        # Skipping the unnecessary loop
        FL_drop_rate_list = [0]


    # Hyperparameter fine-tune for the FedFair
    if ("FederatedFair" in algorithm) or ("FedFair" in algorithm):
        if "ADULT" in dataset_name:
            # ϵ_list = [0.05, 0.125, 0.5, 1]
            ϵ_list = [0.05]
        elif "COMPAS" in dataset_name:
            if "LR" in hypothesis:
                # ϵ_list = [0.1, 0.001, 0.005, 1]
                ϵ_list = [0.1]
            else:
                ϵ_list = [0.05, 0.001, 0.1, 1]
        elif "DRUG" in dataset_name:
            if "LR" in hypothesis:
                # ϵ_list = [0.035, 0.001, 0.14, 1]
                ϵ_list = [0.035]
            else:
                ϵ_list = [0.01, 0.00125, 0.1, 1]
        else:
            ϵ_list = [0.05]
    else:
        ϵ_list = [0]

    # 强行减少循环
    FL_drop_rate_list = [0]
    tolerance_rate_list = [1]
    lamda_list = [1]

    # Serial number of experiment
    Experiment_NO = 1
    total_Experiment_NO = len(ϵ_list) * len(lamda_list) * len(FL_drop_rate_list) \
                          * len(step_T_communication_I_list) * len(split_strategy_list) * len(rho_list) \
                          * len(tolerance_rate_list) * len(sensitive_attribute_skew_list)

    # Main Loop
    for split_strategy in split_strategy_list:
        param_dict['split_strategy'] = split_strategy
        for sensitive_attribute_skew in sensitive_attribute_skew_list:
            param_dict['sensitive_attribute_skew'] = sensitive_attribute_skew
            for ϵ in ϵ_list:
                param_dict['epsilon'] = ϵ
                for lamda in lamda_list:
                    param_dict['lamda'] = lamda
                    for algorithm_step_T, communication_round_I in step_T_communication_I_list:
                        param_dict['algorithm_step_T'] = algorithm_step_T
                        param_dict['communication_round_I'] = communication_round_I
                        for rho in rho_list:
                            param_dict['rho'] = rho
                            for FL_drop_rate in FL_drop_rate_list:
                                param_dict['FL_drop_rate'] = FL_drop_rate
                                for tolerance_rate in tolerance_rate_list:
                                    if (FL_drop_rate == 0) & ("Renyi" in algorithm) & (tolerance_rate != 1):
                                        break
                                    param_dict['tolerance_τ'] = int(tolerance_rate * algorithm_step_T)
                                    ################################################################################################
                                    # Create the log
                                    if ("FederatedRenyi" in param_dict["algorithm"]) or (
                                            "Renyi" in param_dict["algorithm"]):
                                        algorithm_name = param_dict['algorithm'] + "_" + param_dict['γ_k_style']
                                    else:
                                        algorithm_name = param_dict['algorithm']

                                    log_path = os.path.join("./log_path",
                                                            param_dict['dataset_name'],
                                                            algorithm_name,
                                                            param_dict['hypothesis'],
                                                            param_dict['sensitive_attribute_skew'],
                                                            param_dict['split_strategy'],
                                                            )
                                    check_and_make_the_path(log_path)
                                    log_path = os.path.join(log_path, str(Experiment_NO))
                                    param_dict['log_path'] = log_path
                                    file_handler = logging.FileHandler(log_path + ".txt")
                                    file_handler.setFormatter(formatter)
                                    logger.addHandler(file_handler)
                                    ################################################################################################
                                    # Create the model path
                                    model_path = os.path.join("./save_path/model", param_dict['dataset_name'],
                                                              param_dict['algorithm'],
                                                              param_dict['hypothesis'], str(Experiment_NO))
                                    check_and_make_the_path(model_path)
                                    param_dict['model_path'] = model_path
                                    for k in range(param_dict["num_clients_K"]):
                                        check_and_make_the_path(os.path.join(model_path, "client_" + str(k + 1)))
                                    ################################################################################################
                                    # Create the json file
                                    logger.info(f"Experiment {Experiment_NO}/{total_Experiment_NO} setup finish")
                                    json_str = json.dumps(param_dict, indent=4)
                                    with open(log_path + "_Parameter.json", "w") as json_file:
                                        json_file.write(json_str)
                                    with open(os.path.join(model_path, "Parameter.json"), "w") as json_file:
                                        json_file.write(json_str)
                                    param_dict['Experiment_NO'] = str(Experiment_NO)
                                    ################################################################################################
                                    # Parameter announcement
                                    logger.info("Parameter announcement")
                                    for para_key in list(param_dict.keys()):
                                        if "_common" in para_key:
                                            continue
                                        logger.info(f"****** {para_key} : {param_dict[para_key]} ******")
                                    logger.info(
                                        "-----------------------------------------------------------------------------")
                                    ################################################################################################
                                    # Experiment
                                    Experiment(param_dict)
                                    Experiment_NO += 1
                                    logger.removeHandler(file_handler)
                                    logger.info(
                                        "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")
                                    logger.info(
                                        "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")


if __name__ == '__main__':
    # 示例   nohup python main.py COMPAS FederatedRenyi LR uniform_distribution gpu &
    # main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5])


    # main("DRUG", "FedFair", "LR", "__", "gpu")
    # main("DRUG", "FederatedAverage", "LR", "__", "gpu")
    # main("DRUG", "Separate", "LR", "__", "gpu")
    # main("DRUG", "FedProx", "LR", "__", "gpu")
    # main("DRUG", "LCO", "LR", "_", "gpu")

    main("ADULT", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    main("ARRHYTHMIA", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    main("COMPAS", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    main("DRUG", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    main("DUTCH", "FederatedRenyi", "LR", "uniform_distribution", "gpu")

    main("ADULT", "FederatedRenyi", "LR", "unifor_client", "gpu")
    main("ARRHYTHMIA", "FederatedRenyi", "LR", "unifor_client", "gpu")
    main("COMPAS", "FederatedRenyi", "LR", "unifor_client", "gpu")
    main("DRUG", "FederatedRenyi", "LR", "unifor_client", "gpu")
    main("DUTCH", "FederatedRenyi", "LR", "unifor_client", "gpu")

    # BANK数据集LR的FR在FedAvg上恒为1,EOD也恒为0，SPD指标出现了越狄利克雷，SPD反而越靠近0（更公平），所以暂时不考虑BANK
    # GERMAN数据集的FedAvg公平性表现也很高，没法做
