from algorithm.FederatedRenyi.FederatedRenyi_component import *
from tool.logger import *


# A specific version of Fed_Renyi_LR for observing the effect of acceleration process and analyzing the bias

def Fed_Renyi_LR(device,
                 mask_s1_flag,
                 lamda,
                 global_model,
                 tolerance_τ,
                 algorithm_step_T, num_clients_K, communication_round_I,
                 local_step_size,
                 training_dataloaders,
                 training_dataset,
                 client_dataset_list,
                 straggler_rate_α,
                 rho,
                 γ_k_style,
                 communication_cost_list_list,
                 descending_order_list_list
                 ):
    # Initialization
    logger.info("Training process")

    client_datasets_size_list, local_model_list, \
    global_v, r_bar_k_p0_list, r_bar_k_p1_list, \
    γ_k_list, r_hat_p0, r_hat_p1, v_hat_1 = initialization(client_dataset_list, global_model, num_clients_K,
                                                           mask_s1_flag, training_dataset, γ_k_style)
    criterion = torch.nn.BCELoss(reduction='mean')

    # Time consumption
    total_time_consumption = 0
    reference_time_consumption = 0
    total_communication_time_consumption = 0
    reference_communication_time_consumption = 0

    # Synchronous algorithm
    logger.info("Synchronous Training process")

    local_time_consumption_list = []
    statistical_tuple_list = []  # [ {j_1(c, p), u_1(c), θ_1}, …, {j_k(c, p), u_k(c), θ_k} ]
    straggler_statistical_tuple_list = [0 for i in range(num_clients_K)]
    for iter_t in range(tolerance_τ):
        # Simulate Client Parallel
        avg_loss_over_step, avg_ce_over_step, avg_regularization_term_over_step = 0, 0, 0
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]

            # Starts timing the local operation
            start_time = time.time()
            # local optimization
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                s = batch["s2"] if mask_s1_flag else batch["s1"]

                local_prediction = model(X).to(device)
                loss = criterion(local_prediction, y.float())
                avg_ce_over_step += loss

                y_hat_θ = (local_prediction >= 0.5).reshape(-1).to(device)

                Q, _, _, _, _, _, _ = get_Q_hat_θ(y_hat_θ, s, device)
                G = get_G_hat_θ_hat_v(Q, global_v, device).to(device)
                regularization_term = lamda * G
                avg_regularization_term_over_step += regularization_term

                loss += regularization_term
                avg_loss_over_step += loss

                loss.backward()
                optimizer.step()

                break

            # Stops timing for local operations
            end_time = time.time()

            # Record the local operation time for each client
            used_time = end_time - start_time

            # Upgrade the local model list
            local_model_list[i] = model

            # prepare for asynchronous in the last turn
            if (iter_t == tolerance_τ - 1) and (tolerance_τ != algorithm_step_T):
                # Starts timing the preparation about asynchronous
                preparation_start_time = time.time()

                # {j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1, np.array(get_parameters(model))}
                client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag,
                                                     "LR", device)

                statistical_tuple_list.append(client_tuple)

                # Stops timing for local reparation about asynchronous
                preparation_end_time = time.time()

                # Record the time for local reparation about asynchronous
                used_time += preparation_end_time - preparation_start_time

            local_time_consumption_list.append(used_time)

        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg CE over Client: {round(float(avg_ce_over_step / num_clients_K), 4)}; "
                        f"Avg RT over Client: {round(float(avg_regularization_term_over_step / num_clients_K), 4)}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step / num_clients_K), 4)}; ##########")

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            # logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")
            # Simulate the communication cost of each client
            communication_cost_list = communication_cost_list_list[0]
            maximum_communication_cost = communication_cost_list[0]

            # Global operation
            # logger.info("********** Parameter aggregation **********")

            # Starts timing the global operation
            start_time = time.time()

            theta_list = []
            for i in range(num_clients_K):
                model = local_model_list[i]
                γ_k = float(γ_k_list[i])
                theta_list.append(list(γ_k * np.array(get_parameters(model))))

            theta_list = np.array(theta_list, dtype=object)
            theta_avg = np.sum(theta_list, 0).tolist()
            set_parameters(global_model, theta_avg)

            # logger.info("********** Global v update **********")
            backup_v = global_v

            try:
                global_v = get_argmax_v([i for i in range(num_clients_K)], local_model_list, mask_s1_flag,
                                        training_dataset,
                                        client_dataset_list,
                                        r_hat_p0, r_hat_p1, device, γ_k_style, hypothesis="LR")
            except Exception:
                global_v = backup_v

            # Stops timing for global operations
            end_time = time.time()

            # Record the global operation time consumption
            global_operation_used_time = end_time - start_time

            # Parameter Distribution
            # logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Record the total time consumption
            total_time_consumption += max(local_time_consumption_list) \
                                      + 2 * maximum_communication_cost \
                                      + global_operation_used_time
            total_communication_time_consumption += 2 * maximum_communication_cost

    reference_communication_time_consumption = total_communication_time_consumption

    # Prepare for the asynchronous algorithm
    # Starts timing
    preparation_start_time = time.time()

    similarity_matrix = get_similarity_matrix(statistical_tuple_list, lamda, rho)

    # Stops timing
    preparation_end_time = time.time()

    # Record the time
    time_consumption = preparation_end_time - preparation_start_time
    total_time_consumption += time_consumption
    reference_time_consumption = total_time_consumption

    # Asynchronous algorithm
    logger.info("Asynchronous Training process")
    # 记录整个异步过程每一轮平均每个人的Approximation Error
    total_j_distance, total_u_distance, total_θ_distance = 0, 0, 0

    local_time_consumption_list = []
    client_latest_time_stamp = [0 for _ in range(num_clients_K)]

    # Prepare for the bias analysis
    j_0_0_discrepancy_list, j_0_1_discrepancy_list, j_1_0_discrepancy_list, j_1_1_discrepancy_list = [], [], [], []
    u_0_discrepancy_list, u_1_discrepancy_list = [], []
    j_0_0_discrepancy_percentage_list, j_0_1_discrepancy_percentage_list = [], []
    j_1_0_discrepancy_percentage_list, j_1_1_discrepancy_percentage_list = [], []
    u_0_discrepancy_percentage_list, u_1_discrepancy_percentage_list = [], []

    for iter_t in range(algorithm_step_T - tolerance_τ):
        # Record the j,u matrix of each client
        j_bar_0_0_list, j_bar_0_1_list, j_bar_1_0_list, j_bar_1_1_list, u_bar_0_list, u_bar_1_list = [], [], [], [], [], []

        # Record the j,u approximation matrix of each client
        j_tilde_0_0_list, j_tilde_0_1_list, j_tilde_1_0_list, j_tilde_1_1_list, u_tilde_0_list, u_tilde_1_list = [], [], [], [], [], []

        # Simulate Client Parallel
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]
            # Starts timing the local operation
            start_time = time.time()
            # local optimization
            # if (iter_t + tolerance_τ + 1) % 10 == 0:
            # logger.info(f"########## Algorithm Epoch: {iter_t + tolerance_τ + 1} / {algorithm_epoch_T}; "
            #             f"Client: {i + 1} / {num_clients_K};  ##########")
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                s = batch["s2"] if mask_s1_flag else batch["s1"]

                local_prediction = model(X).to(device)
                try:
                    loss = criterion(local_prediction, y.float())
                    avg_ce_over_step += loss

                    y_hat_θ = (local_prediction >= 0.5).reshape(-1).to(device)

                    Q, j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1 = get_Q_hat_θ(y_hat_θ, s, device)
                    j_bar_0_0_list.append(j_bar_0_0)
                    j_bar_0_1_list.append(j_bar_0_1)
                    j_bar_1_0_list.append(j_bar_1_0)
                    j_bar_1_1_list.append(j_bar_1_1)
                    u_bar_0_list.append(u_bar_0)
                    u_bar_1_list.append(u_bar_1)

                    G = get_G_hat_θ_hat_v(Q, global_v, device).to(device)
                    regularization_term = lamda * G

                    # if (iter_t + tolerance_τ + 1) % 10 == 0:
                    #     logger.info(f"      @@@@ Cross Entropy Loss: {round(float(loss), 4)}; "
                    #                 f"Regularization term: {round(float(regularization_term), 4)}; "
                    #                 f"Total Loss: {round(float(loss + regularization_term), 4)}; "
                    #                 f"@@@@")

                    loss += regularization_term
                    loss.backward()
                    optimizer.step()

                    break

                except Exception:
                    logger.info("Some Error happen in this batch of data, skipping this batch !")
                    try:
                        continue
                    except IndexError:
                        break

            # Stops timing for local operations
            end_time = time.time()

            # Record the local operation time for each client
            used_time = end_time - start_time
            local_time_consumption_list.append(used_time)

            # Upgrade the local model list
            local_model_list[i] = model

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {math.floor((iter_t + tolerance_τ + 1) / communication_round_I)} **********")

            # Simulate the communication cost of each client
            communication_cost_list = communication_cost_list_list[0]
            # descending_order_list = descending_order_list_list[iter_t + tolerance_τ]
            descending_order_list = descending_order_list_list[
                int((iter_t + 1) / communication_round_I + (0.5 * algorithm_step_T))]
            idxs_users, straggler_ids = get_communication_idxs_list(num_clients_K, straggler_rate_α,
                                                                    descending_order_list)

            logger.info(f" idxs_users: {idxs_users}")
            logger.info(f" straggler_ids: {straggler_ids}")


            real_time_consumption_list = [communication_cost_list[i] for i in idxs_users]
            real_maximum_communication_cost = max(real_time_consumption_list)
            reference_maximum_communication_cost = communication_cost_list[0]

            tilde_list = [0 for i in range(num_clients_K)]

            # Upgrade the statistical tuple list by the latest uploaded info
            for i in idxs_users:
                # zeta_ζ = iter_t - client_latest_time_stamp[i]
                client_latest_time_stamp[i] = iter_t

                # 如果需要重新计算，则调用函数重新计算：
                # client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag, "NN", device)
                # 否则直接取结果：
                client_tuple = (j_bar_0_0_list[i], j_bar_0_1_list[i], j_bar_1_0_list[i], j_bar_1_1_list[i],
                                u_bar_0_list[i], u_bar_1_list[i], np.array(get_parameters(local_model_list[i])))

                # 如果考虑滞后程度，需要打开这里
                # if (zeta_ζ != 1) and (tilde_list[i] != 0):
                #     compensatory_weight = staleness_function_beta_β(zeta_ζ)
                #
                #     j_0_0_discrepancy = abs(client_tuple[0] - tilde_list[i][0])
                #     j_0_1_discrepancy = abs(client_tuple[1] - tilde_list[i][1])
                #     j_1_0_discrepancy = abs(client_tuple[2] - tilde_list[i][2])
                #     j_1_1_discrepancy = abs(client_tuple[3] - tilde_list[i][3])
                #
                #     u_0_discrepancy = abs(client_tuple[4] - tilde_list[i][4])
                #     u_1_discrepancy = abs(client_tuple[5] - tilde_list[i][5])
                #
                #     θ_discrepancy = abs(client_tuple[6] - tilde_list[i][6])
                #
                #     compensatory_j_0_0_list[i] = compensatory_weight * j_0_0_discrepancy
                #     compensatory_j_0_1_list[i] = compensatory_weight * j_0_1_discrepancy
                #     compensatory_j_1_0_list[i] = compensatory_weight * j_1_0_discrepancy
                #     compensatory_j_1_1_list[i] = compensatory_weight * j_1_1_discrepancy
                #
                #     compensatory_u_0_list[i] = compensatory_weight * u_0_discrepancy
                #     compensatory_u_1_list[i] = compensatory_weight * u_1_discrepancy
                #
                #     compensatory_θ_list[i] = compensatory_weight * θ_discrepancy

                statistical_tuple_list[i] = client_tuple

            for i in straggler_ids:
                # 如果需要重新计算，则调用函数重新计算：
                # client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag, "NN", device)
                # 否则直接取结果：
                client_tuple = (j_bar_0_0_list[i], j_bar_0_1_list[i], j_bar_1_0_list[i], j_bar_1_1_list[i],
                                u_bar_0_list[i], u_bar_1_list[i], np.array(get_parameters(local_model_list[i])))
                straggler_statistical_tuple_list[i] = client_tuple

            # Starts timing the process of approximation and aggregation
            start_time = time.time()

            # Upgrade the statistical tuple list by the approximation process of stragglers

            for i in straggler_ids:
                # (j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1, u_tilde_i_0, u_tilde_i_1, θ_tilde_i)
                tilde = localized_approximation(j_bar_0_0_list, j_bar_0_1_list,
                                                j_bar_1_0_list, j_bar_1_1_list,
                                                u_bar_0_list, u_bar_1_list,
                                                i, local_model_list, similarity_matrix)

                # j_0_0_discrepancy = abs(float(straggler_statistical_tuple_list[i][0] - tilde[0]))
                # j_0_1_discrepancy = abs(float(straggler_statistical_tuple_list[i][1] - tilde[1]))
                # j_1_0_discrepancy = abs(float(straggler_statistical_tuple_list[i][2] - tilde[2]))
                # j_1_1_discrepancy = abs(float(straggler_statistical_tuple_list[i][3] - tilde[3]))
                #
                # u_0_discrepancy = abs(float(straggler_statistical_tuple_list[i][4] - tilde[4]))
                # u_1_discrepancy = abs(float(straggler_statistical_tuple_list[i][5] - tilde[5]))
                #
                # j_0_0_discrepancy_list.append(j_0_0_discrepancy)
                # j_0_1_discrepancy_list.append(j_0_1_discrepancy)
                # j_1_0_discrepancy_list.append(j_1_0_discrepancy)
                # j_1_1_discrepancy_list.append(j_1_1_discrepancy)
                #
                # u_0_discrepancy_list.append(u_0_discrepancy)
                # u_1_discrepancy_list.append(u_1_discrepancy)
                #
                # if tilde[0] != 0:
                #     j_0_0_discrepancy_percentage_list.append(j_0_0_discrepancy / tilde[0])
                # else:
                #     j_0_0_discrepancy_percentage_list.append(0)
                #
                # if tilde[1] != 0:
                #     j_0_1_discrepancy_percentage_list.append(j_0_1_discrepancy / tilde[1])
                # else:
                #     j_0_1_discrepancy_percentage_list.append(0)
                #
                # if tilde[2] != 0:
                #     j_1_0_discrepancy_percentage_list.append(j_1_0_discrepancy / tilde[2])
                # else:
                #     j_1_0_discrepancy_percentage_list.append(0)
                #
                # if tilde[3] != 0:
                #     j_1_1_discrepancy_percentage_list.append(j_1_1_discrepancy / tilde[3])
                # else:
                #     j_1_1_discrepancy_percentage_list.append(0)
                #
                # if tilde[4] != 0:
                #     u_0_discrepancy_percentage_list.append(u_0_discrepancy / tilde[4])
                # else:
                #     u_0_discrepancy_percentage_list.append(0)
                #
                # if tilde[5] != 0:
                #     u_1_discrepancy_percentage_list.append(u_1_discrepancy / tilde[5])
                # else:
                #     u_1_discrepancy_percentage_list.append(0)

                tilde_list[i] = tilde

            # Aggregation
            theta_list = []
            new_local_model_list = []
            j_hat_0_0, j_hat_0_1, j_hat_1_0, j_hat_1_1 = 0, 0, 0, 0
            u_hat_0, u_hat_1 = 0, 0

            # The aggregation of well-connected client and compensation
            for i in idxs_users:
                model = statistical_tuple_list[i][6]
                new_local_model_list.append(model)
                γ_k = float(γ_k_list[i])

                theta_list.append(
                    list(
                        # γ_k * (model + compensatory_θ_list[i])
                        γ_k * (model)
                    )
                )

                j_hat_0_0 += γ_k * statistical_tuple_list[i][0]
                j_hat_0_1 += γ_k * statistical_tuple_list[i][1]
                j_hat_1_0 += γ_k * statistical_tuple_list[i][2]
                j_hat_1_1 += γ_k * statistical_tuple_list[i][3]

                u_hat_0 += γ_k * statistical_tuple_list[i][4]
                u_hat_1 += γ_k * statistical_tuple_list[i][5]

            #  The aggregation of straggler
            for i in straggler_ids:
                γ_k = float(γ_k_list[i])

                θ_tilde_i = tilde_list[i][6]
                theta_list.append(
                    list(γ_k * θ_tilde_i)
                )

                j_tilde_i_0_0 = tilde_list[i][0]
                j_hat_0_0 += γ_k * j_tilde_i_0_0
                j_tilde_i_0_1 = tilde_list[i][1]
                j_hat_0_1 += γ_k * j_tilde_i_0_1
                j_tilde_i_1_0 = tilde_list[i][2]
                j_hat_1_0 += γ_k * j_tilde_i_1_0
                j_tilde_i_1_1 = tilde_list[i][3]
                j_hat_1_1 += γ_k * j_tilde_i_1_1

                u_tilde_i_0 = tilde_list[i][4]
                u_hat_0 += γ_k * u_tilde_i_0
                u_tilde_i_1 = tilde_list[i][5]
                u_hat_1 += γ_k * u_tilde_i_1

            theta_hat = np.array(theta_list, dtype=object).sum(axis=0).tolist()
            set_parameters(global_model, theta_hat)

            logger.info("********** Global v update **********")
            backup_v = global_v
            try:
                q_00 = get_q_c_p(j_hat_0_0, u_hat_0, r_hat_p0, device)
                q_01 = get_q_c_p(j_hat_0_1, u_hat_0, r_hat_p1, device)
                q_10 = get_q_c_p(j_hat_1_0, u_hat_1, r_hat_p0, device)
                q_11 = get_q_c_p(j_hat_1_1, u_hat_1, r_hat_p1, device)
                Q_hat = torch.tensor([
                    [q_00, q_01],
                    [q_10, q_11]
                ]).to(device)

                _, _, v = torch.svd(Q_hat)

                # logger.info(f" Q_hat: {str(Q_hat.tolist())};")                #
                # logger.info(f" V: {str(v[1].reshape(-1, 1).tolist())}; global_v: {str(global_v.tolist())}")
                global_v = v[1].reshape(-1, 1).to(device)
            except Exception:
                global_v = backup_v

            # Stops timing the process of approximation and aggregation
            end_time = time.time()

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            for i in idxs_users:
                local_model_list[i] = copy.deepcopy(global_model)

            # Record the time of approximation and aggregation
            approximation_and_aggregation_cost = (end_time - start_time)

            total_time_consumption += max(local_time_consumption_list) \
                                      + 2 * real_maximum_communication_cost \
                                      + approximation_and_aggregation_cost

            reference_time_consumption += max(local_time_consumption_list) \
                                          + 2 * reference_maximum_communication_cost \
                                          + approximation_and_aggregation_cost

            # 记录这一轮所有掉队者的Approximation Error
            total_j_distance_in_t_round, total_u_distance_in_t_round, total_θ_distance_in_t_round = 0, 0, 0
            # 目前的statistical_tuple_list只更新了正常链接的客户id的内容，还没有把上一次Approximation的结果更新进去
            for i in straggler_ids:
                j_tilde_i_0_0 = tilde_list[i][0]
                j_tilde_i_0_1 = tilde_list[i][1]
                j_tilde_i_1_0 = tilde_list[i][2]
                j_tilde_i_1_1 = tilde_list[i][3]

                u_tilde_i_0 = tilde_list[i][4]
                u_tilde_i_1 = tilde_list[i][5]

                θ_tilde_i = tilde_list[i][6]
                theta_list.append(list(θ_tilde_i))
                theta_hat = np.array(theta_list, dtype=object).sum(axis=0).tolist()

                client_tuple = (j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1,
                                u_tilde_i_0, u_tilde_i_1, np.array(theta_hat))

                # 更新统计量的同时计算Approximation过程导致的误差
                approximation_tuple, real_tuple = client_tuple, statistical_tuple_list[i]
                j_distance, u_distance, θ_distance = get_statistical_distance(approximation_tuple, real_tuple)
                total_j_distance_in_t_round += j_distance
                total_u_distance_in_t_round += u_distance
                total_θ_distance_in_t_round += θ_distance

                # logger.info("********** Straggler {i}-th Statistical distance **********")
                # logger.info("********** j_distance: **********", {j_distance})
                # logger.info("********** u_distance: **********", {u_distance})
                # logger.info("********** θ_distance: **********", {θ_distance})

                statistical_tuple_list[i] = client_tuple
            # 动态计算Approximation Weight
            # similarity_matrix = get_similarity_matrix(statistical_tuple_list, lamda, rho)

            # 计算这一轮的每个掉队者的平均Approximation Error
            straggler_count = num_clients_K * straggler_rate_α
            total_j_distance += total_j_distance_in_t_round/straggler_count
            total_u_distance += total_u_distance_in_t_round/straggler_count
            total_θ_distance += total_θ_distance_in_t_round/straggler_count
            logger.info(f"*** {iter_t + 1}-th Step average approximation j_distance over stragglers: {total_j_distance_in_t_round/straggler_count} ***")
            logger.info(f"*** {iter_t + 1}-th Step average approximation u_distance over stragglers: {total_u_distance_in_t_round/straggler_count } ***")
            logger.info(f"*** {iter_t + 1}-th Step average approximation θ_distance over stragglers: {total_θ_distance_in_t_round/straggler_count} ***")

    # if len(j_0_0_discrepancy_list) == 0:
    #     logger.info("Without straggler, No analysis!")
    # else:
        # final_j_0_0_discrepancy = round(float(sum(j_0_0_discrepancy_list) / len(j_0_0_discrepancy_list)), 4)
        # final_j_0_1_discrepancy = round(float(sum(j_0_1_discrepancy_list) / len(j_0_1_discrepancy_list)), 4)
        # final_j_1_0_discrepancy = round(float(sum(j_1_0_discrepancy_list) / len(j_1_0_discrepancy_list)), 4)
        # final_j_1_1_discrepancy = round(float(sum(j_1_1_discrepancy_list) / len(j_1_1_discrepancy_list)), 4)
        # final_u_0_discrepancy = round(float(sum(u_0_discrepancy_list) / len(u_0_discrepancy_list)), 4)
        # final_u_1_discrepancy = round(float(sum(u_1_discrepancy_list) / len(u_1_discrepancy_list)), 4)
        #
        # final_j_discrepancy = round(float((final_j_0_0_discrepancy + final_j_0_1_discrepancy
        #                                    + final_j_1_0_discrepancy + final_j_1_1_discrepancy) / 4), 4)
        # final_u_discrepancy = round(float((final_u_0_discrepancy + final_u_1_discrepancy) / 2), 4)
        # final_statistical_discrepancy = round(float((final_j_0_0_discrepancy + final_j_0_1_discrepancy
        #                                              + final_j_1_0_discrepancy + final_j_1_1_discrepancy
        #                                              + final_u_0_discrepancy + final_u_1_discrepancy) / 6), 4)
        #
        # final_j_0_0_discrepancy_percentage = round(float(
        #     sum(j_0_0_discrepancy_percentage_list) * 100 / len(j_0_0_discrepancy_percentage_list)
        # ), 2)
        # final_j_0_1_discrepancy_percentage = round(float(
        #     sum(j_0_1_discrepancy_percentage_list) * 100 / len(j_0_1_discrepancy_percentage_list)
        # ), 2)
        # final_j_1_0_discrepancy_percentage = round(float(
        #     sum(j_1_0_discrepancy_percentage_list) * 100 / len(j_1_0_discrepancy_percentage_list)
        # ), 2)
        # final_j_1_1_discrepancy_percentage = round(float(
        #     sum(j_1_1_discrepancy_percentage_list) * 100 / len(j_1_1_discrepancy_percentage_list)
        # ), 2)
        # final_u_0_discrepancy_percentage = round(float(
        #     sum(u_0_discrepancy_percentage_list) * 100 / len(u_0_discrepancy_percentage_list)
        # ), 2)
        # final_u_1_discrepancy_percentage = round(float(
        #     sum(u_1_discrepancy_percentage_list) * 100 / len(u_1_discrepancy_percentage_list)
        # ), 2)
        #
        # final_j_discrepancy_percentage = round(float(
        #     (final_j_0_0_discrepancy_percentage + final_j_0_1_discrepancy_percentage
        #      + final_j_1_0_discrepancy_percentage + final_j_1_1_discrepancy_percentage) / 4
        # ), 2)
        # final_u_discrepancy_percentage = round(float(
        #     (final_u_0_discrepancy_percentage + final_u_1_discrepancy_percentage) / 2
        # ), 2)
        # final_statistical_discrepancy_percentage = round(float(
        #     (final_j_0_0_discrepancy_percentage + final_j_0_1_discrepancy_percentage
        #      + final_j_1_0_discrepancy_percentage + final_j_1_1_discrepancy_percentage
        #      + final_u_0_discrepancy_percentage + final_u_1_discrepancy_percentage) / 6
        # ), 2)
        # logger.info(f" Analysising the bias: \n"
        #             f" final_j_0_0_discrepancy: {final_j_0_0_discrepancy} "
        #             f" final_j_0_1_discrepancy: {final_j_0_1_discrepancy} \n"
        #             f" final_j_1_0_discrepancy: {final_j_1_0_discrepancy} "
        #             f" final_j_1_1_discrepancy: {final_j_1_1_discrepancy} ; \n"
        #             f" final_u_0_discrepancy: {final_u_0_discrepancy} "
        #             f" final_u_1_discrepancy: {final_u_1_discrepancy} ; \n"
        #             f" final_j_discrepancy: {final_j_discrepancy} ; \n"
        #             f" final_u_discrepancy: {final_u_discrepancy} ; \n"
        #             f" final_statistical_discrepancy: {final_statistical_discrepancy} ; \n\n\n"
        #             f" final_j_0_0_discrepancy_percentage: {final_j_0_0_discrepancy_percentage} "
        #             f" final_j_0_1_discrepancy_percentage: {final_j_0_1_discrepancy_percentage} \n"
        #             f" final_j_1_0_discrepancy_percentage: {final_j_1_0_discrepancy_percentage} "
        #             f" final_j_1_1_discrepancy_percentage: {final_j_1_1_discrepancy_percentage} ; \n"
        #             f" final_u_0_discrepancy_percentage: {final_u_0_discrepancy_percentage} "
        #             f" final_u_1_discrepancy_percentage: {final_u_1_discrepancy_percentage}  ; \n"
        #             f" final_j_discrepancy_percentage: {final_j_discrepancy_percentage} ; \n"
        #             f" final_u_discrepancy_percentage: {final_u_discrepancy_percentage} ; \n"
        #             f" final_statistical_discrepancy_percentage: {final_statistical_discrepancy_percentage} ; \n\n\n"
        #             )

    if (straggler_rate_α != 0) and (algorithm_step_T != tolerance_τ):
        # 输出平均每轮的Approximation Error
        logger.info(
            f"*** Average approximation j_distance: {total_j_distance / (algorithm_step_T - tolerance_τ)} ***")
        logger.info(
            f"*** Average approximation u_distance: {total_u_distance / (algorithm_step_T - tolerance_τ)} ***")
        logger.info(
            f"*** Average approximation θ_distance: {total_θ_distance / (algorithm_step_T - tolerance_τ)} ***")


    # 算法结束
    logger.info("|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")
    logger.info("|||||||||||||||||||||||||||| Algorithm End ||||||||||||||||||||||||||||||||||")
    logger.info("|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||")

    # 输出时间
    logger.info(f" The proportion of communication time cost."
                f" Reference: {reference_communication_time_consumption / reference_time_consumption};"
                f" Total: {total_communication_time_consumption / total_time_consumption}")
    logger.info(
        f" Training finish, total use time: {total_time_consumption} s, reference use time: {reference_time_consumption},"
        f" speed up: {abs(total_time_consumption - reference_time_consumption)}."
        f" speed up rate: {abs(total_time_consumption - reference_time_consumption) / reference_time_consumption}."
        f" Return global model and local model list")

    return global_model, local_model_list
