# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def fp(p):
    fp = 10*np.log(T)
#    fp = 1 #simulate the centralized setting
    return int(fp)


N = 10 #repeat times

C = 1 #communication loss

global T 
T = int(1e6)
sigma = 1/2
# regret = np.zeros([N,T])


K = 8
M=8
mu_local = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6],
        [0.8, 0.6, 0.7, 0.1, 0.6, 0.2, 0.3, 0.5],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.6, 0.6, 0.5],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.6, 0.3, 0.7],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4]
    ])

mu_global = np.mean(mu_local, axis=0)
global best_arm
best_arm = int(np.argmax(mu_global))
print(best_arm)
def pull_arm(mu):
    X = np.random.uniform(0, 1)
    return 1 if X < mu else 0


def get_bits(num):
    num = num if num > 0 else 1
    return int(np.ceil(1 + np.log2(num)))
comm_c = 0

regret_list = []
ind_reg_list = []
comm_times_list = []
comm_bits_list = []
for rep in tqdm(range(N)):
    t = 1
    p = 0

    active_arm = np.array(range(K),dtype = int)
    pull_num = np.zeros([M,K])
    reward_local = np.zeros([M,K])
    reward_t = np.zeros((M, T))
    reward_global = np.zeros(T)
    opt_rw = np.zeros((M, T))
    optimal_reward = np.zeros(T)
    regret = np.zeros((M, T))
    comm_times = np.zeros(T)
    comm_bits = np.zeros(T)

    
    data_local = np.zeros([M,K,T])#M*K*T
    data_global = np.zeros([K,T]) #K*T
    

    for j in range(M):
        for i in range(K):
            # data_local[j,i] = np.random.normal(mu_local[j,i],sigma,T)
            data_local[j,i] = np.array([pull_arm(mu_local[j,i]) for _ in range(T)])
    
    optimal_index = best_arm
    for i in range(K):
        # data_global[i] = np.random.normal(mu_global[i],sigma, T)
        data_global[i] = np.array([pull_arm(mu_global[i]) for _ in range(T)])
    
    while t<T:
        '''
        round p
        '''
        
        '''
        local players
        '''
        
        if len(active_arm)>1:
            expl_len = fp(p)
            p += 1
            for k in active_arm:
                for _ in range(min(T-t,expl_len)):
                    for m in range(M):
                        # if t >= T: break
                        reward_local[m,k] += data_local[m,k,t]
                        reward_t[m, t] = reward_t[m, t - 1] + data_global[k,t]
                        opt_rw[m, t] = opt_rw[m, t - 1] + data_global[optimal_index,t]
                        pull_num[m,k] += 1
                        comm_times[t] += M
                        comm_bits[t] += get_bits(reward_global[-1]) * M * len(reward_global)
                    reward_global[t] = reward_global[t-1]+M*data_global[k,t]
                    optimal_reward[t] = optimal_reward[t-1]+M*data_global[optimal_index,t]
                    t = t+1
            mu_local_sample = reward_local/pull_num
            if t >= T: break
            comm_times[t] += 1
            comm_bits[t] += get_bits(reward_global[-1]) * len(reward_global)
               
        if len(active_arm)==1:
            reward_global[t:] = reward_global[t-1]+np.arange(T-t)*M*mu_global[active_arm[0]]
            optimal_reward[t:] = optimal_reward[t-1]+np.arange(T-t)*M*mu_global[optimal_index]
            for m in range(M):
                reward_t[m, t:] = reward_t[m, t - 1] + np.arange(T-t)*mu_global[active_arm[0]]
                opt_rw[m, t:] = opt_rw[m, t - 1] + np.arange(T-t)*mu_global[optimal_index]
            comm_times[t] += M * (T - t)
            comm_bits[t] += get_bits(reward_global[-1]) * M * len(reward_global) * (T - t)
            break
        
        '''
        global server
        '''
        if len(active_arm)>1:
            comm_c += M
            reward_global[t - 1] -= C*M #comment this line out to ignore communication loss
            E = np.array([])
            comm_times[t] += M
            comm_bits[t] += get_bits(mu_local_sample[-1, -1]) * M * len(mu_local_sample)
            mu_global_sample = 1/M*sum(mu_local_sample)
            conf_bnd = np.sqrt(4*sigma**2*np.log(T)/(M*pull_num[0,active_arm[0]])) #the constants are tuned from the original ones in the paper to get better performance
            elm_max = np.nanmax(mu_global_sample)-conf_bnd
            for index in range(len(active_arm)):
                arm = active_arm[index]
                if mu_global_sample[arm]+conf_bnd<elm_max:
                    E = np.append(E,np.array([arm]))
        
            for i in range(len(E)):
                active_arm = np.delete(active_arm, np.where(active_arm == E[i]))
    
    for t in range(1, T):
        comm_times[t] += comm_times[t - 1]
        comm_bits[t] += comm_bits[t - 1]
    ind = np.zeros((M, T))
    regret_list.append(np.array(optimal_reward - reward_global))
    comm_times_list.append(comm_times)
    comm_bits_list.append(comm_bits)
    for m in range(M):
        for t in range(T):
            ind[m, t] = opt_rw[m, t] - reward_t[m, t]
    ind_reg_list.append(ind)
regret_list = np.array(regret_list)
comm_times_list = np.array(comm_times_list)
comm_bits_list = np.array(comm_bits_list)
# ind_reg_list = np.array(ind_reg_list)
# print(ind_reg_list.shape)
np.save('group_regret_list_10.npy', regret_list)
np.save('comm_times_list_10.npy', comm_times_list)
np.save('comm_bits_list_10.npy', comm_bits_list)
np.save('ind_reg_list_10.npy', ind_reg_list)