import json
import torch
from moudle.model_testing import Experiment_Model_testing
from tool.logger import *
from algorithm.FederatedProximal.FederatedProximal_LR import Fed_Prox_LR
from algorithm.FederatedProximal.FederatedProximal_NN import Fed_Prox_NN


def Experiment_Federated_Proximal(param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
                                  testing_dataloader):
    device = param_dict['device']

    if param_dict['hypothesis'] == "LR":
        Fed_Prox = Fed_Prox_LR
    else:
        Fed_Prox = Fed_Prox_NN

    updated_global_model, client_model_list = Fed_Prox(
        device,
        global_model,
        param_dict['algorithm_step_T'],
        param_dict['num_clients_K'],
        param_dict['communication_round_I'],
        param_dict['FL_fraction'],
        param_dict['FL_drop_rate'],
        param_dict['local_step_size'],
        training_dataloaders,
        training_dataset,
        client_dataset_list
    )
    logger.info("-----------------------------------------------------------------------------")

    # logger.info("Global model Saving")
    # check_and_make_the_path(param_dict['model_path'])
    # torch.save(updated_global_model, param_dict['model_path'] + "/global_model.pkl")
    # logger.info("Client Models Saving")
    # for client_id, client_model in enumerate(client_model_list):
    #     _ = os.path.join(param_dict['model_path'], "client_" + str(client_id + 1) + "/client_model.pkl")
    #     check_and_make_the_path(os.path.join(param_dict['model_path'], "client_" + str(client_id + 1)))
    #     torch.save(client_model, _)

    logger.info("Global model testing")
    global_acc, DEO, EOD, SPD, FR, HM = Experiment_Model_testing(device=device,
                                                                 testing_dataloader=testing_dataloader,
                                                                 mask_s1_flag=param_dict['mask_s1_flag'],
                                                                 testing_model=updated_global_model,
                                                                 hypothesis=param_dict['hypothesis'],
                                                                 only_acc=False
                                                                 )
    uniform_client_weight = 1 / int(param_dict['num_clients_K'])
    uniform_distribution_weight = []
    training_dataset_size = len(training_dataset)

    for i in range(int(param_dict['num_clients_K'])):
        _ = len(client_dataset_list[i].indices)
        uniform_distribution_weight.append(_ / training_dataset_size)
    uniform_distribution_weight = torch.tensor(uniform_distribution_weight).to(global_acc.device)

    if param_dict['algorithm_step_T'] % param_dict['communication_round_I'] == 0:
        client_acc_list = [global_acc for i in range(int(param_dict['num_clients_K']))]
    else:
        client_acc_list = []
        for client_id, client_model in enumerate(client_model_list):
            client_acc = Experiment_Model_testing(device=device,
                                                  testing_dataloader=testing_dataloader,
                                                  mask_s1_flag=param_dict['mask_s1_flag'],
                                                  testing_model=client_model,
                                                  hypothesis=param_dict['hypothesis'],
                                                  only_acc=True
                                                  )
            client_acc_list.append(client_acc)
    client_acc_list = torch.tensor(client_acc_list).to(global_acc.device)

    uniform_client_acc = sum(uniform_client_weight * client_acc_list)
    uniform_distribution_acc = sum(uniform_distribution_weight * client_acc_list)

    logger.info(f" ****** global_acc: {global_acc} ******")
    logger.info(f" ****** DEO: {DEO} ******")
    logger.info(f" ****** EOD: {EOD} ******")
    logger.info(f" ****** SPD: {SPD} ******")
    logger.info(f" ****** FR: {FR} ******")
    logger.info(f" ****** HM: {HM} ******")
    logger.info(f" ****** uniform_client_acc: {uniform_client_acc} ******")
    logger.info(f" ****** uniform_distribution_acc: {uniform_distribution_acc} ******")

    result_dict = {
        "global_acc": float(global_acc),
        "DEO": float(DEO),
        "EOD": float(EOD),
        "SPD": float(SPD),
        "FR": float(FR),
        "HM": float(HM),
        "uniform_client_acc": float(uniform_client_acc),
        "uniform_distribution_acc": float(uniform_distribution_acc)
    }
    json_str = json.dumps(result_dict, indent=4)
    with open(param_dict['log_path'] + "_result.json", "w") as json_file:
        json_file.write(json_str)
