import json
import torch
from moudle.model_testing import Experiment_Model_testing
from tool.logger import *
from algorithm.SeparateTraining.SeparateTraining_LR import ST_LR
from algorithm.SeparateTraining.SeparateTraining_NN import ST_NN


def Experiment_Separate(param_dict, global_model, training_dataloaders, training_dataset, client_dataset_list,
                        testing_dataloader):
    device = param_dict['device']
    num_clients_K = param_dict['num_clients_K']
    if param_dict['hypothesis'] == "LR":
        Separate_Training = ST_LR
    else:
        Separate_Training = ST_NN

    client_model_list = Separate_Training(
        device,
        global_model,
        param_dict['algorithm_step_T'],
        num_clients_K,
        param_dict['local_step_size'],
        training_dataloaders
    )
    logger.info("-----------------------------------------------------------------------------")

    # logger.info("Global model Saving")
    # check_and_make_the_path(param_dict['model_path'])
    # torch.save(updated_global_model, param_dict['model_path'] + "/global_model.pkl")
    # logger.info("Client Models Saving")
    # for client_id, client_model in enumerate(client_model_list):
    #     _ = os.path.join(param_dict['model_path'], "client_" + str(client_id + 1) + "/client_model.pkl")
    #     check_and_make_the_path(os.path.join(param_dict['model_path'], "client_" + str(client_id + 1)))
    #     torch.save(client_model, _)

    uniform_client_weight = 1 / int(num_clients_K)
    uniform_distribution_weight = []
    training_dataset_size = len(training_dataset)

    for i in range(int(num_clients_K)):
        _ = len(client_dataset_list[i].indices)
        uniform_distribution_weight.append(_ / training_dataset_size)
    uniform_distribution_weight = torch.tensor(uniform_distribution_weight).to(device)

    logger.info("Client model list testing")
    client_acc_list = []
    DEO, EOD, SPD, FR, HM = 0, 0, 0, 0, 0
    for client_model in client_model_list:
        client_acc, tmp_DEO, tmp_EOD, tmp_SPD, tmp_FR, tmp_HM = Experiment_Model_testing(device=device,
                                                                     testing_dataloader=testing_dataloader,
                                                                     mask_s1_flag=param_dict['mask_s1_flag'],
                                                                     testing_model=client_model,
                                                                     hypothesis=param_dict['hypothesis'],
                                                                     only_acc=False
                                                                     )
        DEO += tmp_DEO / num_clients_K
        EOD += tmp_EOD / num_clients_K
        SPD += tmp_SPD / num_clients_K
        FR += tmp_FR / num_clients_K
        HM += tmp_HM / num_clients_K

        client_acc_list.append(client_acc)
    client_acc_list = torch.tensor(client_acc_list).to(device)

    uniform_client_acc = sum(uniform_client_weight * client_acc_list)
    uniform_distribution_acc = sum(uniform_distribution_weight * client_acc_list)

    logger.info(f" ****** DEO: {DEO} ******")
    logger.info(f" ****** EOD: {EOD} ******")
    logger.info(f" ****** SPD: {SPD} ******")
    logger.info(f" ****** FR: {FR} ******")
    logger.info(f" ****** HM: {HM} ******")
    logger.info(f" ****** uniform_client_acc: {uniform_client_acc} ******")
    logger.info(f" ****** uniform_distribution_acc: {uniform_distribution_acc} ******")

    result_dict = {
        "DEO": float(DEO),
        "EOD": float(EOD),
        "SPD": float(SPD),
        "FR": float(FR),
        "HM": float(HM),
        "uniform_client_acc": float(uniform_client_acc),
        "uniform_distribution_acc": float(uniform_distribution_acc)
    }
    json_str = json.dumps(result_dict, indent=4)
    with open(param_dict['log_path'] + "_result.json", "w") as json_file:
        json_file.write(json_str)
