from algorithm.FederatedRenyi.FederatedRenyi_component import *
from tool.logger import *

def Fed_Renyi_NN(device,
                 mask_s1_flag,
                 lamda,
                 global_model,
                 tolerance_τ,
                 algorithm_step_T, num_clients_K, communication_round_I,
                 local_step_size,
                 training_dataloaders,
                 training_dataset,
                 client_dataset_list,
                 straggler_rate_α,
                 rho,
                 γ_k_style
                 ):
    # Initialization
    logger.info("Training process")
    client_datasets_size_list, local_model_list, \
    global_v, r_bar_k_p0_list, r_bar_k_p1_list, \
    γ_k_list, r_hat_p0, r_hat_p1, v_hat_1 = initialization(client_dataset_list, global_model, num_clients_K,
                                                           mask_s1_flag, training_dataset, γ_k_style)
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    # Time consumption
    total_time_consumption = 0
    reference_time_consumption = 0
    total_communication_time_consumption = 0
    reference_communication_time_consumption = 0

    # Synchronous algorithm
    local_time_consumption_list = []
    statistical_tuple_list = []  # [ {j_1(c, p), u_1(c), θ_1}, …, {j_k(c, p), u_k(c), θ_k} ]
    straggler_statistical_tuple_list = [0 for i in range(num_clients_K)]
    for iter_t in range(tolerance_τ):
        # Simulate Client Parallel
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]

            # Starts timing the local operation
            start_time = time.time()
            # local optimization
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].to(device)
                s = batch["s2"] if mask_s1_flag else batch["s1"]
                local_prediction = model(X).to(device)
                loss = criterion(local_prediction, y.long())
                y_hat_θ = torch.argmax(local_prediction, dim=1).to(device)

                Q, _, _, _, _, _, _ = get_Q_hat_θ(y_hat_θ, s, device)
                G = get_G_hat_θ_hat_v(Q, global_v, device).to(device)
                regularization_term = lamda * G

                logger.info(f"@@@@  Cross Entropy Loss: {round(float(loss), 4)}; "
                            f"Regularization term: {round(float(regularization_term), 4)}; "
                            f"Total Loss: {round(float(loss + regularization_term), 4)};  @@@@")

                loss += regularization_term
                loss.backward()
                optimizer.step()

                break

            # Stops timing for local operations
            end_time = time.time()

            # Record the local operation time for each client
            used_time = end_time - start_time

            # Upgrade the local model list
            local_model_list[i] = model

            # prepare for asynchronous in the last turn
            if (iter_t == tolerance_τ - 1) and (tolerance_τ != algorithm_step_T):
                # Starts timing the preparation about asynchronous
                preparation_start_time = time.time()

                # {j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1, np.array(get_parameters(model))}
                client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag,
                                                     "NN", device)

                statistical_tuple_list.append(client_tuple)

                # Stops timing for local reparation about asynchronous
                preparation_end_time = time.time()

                # Record the time for local reparation about asynchronous
                used_time += preparation_end_time - preparation_start_time

            local_time_consumption_list.append(used_time)

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")
            # Simulate the communication cost of each client
            communication_cost_list, descending_order_list = communication_cost_simulated_by_beta_distribution(num_clients_K)
            maximum_communication_cost = communication_cost_list[0]

            # idxs_users, straggler_ids = get_communication_idxs_list(num_clients_K, straggler_rate_α,
            #                                                         descending_order_list)

            # Global operation
            logger.info("********** Parameter aggregation **********")

            # Starts timing the global operation
            start_time = time.time()

            theta_list = []
            for i in range(num_clients_K):
                model = local_model_list[i]
                γ_k = float(γ_k_list[i])
                theta_list.append(list(γ_k * np.array(get_parameters(model), dtype=object)))

            theta_list = np.array(theta_list, dtype=object)
            theta_avg = np.sum(theta_list, 0).tolist()
            set_parameters(global_model, theta_avg)

            logger.info("********** Global v update **********")
            backup_v = global_v

            try:
                global_v = get_argmax_v([i for i in range(num_clients_K)], local_model_list, mask_s1_flag,
                                        training_dataset,
                                        client_dataset_list,
                                        r_hat_p0, r_hat_p1, device, γ_k_style, hypothesis="LR")
            except Exception:
                global_v = backup_v

            # Stops timing for global operations
            end_time = time.time()

            # Record the global operation time consumption
            global_operation_used_time = end_time - start_time

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Record the total time consumption
            total_time_consumption += max(local_time_consumption_list) \
                                      + 2 * maximum_communication_cost \
                                      + global_operation_used_time
            total_communication_time_consumption += 2 * maximum_communication_cost

    reference_communication_time_consumption = total_communication_time_consumption

    # Prepare for the asynchronous algorithm
    # Starts timing
    preparation_start_time = time.time()

    similarity_matrix = get_similarity_matrix(statistical_tuple_list, lamda, rho)

    # Stops timing
    preparation_end_time = time.time()

    # Record the time
    time_consumption = preparation_end_time - preparation_start_time
    total_time_consumption += time_consumption
    reference_time_consumption = total_time_consumption

    # Asynchronous algorithm
    local_time_consumption_list = []
    client_latest_time_stamp = [0 for i in range(num_clients_K)]

    # Prepare for the bias analysis
    j_0_0_discrepancy_list, j_0_1_discrepancy_list, j_1_0_discrepancy_list, j_1_1_discrepancy_list = [], [], [], []
    u_0_discrepancy_list, u_1_discrepancy_list = [], []
    j_0_0_discrepancy_percentage_list, j_0_1_discrepancy_percentage_list = [], []
    j_1_0_discrepancy_percentage_list, j_1_1_discrepancy_percentage_list = [], []
    u_0_discrepancy_percentage_list, u_1_discrepancy_percentage_list = [], []

    for iter_t in range(algorithm_step_T - tolerance_τ):
        # Record the j,u matrix of each client
        j_bar_0_0_list, j_bar_0_1_list, j_bar_1_0_list, j_bar_1_1_list, u_bar_0_list, u_bar_1_list = [], [], [], [], [], []

        # Record the j,u approximation matrix of each client
        j_tilde_0_0_list, j_tilde_0_1_list, j_tilde_1_0_list, j_tilde_1_1_list, u_tilde_0_list, u_tilde_1_list = [], [], [], [], [], []

        # Simulate Client Parallel
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]
            # Starts timing the local operation
            start_time = time.time()
            # local optimization
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].to(device)
                s = batch["s2"] if mask_s1_flag else batch["s1"]

                local_prediction = model(X).to(device)
                loss = criterion(local_prediction, y.long())
                loss = loss / len(batch["X"])
                y_hat_θ = torch.argmax(local_prediction, dim=1).to(device)

                Q, j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1 = get_Q_hat_θ(y_hat_θ, s, device)
                j_bar_0_0_list.append(j_bar_0_0)
                j_bar_0_1_list.append(j_bar_0_1)
                j_bar_1_0_list.append(j_bar_1_0)
                j_bar_1_1_list.append(j_bar_1_1)
                u_bar_0_list.append(u_bar_0)
                u_bar_1_list.append(u_bar_1)

                G = get_G_hat_θ_hat_v(Q, global_v, device).to(device)
                regularization_term = lamda * G

                if (iter_t + tolerance_τ + 1) % 10 == 0:
                    logger.info(f"@@@@  Cross Entropy Loss: {round(float(loss), 4)}; "
                                f"Regularization term: {round(float(regularization_term), 4)}; "
                                f"Total Loss: {round(float(loss + regularization_term), 4)};  @@@@")

                loss += regularization_term
                loss.backward()
                optimizer.step()

                break

            # Stops timing for local operations
            end_time = time.time()

            # Record the local operation time for each client
            used_time = end_time - start_time
            local_time_consumption_list.append(used_time)

            # Upgrade the local model list
            local_model_list[i] = model

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + tolerance_τ + 1) / communication_round_I} **********")
            # Simulate the communication cost of each client
            communication_cost_list, descending_order_list = communication_cost_simulated_by_beta_distribution(num_clients_K)

            idxs_users, straggler_ids = get_communication_idxs_list(num_clients_K, straggler_rate_α,
                                                                    descending_order_list)
            real_time_consumption_list = [communication_cost_list[i] for i in idxs_users]
            real_maximum_communication_cost = max(real_time_consumption_list)
            reference_maximum_communication_cost = communication_cost_list[0]

            # Prepare for compensation
            # compensatory_j_0_0_list, compensatory_j_0_1_list = [0 for i in range(num_clients_K)], \
            #                                                    [0 for i in range(num_clients_K)]
            # compensatory_j_1_0_list, compensatory_j_1_1_list = [0 for i in range(num_clients_K)], \
            #                                                    [0 for i in range(num_clients_K)]
            # compensatory_u_0_list, compensatory_u_1_list = [0 for i in range(num_clients_K)], \
            #                                                [0 for i in range(num_clients_K)]
            # compensatory_θ_list = [0 for i in range(num_clients_K)]
            tilde_list = [0 for _ in range(num_clients_K)]

            # Upgrade the statistical tuple list by the latest uploaded info
            for i in idxs_users:
                # zeta_ζ = iter_t - client_latest_time_stamp[i]
                client_latest_time_stamp[i] = iter_t

                # 如果需要重新计算，则调用函数重新计算：
                # client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag, "NN", device)
                # 否则直接取结果：
                client_tuple = (j_bar_0_0_list[i], j_bar_0_1_list[i], j_bar_1_0_list[i], j_bar_1_1_list[i],
                                u_bar_0_list[i], u_bar_1_list[i], np.array(get_parameters(local_model_list[i]), dtype=object))

                # if (zeta_ζ != 1) and (tilde_list[i] != 0):
                #     compensatory_weight = staleness_function_beta_β(zeta_ζ)
                #
                #     j_0_0_discrepancy = abs(client_tuple[0] - tilde_list[i][0])
                #     j_0_1_discrepancy = abs(client_tuple[1] - tilde_list[i][1])
                #     j_1_0_discrepancy = abs(client_tuple[2] - tilde_list[i][2])
                #     j_1_1_discrepancy = abs(client_tuple[3] - tilde_list[i][3])
                #
                #     u_0_discrepancy = abs(client_tuple[4] - tilde_list[i][4])
                #     u_1_discrepancy = abs(client_tuple[5] - tilde_list[i][5])
                #
                #     θ_discrepancy = abs(client_tuple[6] - tilde_list[i][6])
                #
                #     compensatory_j_0_0_list[i] = compensatory_weight * j_0_0_discrepancy
                #     compensatory_j_0_1_list[i] = compensatory_weight * j_0_1_discrepancy
                #     compensatory_j_1_0_list[i] = compensatory_weight * j_1_0_discrepancy
                #     compensatory_j_1_1_list[i] = compensatory_weight * j_1_1_discrepancy
                #
                #     compensatory_u_0_list[i] = compensatory_weight * u_0_discrepancy
                #     compensatory_u_1_list[i] = compensatory_weight * u_1_discrepancy
                #
                #     compensatory_θ_list[i] = compensatory_weight * θ_discrepancy

                statistical_tuple_list[i] = client_tuple

            for i in straggler_ids:
                # 如果需要重新计算，则调用函数重新计算：
                # client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag, "NN", device)
                # 否则直接取结果：
                client_tuple = (j_bar_0_0_list[i], j_bar_0_1_list[i], j_bar_1_0_list[i], j_bar_1_1_list[i],
                                u_bar_0_list[i], u_bar_1_list[i], np.array(get_parameters(local_model_list[i])))
                straggler_statistical_tuple_list[i] = client_tuple

            # Starts timing the process of approximation and aggregation
            start_time = time.time()

            # Upgrade the statistical tuple list by the approximation process of stragglers

            for i in straggler_ids:
                # (j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1, u_tilde_i_0, u_tilde_i_1, θ_tilde_i)
                tilde = localized_approximation(j_bar_0_0_list, j_bar_0_1_list,
                                                j_bar_1_0_list, j_bar_1_1_list,
                                                u_bar_0_list, u_bar_1_list,
                                                i, local_model_list, similarity_matrix)

                j_0_0_discrepancy = abs(float(straggler_statistical_tuple_list[i][0] - tilde[0]))
                j_0_1_discrepancy = abs(float(straggler_statistical_tuple_list[i][1] - tilde[1]))
                j_1_0_discrepancy = abs(float(straggler_statistical_tuple_list[i][2] - tilde[2]))
                j_1_1_discrepancy = abs(float(straggler_statistical_tuple_list[i][3] - tilde[3]))

                u_0_discrepancy = abs(float(straggler_statistical_tuple_list[i][4] - tilde[4]))
                u_1_discrepancy = abs(float(straggler_statistical_tuple_list[i][5] - tilde[5]))

                j_0_0_discrepancy_list.append(j_0_0_discrepancy)
                j_0_1_discrepancy_list.append(j_0_1_discrepancy)
                j_1_0_discrepancy_list.append(j_1_0_discrepancy)
                j_1_1_discrepancy_list.append(j_1_1_discrepancy)

                u_0_discrepancy_list.append(u_0_discrepancy)
                u_1_discrepancy_list.append(u_1_discrepancy)

                j_0_0_discrepancy_percentage_list.append(j_0_0_discrepancy / tilde[0])
                j_0_1_discrepancy_percentage_list.append(j_0_1_discrepancy / tilde[1])
                j_1_0_discrepancy_percentage_list.append(j_1_0_discrepancy / tilde[2])
                j_1_1_discrepancy_percentage_list.append(j_1_1_discrepancy / tilde[3])

                u_0_discrepancy_percentage_list.append(u_0_discrepancy / tilde[4])
                u_1_discrepancy_percentage_list.append(u_1_discrepancy / tilde[5])

                tilde_list[i] = tilde

            # Aggregation
            theta_list = []
            new_local_model_list = []
            j_hat_0_0, j_hat_0_1, j_hat_1_0, j_hat_1_1 = 0, 0, 0, 0
            u_hat_0, u_hat_1 = 0, 0

            # The aggregation of well-connected client and compensation
            for i in idxs_users:
                model = statistical_tuple_list[i][6]
                new_local_model_list.append(model)
                γ_k = float(γ_k_list[i])

                theta_list.append(
                    list(
                        # γ_k * (model + compensatory_θ_list[i])
                        γ_k * (model)
                    )
                )

                j_hat_0_0 += γ_k * statistical_tuple_list[i][0]
                j_hat_0_1 += γ_k * statistical_tuple_list[i][1]
                j_hat_1_0 += γ_k * statistical_tuple_list[i][2]
                j_hat_1_1 += γ_k * statistical_tuple_list[i][3]

                u_hat_0 += γ_k * statistical_tuple_list[i][4]
                u_hat_1 += γ_k * statistical_tuple_list[i][5]

            #  The aggregation of straggler
            for i in straggler_ids:
                γ_k = float(γ_k_list[i])

                θ_tilde_i = tilde_list[i][6]
                theta_list.append(
                    list(γ_k * θ_tilde_i)
                )

                j_tilde_i_0_0 = tilde_list[i][0]
                j_hat_0_0 += γ_k * j_tilde_i_0_0
                j_tilde_i_0_1 = tilde_list[i][1]
                j_hat_0_1 += γ_k * j_tilde_i_0_1
                j_tilde_i_1_0 = tilde_list[i][2]
                j_hat_1_0 += γ_k * j_tilde_i_1_0
                j_tilde_i_1_1 = tilde_list[i][3]
                j_hat_1_1 += γ_k * j_tilde_i_1_1

                u_tilde_i_0 = tilde_list[i][4]
                u_hat_0 += γ_k * u_tilde_i_0
                u_tilde_i_1 = tilde_list[i][5]
                u_hat_1 += γ_k * u_tilde_i_1

            theta_hat = np.array(theta_list, dtype=object).sum(axis=0).tolist()
            set_parameters(global_model, theta_hat)

            logger.info("********** Global v update **********")
            backup_v = global_v
            try:
                q_00 = get_q_c_p(j_hat_0_0, u_hat_0, r_hat_p0, device)
                q_01 = get_q_c_p(j_hat_0_1, u_hat_0, r_hat_p1, device)
                q_10 = get_q_c_p(j_hat_1_0, u_hat_1, r_hat_p0, device)
                q_11 = get_q_c_p(j_hat_1_1, u_hat_1, r_hat_p1, device)
                Q_hat = torch.tensor([
                    [q_00, q_01],
                    [q_10, q_11]
                ]).to(device)

                _, _, v = torch.svd(Q_hat)

                logger.info(f" Q_hat: {str(Q_hat.tolist())};")

                logger.info(f" V: {str(v[1].reshape(-1, 1).tolist())}; global_v: {str(global_v.tolist())}")
                global_v = v[1].reshape(-1, 1).to(device)
            except Exception:
                global_v = backup_v

            # Stops timing the process of approximation and aggregation
            end_time = time.time()

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            for i in idxs_users:
                local_model_list[i] = copy.deepcopy(global_model)

            # Record the time of approximation and aggregation
            approximation_and_aggregation_cost = (end_time - start_time)

            total_time_consumption += max(local_time_consumption_list) \
                                      + 2 * real_maximum_communication_cost \
                                      + approximation_and_aggregation_cost

            reference_time_consumption += max(local_time_consumption_list) \
                                          + 2 * reference_maximum_communication_cost \
                                          + approximation_and_aggregation_cost

    if len(j_0_0_discrepancy_list) == 0:
        logger.info("Without straggler, No analysis!")
    else:
        final_j_0_0_discrepancy = round(float(sum(j_0_0_discrepancy_list) / len(j_0_0_discrepancy_list)), 4)
        final_j_0_1_discrepancy = round(float(sum(j_0_1_discrepancy_list) / len(j_0_1_discrepancy_list)), 4)
        final_j_1_0_discrepancy = round(float(sum(j_1_0_discrepancy_list) / len(j_1_0_discrepancy_list)), 4)
        final_j_1_1_discrepancy = round(float(sum(j_1_1_discrepancy_list) / len(j_1_1_discrepancy_list)), 4)
        final_u_0_discrepancy = round(float(sum(u_0_discrepancy_list) / len(u_0_discrepancy_list)), 4)
        final_u_1_discrepancy = round(float(sum(u_1_discrepancy_list) / len(u_1_discrepancy_list)), 4)

        final_j_discrepancy = round(float((final_j_0_0_discrepancy + final_j_0_1_discrepancy
                                           + final_j_1_0_discrepancy + final_j_1_1_discrepancy) / 4), 4)
        final_u_discrepancy = round(float((final_u_0_discrepancy + final_u_1_discrepancy) / 2), 4)
        final_statistical_discrepancy = round(float((final_j_0_0_discrepancy + final_j_0_1_discrepancy
                                                + final_j_1_0_discrepancy + final_j_1_1_discrepancy
                                                + final_u_0_discrepancy + final_u_1_discrepancy) / 6), 4)

        final_j_0_0_discrepancy_percentage = round(float(
            sum(j_0_0_discrepancy_percentage_list) * 100 / len(j_0_0_discrepancy_percentage_list)
        ), 2)
        final_j_0_1_discrepancy_percentage = round(float(
            sum(j_0_1_discrepancy_percentage_list) * 100 / len(j_0_1_discrepancy_percentage_list)
        ), 2)
        final_j_1_0_discrepancy_percentage = round(float(
            sum(j_1_0_discrepancy_percentage_list) * 100 / len(j_1_0_discrepancy_percentage_list)
        ), 2)
        final_j_1_1_discrepancy_percentage = round(float(
            sum(j_1_1_discrepancy_percentage_list) * 100 / len(j_1_1_discrepancy_percentage_list)
        ), 2)
        final_u_0_discrepancy_percentage = round(float(
            sum(u_0_discrepancy_percentage_list) * 100 / len(u_0_discrepancy_percentage_list)
        ), 2)
        final_u_1_discrepancy_percentage = round(float(
            sum(u_1_discrepancy_percentage_list) * 100 / len(u_1_discrepancy_percentage_list)
        ), 2)

        final_j_discrepancy_percentage = round(float(
            (final_j_0_0_discrepancy_percentage + final_j_0_1_discrepancy_percentage
             + final_j_1_0_discrepancy_percentage + final_j_1_1_discrepancy_percentage) / 4
        ), 2)
        final_u_discrepancy_percentage = round(float(
            (final_u_0_discrepancy_percentage + final_u_1_discrepancy_percentage) / 2
        ), 2)
        final_statistical_discrepancy_percentage = round(float(
            (final_j_0_0_discrepancy_percentage + final_j_0_1_discrepancy_percentage
              + final_j_1_0_discrepancy_percentage + final_j_1_1_discrepancy_percentage
              + final_u_0_discrepancy_percentage + final_u_1_discrepancy_percentage) / 6
        ), 2)
        logger.info(f" Analysising the bias: \n"
                    f" final_j_0_0_discrepancy: {final_j_0_0_discrepancy} "
                    f" final_j_0_1_discrepancy: {final_j_0_1_discrepancy} \n"
                    f" final_j_1_0_discrepancy: {final_j_1_0_discrepancy} "
                    f" final_j_1_1_discrepancy: {final_j_1_1_discrepancy} ; \n"
                    f" final_u_0_discrepancy: {final_u_0_discrepancy} "
                    f" final_u_1_discrepancy: {final_u_1_discrepancy} ; \n"
                    f" final_j_discrepancy: {final_j_discrepancy} ; \n"
                    f" final_u_discrepancy: {final_u_discrepancy} ; \n"
                    f" final_statistical_discrepancy: {final_statistical_discrepancy} ; \n\n\n"
                    f" final_j_0_0_discrepancy_percentage: {final_j_0_0_discrepancy_percentage} "
                    f" final_j_0_1_discrepancy_percentage: {final_j_0_1_discrepancy_percentage} \n"
                    f" final_j_1_0_discrepancy_percentage: {final_j_1_0_discrepancy_percentage} "
                    f" final_j_1_1_discrepancy_percentage: {final_j_1_1_discrepancy_percentage} ; \n"
                    f" final_u_0_discrepancy_percentage: {final_u_0_discrepancy_percentage} "
                    f" final_u_1_discrepancy_percentage: {final_u_1_discrepancy_percentage}  ; \n"
                    f" final_j_discrepancy_percentage: {final_j_discrepancy_percentage} ; \n"
                    f" final_u_discrepancy_percentage: {final_u_discrepancy_percentage} ; \n"
                    f" final_statistical_discrepancy_percentage: {final_statistical_discrepancy_percentage} ; \n\n\n"
                )

    logger.info(f" The proportion of communication time cost."
                f" Reference: {reference_communication_time_consumption / reference_time_consumption};"
                f" Total: {total_communication_time_consumption / total_time_consumption}")
    logger.info(
        f" Training finish, total use time: {total_time_consumption} s, reference use time: {reference_time_consumption},"
        f" speed up: {abs(total_time_consumption - reference_time_consumption)}."
        f" Return global model and local model list")

    return global_model, local_model_list
