import numpy as np
from tqdm import tqdm
from collections import deque
class TOMF:
    def __init__(self, T, K, N, alpha, beta, eta, best_arm, mu):
        self.T = T
        self.K = K
        self.N = N
        self.alpha = alpha
        self.beta = beta
        self.eta = eta
        self.best_arm = best_arm
        self.mu = mu  # shape (N, K)
        self.n_hat = np.zeros((N, K), dtype=float)
        self.mu_tidle = np.zeros((N, K), dtype=float)
        self.mu_hat = np.zeros((N, K), dtype=float)
        self.Nt = np.zeros((T, N, K), dtype=float)
        self.Xt = np.zeros((N, K), dtype=float)
        # self.PH = np.ones((N, K), dtype=float)
        # self.Ncom = np.zeros((N, K), dtype=float)
        self.comm_times = np.zeros(T)
        self.comm_bits = np.zeros(T)
        self.regret = np.zeros((T, N), dtype=float)
        self.tau = np.zeros((N, K), dtype=int)
        self.Ct = set(range(K))
        self.group_regret = []
        self.individual_regret = []
        self.t = 0

    # def pull_arm(self, i, a):
    #     X = np.random.uniform(0, 1)
    #     return 1 if X < self.mu[i, a] else 0
    
    def pull_arm(self, i, a):
        return np.random.normal(loc=self.mu[i, a], scale=0.25)
    
    def get_bits(self, num):
        num = 1 if num <= 0 else num
        return int(np.ceil(1 + np.log2(num)))

    def select_arm(self, t, i):
        n = self.n_hat[i, :]
        # mu = self.mu_tidle[i, :]
        mu = self.mu_hat[i, :]
        ucb = np.zeros(self.K)
        for k in range(self.K):
            if n[k] == 0:
                ucb[k] = np.inf
            else:
                ucb[k] = mu[k] + np.sqrt(2 * np.log(t) / n[k])
        return np.argmax(ucb)

    def update_Ct(self, i, t):
        ci_k = self.alpha * np.sqrt(np.log(t) / self.n_hat[i, :] + 1e-8)
        Ct_new = set()
        for k in range(self.K):
            lcb_k = self.mu_hat[i, k] - ci_k[k]
            for kp in range(self.K):
                ucb_kp = self.mu_hat[i, kp] + ci_k[kp]
                if k != kp and ucb_kp > lcb_k:
                    Ct_new.add(k)
                    break
        # print(Ct_new)
        self.Ct = Ct_new

    def message(self, t, i, k):
        return (self.mu_tidle[i][k], self.Nt[t][i][k], (i, k))

    def tomf_mt(self):
        self.regret[0, :] = 0
        initial_pulls = 5
        for i in tqdm(range(self.N)):
            for k in range(self.K):
                rewards = [self.pull_arm(i, k) for _ in range(initial_pulls)]
                r = np.mean(rewards)
                # global state
                self.n_hat[i][k] = initial_pulls
                self.tau[i][k] = 0
                # local state
                self.mu_tidle[i][k] = r
                self.Nt[0][i][k] = initial_pulls
                self.Xt[i][k] = r * initial_pulls
        center_q = deque()
        for t in range(1, self.T):
            q = deque(maxlen=self.T)
            self.comm_times[t] = self.comm_times[t - 1]
            self.comm_bits[t] = self.comm_bits[t - 1]
            for i in range(self.N):
                self.regret[t][i] = self.regret[t - 1][i]
                for k in range(self.K):
                    self.Nt[t][i][k] = self.Nt[t - 1][i][k]
                k = self.select_arm(t, i)
                if t == self.T - 1:
                    print(f'Agent: {i}, Arm: {k}')
                reward_i_k = self.pull_arm(i, k)
                # local update: update counts and estimate regardless of Ct membership
                self.Xt[i][k] += reward_i_k
                self.Nt[t][i][k] = self.Nt[t-1][i][k] + 1
                self.mu_tidle[i][k] = self.Xt[i][k] / self.Nt[t][i][k]
                if self.Nt[t][i][k] >= np.ceil(self.beta * self.Nt[self.tau[i][k]][i][k]):
                    msg = self.message(t, i, k)
                    q.append(msg)
                    self.tau[i][k] = t
                    self.comm_times[t] += self.N
                    self.comm_bits[t] += self.N * (self.get_bits(self.mu_tidle[i][k]) * K + self.get_bits(self.Nt[t][i][k]) * K + self.get_bits(1) * K)
                    

                self.regret[t][i] += (self.mu[i][self.best_arm] - self.mu[i][k])
            # print(q)
            
            if q:
                # merge all message to center server
                merge_Xt = 0.0
                merge_nt = 0.0
                meg_list = []
                for k in range(self.K):
                    # print(f'i:{i}')
                    merge_nt += self.Nt[t][i][k]
                    merge_Xt += self.Xt[i][k]
                    meg_list.append((k, merge_Xt, merge_nt))
                self.comm_times[t] += self.N
                self.comm_bits[t] += self.N * self.get_bits(merge_Xt) * 3
                center_q.append(meg_list)
            for msg in q:
                mu_tilde, N_tilde, (sender, k) = msg
                n_add = np.floor(N_tilde * (1 - 1 / self.beta))
                for i in range(self.N):
                    if self.n_hat[i][k] + n_add > 0:
                        self.mu_hat[i][k] = (self.mu_hat[i][k] * self.n_hat[i][k] + mu_tilde * n_add) / (self.n_hat[i][k] + n_add)
                    self.n_hat[i][k] += n_add

                    # update center server message to each agent i
                    if len(center_q) > 1:
                        meg_list = center_q[-2]
                        center_q.popleft()
                        for k, m_nt, m_Xt in meg_list:
                            self.Nt[t][i][k] += m_nt
                            self.Xt[i][k] += m_Xt
                            self.mu_tidle[i][k] = self.Xt[i][k] / self.Nt[t][i][k]
                            self.comm_times[t] += self.N
                            self.comm_bits[t] += self.N * (self.get_bits(self.Nt[t][i][k]) + self.get_bits(self.Xt[i][k]) + self.get_bits(1))
                            
            for i in range(self.N):
                self.update_Ct(i, t)
            q.clear()
        print(self.mu_hat)
        return np.array(self.regret).T, np.array(self.comm_times), np.array(self.comm_bits)

if __name__ == "__main__":
    repetitions = 20
    regret_list = []
    comm_times_list = []
    comm_bits_list = []
    T = int(1e6)
    
    K = 20
    N = 8
    mu8 = np.array([[0.95 - 0.03 * m - 0.01 * n for m in range(20)] for n in range(8)])
    best_arm = np.argmax(np.mean(mu8, axis=0))
    for _ in tqdm(range(repetitions)):
        tomf = TOMF(T=T, K=K, N=N, alpha=1, beta=3, eta=0.5, best_arm=best_arm, mu=mu8)
        regret, comm_times, comm_bits = tomf.tomf_mt()
        regret_list.append(regret)
        comm_times_list.append(comm_times)
        comm_bits_list.append(comm_bits)
    regret_list = np.array(regret_list)
    print(f'regret_list shape: {regret_list.shape}')
    np.save('~/var_agent/data/tomf/regret_list_agent8.npy', regret_list)
    np.save('~/var_agent/data/tomf/comm_times_list_agent8.npy', comm_times_list)
    np.save('~/var_agent/data/tomf/comm_bits_list_agent8.npy', comm_bits_list)
    
