import math
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm

def pull_arm(mu, i, a):
    X = random.uniform(0, 1)
    if X < mu[i][a]:
        return 1
    else:
        return 0
def Confidence(t, n_i_k_t, N_i, beta_i):
    C = (1 + beta_i) * math.sqrt(3 * math.log(t) / (N_i*n_i_k_t)) + (1 / (2 * t))
    return C

def get_bits(num):
    if num <= 0: num = 1
    return np.ceil(1 + np.log2(num))

def DUCB(N,M,T,mu,best_arm,W,beta,Neighbor):

    mu_star = np.mean(mu, axis=0)
    sum_reward = np.zeros((N,M))
    sum_pull_time = np.zeros((N,M))
    regret_list = [[0] for _ in range(N)]
    comm_times = [0]
    comm_bits = [0]
    #print("start:",regret_list)

    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))
    #z = mu
    #for i in range(N):
    #    for k in range(M):
    #        hat_x[i][k][0] = mu[i][k]

    retrain_time = 10
    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))

    for train_time in range(retrain_time):
        for i in range(N):
            for k in range(M):
                reward = pull_arm(mu, i, k)
                hat_x[i][k][0] += reward
                hat_x[i][k][1] += 1
    for i in range(N):
       for k in range(M):
            z[i][k] = hat_x[i][k][0]/hat_x[i][k][1]
    # print(z)

    m = np.zeros((N, M))
    for i in range(N):
        for k in range(M):
            m[i][k] = retrain_time

    #for i in range(N):
    #    for k in range(M):
    #        sum_reward[i][k] += 1
    #        sum_pull_time[i][k] = np.zeros((N, M))
    #        regret_list = [[0], [0], [0]]
    #        m[i][k] = 1

    n = np.zeros((N, M))
    n = m

    C = 0

    for t in range(1,T):
        a = np.zeros(N)
        new_hat_x = hat_x
        comm_times.append(comm_times[-1])
        comm_bits.append(comm_bits[-1])
        for i in range(N):
            # line 3
            A = [set() for _ in range(N)]
            # A[0].add(1)
            # random_element = random.choice(list(A[0]))

            # line 4
            for k in range(M):
                if n[i][k] <= m[i][k] - M:
                    A[i].add(k)

            Q = np.zeros((N,M))
            if not A[i]:
                for k in range(M):
                    # print(len(Neighbor[i]))
                    Q[i][k] = z[i][k] + Confidence(t, n[i][k], len(Neighbor[i]), beta[i])
                a[i] = np.argmax(Q[i])
            else:
                a[i] = random.choice(list(A[i]))



            reward = pull_arm(mu,int(i),int(a[i]))
            new_hat_x[int(i)][int(a[i])][0] += reward
            new_hat_x[int(i)][int(a[i])][1] += 1
            sum_reward[i][int(a[i])] += reward
            sum_pull_time[i][int(a[i])] += 1
            regret_list[i].append((regret_list[i][-1] + mu_star[best_arm]-mu_star[int(a[i])]))


        new_z = np.zeros((N,M))
        for i in range(N):
            n[int(i)][int(a[i])] += 1
            for k in range(M):
                for nei in Neighbor[i]:
                    new_z[i][k] += W[i][nei] * z[nei][k]
                    comm_times[-1] += 1
                    comm_bits[-1] += get_bits(z[nei][k]) * 2
                new_z[i][k] +=  (new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1])
                #new_z[i][k] =  sum(W[i] * ((z.T)[k])) + new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1]
            for k in range(M):
                for j in (Neighbor[i]):
                    m[i][k] = max(n[i][k],m[j][k],m[i][k])

        hat_x = new_hat_x
        z = new_z


    Regret = []
    for agent in range(N):
        reward = 0
        for arm in range(M):
            reward += hat_x[agent][arm][0]
        best_reward = T*mu[agent][best_arm]
        Regret.append((best_reward-reward))
    result_x = np.zeros((M,2))
    for agent in range(N):
        result_x += hat_x[agent]
    result_mu = np.zeros(M)
    for arm in range(M):
        result_mu[arm] = result_x[arm][0]/result_x[arm][1]
    return np.array(regret_list), np.array(comm_times), np.array(comm_times)

if __name__ == '__main__':
    M = 20 # arm
    N = 8 # agent
    mu8 = [[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(8)]
    Neighbor8 = [[1 for _ in range(8)] for _ in range(8)]
    W8 = [[1 / 8 for _ in range(8)] for _ in range(8)]
    T = int(1e6)  # time
    best_arm = int(np.argmax(np.mean(mu8, axis=0)))
    # print(mu_star)

    repeated_time = 2
    
    beta = [0.1] * N
    regret_lists = []
    comm_times_lists = []
    comm_bits_lists = []
    for repeat_time in tqdm(range(repeated_time)):
        regret_list, comm_times, comm_bits = DUCB(N, M, T, mu8, best_arm, W8, beta, Neighbor8)
        # print(regret_list.shape)
        regret_lists.append(regret_list)
        comm_times_lists.append(comm_times)
        comm_bits_lists.append(comm_bits)

    regret_lists = np.array(regret_lists)
    comm_times_lists = np.array(comm_times_lists)
    comm_bits_lists = np.array(comm_bits_lists)

    np.save('~/var_agent/data/ducb/regret_lists_agent8_2.npy', regret_lists)
    np.save('~/var_agent/data/ducb/comm_times_list_agent8_2.npy',comm_times_lists)
    np.save('~/var_agent/data/ducb/comm_bits_list_agent8_2.npy', comm_bits_lists)