import os
import json


with open("./json/COMMON.json", "r") as f:
    temp_dict = json.load(f)
    os.environ["CUDA_VISIBLE_DEVICES"] = temp_dict['CUDA_VISIBLE_DEVICES']
import torch

from moudle.experiment_setup import Experiment_Create_dataset, Experiment_Create_dataloader, \
    Experiment_Model_construction
from tool.logger import *
from tool.utils import check_and_make_the_path
from moudle.model_testing import Experiment_Model_testing
from algorithm.FederatedRenyi.FederatedRenyi_component import *

# Hyperparameters of FL from json file
def load_fl_hyperparameter_from_json(dataset_name):
    param_dict = {}
    with open("./json/COMMON.json", "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    with open(os.path.join("./json/", dataset_name + ".json"), "r") as f:
        temp_dict = json.load(f)
    param_dict.update(**temp_dict)
    return param_dict


def Fed_Renyi_LR(device,
                 mask_s1_flag,
                 lamda,
                 global_model,
                 tolerance_τ,
                 algorithm_step_T, num_clients_K, communication_round_I,
                 local_step_size,
                 training_dataloaders,
                 training_dataset,
                 testing_dataloader,
                 client_dataset_list,
                 straggler_rate_α,
                 rho,
                 γ_k_style
                 ):
    # Initialization
    logger.info("Initialization")

    client_datasets_size_list, local_model_list, \
    global_v, r_bar_k_p0_list, r_bar_k_p1_list, \
    γ_k_list, r_hat_p0, r_hat_p1, v_hat_1 = initialization(client_dataset_list, global_model, num_clients_K,
                                                           mask_s1_flag, training_dataset, γ_k_style)

    criterion = torch.nn.BCELoss(reduction='none')

    local_time_consumption_list = []
    for iter_t in range(tolerance_τ):
        # Simulate Client Parallel
        avg_loss_over_client = 0

        for i in range(num_clients_K):
            avg_loss_over_step = 0

            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]

            all_y_hat_θ = []
            all_s = []
            # local optimization
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                s = batch["s2"] if mask_s1_flag else batch["s1"]
                local_prediction = model(X).to(device)
                # 交叉熵损失
                loss = criterion(local_prediction, y.float())
                loss = sum(loss) / client_datasets_size_list[i]
                y_hat_θ = (local_prediction >= 0.5).reshape(-1).to(device)

                all_y_hat_θ.append(y_hat_θ)
                all_s.append(s)

                avg_loss_over_step += float(loss)
                loss.backward()

            all_y_hat_θ = torch.concat(all_y_hat_θ)
            all_s = torch.concat(all_s)

            # 构建临时计算图，计算正则项
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                local_prediction = model(X).to(device)
                # 交叉熵损失
                loss = criterion(local_prediction, y.float())
                loss = sum(loss) * 0
                Q, _, _, _, _, _, _ = get_Q_hat_θ(all_y_hat_θ, all_s, device)
                G = get_G_hat_θ_hat_v(Q, global_v, device).to(device)
                regularization_term = lamda * G
                loss += regularization_term

                avg_loss_over_step += float(loss)

                loss.backward()
                break

            optimizer.step()

            # 更新平均损失的记录值
            avg_loss_over_client += avg_loss_over_step * γ_k_list[i]
            # Upgrade the local model list
            local_model_list[i] = model


        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_client), 4)}; ##########")

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")

            # Global operation
            logger.info("********** Parameter aggregation **********")

            theta_list = []
            for i in range(num_clients_K):
                model = local_model_list[i]
                γ_k = float(γ_k_list[i])
                theta_list.append(list(γ_k * np.array(get_parameters(model))))

            theta_list = np.array(theta_list, dtype=object)
            theta_avg = np.sum(theta_list, 0).tolist()
            set_parameters(global_model, theta_avg)

            # Global testing
            logger.info("********** Global testing **********")
            logger.info("Global model testing")
            global_acc, DEO, EOD, SPD, FR, HM = Experiment_Model_testing(device=device,
                                                                         testing_dataloader=testing_dataloader,
                                                                         mask_s1_flag=mask_s1_flag,
                                                                         testing_model=global_model,
                                                                         hypothesis="LR",
                                                                         only_acc=False
                                                                         )
            global_acc, DEO, EOD, SPD, FR, HM = (
                                                round(float(global_acc), 4), round(float(DEO), 4),
                                                 round(float(EOD), 4), round(float(SPD), 4),
                                                 round(float(FR),4), round(float(HM), 4)
                                                 )
            logger.info(f" Global Acc:{global_acc}, FR:{FR}, HM:{HM}, DEO:{DEO}, EOD:{EOD}, SPD:{SPD}")


            logger.info("********** Global v update **********")
            backup_v = global_v
            try:
                global_v = get_argmax_v([i for i in range(num_clients_K)], local_model_list, mask_s1_flag, training_dataset,
                                        client_dataset_list,
                                        r_hat_p0, r_hat_p1, device, γ_k_style, hypothesis="LR")
            except Exception:
                global_v = backup_v


            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]




    return global_model, local_model_list


def Experiment_Federated_Renyi(param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
                               testing_dataloader):
    device = param_dict['device']
    γ_k_style = param_dict['γ_k_style']
    Fed_Renyi = Fed_Renyi_LR

    updated_global_model, client_model_list = Fed_Renyi(
        device,
        param_dict['mask_s1_flag'],
        param_dict['lamda'],
        global_model,
        param_dict['tolerance_τ'],
        param_dict['algorithm_step_T'],
        param_dict['num_clients_K'],
        param_dict['communication_round_I'],
        param_dict['local_step_size'],
        training_dataloaders,
        training_dataset,
        testing_dataloader,
        client_dataset_list,
        param_dict['FL_drop_rate'],
        param_dict['rho'],
        γ_k_style
    )
    logger.info("-----------------------------------------------------------------------------")

    # logger.info("Global model Saving")
    # check_and_make_the_path(param_dict['model_path'])
    # torch.save(updated_global_model, param_dict['model_path'] + "/global_model.pkl")
    # logger.info("Client Models Saving")
    # for client_id, client_model in enumerate(client_model_list):
    #     _ = os.path.join(param_dict['model_path'], "client_" + str(client_id + 1) + "/client_model.pkl")
    #     check_and_make_the_path(os.path.join(param_dict['model_path'], "client_" + param_dict['Experiment_NO']))
    #     torch.save(client_model, _)

    logger.info("Global model testing")
    global_acc, DEO, EOD, SPD, FR, HM = Experiment_Model_testing(device=device,
                                                       testing_dataloader=testing_dataloader,
                                                       mask_s1_flag=param_dict['mask_s1_flag'],
                                                       testing_model=updated_global_model,
                                                       hypothesis=param_dict['hypothesis'],
                                                       only_acc=False
                                                       )
    logger.info("Client models testing")
    uniform_client_weight = 1 / int(param_dict['num_clients_K'])
    uniform_distribution_weight = []
    training_dataset_size = len(training_dataset)

    for i in range(int(param_dict['num_clients_K'])):
        _ = len(client_dataset_list[i].indices)
        uniform_distribution_weight.append(_ / training_dataset_size)
    uniform_distribution_weight = torch.tensor(uniform_distribution_weight).to(global_acc.device)

    if param_dict['algorithm_step_T'] % param_dict['communication_round_I'] == 0:
        client_acc_list = [global_acc for i in range(int(param_dict['num_clients_K']))]
    else:
        client_acc_list = []
        for client_id, client_model in enumerate(client_model_list):
            client_acc = Experiment_Model_testing(device=device,
                                                  testing_dataloader=testing_dataloader,
                                                  mask_s1_flag=param_dict['mask_s1_flag'],
                                                  testing_model=client_model,
                                                  hypothesis=param_dict['hypothesis'],
                                                  only_acc=True
                                                  )
            client_acc_list.append(client_acc)
    client_acc_list = torch.tensor(client_acc_list).to(global_acc.device)

    uniform_client_acc = sum(uniform_client_weight * client_acc_list)
    uniform_distribution_acc = sum(uniform_distribution_weight * client_acc_list)

    logger.info(f" ****** global_acc: {global_acc} ******")
    logger.info(f" ****** DEO: {DEO} ******")
    logger.info(f" ****** EOD: {EOD} ******")
    logger.info(f" ****** SPD: {SPD} ******")
    logger.info(f" ****** FR: {FR} ******")
    logger.info(f" ****** HM: {HM} ******")
    logger.info(f" ****** uniform_client_acc: {uniform_client_acc} ******")
    logger.info(f" ****** uniform_distribution_acc: {uniform_distribution_acc} ******")

    result_dict = {
        "global_acc": float(global_acc),
        "DEO": float(DEO),
        "EOD": float(EOD),
        "SPD": float(SPD),
        "FR": float(FR),
        "HM": float(HM),
        "uniform_client_acc": float(uniform_client_acc),
        "uniform_distribution_acc": float(uniform_distribution_acc)
    }
    json_str = json.dumps(result_dict, indent=4)
    with open(param_dict['log_path'] + "_result.json", "w") as json_file:
        json_file.write(json_str)



def Experiment(param_dict):
    # Create dataset
    logger.info("Creating dataset")
    training_dataset, positive_training_dataset, negative_training_dataset, \
    testing_dataset, nn_input_size = Experiment_Create_dataset(param_dict, no_pickle=False)

    # Create dataloader
    logger.info("Creating dataloader")
    training_dataloaders, positive_training_dataloaders, negative_training_dataloaders, \
    validation_dataloaders, client_dataset_list, testing_dataloader = Experiment_Create_dataloader(
        param_dict, training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset,
        param_dict['split_strategy'])

    # Model Construction
    global_model = Experiment_Model_construction(param_dict, nn_input_size)
    logger.info("-----------------------------------------------------------------------------")

    # Federated Renyi
    Experiment_Federated_Renyi(param_dict, global_model, training_dataloaders,
                               training_dataset, client_dataset_list, testing_dataloader)


def main(dataset_name, algorithm, hypothesis, γ_k_style, device):
    ################################################################################################
    param_dict = load_fl_hyperparameter_from_json(dataset_name)

    os.environ["CUDA_VISIBLE_DEVICES"] = param_dict['CUDA_VISIBLE_DEVICES']

    if "gpu" in device:
        param_dict['device'] = "cuda" if torch.cuda.is_available() else "cpu"  # Get cpu or gpu device for experiment
    else:
        param_dict['device'] = "cpu"

    step_T_communication_I_list = [(100, 4)]

    split_strategy_list = ["Dirichlet0.5", "Uniform"]

    sensitive_attribute_skew_list = ["no_attribute_skew",]


    param_dict['γ_k_style'] = γ_k_style
    param_dict['dataset_name'] = dataset_name
    param_dict['algorithm'] = algorithm
    param_dict['hypothesis'] = hypothesis
    param_dict["save_checkpoint_rounds"] = 99999999999

    lamda_list = [1]
    rho_list, tolerance_rate_list = [0.1], [1]
    FL_drop_rate_list = [0]


    # Serial number of experiment
    Experiment_NO = 1
    total_Experiment_NO = len(lamda_list) * len(FL_drop_rate_list) \
                          * len(step_T_communication_I_list) * len(split_strategy_list) * len(rho_list) \
                          * len(tolerance_rate_list) * len(sensitive_attribute_skew_list)

    # Main Loop
    for split_strategy in split_strategy_list:
        param_dict['split_strategy'] = split_strategy
        for sensitive_attribute_skew in sensitive_attribute_skew_list:
            param_dict['sensitive_attribute_skew'] = sensitive_attribute_skew
            for lamda in lamda_list:
                param_dict['lamda'] = lamda
                for algorithm_step_T, communication_round_I in step_T_communication_I_list:
                    param_dict['algorithm_step_T'] = algorithm_step_T
                    param_dict['communication_round_I'] = communication_round_I
                    for rho in rho_list:
                        param_dict['rho'] = rho
                        for FL_drop_rate in FL_drop_rate_list:
                            param_dict['FL_drop_rate'] = FL_drop_rate
                            for tolerance_rate in tolerance_rate_list:
                                if (FL_drop_rate == 0) & ("Renyi" in algorithm) & (tolerance_rate != 1):
                                    break
                                param_dict['tolerance_τ'] = int(tolerance_rate * algorithm_step_T)
                                ################################################################################################
                                # Create the log
                                if ("FederatedRenyi" in param_dict["algorithm"]) or (
                                        "Renyi" in param_dict["algorithm"]):
                                    algorithm_name = param_dict['algorithm'] + "_" + param_dict['γ_k_style']
                                else:
                                    algorithm_name = param_dict['algorithm']

                                log_path = os.path.join("./log_path/convergence_analysis",
                                                        param_dict['dataset_name'],
                                                        algorithm_name,
                                                        param_dict['hypothesis'],
                                                        param_dict['sensitive_attribute_skew'],
                                                        param_dict['split_strategy'],
                                                        )
                                check_and_make_the_path(log_path)
                                log_path = os.path.join(log_path, str(Experiment_NO))
                                param_dict['log_path'] = log_path
                                file_handler = logging.FileHandler(log_path + ".txt")
                                file_handler.setFormatter(formatter)
                                logger.addHandler(file_handler)
                                ################################################################################################
                                # Create the model path
                                model_path = os.path.join("./save_path/model", param_dict['dataset_name'],
                                                          param_dict['algorithm'],
                                                          param_dict['hypothesis'], str(Experiment_NO))
                                check_and_make_the_path(model_path)
                                param_dict['model_path'] = model_path
                                for k in range(param_dict["num_clients_K"]):
                                    check_and_make_the_path(os.path.join(model_path, "client_" + str(k + 1)))
                                ################################################################################################
                                # Create the json file
                                logger.info(f"Experiment {Experiment_NO}/{total_Experiment_NO} setup finish")
                                json_str = json.dumps(param_dict, indent=4)
                                with open(log_path + "_Parameter.json", "w") as json_file:
                                    json_file.write(json_str)
                                with open(os.path.join(model_path, "Parameter.json"), "w") as json_file:
                                    json_file.write(json_str)
                                param_dict['Experiment_NO'] = str(Experiment_NO)
                                ################################################################################################
                                # Parameter announcement
                                logger.info("Parameter announcement")
                                for para_key in list(param_dict.keys()):
                                    if "_common" in para_key:
                                        continue
                                    logger.info(f"****** {para_key} : {param_dict[para_key]} ******")
                                logger.info(
                                    "-----------------------------------------------------------------------------")
                                ################################################################################################
                                # Experiment
                                Experiment(param_dict)
                                Experiment_NO += 1
                                logger.removeHandler(file_handler)
                                logger.info(
                                    "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")
                                logger.info(
                                    "|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")


if __name__ == '__main__':
    # 示例   nohup python FedRenyi_convergence_analysis.py COMPAS FederatedRenyi LR uniform_distribution gpu &
    main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5])
    # main("DUTCH", "FederatedRenyi", "LR", "uniform_distribution", "gpu")

    # main("ADULT", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    # main("COMPAS", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    # main("DRUG", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    # main("DUTCH", "FederatedRenyi", "LR", "uniform_distribution", "gpu")
    #
    # main("ADULT", "FederatedRenyi", "LR", "uniform_client", "gpu")
    # main("COMPAS", "FederatedRenyi", "LR", "uniform_client", "gpu")
    # main("DRUG", "FederatedRenyi", "LR", "uniform_client", "gpu")
    # main("DUTCH", "FederatedRenyi", "LR", "uniform_client", "gpu")
