import time

import torch
import numpy as np

from tool.logger import *


def D_hat_θ(client_dataset, mask_s1_flag, client_model, criterion, device, hypothesis):
    # According to Eq 7., the D_hat_θ = L_hat_a,c - L_hat_b,c , where a & b are the value of the sensitive attribute.

    client_X = client_dataset["X"].detach()
    client_y = client_dataset["y"]
    if mask_s1_flag:
        client_s = client_dataset["s2"]
    else:
        client_s = client_dataset["s1"]

    a, b = 1, 0

    c0, c1 = (client_y == 0), (client_y == 1)
    sa, sb = (client_s == a), (client_s == b)

    sa_c0, sa_c1 = sa * c0, sa * c1
    sb_c0, sb_c1 = sb * c0, sb * c1

    m_sa_c0, m_sa_c1 = sum(sa_c0), sum(sa_c1)
    m_sb_c0, m_sb_c1 = sum(sb_c0), sum(sb_c1)

    X_sa_c0, X_sa_c1, y_sa_c0, y_sa_c1 = [], [], [], []
    # for index, item in enumerate(sa_c0):
    #     if item:
    #         X_sa_c0.append(client_X[index])
    #         y_sa_c0.append(torch.tensor(client_dataset["y"])[index])
    # for index, item in enumerate(sa_c1):
    #     if item:
    #         X_sa_c1.append(client_X[index])
    #         y_sa_c1.append(torch.tensor(client_dataset["y"])[index])
    X_sb_c0, X_sb_c1, y_sb_c0, y_sb_c1 = [], [], [], []
    # for index, item in enumerate(sb_c0):
    #     if item:
    #         X_sb_c0.append(client_X[index])
    #         y_sb_c0.append(torch.tensor(client_dataset["y"])[index])
    # for index, item in enumerate(sb_c1):
    #     if item:
    #         X_sb_c1.append(client_X[index])
    #         y_sb_c1.append(torch.tensor(client_dataset["y"])[index])
    for index in range(len(sa_c0)):
        flag_1 = sa_c0[index]
        flag_2 = sa_c1[index]
        flag_3 = sb_c0[index]
        flag_4 = sb_c1[index]

        if flag_1:
            X_sa_c0.append(client_X[index])
            y_sa_c0.append(torch.tensor(client_dataset["y"])[index])
        if flag_2:
            X_sa_c1.append(client_X[index])
            y_sa_c1.append(torch.tensor(client_dataset["y"])[index])
        if flag_3:
            X_sb_c0.append(client_X[index])
            y_sb_c0.append(torch.tensor(client_dataset["y"])[index])
        if flag_4:
            X_sb_c1.append(client_X[index])
            y_sb_c1.append(torch.tensor(client_dataset["y"])[index])

    if len(X_sa_c0) != 0:
        X_sa_c0 = torch.stack(X_sa_c0, 0).to(device)
    if len(X_sa_c1) != 0:
        X_sa_c1 = torch.stack(X_sa_c1, 0).to(device)
    if len(y_sa_c0) != 0:
        y_sa_c0 = torch.stack(y_sa_c0, 0).to(device)
    if len(y_sa_c1) != 0:
        y_sa_c1 = torch.stack(y_sa_c1, 0).to(device)

    if len(X_sb_c0) != 0:
        X_sb_c0 = torch.stack(X_sb_c0, 0).to(device)
    if len(X_sb_c1) != 0:
        X_sb_c1 = torch.stack(X_sb_c1, 0).to(device)
    if len(y_sb_c0) != 0:
        y_sb_c0 = torch.stack(y_sb_c0, 0).to(device)
    if len(y_sb_c1) != 0:
        y_sb_c1 = torch.stack(y_sb_c1, 0).to(device)


    L_hat_ac0, L_hat_ac1 = 0, 0
    if len(X_sa_c0) != 0:
        for X, y in zip(X_sa_c0, y_sa_c0):
            prediction = client_model(X).to(device)
            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_ac0 += loss * 1 / m_sa_c0
            # logger.info(f"Prediction: {float(prediction.data)}")

    if len(X_sa_c1) != 0:
        for X, y in zip(X_sa_c1, y_sa_c1):
            prediction = client_model(X).to(device)

            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_ac1 += loss * 1 / m_sa_c1
            # logger.info(f"Prediction: {float(prediction.data)}")

    L_hat_bc0, L_hat_bc1 = 0, 0
    if len(X_sb_c0) != 0:
        for X, y in zip(X_sb_c0, y_sb_c0):
            prediction = client_model(X).to(device)
            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_bc0 += loss * 1 / m_sb_c0
            # logger.info(f"Prediction: {float(prediction.data)}")

    if len(X_sb_c1) != 0:
        for X, y in zip(X_sb_c1, y_sb_c1):
            prediction = client_model(X).to(device)
            if "LR" in hypothesis:
                loss = criterion(prediction, y.reshape(-1).float())
            else:
                loss = criterion(prediction, y.long())
            L_hat_bc1 += loss * 1 / m_sb_c1
            # logger.info(f"Prediction: {float(prediction.data)}")

    L_hat_ac = L_hat_ac0 + L_hat_ac1
    L_hat_bc = L_hat_bc0 + L_hat_bc1

    return L_hat_ac - L_hat_bc


def Fed_Fair_LR(device,
                global_model,
                algorithm_step_T, num_clients_N,
                training_dataloaders,
                training_dataset,
                client_dataset_list,
                ϵ
                ):
    client_datasets_size_list = []
    for i in range(num_clients_N):
        client_datasets_indices = client_dataset_list[i].indices
        client_datasets_dict = training_dataset[client_datasets_indices]
        client_datasets_size_list.append(len(client_datasets_indices))
        client_dataset_list[i] = client_datasets_dict

    m_i_list = torch.tensor(client_datasets_size_list)
    m_total = sum(client_datasets_size_list)

    # Training process
    logger.info("Training process")
    # Parameter Initialization

    α = 0.05
    β = 0.05
    γ = 0.001

    # 这里的损失函数先不要采用reduction=mean,改为后续再手动求平均
    criterion = torch.nn.BCELoss(reduction="none")
    optimizer = torch.optim.SGD(global_model.parameters(), lr=α)
    for iter_t in range(algorithm_step_T):
        if iter_t == 0:
            λ_a, λ_b = 0.15, 1

        if iter_t % 20000 == 0 and iter_t != 0:
            α = 0.1 * α

        client_loss_list = []
        client_D_hat_list = []

        # Simulate Client Parallel for computation
        for i in range(num_clients_N):
            print("Client :", i)
            # logger.info(f"########## Algorithm Epoch: {iter_t + 1} / {algorithm_epoch_T}; "
            #             f"Client: {i + 1} / {num_clients_N};  ##########")
            client_i_dataloader = training_dataloaders[i]
            # Since there is no additional update after the aggregation, so local and global model are consistent
            client_model = global_model.to(device)
            client_i_loss = 0
            client_model.train()
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                try:
                    local_prediction = client_model(X).to(device)
                    loss = criterion(local_prediction, y.float())
                    client_i_loss += loss.sum()

                except RuntimeError:
                    logger.info("Something wrong happen in inference. Skipping this batch of data !")
                    pass
            # L_hat_i_θ in Eq.3
            client_i_loss = client_i_loss / m_i_list[i]

            client_dataset = client_dataset_list[i]
            D_hat_i_θ = D_hat_θ(client_dataset, False, client_model, criterion, device, "LR")
            # Equation 11
            first_term_in_Eq11 = (m_i_list[i] / m_total) * client_i_loss
            second_term_in_Eq11 = (λ_a - λ_b) * D_hat_i_θ / num_clients_N
            client_i_loss = first_term_in_Eq11 + second_term_in_Eq11
            client_loss_list.append(client_i_loss)
            client_D_hat_list.append(D_hat_i_θ)

        # Parameter update by Equation 10
        global_loss = sum(client_loss_list).squeeze().sum()
        global_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # if (iter_t + 1) % 1 == 0:
        #     logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
        #                 f"Avg Loss over Client: {round(float(global_loss / num_clients_N), 4)} ##########")

        # Prepare for the Eq.12 & Eq.13
        common_term_of_the_first_term_in_Eq12_Eq13 = 1 - γ * β
        first_term_in_Eq12 = common_term_of_the_first_term_in_Eq12_Eq13 * λ_a
        first_term_in_Eq13 = common_term_of_the_first_term_in_Eq12_Eq13 * λ_b

        accumulation_of_client_D_hat_list = torch.tensor(client_D_hat_list).sum()
        second_term_in_Eq12_Eq_13 = (β / num_clients_N) * accumulation_of_client_D_hat_list

        third_term_in_Eq12_Eq_13 = β * ϵ

        eq_12 = first_term_in_Eq12 + second_term_in_Eq12_Eq_13 - third_term_in_Eq12_Eq_13
        eq_13 = first_term_in_Eq13 - second_term_in_Eq12_Eq_13 - third_term_in_Eq12_Eq_13
        # Updates λ_a, λ_b following Eq.12 & Eq.13
        λ_a = max(eq_12, 0)
        λ_b = max(eq_13, 0)

        if (iter_t + 1) % 1 == 0:
            # 算法本来的损失设计思路已经带有求均值的过程，所以Avg Loss over Client不需要再除以client_number
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg Loss over Client: {round(float(global_loss), 4)}; "
                        f"λ_a: {round(float(λ_a), 4)}; "
                        f"λ_b: {round(float(λ_b), 4)} ##########")
        # # Equation 8 in paper
        # L_θ_λ_a_λ_b = 0
        # L_θ_λ_a_λ_b += torch.matmul(client_loss_list, (m_i_list / m_total))  # 1st term in Eq.8
        # L_θ_λ_a_λ_b += λ_a * ((client_D_hat_list / num_clients_K).sum() - ϵ)  # 2nd term in Eq.8
        # L_θ_λ_a_λ_b += λ_b * ((client_D_hat_list / num_clients_K).sum() + ϵ)  # 3rd term in Eq.8
        #
        # # Equation 9 in paper
        # L_bar_θ_λ_a_λ_b = L_θ_λ_a_λ_b - ( (0.5*γ) * ( (λ_a*λ_a) + (λ_b*λ_b) ) )

    logger.info("Training finish, return global model")
    return global_model
    # except Exception:
    #     logger.info("Some error happen in training process, return global model")
    #     return global_model