import torch
import numpy as np
import copy
import gc
from tool.logger import *
from tool.utils import get_parameters, set_parameters
from torch.utils.data import DataLoader
from algorithm.client_selection import client_selection
from algorithm.FairBatch.FairBatchSampler import FairBatch, CustomDataset


def construct_fairbatch_dataset(device, client_training_dataset):
    indices = client_training_dataset.indices
    x = client_training_dataset.dataset[indices]['X']
    y = torch.tensor(client_training_dataset.dataset[indices]['y'])
    y = y.to(device)
    z = torch.tensor(client_training_dataset.dataset[indices]['s1'])
    z = z.to(device)
    fairbatch_dataset = CustomDataset(x, y, z)
    return fairbatch_dataset


def construct_the_fairsampler(device, model, criterion, train_data, batch_size, alpha, target_fairness):
    if ("oportunity" in target_fairness) or ("eqopp" in target_fairness):
        # case 1: Equal opportunity
        target_fairness = 'eqopp'
    elif ("odds" in target_fairness) or ("eqodds" in target_fairness):
        # case 2: Equalized odds
        target_fairness = 'eqodds'
    else:
        # case 3: Demographic parity
        target_fairness = 'dp'

    sampler = FairBatch(model, train_data.x, train_data.y, train_data.z, batch_size,
                        alpha, target_fairness=target_fairness, replacement=False)
    return sampler


# FairBatch with Logistic Regression Model
def FairBatch_LR(device,
                 global_model,
                 algorithm_step_T, num_clients_K, communication_round_I, FL_fraction, FL_drop_rate, local_step_size,
                 origin_training_dataloaders,
                 training_dataset,
                 client_dataset_list
                 ):
    # Reconstruct the dataloader

    # Training process
    logger.info("Training process")
    criterion = torch.nn.BCELoss(reduction='mean')

    training_dataset_size = len(training_dataset)
    client_datasets_size_list = [len(item) for item in client_dataset_list]
    average_weight = np.array([float(i / training_dataset_size) for i in client_datasets_size_list])

    # Parameter Initialization
    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for i in range(num_clients_K)]

    client_loss_list = [[] for _ in range(num_clients_K)]

    # Client selection
    logger.info("********** Client selection **********")
    idxs_users = client_selection(
        client_num=num_clients_K,
        fraction=FL_fraction,
        dataset_size=training_dataset_size,
        client_dataset_size_list=client_datasets_size_list,
        drop_rate=FL_drop_rate,
        style="FedAvg",
    )
    logger.info(f"********** Select client list: {idxs_users} **********")

    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        avg_loss_over_step = 0
        for i in idxs_users:
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            model.zero_grad()
            model.to(device)
            client_i_dataset = client_dataset_list[i]
            client_i_fairbatch_dataset = construct_fairbatch_dataset(device, client_i_dataset)
            fair_sampler = construct_the_fairsampler(device, model, criterion, client_i_fairbatch_dataset,
                                                     origin_training_dataloaders[i].batch_size, 0.005, 'eqopp')
            client_i_dataloader = torch.utils.data.DataLoader(client_i_fairbatch_dataset, sampler=fair_sampler, num_workers=0)

            # Local Optimizing By SGD
            for batch_index, batch in enumerate(client_i_dataloader):
                # X = batch["X"].to(device)
                X = batch[0].to(device)
                # y = batch["y"].reshape(-1, 1).to(device)
                y = batch[1].reshape(-1, 1).to(device)

                # local_prediction = model(X).to(device)
                local_prediction = model(X).to(device).reshape(-1, 1)
                loss = criterion(local_prediction, y.float())
                batch_loss = round(float(loss), 4)
                loss.backward()

                avg_loss_over_step += batch_loss
                client_loss_list[i].append(batch_loss)

                # if (iter_t + 1) % 5 == 0:
                #     logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                #                 f"Client: {i} / {num_clients_K}; "
                #                 f"Avg Loss in Batch: {round(float(sub_batch_loss), 4)} ##########")

                optimizer.step()
                optimizer.zero_grad()
                # del X, y, local_prediction, loss
                # gc.collect()
                # torch.cuda.empty_cache()
                break

            # Upgrade the local model list
            local_model_list[i] = model

        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step / len(idxs_users)), 4)} ##########")
        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")

            # Global operation
            logger.info("********** Parameter aggregation **********")
            theta_list = []
            for id in idxs_users:
                selected_model = local_model_list[id]
                theta_list.append(get_parameters(selected_model))
            theta_list = np.array(theta_list, dtype=object)
            theta_avg = np.average(theta_list, axis=0, weights=[average_weight[j] for j in idxs_users]).tolist()
            # theta_avg = np.mean(theta_list, 0).tolist()
            global_model = set_parameters(global_model, theta_avg)

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Client Reselection
            logger.info("********** Client selection **********")
            idxs_users = client_selection(
                client_num=num_clients_K,
                fraction=FL_fraction,
                dataset_size=training_dataset_size,
                client_dataset_size_list=client_datasets_size_list,
                drop_rate=FL_drop_rate,
                style="FedAvg",
            )
            logger.info(f"********** Select client list: {idxs_users} **********")

        # Save
        # if (iter_t) % param_dict["save_checkpoint_rounds"] == 0 and iter_t != 0:
        #     save_model(param_dict, global_model, local_model_list, iter_t)
    # logger.info(f"########## Federated Average client loss list: {client_loss_list}; ##########")

    logger.info("Training finish, return global model and local model list")
    return global_model, local_model_list
