import math
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
from tqdm import tqdm


def pull_arm(mu, i, j):
    X = random.uniform(0, 1)
    if X < mu[i][j]:
        return 1
    else:
        return 0

def truncate(value, bits):
    """将 value 截断到指定 bits 的二进制小数位。"""
    factor = 2 ** bits
    return int(value * factor) / factor

def d2b(value, bits):
    """先用 truncate 量化 value，然后转换为固定长度的二进制字符串。"""
    truncated_value = truncate(value, bits)
    int_val = int(truncated_value * (2 ** bits))  # 把小数部分转换为整数
    # return format(int_val, f'0{bits}b')
    res = format(int_val, f'0{bits}b')
    res = res.lstrip('0') or '0'  # 去掉前导 0
    return res

def b2d(bit_str, bits):
    """将二进制字符串 bit_str 还原为截断后的小数值。"""
    int_val = int(bit_str, 2)
    return int_val / (2 ** bits)

def distributed_successive_eliminate_2(W, T, C, gamma, K, N, mu, lambda_2, best_arm):
    # Initialization LINE 1
    D = 4
    Nbs = np.count_nonzero(W, axis=1) - 1
    print(f'Nbs:{Nbs}')
    t = np.zeros(N)
    Set_K = [i for i in range(0, K)]
    S = [[i for i in range(0, K)] for _ in range(N)]
    B = [{} for _ in range(N)] # 改用字典
    # B = [[i for i in range(0, K)] for _ in range(N)]
    X = np.zeros((K, N))
    hat_mu = np.zeros((K, N))
    bar_mu = np.zeros((K, N))
    tau = np.zeros((K, N))
    # tidle_mu = np.zeros((K, N))
    # hat_tau = np.ones((K, N))
    reward_x = np.zeros((K, N))
    average_reward_x = np.zeros((K, N))
    # theta = [0] * K
    theta = np.zeros((K, N))
    delta = [["0" for _ in range(N)] for _ in range(K)]
    bits = np.zeros((K, N), dtype=int) + 8
    # regre = np.zeros((1, N))

    a = 5
    U = np.ones((K, N))
    sum_pull_time = np.zeros((K, N))
    sum_reward = np.zeros((K, N))
    regret_list = [[0], [0], [0], [0], [0], [0], [0], [0]]
    commucate_list = [0] * (T << 1)
    commucate_count_time = [0] * (T << 1)

    print(len(commucate_list))

    # 预训练跑f次
    # f = 1
    # for j in range(N):
    #     for i in range(K):
    #         hat_mu[i][j] = mu[i, j]
    #         tau[i][j] = 1
    #         U[i][j] = 1
    #         sum_pull_time[i][j] = 1
    #         sum_reward[i][j] += mu[i][j] * 1
    

    # Batch pull arm for a times
    for _ in range(a):
        for j in range(N):
            for i in S[j]:
                reward_x[i][j] = pull_arm(mu, i, j)
                sum_pull_time[i][j] += 1
                sum_reward[i][j] += reward_x[i][j]
    for j in range(N):
        for i in S[j]:
            bar_mu[i][j] = sum_reward[i][j] / sum_pull_time[i][j]
            delta[i][j] = d2b(bar_mu[i][j], 8)
            average_reward_x[i][j] = sum_reward[i][j] / a
            hat_mu[i][j] = sum_reward[i][j] / a

    # print(f'hat_mu:{hat_mu}')
    # print(f'average_reward_x:{average_reward_x}')

    for j in range(N):
        for i in range(K):
            tau[i][j] = a
    for j in range(N):
        t[j] = K * a

    # r = [0] * K
    r = np.zeros((K, N))
    hat_mu_max = [0] * N
    # tidle_mu_max = [0] * N

    r1 = 0
    # 对每个agent 

    # for j in range(N):
    # print(f':{T}')
    while t.max() <= T:
        # print(f'tmax:{t.max()}')
        for j in range(N):
            i_max = sorted([(i, hat_mu[i][j]) for i in S[j]], key=lambda x: -x[1])[0][0]
            # print(f'imax:{i_max}')
            hat_mu_max[j] = hat_mu[i_max][j]
            # LINE 4
            # r1 = 0
            for i in S[j]:
                r[i][j] += 1
                r1 = max(r1, r[i][j])

            # r1 = sorted([])
            # haat_mu_max
        for j in range(N):
            for i in range(K):
                average_reward_x[i][j] = 0
        # [1, a * (r1 + 1) + 1]
        for p_r in range(1, int(a * (r1 + 1)) + 1):
            for j in range(N):
                for i in S[j]:
                    # LINE 8 pull an arm
                    reward_x[i][j] = pull_arm(mu, i, j)
                    sum_pull_time[i][j] += 1
                    sum_reward[i][j] += reward_x[i][j]
                    # hat_mu[i][j] = sum_reward[i][j] / sum_pull_time[i][j]
                    regret_list[j].append((regret_list[j][-1] + mu_star[best_arm] - mu_star[i]))
                    # LINE 9
                    average_reward_x[i][j] = (average_reward_x[i][j] * (p_r - 1) + reward_x[i][j]) / p_r # LINE 10

                    t[j] += 1
                    tau[i][j] += 1
                    

                # Line 13 receive other remove
                # acheive other neighbor of j: W[x][j] != 0
                # B_j_hat = set()
            # print(average_reward_x)
            # send B[j] to neighbors
            for j in range(N):
                B_j_hat = {}
                for j_hat in range(len(W)):
                    if W[j_hat][j] != 0 and len(B[j_hat]) > 0:
                        for x in B[j_hat]:
                            if x not in B_j_hat or (x in B_j_hat and B_j_hat[x] > B[j_hat][x]):
                                B_j_hat[x] = B[j_hat][x]

                # Line 14:
                for i_x in B_j_hat:
                    if i_x not in B[j] or (i_x in B[j] and B[j][i_x] > B_j_hat[i_x]):
                        B[j][i_x] = B_j_hat[i_x]

                # # Line 15:
                # if len(B[j]) > 0:
                #     # Line 16:
                #     for j_hat in range(len(W[j])):
                #         if W[j][j_hat] != 0:
                #             for i_x in B[j]:
                #                 if i_x not in B[j_hat] or B[j_hat][i_x] > B[j][i_x]:
                #                     B[j_hat][i_x] = B[j][i_x]
        # print(average_reward_x)
                # B_j_hat = []
                # for j_hat in range(len(W)):
                #     if W[j_hat][j] != 0 and len(B[j_hat]) > 0:
                #         for x in B[j_hat]:
                #             if x not in B_j_hat:
                #                 B_j_hat.append(x)
                # # Line 14:
                # for i_x in B_j_hat:
                #     if i_x not in B[j]:
                #         B[j].append(i_x)
                # # Line 15:
                # if len(B[j]) > 0:
                #     # Line 16:
                #     for j_hat in range(len(W[j])):
                #         if W[j][j_hat] != 0:
                #             for i_x in B[j]:
                #                 if i_x not in B[j_hat]:
                #                     B[j_hat].append(i_x)
            # print(1)

            # Line 17
            # if len(S[j]) > 1:
            #     # Line 18 Send FIXME
            #     pass
            # Line 19 Receive and estimate hat_mu FIXME
        new_hat_mu = copy.deepcopy(hat_mu)
        new_delta = copy.deepcopy(delta)
        new_bits = copy.deepcopy(bits)
        # new_bar_mu = copy.deepcopy(bar_mu)
         
        for j in range(N):
            # Line 20
            for i in S[j]:
                U[i, j] = np.sqrt(np.log(T) / (1 * N * tau[i][j])) + 2 * C / ((r[i][j] + 2) * (1 - lambda_2)) + 1 / (r[i][j] + 2)
                new_bits[i][j] = int(math.ceil(1 + math.log2(r[i][j])))
                new_hat_mu[i][j] = (r[i][j] / (r[i][j] + 2)) * sum(W[j] * (hat_mu[i] + b2d(delta[i][j], bits[i][j]))) + (2 * average_reward_x[i][j]) / (2 + r[i][j])
                if len(S[j]) > 1:
                    commucate_list[int(t[j])] += 1
                    commucate_count_time[int(t[j])] += bits[i][j]
                new_hat_mu[i][j] = round(new_hat_mu[i][j], new_bits[i][j])
                new_delta[i][j] = d2b(new_hat_mu[i][j] - hat_mu[i][j], new_bits[i][j])  # 量化误差

        hat_mu = new_hat_mu
        delta = new_delta
        bits = new_bits
        # print(f"hat_mu:{hat_mu}")

        for j in range(N):
            delete_b_idx = []
            for i in B[j]:
                # if theta[i][j] <= t[j]:
                if B[j][i] <= t[j] and len(S[j]) > 1 and i in S[j]:
                    S[j].remove(i)
                    commucate_list[int(t[j])] += 1
                    commucate_count_time[int(t[j])] += (4 + int(np.floor(np.log2(B[j][i])))) * (Nbs[j])
                    # delete_b_idx.append(i)
            for idx in delete_b_idx:
                del B[j][idx]
            # print(f'B:{B[j]}')
        # LINE 11
        for j in range(N):
            for i in S[j]:
                # U[i, j] = math.sqrt((math.log(T)) / (2 * N * hat_tau[i][j])) + C / ((1 - lambda_2) * (hat_tau[i][j] + 1))
                if hat_mu[i][j] <= hat_mu_max[j] - gamma * U[i][j]:
                    theta[i][j] = t[j] + len(S[j]) * D
                    # i_theta <- i
                    # B[j].append([i, theta[i][j]])  #这边需要加上theta的时间标签
                    # B[j] = {i: theta}
                    if i not in B[j] or (i in B[j] and B[j][i] > theta[i][j]):
                        B[j][i] = theta[i][j]
    # print(hat_mu)
    # print(hat_tau)
    commucate_count = [0] * (T + 2)
    for i in range(0, T + 1):
        commucate_count[i] = commucate_count[i - 1] + commucate_count_time[i]
    for i in range(1, T):
        commucate_list[i] += commucate_list[i - 1]
    print(f'S:{S}')
    print(f'r:{r}')
    # print(f't:{t}')
    commucate_count = np.array(commucate_count[:T + 1])
    commucate_list = np.array(commucate_list[:T + 1])
    print(f'commucate_count : {commucate_count.shape}')
    print(f'commucate_list : {commucate_list.shape}')
    return sum_pull_time, t, S, sum_reward, regret_list, commucate_list, commucate_count


if __name__ == '__main__':
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])
    # W = np.array([[0.34, 0.33, 0.33],
    #               [0.33, 0.34, 0.33],
    #               [0.33, 0.33, 0.34]])
    # Mean of arms (K, N)
    mu15 = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.8, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5, 0.6],
        [0.8, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3, 0.5],
        [0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4, 0.5],
        [0.8, 0.8, 0.6, 0.3, 0.8, 0.9, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1],
        [0.7, 0.8, 0.8, 0.2, 0.8, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1],
        [0.9, 0.9, 0.6, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7, 0.9],
        [0.8, 0.6, 0.9, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7, 0.9],
        [0.8, 0.6, 0.6, 0.7, 0.7, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3, 0.4]
    ]).T
    mu_star = column_means = np.mean(mu15, axis=1)
    best_arm = np.argmax(np.mean(mu15.T, axis=0))
    print(f'mu_star:{mu_star}')
    # mu = np.array([[0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1], [0.9, 0.6, 0.4, 0.2, 0.1]]).T
    T = int(1e6)  # time
    K = 15  # K arms
    N = 8  # n_agent
    LAMDA_2 = 0.5759
    GAMMA = 0.5
    # constant -> 3^(1/2)
    C = math.sqrt(N)

    repeated_time = 10
    regret_list_zero = []
    regret_lists = []
    comm_bit_lists = []
    comm_lists = []

    for repeat_time in tqdm(range(repeated_time)):
        sum_pull_time, time, S, sum_reward, regret_list, commucate_list, commucate_count = distributed_successive_eliminate_2(W, T, C, GAMMA, K, N, mu15, LAMDA_2, best_arm)
        regret_lists.append(regret_list)
        comm_bit_lists.append(commucate_count)
        comm_lists.append(commucate_list)


    min_len = min([min([len(regret_lists[i][j]) for j in range(N)]) for i in range(repeated_time)])
    regret_lists_aligns = np.zeros((repeated_time, N, min_len))
    for i in range(repeated_time):
        for j in range(N):
            regret_lists_aligns[i, j, :] = regret_lists[i][j][:min_len]

    comm_bit_lists = np.array(comm_bit_lists)
    comm_lists = np.array(comm_lists)
    # print(f'comm_bit_lists shape:{comm_bit_lists.shape}')
    # print(f'comm_lists shape:{comm_lists.shape}')

    # regret_lists = np.array(regret_lists)
    np.save('/home/amax/xuyang/var_arm/data/des/regret_lists_mu15.npy', np.array(regret_lists_aligns))
    np.save('/home/amax/xuyang/var_arm/data/des/comm_bits_lists_mu15.npy', np.array(comm_bit_lists))
    np.save("/home/amax/xuyang/var_arm/data/des/comm_times_lists_mu15.npy", np.array(comm_lists))
    # print(regret_list_zero_np.shape)
