import math

from algorithm.FederatedRenyi.FederatedRenyi_component import *
from tool.logger import *


def Fed_Renyi_LR(device,
                 mask_s1_flag,
                 lamda,
                 global_model,
                 tolerance_τ,
                 algorithm_step_T, num_clients_K, communication_round_I,
                 local_step_size,
                 training_dataloaders,
                 training_dataset,
                 client_dataset_list,
                 straggler_rate_α,
                 rho,
                 γ_k_style
                 ):
    # Initialization
    logger.info("Initialization")

    client_datasets_size_list, local_model_list, \
    global_v, r_bar_k_p0_list, r_bar_k_p1_list, \
    γ_k_list, r_hat_p0, r_hat_p1, v_hat_1 = initialization(client_dataset_list, global_model, num_clients_K,
                                                           mask_s1_flag, training_dataset, γ_k_style)

    criterion = torch.nn.BCELoss(reduction='mean')

    # Time consumption
    total_time_consumption = 0
    reference_time_consumption = 0


    # Synchronous algorithm
    logger.info("Synchronous Training process")

    local_time_consumption_list = []
    statistical_tuple_list = []  # [ {j_1(c, p), u_1(c), θ_1}, …, {j_k(c, p), u_k(c), θ_k} ]
    for iter_t in range(tolerance_τ):
        # Simulate Client Parallel
        avg_loss_over_step, avg_ce_over_step, avg_regularization_term_over_step = 0, 0, 0
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]

            # Starts timing the local operation
            start_time = time.time()
            # local optimization
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                s = batch["s2"] if mask_s1_flag else batch["s1"]

                local_prediction = model(X).to(device)
                loss = criterion(local_prediction, y.float())
                avg_ce_over_step += loss

                y_hat_θ = (local_prediction >= 0.5).reshape(-1).to(device)

                Q, _, _, _, _, _, _ = get_Q_hat_θ(y_hat_θ, s, device)
                G = get_G_hat_θ_hat_v(Q, global_v, device).to(device)
                regularization_term = lamda * G
                avg_regularization_term_over_step += regularization_term

                # logger.info(f"########## Client No: {i}; Step: {iter_t + 1} / {algorithm_step_T}; "
                #             f"CE Loss: {round(float(loss), 4)}; "
                #             f"Regularization term: {round(float(regularization_term), 4)}; "
                #             f"Total Loss: {round(float(loss + regularization_term), 4)}; ##########")

                loss += regularization_term
                avg_loss_over_step += loss

                loss.backward()
                optimizer.step()
                break

            # Stops timing for local operations
            end_time = time.time()

            # Record the local operation time for each client
            used_time = end_time - start_time

            # Upgrade the local model list
            local_model_list[i] = model

            # prepare for asynchronous in the last turn
            if (iter_t == tolerance_τ - 1) and (tolerance_τ != algorithm_step_T):
                # Starts timing the preparation about asynchronous
                preparation_start_time = time.time()

                # {j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1, np.array(get_parameters(model))}
                client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag, "LR", device)

                statistical_tuple_list.append(client_tuple)

                # Stops timing for local reparation about asynchronous
                preparation_end_time = time.time()

                # Record the time for local reparation about asynchronous
                used_time += preparation_end_time - preparation_start_time

            local_time_consumption_list.append(used_time)


        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg CE over Client: {round(float(avg_ce_over_step/num_clients_K), 4)}; "
                        f"Avg RT over Client: {round(float(avg_regularization_term_over_step/num_clients_K), 4)}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step/num_clients_K), 4)}; ##########")

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")
            # Simulate the communication cost of each client
            communication_cost_list, descending_order_list = communication_cost_simulated_by_beta_distribution(num_clients_K)
            maximum_communication_cost = communication_cost_list[0]

            # idxs_users, straggler_ids = get_communication_idxs_list(num_clients_K, straggler_rate_α,
            #                                                         descending_order_list)

            # Global operation
            logger.info("********** Parameter aggregation **********")

            # Starts timing the global operation
            start_time = time.time()

            theta_list = []
            for i in range(num_clients_K):
                model = local_model_list[i]
                γ_k = float(γ_k_list[i])
                theta_list.append(list(γ_k * np.array(get_parameters(model))))

            theta_list = np.array(theta_list, dtype=object)
            theta_avg = np.sum(theta_list, 0).tolist()
            set_parameters(global_model, theta_avg)

            logger.info("********** Global v update **********")
            backup_v = global_v

            try:
                global_v = get_argmax_v([i for i in range(num_clients_K)], local_model_list, mask_s1_flag, training_dataset,
                                        client_dataset_list,
                                        r_hat_p0, r_hat_p1, device, γ_k_style, hypothesis="LR")
            except Exception:
                global_v = backup_v

            # Stops timing for global operations
            end_time = time.time()

            # Record the global operation time consumption
            global_operation_used_time = end_time - start_time

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Record the total time consumption
            total_time_consumption += max(local_time_consumption_list) \
                                      + 2 *  maximum_communication_cost \
                                      + global_operation_used_time

    # Prepare for the asynchronous algorithm
    # Starts timing
    preparation_start_time = time.time()

    similarity_matrix = get_similarity_matrix(statistical_tuple_list, lamda, rho)

    # Stops timing
    preparation_end_time = time.time()

    # Record the time
    time_consumption = preparation_end_time - preparation_start_time
    total_time_consumption += time_consumption
    reference_time_consumption = total_time_consumption

    # Asynchronous algorithm
    logger.info("Asynchronous Training process")

    local_time_consumption_list = []
    client_latest_time_stamp = [0 for _ in range(num_clients_K)]
    for iter_t in range(algorithm_step_T - tolerance_τ):

        # Record the j,u matrix of each client
        j_bar_0_0_list, j_bar_0_1_list, j_bar_1_0_list, j_bar_1_1_list, u_bar_0_list, u_bar_1_list = [], [], [], [], [], []

        # Record the j,u approximation matrix of each client
        j_tilde_0_0_list, j_tilde_0_1_list, j_tilde_1_0_list, j_tilde_1_1_list, u_tilde_0_list, u_tilde_1_list = [], [], [], [], [], []

        # Simulate Client Parallel
        avg_loss_over_step, avg_ce_over_step, avg_regularization_term_over_step = 0, 0, 0
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]
            # Starts timing the local operation
            start_time = time.time()
            # local optimization
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                s = batch["s2"] if mask_s1_flag else batch["s1"]

                local_prediction = model(X).to(device)
                loss = criterion(local_prediction, y.float())
                avg_ce_over_step += loss

                y_hat_θ = (local_prediction >= 0.5).reshape(-1).to(device)

                Q, j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1 = get_Q_hat_θ(y_hat_θ, s, device)
                j_bar_0_0_list.append(j_bar_0_0)
                j_bar_0_1_list.append(j_bar_0_1)
                j_bar_1_0_list.append(j_bar_1_0)
                j_bar_1_1_list.append(j_bar_1_1)
                u_bar_0_list.append(u_bar_0)
                u_bar_1_list.append(u_bar_1)

                G = get_G_hat_θ_hat_v(Q, global_v, device).to(device)
                regularization_term = lamda * G
                avg_regularization_term_over_step += regularization_term

                # logger.info(f"########## Client No: {i}; Step: {iter_t + 1 + tolerance_τ} / {algorithm_step_T}; "
                #             f"CE Loss: {round(float(loss), 4)}; "
                #             f"Regularization term: {round(float(regularization_term), 4)}; "
                #             f"Total Loss: {round(float(loss + regularization_term), 4)}; ##########")

                loss += regularization_term
                avg_loss_over_step += loss

                loss.backward()
                optimizer.step()

                break

            # Stops timing for local operations
            end_time = time.time()

            # Record the local operation time for each client
            used_time = end_time - start_time
            local_time_consumption_list.append(used_time)

            # Upgrade the local model list
            local_model_list[i] = model

        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1 + tolerance_τ} / {algorithm_step_T}; "
                        f"Avg CE over Client: {round(float(avg_ce_over_step / num_clients_K), 4)}; "
                        f"Avg RT over Client: {round(float(avg_regularization_term_over_step / num_clients_K), 4)}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step / num_clients_K), 4)}; ##########")

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {math.floor((iter_t + tolerance_τ + 1) / communication_round_I)} **********")
            # Simulate the communication cost of each client
            communication_cost_list, descending_order_list = communication_cost_simulated_by_beta_distribution(num_clients_K)
            maximum_communication_cost = communication_cost_list[0]

            idxs_users, straggler_ids = get_communication_idxs_list(num_clients_K, straggler_rate_α,
                                                                    descending_order_list)

            # Prepare for compensation
            # compensatory_j_0_0_list, compensatory_j_0_1_list = [0 for i in range(num_clients_K)], \
            #                                                    [0 for i in range(num_clients_K)]
            # compensatory_j_1_0_list, compensatory_j_1_1_list = [0 for i in range(num_clients_K)], \
            #                                                    [0 for i in range(num_clients_K)]
            # compensatory_u_0_list, compensatory_u_1_list = [0 for i in range(num_clients_K)], \
            #                                                [0 for i in range(num_clients_K)]
            # compensatory_θ_list = [0 for i in range(num_clients_K)]

            tilde_list = [0 for _ in range(num_clients_K)]

            # Upgrade the statistical tuple list by the latest uploaded info
            for i in idxs_users:
                zeta_ζ = iter_t - client_latest_time_stamp[i]
                client_latest_time_stamp[i] = iter_t

                # 如果需要重新计算，则调用函数重新计算：
                # client_tuple = get_statistical_tuple(training_dataset, client_dataset_list, i, model, mask_s1_flag, "LR", device)
                # 否则直接取结果：
                client_tuple = (j_bar_0_0_list[i], j_bar_0_1_list[i], j_bar_1_0_list[i], j_bar_1_1_list[i],
                                u_bar_0_list[i], u_bar_1_list[i], np.array(get_parameters(local_model_list[i])))


                # 如果考虑滞后程度，需要打开这里
                # if (zeta_ζ != 1) and (tilde_list[i] != 0):
                #     compensatory_weight = staleness_function_beta_β(zeta_ζ)
                #
                #     j_0_0_discrepancy = abs(client_tuple[0] - tilde_list[i][0])
                #     j_0_1_discrepancy = abs(client_tuple[1] - tilde_list[i][1])
                #     j_1_0_discrepancy = abs(client_tuple[2] - tilde_list[i][2])
                #     j_1_1_discrepancy = abs(client_tuple[3] - tilde_list[i][3])
                #
                #     u_0_discrepancy = abs(client_tuple[4] - tilde_list[i][4])
                #     u_1_discrepancy = abs(client_tuple[5] - tilde_list[i][5])
                #
                #     θ_discrepancy = abs(client_tuple[6] - tilde_list[i][6])
                #
                #     compensatory_j_0_0_list[i] = compensatory_weight * j_0_0_discrepancy
                #     compensatory_j_0_1_list[i] = compensatory_weight * j_0_1_discrepancy
                #     compensatory_j_1_0_list[i] = compensatory_weight * j_1_0_discrepancy
                #     compensatory_j_1_1_list[i] = compensatory_weight * j_1_1_discrepancy
                #
                #     compensatory_u_0_list[i] = compensatory_weight * u_0_discrepancy
                #     compensatory_u_1_list[i] = compensatory_weight * u_1_discrepancy
                #
                #     compensatory_θ_list[i] = compensatory_weight * θ_discrepancy

                statistical_tuple_list[i] = client_tuple

            # Starts timing the process of approximation and aggregation
            start_time = time.time()

            # The approximation process of stragglers
            for i in straggler_ids:
                # (j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1, u_tilde_i_0, u_tilde_i_1, θ_tilde_i)
                tilde = localized_approximation(j_bar_0_0_list, j_bar_0_1_list,
                                                j_bar_1_0_list, j_bar_1_1_list,
                                                u_bar_0_list, u_bar_1_list,
                                                i, local_model_list, similarity_matrix)

                # j_0_0_discrepancy = abs(statistical_tuple_list[i][0] - tilde[0])
                # j_0_1_discrepancy = abs(statistical_tuple_list[i][1] - tilde[1])
                # j_1_0_discrepancy = abs(statistical_tuple_list[i][2] - tilde[2])
                # j_1_1_discrepancy = abs(statistical_tuple_list[i][3] - tilde[3])
                #
                # u_0_discrepancy = abs(statistical_tuple_list[i][4] - tilde[4])
                # u_1_discrepancy = abs(statistical_tuple_list[i][5] - tilde[5])
                #
                # θ_discrepancy = abs(statistical_tuple_list[i][6] - tilde[6])

                tilde_list[i] = tilde

            # Aggregation
            theta_list = []
            new_local_model_list = []
            j_hat_0_0, j_hat_0_1, j_hat_1_0, j_hat_1_1 = 0, 0, 0, 0
            u_hat_0, u_hat_1 = 0, 0

            # The aggregation of well-connected client and compensation
            for i in idxs_users:
                model = statistical_tuple_list[i][6]
                new_local_model_list.append(model)
                γ_k = float(γ_k_list[i])

                theta_list.append(
                    list(
                        # γ_k * (model + compensatory_θ_list[i])
                        γ_k * (model)
                    )
                )

                j_hat_0_0 += γ_k * statistical_tuple_list[i][0]
                j_hat_0_1 += γ_k * statistical_tuple_list[i][1]
                j_hat_1_0 += γ_k * statistical_tuple_list[i][2]
                j_hat_1_1 += γ_k * statistical_tuple_list[i][3]

                u_hat_0 += γ_k * statistical_tuple_list[i][4]
                u_hat_1 += γ_k * statistical_tuple_list[i][5]

            #  The aggregation of straggler
            for i in straggler_ids:
                γ_k = float(γ_k_list[i])

                θ_tilde_i = tilde_list[i][6]
                theta_list.append(
                    list( γ_k * θ_tilde_i )
                )

                j_tilde_i_0_0 = tilde_list[i][0]
                j_hat_0_0 += γ_k * j_tilde_i_0_0
                j_tilde_i_0_1 = tilde_list[i][1]
                j_hat_0_1 += γ_k * j_tilde_i_0_1
                j_tilde_i_1_0 = tilde_list[i][2]
                j_hat_1_0 += γ_k * j_tilde_i_1_0
                j_tilde_i_1_1 = tilde_list[i][3]
                j_hat_1_1 += γ_k * j_tilde_i_1_1

                u_tilde_i_0 = tilde_list[i][4]
                u_hat_0 += γ_k * u_tilde_i_0
                u_tilde_i_1 = tilde_list[i][5]
                u_hat_1 += γ_k * u_tilde_i_1

            theta_hat = np.array(theta_list, dtype=object).sum(axis=0).tolist()
            set_parameters(global_model, theta_hat)

            logger.info("********** Global v update **********")
            backup_v = global_v
            try:
                q_00 = get_q_c_p(j_hat_0_0, u_hat_0, r_hat_p0, device)
                q_01 = get_q_c_p(j_hat_0_1, u_hat_0, r_hat_p1, device)
                q_10 = get_q_c_p(j_hat_1_0, u_hat_1, r_hat_p0, device)
                q_11 = get_q_c_p(j_hat_1_1, u_hat_1, r_hat_p1, device)
                Q_hat = torch.tensor([
                    [q_00, q_01],
                    [q_10, q_11]
                ]).to(device)

                _, _, v = torch.svd(Q_hat)

                # logger.info(f" Q_hat: {str(Q_hat.tolist())};")

                # logger.info(f" V: {str(v[1].reshape(-1, 1).tolist())}; global_v: {str(global_v.tolist())}")
                global_v = v[1].reshape(-1, 1).to(device)
            except Exception:
                logger.info("Some Errors happen in Global v update, Using the backup v")
                global_v = backup_v

            # Stops timing the process of approximation and aggregation
            end_time = time.time()

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            for i in idxs_users:
                local_model_list[i] = copy.deepcopy(global_model)

            # Record the time of approximation and aggregation
            approximation_and_aggregation_cost = (end_time - start_time)

            total_time_consumption += max(local_time_consumption_list) \
                                      + maximum_communication_cost \
                                      + approximation_and_aggregation_cost

            # 动态更新模式
            # # 目前的statistical_tuple_list只更新了正常链接的客户id的内容，还没有把上一次Approximation的结果更新进去
            # for i in straggler_ids:
            #     j_tilde_i_0_0 = tilde_list[i][0]
            #     j_tilde_i_0_1 = tilde_list[i][1]
            #     j_tilde_i_1_0 = tilde_list[i][2]
            #     j_tilde_i_1_1 = tilde_list[i][3]
            #
            #     u_tilde_i_0 = tilde_list[i][4]
            #     u_tilde_i_1 = tilde_list[i][5]
            #
            #     θ_tilde_i = tilde_list[i][6]
            #     theta_list.append(list(θ_tilde_i))
            #     theta_hat = np.array(theta_list, dtype=object).sum(axis=0).tolist()
            #
            #     client_tuple = (j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1,
            #                     u_tilde_i_0, u_tilde_i_1, np.array(theta_hat))
            #
            #     statistical_tuple_list[i] = client_tuple
            # # 动态计算Approximation Weight
            # similarity_matrix = get_similarity_matrix(statistical_tuple_list, lamda, rho)


    logger.info(f" Training finish, total use time: {total_time_consumption} s.")

    return global_model, local_model_list
