import math
import time
import torch
import numpy as np
import copy
from tool.logger import *
from tool.utils import get_parameters, set_parameters, communication_cost_simulated_by_beta_distribution


def get_argmax_v(idxs_users, local_model_list, mask_s1_flag, training_dataset, client_dataset_list,
                 r_hat_p0, r_hat_p1, device, γ_k_style, hypothesis="LR"):
    training_dataset_size = len(training_dataset)

    j_hat_0_0, j_hat_0_1, j_hat_1_0, j_hat_1_1 = 0, 0, 0, 0
    u_hat_0, u_hat_1 = 0, 0

    for id in idxs_users:
        selected_model = local_model_list[id].to(device)

        if mask_s1_flag:
            # Sensitive attribute 2
            s = torch.tensor([training_dataset[idx]['s2'] for idx in client_dataset_list[id].indices]).to(device)
        else:
            # Sensitive attribute 1
            s = torch.tensor([training_dataset[idx]['s1'] for idx in client_dataset_list[id].indices])

        client_X = torch.tensor(np.array([training_dataset[idx]['X'] for idx in client_dataset_list[id].indices])).to(
            device)

        if "LR" in hypothesis:
            y_hat_θ = (selected_model(client_X) >= 0.5).reshape(-1).to(device)
        else:  # NN
            y_hat_θ = selected_model(client_X).argmax(dim=1).to(device)

        if "uniform_client" in γ_k_style:
            γ_k = float(1 / len(idxs_users))
        else:
            γ_k = float(len(client_X) / training_dataset_size)

        j_bar_0_0 = get_j_bar_c_p(y_hat_θ, 0, s, 0, device)
        j_bar_0_1 = get_j_bar_c_p(y_hat_θ, 0, s, 1, device)
        j_bar_1_0 = get_j_bar_c_p(y_hat_θ, 1, s, 0, device)
        j_bar_1_1 = get_j_bar_c_p(y_hat_θ, 1, s, 1, device)

        u_bar_0 = get_u_bar_c(y_hat_θ, 0, device)
        u_bar_1 = get_u_bar_c(y_hat_θ, 1, device)

        j_hat_0_0 += γ_k * j_bar_0_0
        j_hat_0_1 += γ_k * j_bar_0_1
        j_hat_1_0 += γ_k * j_bar_1_0
        j_hat_1_1 += γ_k * j_bar_1_1
        u_hat_0 += γ_k * u_bar_0
        u_hat_1 += γ_k * u_bar_1

    j_hat_c0_p0 = j_hat_0_0
    j_hat_c0_p1 = j_hat_0_1
    j_hat_c1_p0 = j_hat_1_0
    j_hat_c1_p1 = j_hat_1_1
    u_hat_c0 = u_hat_0
    u_hat_c1 = u_hat_1

    q_00 = get_q_c_p(j_hat_c0_p0, u_hat_c0, r_hat_p0, device)
    q_01 = get_q_c_p(j_hat_c0_p1, u_hat_c0, r_hat_p1, device)
    q_10 = get_q_c_p(j_hat_c1_p0, u_hat_c1, r_hat_p0, device)
    q_11 = get_q_c_p(j_hat_c1_p1, u_hat_c1, r_hat_p1, device)
    Q_hat = torch.tensor([
        [q_00, q_01],
        [q_10, q_11]
    ]).to(device)

    u, s, v = torch.linalg.svd(Q_hat)

    second_singular_vector_of_Q_hat = v[1].reshape(-1, 1).to(device)
    return second_singular_vector_of_Q_hat


def get_communication_idxs_list(num_clients_K, straggler_rate_α, descending_order_list):
    idxs_users = [i for i in range(num_clients_K)]
    straggler_ids = []
    if straggler_rate_α != 0:
        straggle_count = math.ceil(straggler_rate_α * num_clients_K)  # Round up the straggler count
        for tmp in range(straggle_count):
            straggler_id = descending_order_list[tmp]
            straggler_ids.append(straggler_id)
            # Remove stragglers
            idxs_users.remove(straggler_id)  # .remove(内容) or .pop(索引)
    return idxs_users, straggler_ids


def get_gamma_k_list(γ_k_style, client_datasets_size_list, num_clients_K):
    if "uniform_distribution" in γ_k_style:
        # uniform over distribution, γ_k =  n_k / n
        γ_denominator = sum(client_datasets_size_list)
    else:
        # uniform over client, γ_k = 1 / K
        γ_denominator = num_clients_K

    γ_k_list = []
    for i in range(num_clients_K):
        if "uniform_distribution" in γ_k_style:
            γ_numerator = client_datasets_size_list[i]
        else:
            γ_numerator = 1
        γ_k = γ_numerator / γ_denominator
        γ_k_list.append(float(γ_k))

    return γ_k_list


def get_j_bar_c_p(y_hat_θ, c, s, p, device):
    y_hat_θ_c = (y_hat_θ == c).to(device)
    s_p = (s == p).to(device)
    joint = (y_hat_θ_c * s_p).to(device)

    P_s_p = (sum(s_p) / len(s)).to(device)  # r_bar(p)
    P_joint = (sum(joint) / len(s)).to(device)
    P_conditional = (P_joint / P_s_p).to(device)  # j_bar(c, p)
    return P_conditional


def get_q_c_p(j_c_p, u_c, r_p, device):
    if j_c_p == 0 or r_p == 0 or u_c == 0:
        q = torch.tensor(0.)
    else:
        q = j_c_p * r_p / torch.sqrt(u_c * r_p)

    return q.to(device)


def get_Q_hat_θ(y_hat_θ, s, device):
    j_bar_0_0 = get_j_bar_c_p(y_hat_θ, 0, s, 0, device)
    j_bar_0_1 = get_j_bar_c_p(y_hat_θ, 0, s, 1, device)
    j_bar_1_0 = get_j_bar_c_p(y_hat_θ, 1, s, 0, device)
    j_bar_1_1 = get_j_bar_c_p(y_hat_θ, 1, s, 1, device)

    u_bar_0 = get_u_bar_c(y_hat_θ, 0, device)
    u_bar_1 = get_u_bar_c(y_hat_θ, 1, device)

    r_bar_0 = get_r_bar_p(s, 0, device)
    r_bar_1 = get_r_bar_p(s, 1, device)

    q_00 = get_q_c_p(j_bar_0_0, u_bar_0, r_bar_0, device)
    q_01 = get_q_c_p(j_bar_0_1, u_bar_0, r_bar_1, device)
    q_10 = get_q_c_p(j_bar_1_0, u_bar_1, r_bar_0, device)
    q_11 = get_q_c_p(j_bar_1_1, u_bar_1, r_bar_1, device)
    Q = torch.tensor([
        [q_00, q_01],
        [q_10, q_11]
    ]).to(device)
    return Q, j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1


def get_G_hat_θ_hat_v(Q, v, device):
    Q = Q.to(device)
    v = v.reshape(-1, 1).to(device)
    result = v.T.matmul(Q.T).matmul(Q).matmul(v)
    result = result[0][0].to(device)
    result = torch.where(torch.isnan(result), torch.full_like(result, 0), result)
    return result


def get_r_bar_p(s, p, device):
    s_p = (s == p).to(device)
    P_s_p = (sum(s_p) / len(s)).to(device)  # r_bar(p)
    return P_s_p


def get_r_bar_k_p_list(num_clients_K, mask_s1_flag, training_dataset, client_dataset_list, p):
    r_bar_k_p_list = []
    for k in range(num_clients_K):
        if mask_s1_flag:
            # Sensitive attribute 2
            sensitive_attribute = torch.tensor([training_dataset[idx]['s2'] for idx in client_dataset_list[k].indices])
        else:
            # Sensitive attribute 1
            sensitive_attribute = torch.tensor([training_dataset[idx]['s1'] for idx in client_dataset_list[k].indices])

        r_bar_k_p = sum(sensitive_attribute == p) / len(sensitive_attribute)

        r_bar_k_p_list.append(r_bar_k_p)

    return r_bar_k_p_list


def get_statistical_distance(tuple_a, tuple_b):
    # Tuple:  {j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1, get_parameters(model)}
    j_a = np.array([
        [tuple_a[0].cpu(), tuple_a[1].cpu()],
        [tuple_a[2].cpu(), tuple_a[3].cpu()]
    ])
    j_b = np.array([
        [tuple_b[0].cpu(), tuple_b[1].cpu()],
        [tuple_b[2].cpu(), tuple_b[3].cpu()]
    ])

    u_a = np.array([tuple_a[4].cpu(), tuple_a[5].cpu()])
    u_b = np.array([tuple_b[4].cpu(), tuple_b[5].cpu()])

    θ_a, θ_b = np.array(tuple_a[-1]), np.array(tuple_b[-1])

    j_distance = np.linalg.norm(j_a - j_b)
    u_distance = np.linalg.norm(u_a - u_b)
    try:
        θ_distance = np.linalg.norm(θ_a - θ_b)
    except ValueError:
        differ = (θ_a - θ_b)
        θ_distance = sum([np.linalg.norm(differ[i]) for i in range(len(differ))])
    return j_distance, u_distance, θ_distance


def get_statistical_similarity(tuple_a, tuple_b, λ, ρ):
    j_distance, u_distance, θ_distance = get_statistical_distance(tuple_a, tuple_b)
    # DSim_a_b = θ_distance \
    #            + (-0.5 + (math.sqrt(4 * λ + 1) / 2)) * u_distance \
    #            + (λ + 1 / 2 - (math.sqrt(4 * λ + 1) / 2)) * j_distance
    if θ_distance != 0:
        W_θ_a_b = math.exp(-θ_distance / ρ)
    else:
        W_θ_a_b = 0

    if u_distance != 0:
        W_u_a_b = math.exp(-u_distance / ρ)
    else:
        W_u_a_b = 0

    if j_distance != 0:
        W_j_a_b = math.exp(-j_distance / ρ)
    else:
        W_j_a_b = 0

    return (W_θ_a_b, W_u_a_b, W_j_a_b)


def get_statistical_tuple(training_dataset, client_dataset_list, client_id, model, mask_s1_flag, hypothesis, device):
    j_bar_0_0_list, j_bar_0_1_list, j_bar_1_0_list, j_bar_1_1_list = [], [], [], []
    u_bar_0_list, u_bar_1_list = [], []

    client_X = torch.stack([training_dataset[idx]['X'] for idx in client_dataset_list[client_id].indices]).to(device)
    local_prediction = model(client_X).to(device)
    if mask_s1_flag:
        s = torch.tensor(
            np.array([training_dataset[idx]['s2'] for idx in client_dataset_list[client_id].indices])).to(device)
    else:
        s = torch.tensor(
            np.array([training_dataset[idx]['s1'] for idx in client_dataset_list[client_id].indices])).to(device)

    if "LR" in hypothesis:
        y_hat_θ = (local_prediction >= 0.5).reshape(-1).to(device)
    else:  # NN
        y_hat_θ = local_prediction.argmax(dim=1).to(device)

    _, temp_j_bar_0_0, temp_j_bar_0_1, temp_j_bar_1_0, temp_j_bar_1_1, \
    temp_u_bar_0, temp_u_bar_1 = get_Q_hat_θ(y_hat_θ, s, device)

    j_bar_0_0_list.append(temp_j_bar_0_0)
    j_bar_0_1_list.append(temp_j_bar_0_1)
    j_bar_1_0_list.append(temp_j_bar_1_0)
    j_bar_1_1_list.append(temp_j_bar_1_1)
    u_bar_0_list.append(temp_u_bar_0)
    u_bar_1_list.append(temp_u_bar_1)

    j_bar_0_0 = sum(j_bar_0_0_list) / len(j_bar_0_0_list)
    j_bar_0_1 = sum(j_bar_0_1_list) / len(j_bar_0_1_list)
    j_bar_1_0 = sum(j_bar_1_0_list) / len(j_bar_1_0_list)
    j_bar_1_1 = sum(j_bar_1_1_list) / len(j_bar_1_1_list)
    u_bar_0 = sum(u_bar_0_list) / len(u_bar_0_list)
    u_bar_1 = sum(u_bar_1_list) / len(u_bar_1_list)

    return (j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1, u_bar_0, u_bar_1, np.array(get_parameters(model), dtype=object))


def get_similarity_matrix(tuple_list, λ, ρ):
    K = len(tuple_list)
    similarity_matrix = []
    # 此处可以利用矩阵的对称性进行优化，降低运算次数
    for i in range(K):
        similarity_matrix.append([])
        for j in range(K):
            if i == j:
                similarity_matrix[i].append((1, 1, 1, 1))
            else:
                similarity_matrix[i].append(  # i用户 对 j用户的相似性
                    get_statistical_similarity(tuple_list[i], tuple_list[j], λ, ρ)
                )
    return similarity_matrix


def get_u_bar_c(y_hat_θ, c, device):
    y_hat_θ_c = (y_hat_θ == c).to(device)
    P_y_hat_θ_c = (sum(y_hat_θ_c) / len(y_hat_θ)).to(device)  # u_bar(c)
    return P_y_hat_θ_c


def initialization(client_dataset_list, global_model, num_clients_K, mask_s1_flag, training_dataset, γ_k_style):
    client_datasets_size_list = [len(item) for item in client_dataset_list]

    local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

    global_v = torch.rand(2, 1)

    r_bar_k_p0_list = get_r_bar_k_p_list(num_clients_K, mask_s1_flag, training_dataset, client_dataset_list, p=0)
    r_bar_k_p1_list = get_r_bar_k_p_list(num_clients_K, mask_s1_flag, training_dataset, client_dataset_list, p=1)

    γ_k_list = get_gamma_k_list(γ_k_style, client_datasets_size_list, num_clients_K)

    r_hat_p0 = sum([r_bar_k_p0_list[i] * γ_k_list[i] for i in range(len(γ_k_list))])
    r_hat_p1 = sum([r_bar_k_p1_list[i] * γ_k_list[i] for i in range(len(γ_k_list))])

    v_hat_1 = [math.sqrt(r_hat_p0), math.sqrt(r_hat_p1)]
    return client_datasets_size_list, local_model_list, global_v, r_bar_k_p0_list, r_bar_k_p1_list, γ_k_list, r_hat_p0, r_hat_p1, v_hat_1


def localized_approximation(j_bar_0_0_list, j_bar_0_1_list, j_bar_1_0_list, j_bar_1_1_list,
                            u_bar_0_list, u_bar_1_list,
                            i, local_model_list, similarity_matrix):
    θ_tilde_i = 0
    j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1 = 0, 0, 0, 0
    u_tilde_i_0, u_tilde_i_1 = 0, 0
    W_θ_sum, W_u_sum, W_j_sum = 0, 0, 0
    for j in range(len(local_model_list)):
        if j == i:
            continue
        else:
            W_θ, W_u, W_j = similarity_matrix[j][i][0], similarity_matrix[j][i][1], similarity_matrix[j][i][2]
            W_θ_sum += W_θ
            W_u_sum += W_u
            W_j_sum += W_j

            θ_tilde_i += W_θ * np.array(get_parameters(local_model_list[j]))

            j_bar_0_0, j_bar_0_1, j_bar_1_0, j_bar_1_1 = j_bar_0_0_list[j], j_bar_0_1_list[j], j_bar_1_0_list[j], \
                                                         j_bar_1_1_list[j]

            u_bar_0, u_bar_1 = u_bar_0_list[j], u_bar_1_list[j]

            j_tilde_i_0_0 += W_j * j_bar_0_0
            j_tilde_i_0_1 += W_j * j_bar_0_1
            j_tilde_i_1_0 += W_j * j_bar_1_0
            j_tilde_i_1_1 += W_j * j_bar_1_1

            u_tilde_i_0 += W_u * u_bar_0
            u_tilde_i_1 += W_u * u_bar_1

    θ_tilde_i = θ_tilde_i / W_θ_sum

    if W_j_sum == 0:
        j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1 = 0,0,0,0
    else:
        j_tilde_i_0_0 = j_tilde_i_0_0 / W_j_sum
        j_tilde_i_0_1 = j_tilde_i_0_1 / W_j_sum
        j_tilde_i_1_0 = j_tilde_i_1_0 / W_j_sum
        j_tilde_i_1_1 = j_tilde_i_1_1 / W_j_sum

    if W_u_sum == 0:
        u_tilde_i_0, u_tilde_i_1 = 0, 0
    else:
        u_tilde_i_0 = u_tilde_i_0 / W_u_sum
        u_tilde_i_1 = u_tilde_i_1 / W_u_sum

    return (j_tilde_i_0_0, j_tilde_i_0_1, j_tilde_i_1_0, j_tilde_i_1_1, u_tilde_i_0, u_tilde_i_1, θ_tilde_i)


def staleness_function_beta_β(zeta_ζ, phi_φ=0.1, function_style="exponential"):
    if "exponential" in function_style:
        compensation = math.exp(-zeta_ζ * phi_φ)
    elif "polynomial" in function_style:
        compensation = math.pow(zeta_ζ, -phi_φ)
    elif "linear" in function_style:
        compensation = phi_φ / zeta_ζ
    else:
        compensation = math.exp(-zeta_ζ * phi_φ)
    return compensation

