# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def fp(p):
    fp = 10*np.log(T)
#    fp = 1 #simulate the centralized setting
    return int(fp)


N = 50 #repeat times

C = 1 #communication loss

global T 
T = int(1e6)
sigma = 1/2
# regret = np.zeros([N,T])


K = 14
M=8

mu_local = np.array([
        [0.9, 0.6, 0.8, 0.8, 0.6, 0.8, 0.7, 0.6, 0.6, 0.9, 0.7, 0.6, 0.4, 0.5],
        [0.8, 0.6, 0.7, 0.1, 0.6, 0.2, 0.3, 0.5, 0.6, 0.7, 0.1, 0.7, 0.2, 0.3],
        [0.7, 0.7, 0.6, 0.4, 0.6, 0.4, 0.5, 0.7, 0.7, 0.7, 0.6, 0.4, 0.8, 0.4],
        [0.8, 0.8, 0.6, 0.3, 0.6, 0.6, 0.6, 0.6, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.7, 0.8, 0.8, 0.2, 0.6, 0.3, 0.3, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.9, 0.9, 0.6, 0.7, 0.6, 0.6, 0.6, 0.5, 0.7, 0.2, 0.3, 0.5, 0.6, 0.7],
        [0.8, 0.6, 0.9, 0.7, 0.6, 0.6, 0.3, 0.7, 0.7, 0.9, 0.7, 0.6, 0.5, 0.7],
        [0.8, 0.6, 0.6, 0.7, 0.6, 0.7, 0.5, 0.4, 0.6, 0.5, 0.8, 0.1, 0.4, 0.3]
    ])

mu_global = np.mean(mu_local, axis=0)
global best_arm
best_arm = int(np.argmax(mu_global))
print(best_arm)
def pull_arm(mu):
    X = np.random.uniform(0, 1)
    return 1 if X < mu else 0


def get_bits(num):
    num = num if num > 0 else 1
    return int(np.ceil(1 + np.log2(num)))
comm_c = 0

regret_list = []
comm_times_list = []
comm_bits_list = []
for rep in tqdm(range(N)):
    t = 1
    p = 0

    active_arm = np.array(range(K),dtype = int)
    pull_num = np.zeros([M,K])
    reward_local = np.zeros([M,K])
    reward_global = np.zeros((M, T))
    optimal_reward = np.zeros((M, T))
    regret = np.zeros((M, T))
    comm_times = np.zeros(T)
    comm_bits = np.zeros(T)

    
    data_local = np.zeros([M,K,T])#M*K*T
    data_global = np.zeros([K,T]) #K*T
    

    for j in range(M):
        for i in range(K):
            # data_local[j,i] = np.random.normal(mu_local[j,i],sigma,T)
            data_local[j,i] = np.array([pull_arm(mu_local[j,i]) for _ in range(True)])
    
    optimal_index = best_arm
    for i in range(K):
        # data_global[i] = np.random.normal(mu_global[i],sigma, T)
        data_global[i] = np.array([pull_arm(mu_global[i]) for _ in range(True)])
    
    while t<T:
        '''
        round p
        '''
        
        '''
        local players
        '''
        
        if len(active_arm)>1:
            expl_len = fp(p)
            p += 1
            for k in active_arm:
                for _ in range(min(T-t,expl_len)):
                    for m in range(M):
                        if t >= T: break
                        reward_local[m,k] += data_local[m,k,t]
                        pull_num[m,k] += 1
                        reward_global[m, t] += data_global[k, t]
                        optimal_reward[m, t] += data_global[optimal_index, t]
                        
                        
                    # reward_global[t] = reward_global[t-1]+M*data_global[k,t]
                    # optimal_reward[t] = optimal_reward[t-1]+M*data_global[optimal_index,t]

                        comm_times[t] += M
                        comm_bits[t] += get_bits(reward_global[-1, -1]) * M * len(reward_global)
                    t = t+1
                    if t >= T: break
            if t >= T: break
            mu_local_sample = reward_local/pull_num
            # if t >= T: break
            comm_times[t] += 1
            comm_bits[t] += get_bits(reward_global[-1, -1]) * len(reward_global)
               
        if len(active_arm)==1:
            reward_global[m, t:] += np.arange(T-t) * mu_global[active_arm[0]]
            optimal_reward[m, t:] += + np.arange(T-t) * mu_global[optimal_index]
            # print(t)
            comm_times[t] += M * (T - t)
            comm_bits[t] += get_bits(reward_global[-1, -1]) * M * len(reward_global) * (T - t)
            break
        
        '''
        global server
        '''
        if len(active_arm)>1:
            comm_c += M
            for m in range(M):
                reward_global[m, t - 1] -= C #comment this line out to ignore communication loss
            E = np.array([])
            comm_times[t] += M
            comm_bits[t] += get_bits(mu_local_sample[-1, -1]) * M * len(mu_local_sample)
            mu_global_sample = 1/M*sum(mu_local_sample)
            conf_bnd = np.sqrt(4*sigma**2*np.log(T)/(M*pull_num[0,active_arm[0]])) #the constants are tuned from the original ones in the paper to get better performance
            elm_max = np.nanmax(mu_global_sample)-conf_bnd
            for index in range(len(active_arm)):
                arm = active_arm[index]
                if mu_global_sample[arm]+conf_bnd<elm_max:
                    E = np.append(E,np.array([arm]))
        
            for i in range(len(E)):
                active_arm = np.delete(active_arm, np.where(active_arm == E[i]))
        # comm_times[t] = comm_times[t - 1] + comm_c
        # print(comm_times[t])
        # comm_bits[t] = comm_bits[t - 1] + get_bits(reward_global[-1, -1]) * comm_c * len(reward_global)
        # comm_times[t] += M
        # comm_bits[t] += get_bits(mu_local_sample[-1, -1]) * M * len(mu_local_sample)
    # for m in range(M):
    #     regret[m, t] += (regret[m, t - 1] if t - 1 >= 0 else 0) + (data_global[mu_global] - data_global[k])               
    # regret[rep] = optimal_reward-reward_global
    
    for t in range(1, T):
        comm_times[t] += comm_times[t - 1]
        comm_bits[t] += comm_bits[t - 1]
        for m in range(M):
            optimal_reward[m, t] += optimal_reward[m, t - 1]
            reward_global[m, t] += reward_global[m, t - 1]
    # print(comm_times)
    regret_list.append(np.array(optimal_reward - reward_global))
    comm_times_list.append(comm_times)
    comm_bits_list.append(comm_bits)
regret_list = np.array(regret_list)
comm_times_list = np.array(comm_times_list)
comm_bits_list = np.array(comm_bits_list)
np.save('/home/amax/xuyang/var_reward_gap/data/feducb/regret_list_mu14.npy', regret_list)
np.save('/home/amax/xuyang/var_reward_gap/data/feducb/comm_times_list_mu14.npy', comm_times_list)
np.save('/home/amax/xuyang/var_reward_gap/data/feducb/comm_bits_list_mu14.npy', comm_bits_list)
