import torch
import numpy as np
from tool.logger import *


def D_hat_θ(client_dataset, mask_s1_flag, client_model, criterion, device, hypothesis):
    # According to Eq 7., the D_hat_θ = L_hat_a,c - L_hat_b,c , where a & b are the value of the sensitive attribute.

    client_X = client_dataset["X"].detach()
    client_y = client_dataset["y"]
    if mask_s1_flag:
        client_s = client_dataset["s2"]
    else:
        client_s = client_dataset["s1"]

    a, b = 1, 0

    c0, c1 = (client_y == 0), (client_y == 1)
    sa, sb = (client_s == a), (client_s == b)

    sa_c0, sa_c1 = sa * c0, sa * c1
    sb_c0, sb_c1 = sb * c0, sb * c1

    m_sa_c0, m_sa_c1 = sum(sa_c0), sum(sa_c1)
    m_sb_c0, m_sb_c1 = sum(sb_c0), sum(sb_c1)

    X_sa_c0, X_sa_c1, y_sa_c0, y_sa_c1 = [], [], [], []
    # for index, item in enumerate(sa_c0):
    #     if item:
    #         X_sa_c0.append(client_X[index])
    #         y_sa_c0.append(torch.tensor(client_dataset["y"])[index])
    # for index, item in enumerate(sa_c1):
    #     if item:
    #         X_sa_c1.append(client_X[index])
    #         y_sa_c1.append(torch.tensor(client_dataset["y"])[index])
    X_sb_c0, X_sb_c1, y_sb_c0, y_sb_c1 = [], [], [], []
    # for index, item in enumerate(sb_c0):
    #     if item:
    #         X_sb_c0.append(client_X[index])
    #         y_sb_c0.append(torch.tensor(client_dataset["y"])[index])
    # for index, item in enumerate(sb_c1):
    #     if item:
    #         X_sb_c1.append(client_X[index])
    #         y_sb_c1.append(torch.tensor(client_dataset["y"])[index])
    for index in range(len(sa_c0)):
        flag_1 = sa_c0[index]
        flag_2 = sa_c1[index]
        flag_3 = sb_c0[index]
        flag_4 = sb_c1[index]

        if flag_1:
            X_sa_c0.append(client_X[index])
            y_sa_c0.append(torch.tensor(client_dataset["y"])[index])
        if flag_2:
            X_sa_c1.append(client_X[index])
            y_sa_c1.append(torch.tensor(client_dataset["y"])[index])
        if flag_3:
            X_sb_c0.append(client_X[index])
            y_sb_c0.append(torch.tensor(client_dataset["y"])[index])
        if flag_4:
            X_sb_c1.append(client_X[index])
            y_sb_c1.append(torch.tensor(client_dataset["y"])[index])

    if len(X_sa_c0) != 0:
        X_sa_c0 = torch.stack(X_sa_c0, 0).to(device)
    if len(X_sa_c1) != 0:
        X_sa_c1 = torch.stack(X_sa_c1, 0).to(device)
    if len(y_sa_c0) != 0:
        y_sa_c0 = torch.stack(y_sa_c0, 0).to(device)
    if len(y_sa_c1) != 0:
        y_sa_c1 = torch.stack(y_sa_c1, 0).to(device)

    if len(X_sb_c0) != 0:
        X_sb_c0 = torch.stack(X_sb_c0, 0).to(device)
    if len(X_sb_c1) != 0:
        X_sb_c1 = torch.stack(X_sb_c1, 0).to(device)
    if len(y_sb_c0) != 0:
        y_sb_c0 = torch.stack(y_sb_c0, 0).to(device)
    if len(y_sb_c1) != 0:
        y_sb_c1 = torch.stack(y_sb_c1, 0).to(device)


    L_hat_ac0, L_hat_ac1 = 0, 0
    if len(X_sa_c0) != 0:
        for X, y in zip(X_sa_c0, y_sa_c0):
            prediction = client_model(X).to(device)
            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_ac0 += loss * 1 / m_sa_c0
            # logger.info(f"Prediction: {float(prediction.data)}")

    if len(X_sa_c1) != 0:
        for X, y in zip(X_sa_c1, y_sa_c1):
            prediction = client_model(X).to(device)

            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_ac1 += loss * 1 / m_sa_c1
            # logger.info(f"Prediction: {float(prediction.data)}")

    L_hat_bc0, L_hat_bc1 = 0, 0
    if len(X_sb_c0) != 0:
        for X, y in zip(X_sb_c0, y_sb_c0):
            prediction = client_model(X).to(device)
            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_bc0 += loss * 1 / m_sb_c0
            # logger.info(f"Prediction: {float(prediction.data)}")

    if len(X_sb_c1) != 0:
        for X, y in zip(X_sb_c1, y_sb_c1):
            prediction = client_model(X).to(device)
            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_bc1 += loss * 1 / m_sb_c1
            # logger.info(f"Prediction: {float(prediction.data)}")

    L_hat_ac = L_hat_ac0 + L_hat_ac1
    L_hat_bc = L_hat_bc0 + L_hat_bc1

    return L_hat_ac + L_hat_bc, L_hat_ac - L_hat_bc


def LCO_LR(device,
           global_model,
           algorithm_step_T, num_clients_N,
           training_dataloaders,
           training_dataset,
           client_dataset_list,
           ϵ
           ):
    try:
        client_datasets_size_list = []
        for i in range(num_clients_N):
            client_datasets_indices = client_dataset_list[i].indices
            client_datasets_dict = training_dataset[client_datasets_indices]
            client_datasets_size_list.append(len(client_datasets_indices))
            client_dataset_list[i] = client_datasets_dict

        m_i_list = torch.tensor(client_datasets_size_list)
        m_total = sum(client_datasets_size_list)

        # Training process
        logger.info("Training process")
        # Parameter Initialization

        α = 0.05
        β = 0.05
        γ = 0.001

        # 这里的损失函数先不要采用reduction=mean,改为后续再手动求平均
        criterion = torch.nn.BCELoss(reduction="none")
        optimizer = torch.optim.SGD(global_model.parameters(), lr=α)

        for iter_t in range(algorithm_step_T):
            if iter_t == 0:
                λ_a_list = [0.15] * num_clients_N
                λ_b_list = [1] * num_clients_N

            if iter_t % 20000 == 0 and iter_t != 0:
                α = 0.1 * α
            client_loss_list = []
            client_D_hat_list = []

            # Simulate Client Parallel for computation
            for i in range(num_clients_N):
                # logger.info(f"########## Algorithm Epoch: {iter_t + 1} / {algorithm_step_T}; "
                            # f"Client: {i + 1} / {num_clients_N};  ##########")
                client_i_dataloader = training_dataloaders[i]
                # Since there is no local update in Algorithm 1 in paper, so local model is always equal to global model
                client_model = global_model.to(device)
                client_model.train()
                client_i_loss = 0

                # for batch_index, batch in enumerate(client_i_dataloader):
                #     X = batch["X"].to(device)
                #     y = batch["y"].reshape(-1, 1).to(device)
                #     try:
                #         local_prediction = client_model(X).to(device)
                #         loss = criterion(local_prediction, y.float())
                #         client_i_loss += loss.sum()
                #     except RuntimeError:
                #         logger.info("Something wrong happen in inference. Skipping this batch of data !")
                #         pass
                #
                # # L_hat_i_θ in Eq.3
                # client_i_loss = client_i_loss / m_i_list[i]

                client_dataset = client_dataset_list[i]
                client_i_loss, D_hat_i_θ = D_hat_θ(client_dataset, False, client_model, criterion, device, "LR")
                # Equation 11
                λ_a, λ_b = λ_a_list[i], λ_b_list[i]
                first_term_in_Eq11 = (m_i_list[i] / m_total) * client_i_loss
                second_term_in_Eq11 = (λ_a - λ_b) * D_hat_i_θ
                client_i_loss = first_term_in_Eq11 + second_term_in_Eq11
                client_i_loss = torch.where(torch.isnan(client_i_loss), torch.full_like(client_i_loss, 0), client_i_loss)
                client_loss_list.append(client_i_loss)
                client_D_hat_list.append(D_hat_i_θ)

            # Parameter update by Equation 10
            optimizer = torch.optim.SGD(global_model.parameters(), lr=α)
            global_loss = sum(client_loss_list)
            global_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Equation 12 & 13
            common_term_of_the_first_term_in_Eq12_Eq13 = 1 - γ * β
            first_term_in_Eq12 = common_term_of_the_first_term_in_Eq12_Eq13 * λ_a
            first_term_in_Eq13 = common_term_of_the_first_term_in_Eq12_Eq13 * λ_b

            second_term_in_Eq12_Eq_13 = β * torch.tensor(client_D_hat_list)
            second_term_in_Eq12_Eq_13 = torch.where(torch.isinf(second_term_in_Eq12_Eq_13),
                                                    torch.full_like(second_term_in_Eq12_Eq_13, 0),
                                                    second_term_in_Eq12_Eq_13)

            third_term_in_Eq12_Eq_13 = β * ϵ

            eq_12 = first_term_in_Eq12 + second_term_in_Eq12_Eq_13 - third_term_in_Eq12_Eq_13
            eq_13 = first_term_in_Eq13 - second_term_in_Eq12_Eq_13 - third_term_in_Eq12_Eq_13

            # Updates λ_a, λ_b by Equation 12 & 13
            for i in range(len(λ_a_list)):
                λ_a_list[i] = max(eq_12[i], 0)
                λ_b_list[i] = max(eq_13[i], 0)

            if (iter_t + 1) % 1 == 0:
                # 算法本来的损失设计思路已经带有求均值的过程，所以Avg Loss over Client不需要再除以client_number
                logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                            f"Avg Loss over Client: {round(float(global_loss), 4)}; "
                            f"λ_a: {round(float(λ_a), 4)}; "
                            f"λ_b: {round(float(λ_b), 4)} ##########")

        logger.info("Training finish, return global model")
        return global_model
    except Exception:
        logger.info("Some error happen in training process, return global model")
        return global_model
