import torch
import numpy as np
import copy
import gc

from tool.logger import *
from tool.utils import get_parameters, set_parameters
from algorithm.client_selection import client_selection
from moudle.model_testing import Experiment_Model_testing


def weighted_loss(criterion, logits, targets, weights, mean=True):
    acc_loss = criterion(logits, targets.float())
    if mean:
        weights_sum = weights.sum().item()
        acc_loss = torch.sum(acc_loss * weights / weights_sum)
    else:
        acc_loss = torch.sum(acc_loss * weights)
    return acc_loss


def weighted_average_weights(w, nc, n):
    w_avg = copy.deepcopy(w[0])
    for i in range(1, len(w)):
        for key in w_avg.keys():
            w_avg[key] += w[i][key] * nc[i]

    for key in w_avg.keys():
        w_avg[key] = torch.div(w_avg[key], n)
    return w_avg


def get_logits_from_logistic(p):
    logits = torch.log(p / (1 - p))
    return logits


def FedFB_style_inference(DEVICE, model, inference_dataloader, bits=False, truem_yz=None):
    """
    Returns the inference accuracy,
                            loss,
                            N(sensitive group, pos),
                            N(non-sensitive group, pos),
                            N(sensitive group),
                            N(non-sensitive group),
                            acc_loss,
                            fair_loss
    """

    model.eval()
    loss, total, correct, fair_loss, acc_loss, num_batch = 0.0, 0.0, 0.0, 0.0, 0.0, 0
    n_yz, loss_yz, m_yz, f_z = {}, {}, {}, {}

    criterion = torch.nn.BCELoss(reduction='sum')

    for y in [0, 1]:
        for z in range(2):
            loss_yz[(y, z)] = 0
            n_yz[(y, z)] = 0
            m_yz[(y, z)] = 0

    with torch.no_grad():
        for _, batch in enumerate(inference_dataloader):
            features = batch["X"].to(DEVICE)
            labels = batch["y"].to(DEVICE)
            sensitive = batch["s1"].to(DEVICE)

            # Inference & Prediction
            outputs = model(features).to(DEVICE)
            pred_labels = (outputs >= 0.5).reshape(-1)
            correct += sum(pred_labels.eq(labels))
            total += len(labels)
            num_batch += 1

            group_boolean_idx = {}

            for yz in n_yz:
                group_boolean_idx[yz] = (labels == yz[0]) & (sensitive == yz[1])
                n_yz[yz] += torch.sum((pred_labels == yz[0]) & (sensitive == yz[1])).item()
                m_yz[yz] += torch.sum((labels == yz[0]) & (sensitive == yz[1])).item()

                if group_boolean_idx[yz].sum() != 0:
                    # the objective function have no lagrangian term
                    acc_loss = criterion(outputs[group_boolean_idx[yz]], labels.reshape(-1, 1)[group_boolean_idx[yz]].float())
                    loss_yz[yz] += acc_loss

            logits_1 = get_logits_from_logistic(outputs)
            logits_0 = 1-logits_1

            fair_loss0 = torch.mul(sensitive - sensitive.type(torch.FloatTensor).mean(),
                                   logits_0 - torch.mean(logits_0))
            fair_loss0 = torch.mean(torch.mul(fair_loss0, fair_loss0))
            fair_loss1 = torch.mul(sensitive - sensitive.type(torch.FloatTensor).mean(),
                                   logits_0 - torch.mean(logits_1))
            fair_loss1 = torch.mean(torch.mul(fair_loss1, fair_loss1))
            fair_loss = fair_loss0 + fair_loss1
            acc_loss = criterion(outputs, labels.reshape(-1, 1).float())
            batch_loss, batch_acc_loss, batch_fair_loss = acc_loss, acc_loss, fair_loss

            loss, acc_loss, fair_loss = (loss + batch_loss.item(),
                                         acc_loss + batch_acc_loss.item(),
                                         fair_loss + batch_fair_loss.item())
        accuracy = correct / total
        for z in range(1, 2):
            f_z[z] = - loss_yz[(0, 0)] / (truem_yz[(0, 0)] + truem_yz[(1, 0)]) + loss_yz[(1, 0)] / (
                        truem_yz[(0, 0)] + truem_yz[(1, 0)]) + loss_yz[(0, z)] / (truem_yz[(0, z)] + truem_yz[(1, z)]) - \
                     loss_yz[(1, z)] / (truem_yz[(0, z)] + truem_yz[(1, z)])

        return accuracy, loss, n_yz, acc_loss / num_batch, fair_loss / num_batch, f_z


# FedFB with Logistic Regression Model
def FedFB_LR(device,
             global_model,
             algorithm_step_T, num_clients_K, communication_round_I, FL_fraction, FL_drop_rate, local_step_size,
             training_dataloaders,
             training_dataset,
             client_dataset_list
             ):
    # Training process
    logger.info("Training process")
    criterion = torch.nn.BCELoss(reduction='none')

    training_dataset_size = len(training_dataset)
    client_datasets_size_list = [len(item) for item in client_dataset_list]

    # Operation in FedFB
    # the number of samples whose label is y and sensitive attribute is z
    m_yz, lbd = {}, {}
    for y in [0, 1]:
        for z in range(2):
            m_yz[(y, z)] = ((training_dataset.y == y) & (training_dataset.s1 == z)).sum()
    for y in [0, 1]:
        for z in range(2):
            lbd[(y, z)] = (m_yz[(1, z)] + m_yz[(0, z)]) / len(training_dataset)
    # New Params in FedFB
    alpha = 0.3
    global_nc = []
    train_loss = []

    # Parameter Initialization
    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

    client_loss_list = [[] for _ in range(num_clients_K)]

    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        avg_loss_over_step = 0
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            model.zero_grad()
            model.to(device)
            client_i_dataloader = training_dataloaders[i]

            nc = 0  # New Param in FedFB
            # Local Optimizing By SGD
            for batch_index, batch in enumerate(client_i_dataloader):
                batch_X = batch["X"].to(device)
                batch_y = batch["y"].reshape(-1, 1).to(device)
                batch_sensitive = batch["s1"].to(device)

                # Operation in FedFB
                v = torch.ones(len(batch_y)).type(torch.DoubleTensor)  # New Param in FedFB
                group_idx = {}  # New Param in FedFB
                for y, z in lbd:
                    group_idx[(y, z)] = torch.where((batch_y == y) & (batch_sensitive == z))[0]
                    v[group_idx[(y, z)]] = lbd[(y, z)] / (m_yz[(1, z)] + m_yz[(0, z)])
                    nc += v[group_idx[(y, z)]].sum().item()
                local_prediction = model(batch_X).to(device)
                loss = weighted_loss(criterion, local_prediction, batch_y, v, mean=False)
                if not np.isnan(loss.item()): loss.backward()

                avg_loss_over_step += round(float(loss), 4) * 1 / num_clients_K
                client_loss_list[i].append(round(float(loss), 4))

                optimizer.step()
                optimizer.zero_grad()
                # del X, y, local_prediction, loss
                # gc.collect()
                # torch.cuda.empty_cache()
                break

            # Upgrade the local model list
            local_model_list[i] = model

            # Operation in FedFB
            train_loss.append(avg_loss_over_step)
            global_nc.append(nc)

        logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                    f"Avg Loss over Client: {round(float(avg_loss_over_step), 4)} ##########")

        # Communicate
        # Global operation
        # logger.info("********** Parameter aggregation **********")
        theta_list = []
        for i in range(num_clients_K):
            selected_model = local_model_list[i]
            theta_list.append(selected_model.state_dict())  # Operation in FedFB

        theta_avg = weighted_average_weights(theta_list, global_nc, sum(global_nc))  # Operation in FedFB
        global_model.load_state_dict(theta_avg)  # Operation in FedFB

        # Parameter Distribution
        # logger.info("********** Parameter distribution **********")

        # Calculate avg training accuracy over all clients at every round
        list_acc = []
        # the number of samples which are assigned to class y and belong to the sensitive group z
        n_yz, f_z = {}, {}
        for z in range(2):
            f_z[z] = 0
            for y in [0, 1]:
                n_yz[(y, z)] = 0
        global_model.eval()
        for i in range(num_clients_K):
            client_i_dataloader = training_dataloaders[i]
            acc, loss, n_yz_c, acc_loss, fair_loss, f_z_c = FedFB_style_inference(device, global_model, client_i_dataloader, False, m_yz)
            list_acc.append(acc)

            for yz in n_yz:
                n_yz[yz] += n_yz_c[yz]

            for z in range(1, 2):
                f_z[z] += f_z_c[z] + m_yz[(0, 0)] / (m_yz[(0, 0)] + m_yz[(1, 0)]) - m_yz[(0, z)] / (
                            m_yz[(0, z)] + m_yz[(1, z)])

        for z in range(2):
            if z == 0:
                lbd[(0, z)] -= alpha / (iter_t + 1) ** .5 * sum([f_z[z] for z in range(1, 2)])
                lbd[(0, z)] = lbd[(0, z)].item()
                lbd[(0, z)] = max(0, min(lbd[(0, z)], 2 * (m_yz[(1, 0)] + m_yz[(0, 0)]) / len(training_dataset)))
                lbd[(1, z)] = 2 * (m_yz[(1, 0)] + m_yz[(0, 0)]) / len(training_dataset) - lbd[(0, z)]
            else:
                lbd[(0, z)] += alpha / (iter_t + 1) ** .5 * f_z[z]
                lbd[(0, z)] = lbd[(0, z)].item()
                lbd[(0, z)] = max(0, min(lbd[(0, z)], 2 * (m_yz[(1, 0)] + m_yz[(0, 0)]) / len(training_dataset)))
                lbd[(1, z)] = 2 * (m_yz[(1, 0)] + m_yz[(0, 0)]) / len(training_dataset) - lbd[(0, z)]

        local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]


    # logger.info(f"########## FedFB client loss list: {client_loss_list}; ##########")
    logger.info("Training finish, return global model and local model list")
    return global_model, local_model_list
