import os
import json
import statistics
from tool.utils import communication_cost_simulated_by_beta_distribution

with open("./json/COMMON.json", "r") as f:
    temp_dict = json.load(f)
    os.environ["CUDA_VISIBLE_DEVICES"] = temp_dict['CUDA_VISIBLE_DEVICES']
import torch

from moudle.experiment_setup import Experiment_Create_dataset, Experiment_Create_dataloader, \
    Experiment_Model_construction
from tool.logger import *
from tool.utils import check_and_make_the_path
from algorithm.FederatedProximal.FederatedProximal_LR import Fed_Prox_LR

# Hyperparameters of FL from json file
def load_fl_hyperparameter_from_json(dataset_name):
    param_dict = {}
    with open("./json/COMMON.json", "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    with open(os.path.join("./json/", dataset_name + ".json"), "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    return param_dict


def Experiment_Federated_Proximal(param_dict, global_model, training_dataloaders,
                                  training_dataset, client_dataset_list, testing_dataloader):
    device = param_dict['device']


    updated_global_model, client_model_list = Fed_Prox_LR(
        device,
        global_model,
        param_dict['algorithm_step_T'],
        param_dict['num_clients_K'],
        param_dict['communication_round_I'],
        param_dict['FL_fraction'],
        param_dict['FL_drop_rate'],
        param_dict['local_step_size'],
        training_dataloaders,
        training_dataset,
        client_dataset_list
    )


    logger.info("-----------------------------------------------------------------------------")

    logger.info("Global model testing")
    global_acc, DEO, FR, HM = Experiment_Model_testing(device=device,
                                                       testing_dataloader=testing_dataloader,
                                                       mask_s1_flag=param_dict['mask_s1_flag'],
                                                       testing_model=updated_global_model,
                                                       hypothesis=param_dict['hypothesis'],
                                                       only_acc=False
                                                       )

    logger.info(f" ****** global_acc: {round(float(global_acc), 2)} ******")
    logger.info(f" ****** FR: {round(float(FR), 2)} ******")
    logger.info(f" ****** HM: {round(float(HM), 2)} ******")


def Experiment_Model_testing(device, testing_dataloader, mask_s1_flag, testing_model, hypothesis,
                             only_acc=False):
    acc_numerator = 0
    acc_denominator = 0

    num_s1_pred1 = 0
    num_s1_pred0 = 0
    num_s0_pred1 = 0
    num_s0_pred0 = 0

    # Model testing
    for batch_index, batch in enumerate(testing_dataloader):
        X = batch["X"].to(device)
        y = batch["y"].to(device)
        if hypothesis == "LR":
            prediction = (testing_model(X) >= 0.5).to(device).reshape(-1)
        else:
            prediction = testing_model(X).to(device).argmax(dim=1)
        acc_numerator += sum(prediction.eq(y))
        acc_denominator += X.shape[0]

        if not only_acc:
            if mask_s1_flag:
                s = batch["s2"]
            else:
                s = batch["s1"]

            y_1 = (y == 1).int().reshape(-1).to(device)
            s_1 = (s == 1).int().to(device)
            s_0 = (s == 0).int().to(device)
            pred_1 = (prediction == 1).int().to(device)
            pred_0 = (prediction == 0).int().to(device)

            num_s1_pred1 += (y_1 * s_1 * pred_1).sum().to(device)
            num_s1_pred0 += (y_1 * s_1 * pred_0).sum().to(device)
            num_s0_pred1 += (y_1 * s_0 * pred_1).sum().to(device)
            num_s0_pred0 += (y_1 * s_0 * pred_0).sum().to(device)

    acc = acc_numerator / acc_denominator
    if not only_acc:
        x1 = num_s1_pred1 / (num_s1_pred1 + num_s1_pred0)
        # logger.info(f"P(y = 1 | s = 1) = {x1} ")
        x2 = num_s0_pred1 / (num_s0_pred1 + num_s0_pred0)
        # logger.info(f"P(y = 1 | s = 0) = {x2} ")
        # logger.info(f"DI: {x2 / x1} ")
        DEO = max(x2 - x1, x1 - x2)
        # logger.info(f"Difference of Equality of Opportunity violation (DEO): {DEO}")
        FR = 1 - DEO
        # logger.info(f"Fairness measurement (FR): {FR}")
        HM = statistics.harmonic_mean([float(acc), float(FR)])
        # logger.info(f"Harmonic Mean of Fairness and Accuracy (HM): {HM}")

        return acc, DEO, FR, HM
    else:
        return acc


def Experiment(param_dict, communication_cost_list_list, descending_order_list_list):
    # Create dataset
    logger.info("Creating dataset")
    training_dataset, positive_training_dataset, negative_training_dataset, \
    testing_dataset, nn_input_size = Experiment_Create_dataset(param_dict, no_pickle=False)

    # Create dataloader
    logger.info("Creating dataloader")
    training_dataloaders, positive_training_dataloaders, negative_training_dataloaders, \
    validation_dataloaders, client_dataset_list, testing_dataloader = Experiment_Create_dataloader(
        param_dict, training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset,
        param_dict['split_strategy'])

    # Model Construction
    global_model = Experiment_Model_construction(param_dict, nn_input_size)
    logger.info("-----------------------------------------------------------------------------")

    # Federated Prox
    logger.info("~~~~~~ Algorithm: Federated Prox ~~~~~~")
    Experiment_Federated_Proximal(
        param_dict, global_model, training_dataloaders,
        training_dataset, client_dataset_list, testing_dataloader
    )


def main(dataset_name, algorithm, hypothesis, device):
    ################################################################################################
    param_dict = load_fl_hyperparameter_from_json(dataset_name)

    os.environ["CUDA_VISIBLE_DEVICES"] = param_dict['CUDA_VISIBLE_DEVICES']

    if "gpu" in device:
        param_dict['device'] = "cuda" if torch.cuda.is_available() else "cpu"  # Get cpu or gpu device for experiment
    else:
        param_dict['device'] = "cpu"


    step_T_communication_I_list = [(100, 4)]

    FL_drop_rate_list = [0, 0.3, 0.5]

    split_strategy_list = ["Dirichlet0.5", "Uniform"]

    param_dict['dataset_name'] = dataset_name
    param_dict['algorithm'] = algorithm
    param_dict['hypothesis'] = hypothesis
    param_dict['sensitive_attribute_skew'] = "no_attribute_skew"


    # Serial number of experiment
    Experiment_NO = 1
    total_Experiment_NO =  len(FL_drop_rate_list)  * len(split_strategy_list)

    # 提前把所有的网络条件先模拟出来，方便对比实验
    communication_cost_list_list, descending_order_list_list = [], []
    for i in range(step_T_communication_I_list[0][0]):
        communication_cost_list, descending_order_list = communication_cost_simulated_by_beta_distribution(
            param_dict['num_clients_K'])
        communication_cost_list_list.append(communication_cost_list)
        descending_order_list_list.append(descending_order_list)

    # Main Loop
    for split_strategy in split_strategy_list:
        param_dict['split_strategy'] = split_strategy
        for FL_drop_rate in FL_drop_rate_list:
            param_dict['FL_drop_rate'] = FL_drop_rate
            for algorithm_step_T, communication_round_I in step_T_communication_I_list:
                param_dict['algorithm_step_T'] = algorithm_step_T
                param_dict['communication_round_I'] = communication_round_I
                ################################################################################################
                # Create the log
                algorithm_name = param_dict['algorithm']
                log_path = os.path.join("./log_path/log_for_acceleration_and_bias",
                                        param_dict['dataset_name'],
                                        algorithm_name,
                                        param_dict['hypothesis'],
                                        param_dict['split_strategy'],
                                        )
                check_and_make_the_path(log_path)
                log_path = os.path.join(log_path, str(Experiment_NO))
                param_dict['log_path'] = log_path
                file_handler = logging.FileHandler(log_path + ".txt")
                file_handler.setFormatter(formatter)
                logger.addHandler(file_handler)
                ################################################################################################
                # Create the model path
                model_path = os.path.join("./save_path/model", param_dict['dataset_name'],
                                          param_dict['algorithm'],
                                          param_dict['hypothesis'], str(Experiment_NO))
                check_and_make_the_path(model_path)
                param_dict['model_path'] = model_path
                for k in range(param_dict["num_clients_K"]):
                    check_and_make_the_path(os.path.join(model_path, "client_" + str(k + 1)))
                ################################################################################################
                # Create the json file
                logger.info(f"Experiment {Experiment_NO}/{total_Experiment_NO} setup finish")
                json_str = json.dumps(param_dict, indent=4)
                with open(log_path + "_Parameter.json", "w") as json_file:
                    json_file.write(json_str)
                with open(os.path.join(model_path, "Parameter.json"), "w") as json_file:
                    json_file.write(json_str)
                param_dict['Experiment_NO'] = str(Experiment_NO)
                ################################################################################################
                # Parameter announcement
                logger.info("Parameter announcement")
                for para_key in list(param_dict.keys()):
                    if "_common" in para_key:
                        continue
                    logger.info(f"****** {para_key} : {param_dict[para_key]} ******")
                logger.info(
                    "-----------------------------------------------------------------------------")
                ################################################################################################
                # Experiment
                Experiment(param_dict, communication_cost_list_list, descending_order_list_list)
                Experiment_NO += 1
                logger.removeHandler(file_handler)
                logger.info(
                    "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")
                logger.info(
                    "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")


if __name__ == '__main__':
    # 示例   nohup python FedProx_time_acceleration_and_bias_analysis.py ADULT FederatedProximal LR gpu &
    # 示例   nohup python FedProx_time_acceleration_and_bias_analysis.py ARRHYTHMIA FederatedProximal LR gpu &
    # 示例   nohup python FedProx_time_acceleration_and_bias_analysis.py COMPAS FederatedProximal LR gpu &
    # 示例   nohup python FedProx_time_acceleration_and_bias_analysis.py DUTCH FederatedProximal LR gpu &
    # 示例   nohup python FedProx_time_acceleration_and_bias_analysis.py DRUG FederatedProximal LR gpu &
    main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
    # main("COMPAS", "FederatedProximal", "LR", "GPU")

