import torch
import numpy as np
import copy
import gc

from tool.logger import *
from tool.utils import get_parameters, set_parameters
from algorithm.client_selection import client_selection


def get_full_dataset_acc(device, model, full_dataloader):
    acc_numerator = 0
    acc_denominator = 0

    for batch_index, batch in enumerate(full_dataloader):
        X = batch["X"].to(device)
        y = batch["y"].to(device)
        tmp = model(X).to(device)
        prediction = (tmp >= 0.5).reshape(-1)
        acc_numerator += sum(prediction.eq(y))
        acc_denominator += X.shape[0]

    acc = acc_numerator / acc_denominator

    return acc


def get_full_dataset_F(device, model, full_dataloader, pr_y1_s0, pr_y1_s1, mask_s1_flag=False):
    num_s1_pred1 = 0
    num_s1_pred0 = 0
    num_s0_pred1 = 0
    num_s0_pred0 = 0

    num_s1_pred1_y1 = 0
    num_s1_pred1_y0 = 0
    num_s0_pred1_y1 = 0
    num_s0_pred1_y0 = 0

    num_s1_y1 = 0
    num_s1_y0 = 0
    num_s0_y1 = 0
    num_s0_y0 = 0

    for batch_index, batch in enumerate(full_dataloader):
        X = batch["X"].to(device)
        y = batch["y"].to(device)
        tmp = model(X).to(device)
        prediction = (tmp >= 0.5).reshape(-1)

        if mask_s1_flag:
            s = batch["s2"]
        else:
            s = batch["s1"]

        y_0 = (y == 0).int().reshape(-1).to(device)
        y_1 = (y == 1).int().reshape(-1).to(device)
        s_1 = (s == 1).int().to(device)
        s_0 = (s == 0).int().to(device)
        pred_1 = (prediction == 1).int().to(device)
        pred_0 = (prediction == 0).int().to(device)

        num_s1_pred1 += (s_1 * pred_1).sum().to(device)
        num_s1_pred0 += (s_1 * pred_0).sum().to(device)
        num_s0_pred1 += (s_0 * pred_1).sum().to(device)
        num_s0_pred0 += (s_0 * pred_0).sum().to(device)

        num_s1_pred1_y1 += (s_1 * pred_1 * y_1).sum().to(device)
        num_s1_pred1_y0 += (s_1 * pred_1 * y_0).sum().to(device)
        num_s0_pred1_y1 += (s_0 * pred_1 * y_1).sum().to(device)
        num_s0_pred1_y0 += (s_0 * pred_1 * y_0).sum().to(device)

        num_s1_y1 += (s_1 * y_1).sum().to(device)
        num_s1_y0 += (s_1 * y_0).sum().to(device)
        num_s0_y1 += (s_0 * y_1).sum().to(device)
        num_s0_y0 += (s_0 * y_0).sum().to(device)

        pr_y_hat1_s0_y1_ck = num_s0_pred1_y1 / (y_0+y_1)
        pr_y_hat1_s1_y1_ck = num_s1_pred1_y1 / (y_0+y_1)

        pr_s1_y1_ck = (num_s1_y1/(y_0+y_1))
        pr_s0_y1_ck = (num_s0_y1/(y_0+y_1))

        A = (pr_y_hat1_s0_y1_ck / pr_s0_y1_ck) * pr_s0_y1_ck / pr_y1_s0
        B = (pr_y_hat1_s1_y1_ck / pr_s1_y1_ck) * pr_s1_y1_ck / pr_y1_s1
        F = A - B

    return F


# FairFed with Logistic Regression Model
def FairFed_LR(device,
               global_model,
               algorithm_step_T, num_clients_K, communication_round_I, FL_fraction, FL_drop_rate, local_step_size,
               training_dataloaders,
               training_dataset,
               client_dataset_list
               ):
    # Training process
    logger.info("Training process")
    criterion = torch.nn.BCELoss(reduction='mean')

    training_dataset_size = len(training_dataset)
    client_datasets_size_list = [len(item) for item in client_dataset_list]
    basic_average_weight = np.array([float(_ / training_dataset_size) for _ in client_datasets_size_list])
    average_weight = basic_average_weight

    # Parameter Initialization
    # Hyperparameter
    β = 1

    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for i in range(num_clients_K)]

    client_loss_list = [[] for _ in range(num_clients_K)]

    # Client selection
    logger.info("********** Client selection **********")
    idxs_users = client_selection(
        client_num=num_clients_K,
        fraction=FL_fraction,
        dataset_size=training_dataset_size,
        client_dataset_size_list=client_datasets_size_list,
        drop_rate=FL_drop_rate,
        style="FedAvg",
    )
    logger.info(f"********** Select client list: {idxs_users} **********")

    w_bar_k_t_array = basic_average_weight
    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        avg_loss_over_step = 0

        client_acc_list = [0 for _ in range(num_clients_K)]
        global_acc = 0
        for i in idxs_users:
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            model.zero_grad()
            model.to(device)
            client_i_dataloader = training_dataloaders[i]

            acc_numerator = 0
            acc_denominator = 0
            # Local Optimizing By SGD
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                local_prediction = model(X).to(device)
                prediction = (local_prediction >= 0.5).reshape(-1)

                acc_numerator += sum(prediction.eq(y))
                acc_denominator += X.shape[0]

                loss = criterion(local_prediction, y.float())
                loss.backward()

                tmp = round(float(loss), 4)
                avg_loss_over_step += tmp
                client_loss_list[i].append(tmp)

                # if (iter_t + 1) % 5 == 0:
                #     logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                #                 f"Client: {i} / {num_clients_K}; "
                #                 f"Avg Loss in Batch: {round(float(sub_batch_loss), 4)} ##########")

                optimizer.step()
                optimizer.zero_grad()
                # del X, y, local_prediction, loss
                # gc.collect()
                # torch.cuda.empty_cache()
                break

            # Upgrade the local model list
            local_model_list[i] = model

            client_i_acc = get_full_dataset_acc(device, model, client_i_dataloader)
            client_acc_list[i] = client_i_acc
            global_acc += basic_average_weight[i] * client_i_acc

        # △_k^t in Eq.6
        delta_k_list = [0 for _ in range(num_clients_K)]
        delta_k_list_mean = 0
        for i in idxs_users:
            delta_k = abs(client_acc_list[i] - global_acc)
            delta_k_list[i] = delta_k
            delta_k_list_mean += delta_k * 1 / len(idxs_users)

        # w_bar_k^t in Eq.6
        for i in idxs_users:
            w_bar_k_t_array[i] = w_bar_k_t_array[i] - β * (delta_k_list[i] - delta_k_list_mean)

        # w_k_t in Eq.6
        for i in idxs_users:
            w_k_t = w_bar_k_t_array[i] / w_bar_k_t_array.sum()
            average_weight[i] = w_k_t

        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step/len(idxs_users)), 4)} ##########")
        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")

            # Global operation
            logger.info("********** Parameter aggregation **********")
            theta_list = []
            for id in idxs_users:
                selected_model = local_model_list[id]
                theta_list.append(get_parameters(selected_model))
            theta_list = np.array(theta_list, dtype=object)

            theta_avg = np.average(theta_list, axis=0, weights=[average_weight[j] for j in idxs_users]).tolist()
            # theta_avg = np.mean(theta_list, 0).tolist()
            global_model = set_parameters(global_model, theta_avg)

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Client Reselection
            logger.info("********** Client selection **********")
            idxs_users = client_selection(
                client_num=num_clients_K,
                fraction=FL_fraction,
                dataset_size=training_dataset_size,
                client_dataset_size_list=client_datasets_size_list,
                drop_rate=FL_drop_rate,
                style="FedAvg",
            )
            logger.info(f"********** Select client list: {idxs_users} **********")

        # Save
        # if (iter_t) % param_dict["save_checkpoint_rounds"] == 0 and iter_t != 0:
        #     save_model(param_dict, global_model, local_model_list, iter_t)
    # logger.info(f"########## Federated Average client loss list: {client_loss_list}; ##########")

    logger.info("Training finish, return global model and local model list")
    return global_model, local_model_list

