import math
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import multiprocessing

def pull_arm(i,a,mu):
    X = random.uniform(0, 1)
    if X < mu[i][a]:
        return 1
    else:
        return 0
def get_bits(num):
    num = 1 if num <= 0 else num
    return int(np.ceil(1 + np.log2(num)))

def CI(i,k,t,N,n,alpha_1):
    C = (2*N*math.log(t)/n[i][k])**0.5 + alpha_1
    return C

def ArgMaxK(Q,agent,M):
    a_i = 0
    for arm in range(M):
        if Q[agent][arm]>Q[agent][a_i]:
            a_i = arm
    return a_i

def RandomSelect(A,agent):
    random_num = random.randint(0, len(A[agent])-1)
    return random_num

def GossipUCB(N,M,alpha1,T,mu,Neighbor,best_arm):
    mu_star = np.mean(mu, axis=0)
    lamda2 = 0.5759
    X_tilde = np.zeros((N,M,2)) # [sum of reward , pull times]
    n = np.zeros((N,M))    #number of pull
    sita = np.zeros((N,M))    #estimation of the global value
    X = np.zeros((N,M))           #local reward
    n_tilde = np.zeros((N, M))    #max number of the agent to pull the arm
    sum_reward = np.zeros((N, M))
    sum_pull_time = np.zeros((N, M))
    regret_list = [[0],[0],[0],[0],[0],[0],[0],[0]]

    # 通信统计
    comm_count = 0
    comm_bits = 0
    comm_count_list = [0]
    comm_bits_list = [0]
    


    # 1 Initialization
    """for agent in range(N):
        for arm in range(M):
            X[agent][arm] = pull_arm(agent,arm,mu)
            X_tilde[agent][arm][0] += X[agent][arm]
            X_tilde[agent][arm][1] += 1
            sita[agent][arm] = X[agent][arm]
            n[agent][arm] = 1
            n_tilde[agent][arm] = 1
"""
    # Initialization
    for agent in range(N):
        for arm in range(M):
            # X[agent][arm] = pull_arm(agent, arm, mu)
            # print(f'agent:{agent}, arm:{arm}, M:{M}')
            X[agent][arm] = mu[agent][arm]
            X_tilde[agent][arm][0] += X[agent][arm]
            X_tilde[agent][arm][1] += 1
            sita[agent][arm] = X[agent][arm]
            n[agent][arm] = 1
            n_tilde[agent][arm] = 1
            #sum_reward[agent][arm] += X[agent][arm]
            #sum_pull_time[agent][arm] += 1
            #regret_list[agent].append((regret_list[agent][-1] + mu_star[best_arm] - mu_star[arm]))



    for t in tqdm(range(1,T+1)):  # T
        A = [set() for _ in range(N)]

        new_n_tilde = np.zeros((N, M))
        for agent in range(N):
            for arm in range(M):
                #line 5
                for nei in Neighbor[agent]:
                    new_n_tilde[agent][arm] = max(n_tilde[nei][arm], n[agent][arm], new_n_tilde[agent][arm])        #计算arm被哪个agentpull的次数最多
                    comm_count += 1
                    comm_bits += get_bits(n_tilde[nei][arm]) + get_bits(n[agent][arm]) + get_bits(new_n_tilde[agent][arm])


        for agent in range(N):
            for arm in range(M):
                #line 6
                if n[agent][arm] < n_tilde[agent][arm] - N:
                    A[agent].add(arm)

        n_tilde = new_n_tilde

        a = np.zeros(N)
        Q = np.zeros((N, M))
        for agent in range(N):
            if not A[agent]:
                for arm in range(M):
                    #line 9
                    Q[agent][arm] = sita[agent][arm] + CI(agent,arm,t,N,n,alpha1)
                    #line 10
                a[agent] = ArgMaxK(Q,agent,M)
            else:
                #print(A[agent])
                #print(RandomSelect(A,agent))
                a[agent] = list(A[agent])[RandomSelect(A,agent)]
                #print("t:",t,"a[agent]:",a[agent])

        X_tilde_old = np.zeros((N,M))
        for agent in range(N):
            for arm in range(M):
                X_tilde_old[agent][arm] =  X_tilde[agent][arm][0] / X_tilde[agent][arm][1]
        for agent in range(N):
            X[agent][int(a[agent])] = pull_arm(agent, int(a[agent]), mu)
            X_tilde[agent][int(a[agent])][0] += X[agent][int(a[agent])]
            X_tilde[agent][int(a[agent])][1] += 1

            sum_reward[agent][int(a[agent])] += X[agent][int(a[agent])]
            sum_pull_time[agent][int(a[agent])] += 1
            regret_list[agent].append((regret_list[agent][-1] + mu_star[best_arm] - mu_star[int(a[agent])]))

        for agent in range(N):
            n[agent][int(a[agent])] += 1

        random_x = random.uniform(0, 1)

        new_sita = np.zeros((N,M))
        for i in range(N):
            gossip_agent = random.choice(Neighbor[i])
            for k in range(M):
                new_sita[i][k] = (sita[i][k]+sita[gossip_agent][k])/2 + X_tilde[i][k][0]/X_tilde[i][k][1] - X_tilde_old[i][k]
                comm_count += 1
                comm_bits += get_bits(sita[gossip_agent][k])
        sita = new_sita
        comm_count_list.append(comm_count)
        comm_bits_list.append(comm_bits)
    return regret_list, comm_count_list, comm_bits_list

def run_gossip_ucb(args):
    N = 8  # agent
    M = 10
    alpha1 = 64 / (N ** 17)
    T = int(1e6)
    mu4 = np.array([[0.68269, 0.86294, 0.1709 , 0.82458, 0.3762 , 0.47955, 0.80783, 0.49867,
  0.58147, 0.03561],
 [0.43568, 0.55027, 0.74445, 0.93006, 0.28862, 0.34975, 0.90536, 0.919  ,
  0.92289, 0.78626],
 [0.9078 , 0.98804, 0.86752, 0.06922, 0.94168, 0.85524, 0.81858, 0.2751 ,
  0.66068, 0.87343],
 [0.78629, 0.68128, 0.91343, 0.82603, 0.99824, 0.90554, 0.57215, 0.99398,
  0.64304, 0.88964],
 [0.87981, 0.90911, 0.91894, 0.96126, 0.99583, 0.99903, 0.66639, 0.90084,
  0.82415, 0.99424],
 [0.98665, 0.75137, 0.99478, 0.98812, 0.98992, 0.99487, 0.89606, 0.9882 ,
  0.976  , 0.99476],
 [0.92108, 0.85349, 0.98298, 0.99023, 0.99551, 0.99852, 0.91263, 0.99971,
  0.96377, 0.99456],
 [0.8    , 0.7995 , 0.799  , 0.7985 , 0.798  , 0.7975 , 0.797  , 0.7965 ,
  0.796  , 0.7955 ]])
    Neighbor = np.array([[1, 2, 6, 7],
                         [0, 2, 3, 7],
                         [0, 1, 3, 4],
                         [1, 2, 4, 5],
                         [2, 3, 5, 6],
                         [3, 4, 6, 7],
                         [0, 4, 5, 7],
                         [0, 1, 5, 6]])
    best_arm = np.argmax(np.mean(mu4, axis=0))
    return GossipUCB(N, M, alpha1, T, mu4, Neighbor, best_arm)

def GOS():
    repeated_time = 5

    regret_lists = []
    comm_count_lists = []
    comm_bits_lists = []

    from multiprocessing import Pool, cpu_count
    with Pool(processes=cpu_count()) as pool:
        results = pool.map(run_gossip_ucb, range(repeated_time))

    for regret_list, comm_count_list, comm_bits_list in results:
        regret_lists.append(regret_list)
        comm_count_lists.append(comm_count_list)
        comm_bits_lists.append(comm_bits_list)
    regret_lists = np.array(regret_lists)
    comm_count_lists = np.array(comm_count_lists)
    comm_bits_lists = np.array(comm_bits_lists)
    print(regret_lists.shape)
    print(comm_count_lists.shape)
    print(comm_bits_lists.shape)

    np.save('~/var_delta/data/gossip/regret_lists_mu4_5.npy', regret_lists)
    np.save('~/var_delta/data/gossip/comm_count_lists_mu4_5.npy', comm_count_lists)
    np.save('~/var_delta/data/gossip/comm_bits_lists_mu4_5.npy', comm_bits_lists)

if __name__ == "__main__":
    GOS()