import math
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm

def pull_arm(mu, i, a):
    X = random.uniform(0, 1)
    if X < mu[i][a]:
        return 1
    else:
        return 0
def Confidence(t, n_i_k_t, N_i, beta_i):
    C = (1 + beta_i) * math.sqrt(3 * math.log(t) / (N_i*n_i_k_t)) + (1 / (2 * t))
    return C

def get_bits(num):
    if num <= 0: num = 1
    return np.ceil(1 + np.log2(num))

def DUCB(N,M,T,mu,best_arm,W,beta,Neighbor,Neighbor_x):

    mu_star = np.mean(mu, axis=0)
    sum_reward = np.zeros((N,M))
    sum_pull_time = np.zeros((N,M))
    regret_list = [[0],[0],[0],[0],[0],[0],[0],[0]]
    comm_times = [0]
    comm_bits = [0]
    #print("start:",regret_list)

    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))
    #z = mu
    #for i in range(N):
    #    for k in range(M):
    #        hat_x[i][k][0] = mu[i][k]

    retrain_time = 10
    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))

    for train_time in range(retrain_time):
        for i in range(N):
            for k in range(M):
                reward = pull_arm(mu, i, k)
                hat_x[i][k][0] += reward
                hat_x[i][k][1] += 1
    for i in range(N):
       for k in range(M):
            z[i][k] = hat_x[i][k][0]/hat_x[i][k][1]
    # print(z)

    m = np.zeros((N, M))
    for i in range(N):
        for k in range(M):
            m[i][k] = retrain_time

    #for i in range(N):
    #    for k in range(M):
    #        sum_reward[i][k] += 1
    #        sum_pull_time[i][k] = np.zeros((N, M))
    #        regret_list = [[0], [0], [0]]
    #        m[i][k] = 1

    n = np.zeros((N, M))
    n = m

    C = 0

    for t in tqdm(range(1, T)):
        a = np.zeros(N)
        new_hat_x = hat_x
        comm_times.append(comm_times[-1])
        comm_bits.append(comm_bits[-1])
        for i in range(N):
            # line 3
            A = [set() for _ in range(N)]
            # A[0].add(1)
            # random_element = random.choice(list(A[0]))

            # line 4
            for k in range(M):
                if n[i][k] <= m[i][k] - M:
                    A[i].add(k)

            Q = np.zeros((N,M))
            if not A[i]:
                for k in range(M):
                    # print(len(Neighbor[i]))
                    Q[i][k] = z[i][k] + Confidence(t, n[i][k], len(Neighbor[i]), beta[i])
                a[i] = np.argmax(Q[i])
            else:
                a[i] = random.choice(list(A[i]))



            reward = pull_arm(mu,int(i),int(a[i]))
            new_hat_x[int(i)][int(a[i])][0] += reward
            new_hat_x[int(i)][int(a[i])][1] += 1
            sum_reward[i][int(a[i])] += reward
            sum_pull_time[i][int(a[i])] += 1
            regret_list[i].append((regret_list[i][-1] + mu_star[best_arm]-mu_star[int(a[i])]))


        new_z = np.zeros((N,M))
        for i in range(N):
            n[int(i)][int(a[i])] += 1
            for k in range(M):
                for nei in Neighbor[i]:
                    new_z[i][k] += W[i][nei] * z[nei][k]
                    comm_times[-1] += 1
                    comm_bits[-1] += get_bits(z[nei][k]) * 2
                new_z[i][k] +=  (new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1])
                #new_z[i][k] =  sum(W[i] * ((z.T)[k])) + new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1]
            for k in range(M):
                for j in (Neighbor[i]):
                    m[i][k] = max(n[i][k],m[j][k],m[i][k])

        hat_x = new_hat_x
        z = new_z


    Regret = []
    for agent in range(N):
        reward = 0
        for arm in range(M):
            reward += hat_x[agent][arm][0]
        best_reward = T*mu[agent][best_arm]
        Regret.append((best_reward-reward))
    result_x = np.zeros((M,2))
    for agent in range(N):
        result_x += hat_x[agent]
    result_mu = np.zeros(M)
    for arm in range(M):
        result_mu[arm] = result_x[arm][0]/result_x[arm][1]
    return np.array(regret_list), np.array(comm_times), np.array(comm_times)

if __name__ == '__main__':
    
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])
    
    mu4 = np.array([[0.68269, 0.86294, 0.1709 , 0.82458, 0.3762 , 0.47955, 0.80783, 0.49867,
  0.58147, 0.03561],
 [0.43568, 0.55027, 0.74445, 0.93006, 0.28862, 0.34975, 0.90536, 0.919  ,
  0.92289, 0.78626],
 [0.9078 , 0.98804, 0.86752, 0.06922, 0.94168, 0.85524, 0.81858, 0.2751 ,
  0.66068, 0.87343],
 [0.78629, 0.68128, 0.91343, 0.82603, 0.99824, 0.90554, 0.57215, 0.99398,
  0.64304, 0.88964],
 [0.87981, 0.90911, 0.91894, 0.96126, 0.99583, 0.99903, 0.66639, 0.90084,
  0.82415, 0.99424],
 [0.98665, 0.75137, 0.99478, 0.98812, 0.98992, 0.99487, 0.89606, 0.9882 ,
  0.976  , 0.99476],
 [0.92108, 0.85349, 0.98298, 0.99023, 0.99551, 0.99852, 0.91263, 0.99971,
  0.96377, 0.99456],
 [0.8    , 0.7995 , 0.799  , 0.7985 , 0.798  , 0.7975 , 0.797  , 0.7965 ,
  0.796  , 0.7955 ]])
    T = int(1e6)  # time
    best_arm = np.argmax(np.mean(mu4, axis=0))

    Neighbor = np.array([[1, 2, 6, 7, 0],
                         [0, 2, 3, 7, 1],
                         [0, 1, 3, 4, 2],
                         [1, 2, 4, 5, 3],
                         [2, 3, 5, 6, 4],
                         [3, 4, 6, 7, 5],
                         [0, 4, 5, 7, 6],
                         [0, 1, 5, 6, 7]])

    Neighbor_x = np.array([[1, 1, 1, 0, 0, 0, 1, 1],
                  [1, 1, 1, 1, 0, 0, 0, 1],
                  [1, 1, 1, 1, 1, 0, 0, 0],
                  [0, 1, 1, 1, 1, 1, 0, 0],
                  [0, 0, 1, 1, 1, 1, 1, 0],
                  [0, 0, 0, 1, 1, 1, 1, 1],
                  [1, 0, 0, 0, 1, 1, 1, 1],
                  [1, 1, 0, 0, 0, 1, 1, 1]])

    repeated_time = 5
    N = 8 # agent
    M = 10 # arm
    beta = [0.1] * N
    regret_lists = []
    comm_times_lists = []
    comm_bits_lists = []
    for repeat_time in tqdm(range(repeated_time)):
        regret_list, comm_times, comm_bits = DUCB(N, M, T, mu4, best_arm, W, beta, Neighbor, Neighbor_x)
        # print(regret_list.shape)
        regret_lists.append(regret_list)
        comm_times_lists.append(comm_times)
        comm_bits_lists.append(comm_bits)

    regret_lists = np.array(regret_lists)
    comm_times_lists = np.array(comm_times_lists)
    comm_bits_lists = np.array(comm_bits_lists)

    np.save('~/var_delta/data/ducb/regret_lists_mu4_0.npy', regret_lists)
    np.save('~/var_delta/data/ducb/comm_times_list_mu4_0.npy',comm_times_lists)
    np.save('~/var_delta/data/ducb/comm_bits_list_mu4_0.npy', comm_bits_lists)