import math
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm

def pull_arm(mu, i, a):
    X = random.uniform(0, 1)
    if X < mu[i][a]:
        return 1
    else:
        return 0
def Confidence(t, n_i_k_t, N_i, beta_i):
    C = (1 + beta_i) * math.sqrt(3 * math.log(t) / (N_i*n_i_k_t)) + (1 / (2 * t))
    return C

def get_bits(num):
    if num <= 0: num = 1
    return np.ceil(1 + np.log2(num))

def DUCB(N,M,T,mu,best_arm,W,beta,Neighbor,Neighbor_x):

    mu_star = np.mean(mu, axis=0)
    sum_reward = np.zeros((N,M))
    sum_pull_time = np.zeros((N,M))
    regret_list = [[0],[0],[0],[0],[0],[0],[0],[0]]
    comm_times = [0]
    comm_bits = [0]
    #print("start:",regret_list)

    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))
    #z = mu
    #for i in range(N):
    #    for k in range(M):
    #        hat_x[i][k][0] = mu[i][k]

    retrain_time = 10
    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))

    for train_time in range(retrain_time):
        for i in range(N):
            for k in range(M):
                reward = pull_arm(mu, i, k)
                hat_x[i][k][0] += reward
                hat_x[i][k][1] += 1
    for i in range(N):
       for k in range(M):
            z[i][k] = hat_x[i][k][0]/hat_x[i][k][1]
    # print(z)

    m = np.zeros((N, M))
    for i in range(N):
        for k in range(M):
            m[i][k] = retrain_time

    #for i in range(N):
    #    for k in range(M):
    #        sum_reward[i][k] += 1
    #        sum_pull_time[i][k] = np.zeros((N, M))
    #        regret_list = [[0], [0], [0]]
    #        m[i][k] = 1

    n = np.zeros((N, M))
    n = m

    C = 0

    for t in tqdm(range(1, T)):
        a = np.zeros(N)
        new_hat_x = hat_x
        comm_times.append(comm_times[-1])
        comm_bits.append(comm_bits[-1])
        for i in range(N):
            # line 3
            A = [set() for _ in range(N)]
            # A[0].add(1)
            # random_element = random.choice(list(A[0]))

            # line 4
            for k in range(M):
                if n[i][k] <= m[i][k] - M:
                    A[i].add(k)

            Q = np.zeros((N,M))
            if not A[i]:
                for k in range(M):
                    # print(len(Neighbor[i]))
                    Q[i][k] = z[i][k] + Confidence(t, n[i][k], len(Neighbor[i]), beta[i])
                a[i] = np.argmax(Q[i])
            else:
                a[i] = random.choice(list(A[i]))



            reward = pull_arm(mu,int(i),int(a[i]))
            new_hat_x[int(i)][int(a[i])][0] += reward
            new_hat_x[int(i)][int(a[i])][1] += 1
            sum_reward[i][int(a[i])] += reward
            sum_pull_time[i][int(a[i])] += 1
            regret_list[i].append((regret_list[i][-1] + mu_star[best_arm]-mu_star[int(a[i])]))


        new_z = np.zeros((N,M))
        for i in range(N):
            n[int(i)][int(a[i])] += 1
            for k in range(M):
                for nei in Neighbor[i]:
                    new_z[i][k] += W[i][nei] * z[nei][k]
                    comm_times[-1] += 1
                    comm_bits[-1] += get_bits(z[nei][k]) * 2
                new_z[i][k] +=  (new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1])
                #new_z[i][k] =  sum(W[i] * ((z.T)[k])) + new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1]
            for k in range(M):
                for j in (Neighbor[i]):
                    m[i][k] = max(n[i][k],m[j][k],m[i][k])

        hat_x = new_hat_x
        z = new_z


    Regret = []
    for agent in range(N):
        reward = 0
        for arm in range(M):
            reward += hat_x[agent][arm][0]
        best_reward = T*mu[agent][best_arm]
        Regret.append((best_reward-reward))
    result_x = np.zeros((M,2))
    for agent in range(N):
        result_x += hat_x[agent]
    result_mu = np.zeros(M)
    for arm in range(M):
        result_mu[arm] = result_x[arm][0]/result_x[arm][1]
    return np.array(regret_list), np.array(comm_times), np.array(comm_times)

if __name__ == '__main__':
    
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])
    
    mu2 = np.array([[0.45, 0.68, 0.89, 0.47, 0.73, 0.65, 0.2 , 0.58, 0.76, 0.37],
 [0.67, 0.68, 0.13, 0.89, 0.1 , 0.83, 0.15, 0.33, 0.56, 0.54],
 [0.52, 0.1 , 0.59, 0.82, 0.21, 0.89, 0.4 , 0.32, 0.29, 0.06],
 [0.99, 0.86, 0.66, 0.38, 0.99, 0.5 , 0.33, 0.75, 0.35, 0.39],
 [0.99, 0.95, 0.4 , 0.26, 0.88, 0.3 , 0.45, 0.24, 0.01, 0.18],
 [0.99, 0.99, 0.77, 0.87, 0.77, 0.2 , 0.99, 0.36, 0.01, 0.8 ],
 [0.99, 0.99, 0.76, 0.86, 0.52, 0.48, 0.98, 0.57, 0.82, 0.11],
 [0.8 , 0.75, 0.6 , 0.65, 0.6 , 0.55, 0.5 , 0.45, 0.4 , 0.35]])
    T = int(1e6)  # time
    best_arm = np.argmax(np.mean(mu2, axis=0))

    Neighbor = np.array([[1, 2, 6, 7, 0],
                         [0, 2, 3, 7, 1],
                         [0, 1, 3, 4, 2],
                         [1, 2, 4, 5, 3],
                         [2, 3, 5, 6, 4],
                         [3, 4, 6, 7, 5],
                         [0, 4, 5, 7, 6],
                         [0, 1, 5, 6, 7]])

    Neighbor_x = np.array([[1, 1, 1, 0, 0, 0, 1, 1],
                  [1, 1, 1, 1, 0, 0, 0, 1],
                  [1, 1, 1, 1, 1, 0, 0, 0],
                  [0, 1, 1, 1, 1, 1, 0, 0],
                  [0, 0, 1, 1, 1, 1, 1, 0],
                  [0, 0, 0, 1, 1, 1, 1, 1],
                  [1, 0, 0, 0, 1, 1, 1, 1],
                  [1, 1, 0, 0, 0, 1, 1, 1]])

    repeated_time = 5
    N = 8 # agent
    M = 10 # arm
    beta = [0.1] * N
    regret_lists = []
    comm_times_lists = []
    comm_bits_lists = []
    for repeat_time in tqdm(range(repeated_time)):
        regret_list, comm_times, comm_bits = DUCB(N, M, T, mu2, best_arm, W, beta, Neighbor, Neighbor_x)
        # print(regret_list.shape)
        regret_lists.append(regret_list)
        comm_times_lists.append(comm_times)
        comm_bits_lists.append(comm_bits)

    regret_lists = np.array(regret_lists)
    comm_times_lists = np.array(comm_times_lists)
    comm_bits_lists = np.array(comm_bits_lists)

    np.save('~/var_delta/data/ducb/regret_lists_mu2_0.npy', regret_lists)
    np.save('~/var_delta/data/ducb/comm_times_list_mu2_0.npy',comm_times_lists)
    np.save('~/var_delta/data/ducb/comm_bits_list_mu2_0.npy', comm_bits_lists)