import json
import torch
from moudle.model_testing import Experiment_Model_testing
from tool.logger import *
from algorithm.FederatedFair.FederatedFair_LR import Fed_Fair_LR
from algorithm.FederatedFair.FederatedFair_NN import Fed_Fair_NN


def Experiment_Federated_Fair(param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
                              testing_dataloader):
    device = param_dict['device']
    hypothesis = param_dict['hypothesis']
    ϵ = param_dict['epsilon']
    if hypothesis == "LR":
        Fed_Fair = Fed_Fair_LR

    else:
        Fed_Fair = Fed_Fair_NN

    updated_global_model = Fed_Fair(
        device,
        global_model,
        param_dict['algorithm_step_T'], param_dict['num_clients_K'],
        training_dataloaders,
        training_dataset,
        client_dataset_list,
        ϵ
    )
    logger.info("-----------------------------------------------------------------------------")

    # Model testing
    logger.info("Global model testing")
    global_acc, DEO, EOD, SPD, FR, HM = Experiment_Model_testing(device=device,
                                                       testing_dataloader=testing_dataloader,
                                                       mask_s1_flag=param_dict['mask_s1_flag'],
                                                       testing_model=updated_global_model,
                                                       hypothesis=param_dict['hypothesis'],
                                                       only_acc=False
                                                       )
    logger.info("Client models testing")
    uniform_client_weight = 1 / int(param_dict['num_clients_K'])
    uniform_distribution_weight = []
    training_dataset_size = len(training_dataset)

    for i in range(int(param_dict['num_clients_K'])):
        _ = client_dataset_list[i]['y'].size
        uniform_distribution_weight.append(_ / training_dataset_size)
    uniform_distribution_weight = torch.tensor(uniform_distribution_weight).to(global_acc.device)

    client_acc_list = [global_acc for i in range(int(param_dict['num_clients_K']))]
    client_acc_list = torch.tensor(client_acc_list).to(global_acc.device)

    uniform_client_acc = sum(uniform_client_weight * client_acc_list)
    uniform_distribution_acc = sum(uniform_distribution_weight * client_acc_list)

    logger.info(f" ****** global_acc: {global_acc} ******")
    logger.info(f" ****** DEO: {DEO} ******")
    logger.info(f" ****** EOD: {EOD} ******")
    logger.info(f" ****** SPD: {SPD} ******")
    logger.info(f" ****** FR: {FR} ******")
    logger.info(f" ****** HM: {HM} ******")
    logger.info(f" ****** uniform_client_acc: {uniform_client_acc} ******")
    logger.info(f" ****** uniform_distribution_acc: {uniform_distribution_acc} ******")

    result_dict = {
        "global_acc": float(global_acc),
        "DEO": float(DEO),
        "EOD": float(EOD),
        "SPD": float(SPD),
        "FR": float(FR),
        "HM": float(HM),
        "uniform_client_acc": float(uniform_client_acc),
        "uniform_distribution_acc": float(uniform_distribution_acc)
    }
    json_str = json.dumps(result_dict, indent=4)
    with open(param_dict['log_path'] + "_result.json", "w") as json_file:
        json_file.write(json_str)
