import os
import pickle

from hypothesis.LogisticRegression import RenyiLogisticRegression
from hypothesis.NeuralNetwork import RenyiNeuralNetwork, FedFairNeuralNetwork, Two_D_NeuralNetwork
from .dataset import get_ADULT_dataset, get_COMPAS_dataset, get_DRUG_dataset,\
    get_BANK_dataset, get_GERMAN_dataset, get_ARRHYTHMIA_dataset, get_DUTCH_dataset
from .dataloader import get_FL_dataloader
from tool.logger import *


def Experiment_Create_dataset(param_dict, no_pickle=True):
    dataset_name = param_dict['dataset_name']
    mask_s1_flag = param_dict['mask_s1_flag']
    mask_s2_flag = param_dict['mask_s2_flag']
    mask_s1_s2_flag = param_dict['mask_s1_s2_flag']

    if "ADULT" in dataset_name:
        pickle_path = "./dataset/ADULT/ADULT.pickle"
        data_path = "./dataset/ADULT"
        get_dataset = get_ADULT_dataset
    elif "ARRHYTHMIA" in dataset_name:
        pickle_path = "./dataset/ARRHYTHMIA/ARRHYTHMIA.pickle"
        data_path = "./dataset/ARRHYTHMIA"
        get_dataset = get_ARRHYTHMIA_dataset
    elif "BANK" in dataset_name:
        pickle_path = "./dataset/BANK/BANK.pickle"
        data_path = "./dataset/BANK"
        get_dataset = get_BANK_dataset
    elif "COMPAS" in dataset_name:
        pickle_path = "./dataset/COMPAS/COMPAS.pickle"
        data_path = "./dataset/COMPAS"
        get_dataset = get_COMPAS_dataset
    elif "DRUG" in dataset_name:
        pickle_path = "./dataset/DRUG/DRUG.pickle"
        data_path = "./dataset/DRUG"
        get_dataset = get_DRUG_dataset
    elif "DUTCH" in dataset_name:
        pickle_path = "./dataset/DUTCH/DUTCH.pickle"
        data_path = "./dataset/DUTCH"
        get_dataset = get_DUTCH_dataset
    elif "GERMAN" in dataset_name:
        pickle_path = "./dataset/GERMAN/GERMAN.pickle"
        data_path = "./dataset/GERMAN"
        get_dataset = get_GERMAN_dataset
    else:
        pickle_path = "./dataset/GERMAN/GERMAN.pickle"
        data_path = "./dataset/GERMAN"
        get_dataset = get_GERMAN_dataset

    if no_pickle:
        training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset = get_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag)
    elif not os.path.exists(pickle_path):
        training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset = get_dataset(data_path, mask_s1_flag, mask_s2_flag, mask_s1_s2_flag)
        pickle_dict = {
            "training_dataset": training_dataset,
            "positive_training_dataset": positive_training_dataset,
            "negative_training_dataset": negative_training_dataset,
            "testing_dataset": testing_dataset,
        }
        with open(pickle_path, 'wb') as p:
            pickle.dump(pickle_dict, p)
            p.close()
    else:
        with open(pickle_path, 'rb') as r:
            pickle_dict = pickle.load(r)
            r.close()
        training_dataset = pickle_dict['training_dataset']
        positive_training_dataset = pickle_dict['positive_training_dataset']
        negative_training_dataset = pickle_dict['negative_training_dataset']
        testing_dataset = pickle_dict['testing_dataset']

    logger.info(f"Data Info (Training): {len(training_dataset)}")
    logger.info(f"Data Info (Positive Training): {len(positive_training_dataset)}")
    logger.info(f"Data Info (Negative Training): {len(negative_training_dataset)}")
    logger.info(f"Data Info (Testing): {len(testing_dataset)}")

    # Whether to mask the sensitive attribute
    nn_input_size = training_dataset.X.shape[1]

    if mask_s1_flag:
        logger.info("Masking the sensitive attribute s1")
    elif mask_s2_flag:
        logger.info("Masking the sensitive attribute s2")
    elif mask_s1_s2_flag:
        logger.info("Masking the sensitive attribute s1 and s2")
    else:
        logger.info("Do not masking the sensitive attribute")

    return training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset, nn_input_size


def Experiment_Create_dataloader(param_dict, training_dataset, positive_training_dataset, negative_training_dataset, testing_dataset, split_strategy="Uniform"):
    num_clients_K = param_dict['num_clients_K']
    batch_size = param_dict['batch_size']
    need_validation = param_dict['need_validation']
    sensitive_attribute_skew = param_dict['sensitive_attribute_skew']

    testing_dataloader = get_FL_dataloader(
        testing_dataset, num_clients_K, split_strategy="Uniform",
        do_train=False, batch_size=batch_size, num_workers=0
    )

    if ("no_attribute_skew" in sensitive_attribute_skew) or ("no" in sensitive_attribute_skew):
        training_dataloaders, validation_dataloaders, client_dataset_list = get_FL_dataloader(
            training_dataset, num_clients_K, split_strategy=split_strategy,
            do_train=True, need_validation=need_validation, batch_size=batch_size,
            num_workers=0, do_shuffle=True
        )
        return training_dataloaders, None, None, validation_dataloaders, client_dataset_list, testing_dataloader

    elif "positive" in sensitive_attribute_skew:
        if "positive_0.1" in sensitive_attribute_skew:
            positive_num_clients_K = int(num_clients_K * 0.1)
        elif "positive_0.5" in sensitive_attribute_skew:
            positive_num_clients_K = int(num_clients_K * 0.5)
        negative_num_clients_K = num_clients_K - positive_num_clients_K

        positive_training_dataloaders, positive_validation_dataloaders, positive_client_dataset_list = get_FL_dataloader(
            positive_training_dataset, positive_num_clients_K, split_strategy=split_strategy,
            do_train=True, need_validation=need_validation, batch_size=batch_size,
            num_workers=0, do_shuffle=True
        )
        negative_training_dataloaders, negative_validation_dataloaders, negative_client_dataset_list = get_FL_dataloader(
            negative_training_dataset, negative_num_clients_K, split_strategy=split_strategy,
            do_train=True, need_validation=need_validation, batch_size=batch_size,
            num_workers=0, do_shuffle=True
        )

        training_dataloaders = positive_training_dataloaders + negative_training_dataloaders
        if need_validation:
            validation_dataloaders = positive_validation_dataloaders + negative_validation_dataloaders
        else:
            validation_dataloaders = None
        client_dataset_list = positive_client_dataset_list + negative_client_dataset_list
        return training_dataloaders, positive_training_dataloaders, negative_training_dataloaders, \
               validation_dataloaders, client_dataset_list, testing_dataloader


def Experiment_Model_construction(param_dict, nn_input_size):
    if param_dict['hypothesis'] == "LR":
        logger.info("Model construction (Logistic Regression)")
        model = RenyiLogisticRegression(input_size=nn_input_size)
    else:
        logger.info("Model construction (Neural Network)")
        hidden_size = 12  # Copy from Renyi
        model = RenyiNeuralNetwork(input_size=nn_input_size, hidden_size=hidden_size)

        # model = FedFairNeuralNetwork(input_size=nn_input_size)

        # model = Two_D_NeuralNetwork(input_size=nn_input_size)
    device = param_dict['device']
    model.to(device)
    return model


