import torch
import numpy as np

from tool.logger import *


def D_hat_θ(client_dataset, mask_s1_flag, client_model, criterion, device, hypothesis):
    # According to Eq 7., the D_hat_θ = L_hat_a,c - L_hat_b,c , where a & b are the value of the sensitive attribute.

    client_X = client_dataset["X"].detach()
    client_y = client_dataset["y"]

    if mask_s1_flag:
        client_s = client_dataset["s2"]
    else:
        client_s = client_dataset["s1"]

    a, b = 1, 0

    c0, c1 = (client_y == 0), (client_y == 1)
    sa, sb = (client_s == a), (client_s == b)

    sa_c0, sa_c1 = sa * c0, sa * c1
    sb_c0, sb_c1 = sb * c0, sb * c1

    m_sa_c0, m_sa_c1 = sum(sa_c0), sum(sa_c1)
    m_sb_c0, m_sb_c1 = sum(sb_c0), sum(sb_c1)

    X_sa_c0, X_sa_c1, y_sa_c0, y_sa_c1 = [], [], [], []
    for index, item in enumerate(sa_c0):
        if item:
            X_sa_c0.append(client_X[index])
            y_sa_c0.append(client_y[index])
    for index, item in enumerate(sa_c1):
        if item:
            X_sa_c1.append(client_X[index])
            y_sa_c1.append(client_y[index])
    X_sa_c0, X_sa_c1, y_sa_c0, y_sa_c1 = torch.stack(X_sa_c0, 0), torch.stack(X_sa_c1, 0), torch.tensor(
        y_sa_c0), torch.tensor(y_sa_c1),
    X_sa_c0 = X_sa_c0.to(device)
    X_sa_c1 = X_sa_c1.to(device)
    y_sa_c0 = y_sa_c0.to(device)
    y_sa_c1 = y_sa_c1.to(device)

    X_sb_c0, X_sb_c1, y_sb_c0, y_sb_c1 = [], [], [], []
    for index, item in enumerate(sb_c0):
        if item:
            X_sb_c0.append(client_X[index])
            y_sb_c0.append(client_y[index])
    for index, item in enumerate(sb_c1):
        if item:
            X_sb_c1.append(client_X[index])
            y_sb_c1.append(client_y[index])

    X_sb_c0, X_sb_c1, y_sb_c0, y_sb_c1 = torch.stack(X_sb_c0, 0), torch.stack(X_sb_c1, 0), torch.tensor(
        y_sb_c0), torch.tensor(y_sb_c1),
    X_sb_c0 = X_sb_c0.to(device)
    X_sb_c1 = X_sb_c1.to(device)
    y_sb_c0 = y_sb_c0.to(device)
    y_sb_c1 = y_sb_c1.to(device)

    L_hat_ac0, L_hat_ac1 = 0, 0
    for X, y in zip(X_sa_c0, y_sa_c0):
        prediction = client_model(X).to(device)
        if "LR" in hypothesis:
            loss = criterion(prediction, y.reshape(-1).float())
        else:
            loss = criterion(prediction, y.long())
        L_hat_ac0 += loss * 1 / m_sa_c0
        # logger.info(f"Prediction: {float(prediction.data)}")

    for X, y in zip(X_sa_c1, y_sa_c1):
        prediction = client_model(X).to(device)

        if "LR" in hypothesis:
            loss = criterion(prediction, y.reshape(-1).float())
        else:
            loss = criterion(prediction, y.long())
        L_hat_ac1 += loss * 1 / m_sa_c1
        # logger.info(f"Prediction: {float(prediction.data)}")

    L_hat_bc0, L_hat_bc1 = 0, 0
    for X, y in zip(X_sb_c0, y_sb_c0):
        prediction = client_model(X).to(device)
        if "LR" in hypothesis:
            loss = criterion(prediction, y.reshape(-1).float())
        else:
            loss = criterion(prediction, y.long())
        L_hat_bc0 += loss * 1 / m_sb_c0
        # logger.info(f"Prediction: {float(prediction.data)}")

    for X, y in zip(X_sb_c1, y_sb_c1):
        prediction = client_model(X).to(device)
        if "LR" in hypothesis:
            loss = criterion(prediction, y.reshape(-1).float())
        else:
            loss = criterion(prediction, y.long())
        L_hat_bc1 += loss * 1 / m_sb_c1
        # logger.info(f"Prediction: {float(prediction.data)}")

    L_hat_ac = L_hat_ac0 + L_hat_ac1
    L_hat_bc = L_hat_bc0 + L_hat_bc1
    return L_hat_ac - L_hat_bc


# 未调试，请先不要使用
def Fed_Fair_NN(device,
                global_model,
                algorithm_step_T, num_clients_N,
                training_dataloaders,
                training_dataset,
                client_dataset_list,
                ϵ
                ):
    try:
        client_datasets_size_list = []
        for i in range(num_clients_N):
            client_datasets_indices = client_dataset_list[i].indices
            client_datasets_dict = training_dataset[client_datasets_indices]
            client_datasets_size_list.append(len(client_datasets_indices))
            client_dataset_list[i] = client_datasets_dict

        m_i_list = torch.tensor(client_datasets_size_list)
        m_total = sum(client_datasets_size_list)

        # Training process
        logger.info("Training process")
        # Parameter Initialization

        α = 0.05
        β = 0.05
        γ = 0.001

        criterion = torch.nn.CrossEntropyLoss()

        for iter_t in range(algorithm_step_T):
            if iter_t == 0:
                λ_a, λ_b = 0.15, 1

            if iter_t % 20000 == 0 and iter_t != 0:
                α = 0.1 * α
            client_loss_list = []
            client_D_hat_list = []

            # Simulate Client Parallel for computation
            for i in range(num_clients_N):
                # logger.info(f"########## Algorithm Epoch: {iter_t + 1} / {algorithm_step_T}; "
                #             f"Client: {i + 1} / {num_clients_N};  ##########")

                client_i_dataloader = training_dataloaders[i]
                client_dataset = client_dataset_list[i]
                # Since there is no local update in Algorithm 1 in paper, so local model is always equal to global model
                client_model = global_model.to(device)
                client_model.train()
                client_i_loss = 0

                for batch_index, batch in enumerate(client_i_dataloader):
                    X = batch["X"].to(device)
                    y = batch["y"].to(device)
                    local_prediction = client_model(X).to(device)
                    loss = criterion(local_prediction, y.long())
                    client_i_loss += loss

                # L_hat_i_θ in Eq.4
                client_i_loss = client_i_loss / m_i_list[i]
                D_hat_i_θ = D_hat_θ(client_dataset, False, client_model, criterion, device, "NN")

                # Equation 11
                first_term_in_Eq11 = (m_i_list[i] / m_total) * client_i_loss
                second_term_in_Eq11 = (λ_a - λ_b) * D_hat_i_θ / num_clients_N
                client_i_loss = first_term_in_Eq11 + second_term_in_Eq11
                client_i_loss = torch.where(torch.isnan(client_i_loss), torch.full_like(client_i_loss, 0),
                                            client_i_loss)
                client_loss_list.append(client_i_loss)
                client_D_hat_list.append(D_hat_i_θ)

            # Parameter update by Equation 10
            optimizer = torch.optim.SGD(global_model.parameters(), lr=α)
            global_loss = sum(client_loss_list)
            global_loss.backward()
            optimizer.step()

            # Equation 12 & 13
            common_term_of_the_first_term_in_Eq12_Eq13 = 1 - γ * β
            first_term_in_Eq12 = common_term_of_the_first_term_in_Eq12_Eq13 * λ_a
            first_term_in_Eq13 = common_term_of_the_first_term_in_Eq12_Eq13 * λ_b

            accumulation_of_client_D_hat_list = torch.tensor(client_D_hat_list).sum()
            second_term_in_Eq12_Eq_13 = (β / num_clients_N) * accumulation_of_client_D_hat_list
            second_term_in_Eq12_Eq_13 = torch.where(torch.isinf(second_term_in_Eq12_Eq_13),
                                                    torch.full_like(second_term_in_Eq12_Eq_13, 0),
                                                    second_term_in_Eq12_Eq_13)

            third_term_in_Eq12_Eq_13 = β * ϵ

            eq_12 = first_term_in_Eq12 + second_term_in_Eq12_Eq_13 - third_term_in_Eq12_Eq_13
            eq_13 = first_term_in_Eq13 - second_term_in_Eq12_Eq_13 - third_term_in_Eq12_Eq_13

            # Updates λ_a, λ_b by Equation 12 & 13
            λ_a = max(eq_12, 0)
            λ_b = max(eq_13, 0)

            # # Equation 8 in paper
            # L_θ_λ_a_λ_b = 0
            # L_θ_λ_a_λ_b += torch.matmul(client_loss_list, (m_i_list / m_total))  # 1st term in Eq.8
            # L_θ_λ_a_λ_b += λ_a * ((client_D_hat_list / num_clients_K).sum() - ϵ)  # 2nd term in Eq.8
            # L_θ_λ_a_λ_b += λ_b * ((client_D_hat_list / num_clients_K).sum() + ϵ)  # 3rd term in Eq.8
            #
            # # Equation 9 in paper
            # L_bar_θ_λ_a_λ_b = L_θ_λ_a_λ_b - ( (0.5*γ) * ( (λ_a*λ_a) + (λ_b*λ_b) ) )

        logger.info("Training finish, return global model")
        return global_model
    except Exception:
        return global_model
