import math
import numpy as np
import random
import matplotlib.pyplot as plt

def pull_arm(mu,i,a):
    X = random.uniform(0, 1)
    if X < mu[i][a]:
        return 1
    else:
        return 0
def Confidence(t, n_i_k_t, N_i, beta_i):
    C = (1 + beta_i) * math.sqrt(3 * math.log(t) / (N_i*n_i_k_t)) + (1 / (2 * t))
    return C

def DUCB(N,M,T,mu,mu_star,best_arm,W,beta,Neighbor,Neighbor_x):


    sum_reward = np.zeros((N,M))
    sum_pull_time = np.zeros((N,M))
    regret_list = [[0],[0],[0],[0],[0],[0],[0],[0]]
    #print("start:",regret_list)

    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))
    #z = mu
    #for i in range(N):
    #    for k in range(M):
    #        hat_x[i][k][0] = mu[i][k]

    retrain_time = 5
    z = np.zeros((N, M))
    hat_x = np.zeros((N, M,2))

    for train_time in range(retrain_time):
        for i in range(N):
            for k in range(M):
                reward = pull_arm(mu, i, k)
                hat_x[i][k][0] += reward
                hat_x[i][k][1] += 1
    for i in range(N):
       for k in range(M):
            z[i][k] = hat_x[i][k][0]/hat_x[i][k][1]
    # print(z)

    m = np.zeros((N, M))
    for i in range(N):
        for k in range(M):
            m[i][k] = retrain_time

    #for i in range(N):
    #    for k in range(M):
    #        sum_reward[i][k] += 1
    #        sum_pull_time[i][k] = np.zeros((N, M))
    #        regret_list = [[0], [0], [0]]
    #        m[i][k] = 1

    n = np.zeros((N, M))
    n = m

    C = 0

    for t in range(1,T):
        a = np.zeros(N)
        new_hat_x = hat_x
        for i in range(N):
            # line 3
            A = [set() for _ in range(N)]
            # A[0].add(1)
            # random_element = random.choice(list(A[0]))

            # line 4
            for k in range(M):
                if n[i][k] <= m[i][k] - M:
                    A[i].add(k)

            Q = np.zeros((N,M))
            if not A[i]:
                for k in range(M):
                    print(len(Neighbor[i]))
                    Q[i][k] = z[i][k] + Confidence(t, n[i][k], len(Neighbor[i]), beta[i])
                a[i] = np.argmax(Q[i])
            else:
                a[i] = random.choice(list(A[i]))



            reward = pull_arm(mu,int(i),int(a[i]))
            new_hat_x[int(i)][int(a[i])][0] += reward
            new_hat_x[int(i)][int(a[i])][1] += 1
            sum_reward[i][int(a[i])] += reward
            sum_pull_time[i][int(a[i])] += 1
            regret_list[i].append((regret_list[i][-1] + mu_star[best_arm]-mu_star[int(a[i])]))


        new_z = np.zeros((N,M))
        for i in range(N):
            n[int(i)][int(a[i])] += 1
            for k in range(M):
                for nei in Neighbor[i]:
                    new_z[i][k] += W[i][nei] * z[nei][k]
                new_z[i][k] +=  (new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1])
                #new_z[i][k] =  sum(W[i] * ((z.T)[k])) + new_hat_x[i][k][0]/new_hat_x[i][k][1] - hat_x[i][k][0]/hat_x[i][k][1]
            for k in range(M):
                for j in (Neighbor[i]):
                    m[i][k] = max(n[i][k],m[j][k],m[i][k])

        hat_x = new_hat_x
        z = new_z


    Regret = []
    for agent in range(N):
        reward = 0
        for arm in range(M):
            reward += hat_x[agent][arm][0]
        best_reward = T*mu[agent][best_arm]
        Regret.append((best_reward-reward))
    result_x = np.zeros((M,2))
    for agent in range(N):
        result_x += hat_x[agent]
    result_mu = np.zeros(M)
    for arm in range(M):
        result_mu[arm] = result_x[arm][0]/result_x[arm][1]
    return regret_list

if __name__ == '__main__':
    N = 8 # agent
    M = 10 # arm
    W = np.array([[0.2, 0.2, 0.2, 0, 0, 0, 0.2, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0, 0, 0, 0.2],
                  [0.2, 0.2, 0.2, 0.2, 0.2, 0, 0, 0],
                  [0, 0.2, 0.2, 0.2, 0.2, 0.2, 0, 0],
                  [0, 0, 0.2, 0.2, 0.2, 0.2, 0.2, 0],
                  [0, 0, 0, 0.2, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0, 0, 0, 0.2, 0.2, 0.2, 0.2],
                  [0.2, 0.2, 0, 0, 0, 0.2, 0.2, 0.2]])
    mu = np.array([
        [0.8, 0.9, 0.95, 0.85, 0.85, 0.8, 0.7, 0.65, 0.75, 0.75],
        [0.7, 0.6, 0.2, 0.1, 0.7, 0.2, 0.3, 0.5, 0.7, 0.4],
        [0.3, 0.3, 0.6, 0.4, 0.8, 0.4, 0.5, 0.7, 0.6, 0.6],
        [0.6, 0.5, 0.7, 0.3, 0.85, 0.9, 0.6, 0.6, 0.3, 0.3],
        [0.7, 0.6, 0.8, 0.2, 0.85, 0.3, 0.3, 0.5, 0.3, 0.2],
        [0.5, 0.7, 0.9, 0.7, 0.9, 0.7, 0.6, 0.5, 0.3, 0.6],
        [0.4, 0.5, 0.7, 0.7, 0.7, 0.8, 0.3, 0.7, 0.7, 0.5],
        [0.3, 0.4, 0.6, 0.7, 0.75, 0.7, 0.5, 0.4, 0.2, 0.3]
    ])
    T = int(1e4)  # time
    mu_star = column_means = np.mean(mu, axis=0)
    best_arm = 4
    # print(mu_star)

    beta = [0.5] * N
    #Neighbor = [4,4,4,4,4,4,4,4]
    Neighbor = np.array([[1, 2, 6, 7],
              [0, 2, 3, 7],
              [0, 1, 3, 4],
              [1, 2, 4, 5],
              [2, 3, 5, 6],
              [3, 4, 6, 7],
              [0, 4, 5, 7],
              [0, 1, 5, 6]])
    Neighbor = np.array([[1, 2, 6, 7,0],
                         [0, 2, 3, 7,1],
                         [0, 1, 3, 4,2],
                         [1, 2, 4, 5,3],
                         [2, 3, 5, 6,4],
                         [3, 4, 6, 7,5],
                         [0, 4, 5, 7,6],
                         [0, 1, 5, 6,7]])

    Neighbor_x = np.array([[1, 1, 1, 0, 0, 0, 1, 1],
                  [1, 1, 1, 1, 0, 0, 0, 1],
                  [1, 1, 1, 1, 1, 0, 0, 0],
                  [0, 1, 1, 1, 1, 1, 0, 0],
                  [0, 0, 1, 1, 1, 1, 1, 0],
                  [0, 0, 0, 1, 1, 1, 1, 1],
                  [1, 0, 0, 0, 1, 1, 1, 1],
                  [1, 1, 0, 0, 0, 1, 1, 1]])

    repeated_time = 10
    regret_list = np.zeros((repeated_time,T))

    for repeat_time in range(repeated_time):
        temp_regret_list = DUCB(N,M,T,mu,mu_star,best_arm,W,beta,Neighbor,Neighbor_x)
        #print(temp_regret_list[0])
        #print(len(temp_regret_list[0]))
        regret_list[repeat_time] = [sum(items) for items in zip(*temp_regret_list)]
    #print(regret_list)
    regret_mean = np.mean(regret_list, axis=0)
    regret_std = np.std(regret_list, axis=0)
    index_list = range(len(regret_mean))
    #print("regret_var:",len(regret_var))
    # print((regret_std))
    #print(regret_mean - regret_list[0])
    #plt.scatter(index_list, regret_mean,s=1)

    MARKERS = ["o", "D", "s", "^", "v", "p", "*"]
    COLORS = ["#0e5ad3", "#bc2d14", "#22aa16", "#a011a3", "#d1ba0e", "#14ccc2", "#d67413"]
    LINES = ["solid", "dashed", "dashdot", "dotted", "solid", "dashed", "dashdot", "dotted"]

    #for i in range(len(regret_mean)):

    plt.plot(index_list, regret_mean, linewidth=0.1, color=COLORS[0], marker=MARKERS[0], markersize=0.1,linestyle=LINES[0])
    #print(regret_var)
    plt.fill_between(index_list, regret_mean - regret_std/100, regret_mean + regret_std/100, facecolor=COLORS[1], edgecolor='gray', alpha=0.7)

    plt.title('regret')
    plt.xlabel('time')
    plt.ylabel('value')
    plt.legend()
    plt.show()

    #with open('list_data_DistributedUCB.txt', 'w') as f:
    #    for item in regret_list[0]:
    #        f.write('%s\n' % item)

    with open('list_data_DistributedUCB_mean.txt', 'w') as f:
        for item in regret_mean:
            f.write('%s\n' % item)

    with open('list_data_DistributedUCB_std.txt', 'w') as f:
        for item in regret_std:
            temp_item = item/3
            f.write('%s\n' % item)
