import os
import json
import statistics
from tool.utils import communication_cost_simulated_by_beta_distribution

with open("./json/COMMON.json", "r") as f:
    temp_dict = json.load(f)
    os.environ["CUDA_VISIBLE_DEVICES"] = temp_dict['CUDA_VISIBLE_DEVICES']
import torch

from moudle.experiment_setup import Experiment_Create_dataset, Experiment_Create_dataloader, \
    Experiment_Model_construction
from tool.logger import *
from tool.utils import check_and_make_the_path
from algorithm.FederatedRenyi.FederatedRenyi_LR_for_acceleration_and_bias_analysis import Fed_Renyi_LR


# Hyperparameters of FL from json file
def load_fl_hyperparameter_from_json(dataset_name):
    param_dict = {}
    with open("./json/COMMON.json", "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    with open(os.path.join("./json/", dataset_name + ".json"), "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    return param_dict


def Experiment_Federated_Renyi(param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
                               testing_dataloader, communication_cost_list_list, descending_order_list_list):
    device = param_dict['device']
    γ_k_style = param_dict['γ_k_style']

    Fed_Renyi = Fed_Renyi_LR

    updated_global_model, client_model_list = Fed_Renyi(
        device,
        param_dict['mask_s1_flag'],
        param_dict['lamda'],
        global_model,
        param_dict['tolerance_τ'],
        param_dict['algorithm_step_T'],
        param_dict['num_clients_K'],
        param_dict['communication_round_I'],
        param_dict['local_step_size'],
        training_dataloaders,
        training_dataset,
        client_dataset_list,
        param_dict['FL_drop_rate'],
        param_dict['rho'],
        γ_k_style,
        communication_cost_list_list,
        descending_order_list_list
    )
    logger.info("-----------------------------------------------------------------------------")

    logger.info("Global model testing")
    global_acc, DEO, FR, HM = Experiment_Model_testing(device=device,
                                                       testing_dataloader=testing_dataloader,
                                                       mask_s1_flag=param_dict['mask_s1_flag'],
                                                       testing_model=updated_global_model,
                                                       hypothesis=param_dict['hypothesis'],
                                                       only_acc=False
                                                       )

    logger.info(f" ****** global_acc: {round(float(global_acc), 2)} ******")
    logger.info(f" ****** FR: {round(float(FR), 2)} ******")
    logger.info(f" ****** HM: {round(float(HM), 2)} ******")


def Experiment_Model_testing(device, testing_dataloader, mask_s1_flag, testing_model, hypothesis,
                             only_acc=False):
    acc_numerator = 0
    acc_denominator = 0

    num_s1_pred1 = 0
    num_s1_pred0 = 0
    num_s0_pred1 = 0
    num_s0_pred0 = 0

    # Model testing
    for batch_index, batch in enumerate(testing_dataloader):
        X = batch["X"].to(device)
        y = batch["y"].to(device)
        if hypothesis == "LR":
            prediction = (testing_model(X) >= 0.5).to(device).reshape(-1)
        else:
            prediction = testing_model(X).to(device).argmax(dim=1)
        acc_numerator += sum(prediction.eq(y))
        acc_denominator += X.shape[0]

        if not only_acc:
            if mask_s1_flag:
                s = batch["s2"]
            else:
                s = batch["s1"]

            y_1 = (y == 1).int().reshape(-1).to(device)
            s_1 = (s == 1).int().to(device)
            s_0 = (s == 0).int().to(device)
            pred_1 = (prediction == 1).int().to(device)
            pred_0 = (prediction == 0).int().to(device)

            num_s1_pred1 += (y_1 * s_1 * pred_1).sum().to(device)
            num_s1_pred0 += (y_1 * s_1 * pred_0).sum().to(device)
            num_s0_pred1 += (y_1 * s_0 * pred_1).sum().to(device)
            num_s0_pred0 += (y_1 * s_0 * pred_0).sum().to(device)

    acc = acc_numerator / acc_denominator
    if not only_acc:
        x1 = num_s1_pred1 / (num_s1_pred1 + num_s1_pred0)
        # logger.info(f"P(y = 1 | s = 1) = {x1} ")
        x2 = num_s0_pred1 / (num_s0_pred1 + num_s0_pred0)
        # logger.info(f"P(y = 1 | s = 0) = {x2} ")
        # logger.info(f"DI: {x2 / x1} ")
        DEO = max(x2 - x1, x1 - x2)
        # logger.info(f"Difference of Equality of Opportunity violation (DEO): {DEO}")
        FR = 1 - DEO
        # logger.info(f"Fairness measurement (FR): {FR}")
        HM = statistics.harmonic_mean([float(acc), float(FR)])
        # logger.info(f"Harmonic Mean of Fairness and Accuracy (HM): {HM}")

        return acc, DEO, FR, HM
    else:
        return acc


def Experiment(param_dict, communication_cost_list_list, descending_order_list_list):
    # Create dataset
    logger.info("Creating dataset")
    training_dataset, positive_training_dataset, negative_training_dataset, \
    testing_dataset, nn_input_size = Experiment_Create_dataset(param_dict, no_pickle=False)

    # Create dataloader
    logger.info("Creating dataloader")
    training_dataloaders, positive_training_dataloaders, negative_training_dataloaders, \
    validation_dataloaders, client_dataset_list, testing_dataloader = Experiment_Create_dataloader(
        param_dict, training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset,
        param_dict['split_strategy'])

    # Model Construction
    global_model = Experiment_Model_construction(param_dict, nn_input_size)
    logger.info("-----------------------------------------------------------------------------")

    # Federated Renyi
    if ("FederatedRenyi" in param_dict["algorithm"]) or ("Renyi" in param_dict["algorithm"]):
        logger.info("~~~~~~ Algorithm: Federated Renyi ~~~~~~")
        Experiment_Federated_Renyi(
            param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list, testing_dataloader, communication_cost_list_list, descending_order_list_list
        )


def main(dataset_name, algorithm, hypothesis, γ_k_style, device):
    ################################################################################################
    param_dict = load_fl_hyperparameter_from_json(dataset_name)

    os.environ["CUDA_VISIBLE_DEVICES"] = param_dict['CUDA_VISIBLE_DEVICES']

    if "gpu" in device:
        param_dict['device'] = "cuda" if torch.cuda.is_available() else "cpu"  # Get cpu or gpu device for experiment
    else:
        param_dict['device'] = "cpu"

    # step_T_communication_I_list = [(100, 10), (100, 4), (100, 2), (500, 50)]
    # lamda_list = [1000, 5, 1, 0.5, 0.1]
    # step_T_communication_I_list = [(100, 4)]
    step_T_communication_I_list = [(100, 4)]

    lamda_list = [1]
    FL_drop_rate_list = [0, 0.3, 0.5]
    rho_list = [0.1]
    tolerance_rate_list = [0.5, 0.75, 1]

    # split_strategy_list = ["Uniform", "Dirichlet0.5", "Dirichlet1", "Dirichlet8"]
    split_strategy_list = ["Dirichlet0.5", "Uniform"]
    sensitive_attribute_skew_list = ["no_attribute_skew"]

    param_dict['γ_k_style'] = γ_k_style
    param_dict['dataset_name'] = dataset_name
    param_dict['algorithm'] = algorithm
    param_dict['hypothesis'] = hypothesis

    # Serial number of experiment
    Experiment_NO = 1
    total_Experiment_NO = len(lamda_list) * len(FL_drop_rate_list) * len(
        step_T_communication_I_list) * len(split_strategy_list) * len(tolerance_rate_list)
    total_Experiment_NO -= 2

    # 提前把所有的网络条件先模拟出来，方便对比实验
    communication_cost_list_list, descending_order_list_list = [], []
    for i in range(step_T_communication_I_list[0][0]):
        communication_cost_list, descending_order_list = communication_cost_simulated_by_beta_distribution(
            param_dict['num_clients_K'])
        communication_cost_list_list.append(communication_cost_list)
        descending_order_list_list.append(descending_order_list)

    # Main Loop
    for split_strategy in split_strategy_list:
        param_dict['split_strategy'] = split_strategy
        for sensitive_attribute_skew in sensitive_attribute_skew_list:
            param_dict['sensitive_attribute_skew'] = sensitive_attribute_skew
            for lamda in lamda_list:
                param_dict['lamda'] = lamda
                for FL_drop_rate in FL_drop_rate_list:
                    param_dict['FL_drop_rate'] = FL_drop_rate
                    for algorithm_step_T, communication_round_I in step_T_communication_I_list:
                        param_dict['algorithm_step_T'] = algorithm_step_T
                        param_dict['communication_round_I'] = communication_round_I
                        for rho in rho_list:
                            param_dict['rho'] = rho
                            for tolerance_rate in tolerance_rate_list:
                                param_dict['tolerance_τ'] = int(tolerance_rate * algorithm_step_T)
                                ################################################################################################
                                # Create the log
                                algorithm_name = param_dict['algorithm'] + "_" + param_dict['γ_k_style']
                                log_path = os.path.join("./log_path/log_for_acceleration_and_bias",
                                                        param_dict['dataset_name'],
                                                        algorithm_name,
                                                        param_dict['hypothesis'],
                                                        param_dict['split_strategy'],
                                                        )
                                check_and_make_the_path(log_path)
                                log_path = os.path.join(log_path, str(Experiment_NO))
                                param_dict['log_path'] = log_path
                                file_handler = logging.FileHandler(log_path + ".txt")
                                file_handler.setFormatter(formatter)
                                logger.addHandler(file_handler)
                                ################################################################################################
                                # Create the model path
                                model_path = os.path.join("./save_path/model", param_dict['dataset_name'],
                                                          param_dict['algorithm'],
                                                          param_dict['hypothesis'], str(Experiment_NO))
                                check_and_make_the_path(model_path)
                                param_dict['model_path'] = model_path
                                for k in range(param_dict["num_clients_K"]):
                                    check_and_make_the_path(os.path.join(model_path, "client_" + str(k + 1)))
                                ################################################################################################
                                # Create the json file
                                logger.info(f"Experiment {Experiment_NO}/{total_Experiment_NO} setup finish")
                                json_str = json.dumps(param_dict, indent=4)
                                with open(log_path + "_Parameter.json", "w") as json_file:
                                    json_file.write(json_str)
                                with open(os.path.join(model_path, "Parameter.json"), "w") as json_file:
                                    json_file.write(json_str)
                                param_dict['Experiment_NO'] = str(Experiment_NO)
                                ################################################################################################
                                # Parameter announcement
                                logger.info("Parameter announcement")
                                for para_key in list(param_dict.keys()):
                                    if "_common" in para_key:
                                        continue
                                    logger.info(f"****** {para_key} : {param_dict[para_key]} ******")
                                logger.info(
                                    "-----------------------------------------------------------------------------")
                                ################################################################################################
                                # Experiment
                                if (FL_drop_rate == 0) and (tolerance_rate != 1):
                                    logger.removeHandler(file_handler)
                                    continue
                                Experiment(param_dict, communication_cost_list_list, descending_order_list_list)
                                Experiment_NO += 1
                                logger.removeHandler(file_handler)
                                logger.info(
                                    "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")
                                logger.info(
                                    "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")


if __name__ == '__main__':
    # 示例   nohup python FedRenyi_time_acceleration_and_bias_analysis.py COMPAS FederatedRenyi LR uniform_distribution gpu &
    main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5])

    # main("DRUG", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    # main("DRUG", "FederatedRenyi", "NN", "uniform_client", "gpu")
    # main("ARRHYTHMIA", "FederatedRenyi", "NN", "uniform_client", "gpu")

    # main("DRUG", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    # main("ARRHYTHMIA", "FederatedRenyi", "NN", "uniform_distribution", "gpu")
    # main("DRUG", "FederatedRenyi", "NN", "uniform_distribution", "gpu")

