import numpy as np
import random
from random import choice
from scipy.stats import bernoulli
from scipy.stats import norm
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

EPS = 10E-12

T = 50000
K = 12
V = 10
C = [5000] * 5
target = K - 2
confidence_level = 1 / (V * K * T * np.log(T))


class Bandit:
    def __init__(self, K):
        self.K = K
        self.U = np.zeros(K)
        self.U = np.linspace(0.02, 0.96, K)[: : -1]
        self.best_U = np.max(self.U)

    def generate_award(self):
        self.generate_U = np.zeros(self.K)
        for i in range(self.K):
            self.generate_U[i] = self.U[i] + np.random.normal(0, 0.1)
            self.generate_U[i] = min(self.generate_U[i], 1)
            self.generate_U[i] = max(self.generate_U[i], 0)
            # self.generate_U[i] = bernoulli.rvs(self.U[i])

        return self.generate_U

    def generate_regret(self, action):
        r = self.best_U - self.U[action]
        return r


class Meta_Algorithm:
    def __init__(self, K):
        self.K = K

    def next(self):
        raise NotImplementedError

    def reset(self) -> object:
        raise NotImplementedError

    def update(self, action, feedback):
        raise NotImplementedError

    def sample_action(self, x):
        action_set = np.random.multinomial(1, x, size=1)
        for action in range(self.K):
            if action_set[0][action] == 1:
                return action


class MABARBAT(Meta_Algorithm):
    def __init__(self, K):
        super().__init__(K)
        self.W = np.zeros(self.K)

    def generate_probability(self, N_m, n_m, r_m):
        k_m = 0
        r_max = 0
        for k in range(self.K - 1):
            if r_max > r_m[k]:
                r_max = r_m[k]
                k_m = k
        self.W = n_m / N_m
        mean = 1 - (self.W.sum())
        self.W[k_m] += mean
        return self.W * N_m

    def next(self):
        return self.sample_action(self.W)


class DRAA(Meta_Algorithm):
    def __init__(self, K):
        super().__init__(K)
        self.W = np.zeros(self.K)

    def generate_probability(self, Delta, Active, DR_m):
        active_num = 0
        Delta = np.power(Delta, 0 - 2)
        self.W = np.zeros(self.K)
        for k in range(self.K):
            if Active[k] == 1:
                active_num += 1
            else:
                self.W[k] = np.power(float(2), 0 - 2 * DR_m + 1) * Delta[k] / np.sum(Delta)
        p_1 = (1 - np.sum(self.W)) / active_num
        for k in range(self.K):
            if Active[k] == 1:
                self.W[k] = p_1
        self.W /= self.W.sum()
        return self.W

    def next(self):
        return self.sample_action(self.W)


class IND_BARBAR(Meta_Algorithm):
    def __init__(self, K):
        super().__init__(K)
        self.W = np.zeros(self.K)

    def generate_probability(self, N_m, n_m):
        self.W = n_m / N_m
        self.W /= self.W.sum()

    def next(self):
        return self.sample_action(self.W)


class OSMD(Meta_Algorithm):
    def __init__(self, K):
        super().__init__(K)
        self.W = np.zeros(self.K)
        self.L = np.zeros(self.K)
        self.x = 1

    def generate_probability(self, eta_t):
        while True:
            x_pre = self.x
            sum0 = 0
            sum1 = 0
            for k in range(self.K):
                self.W[k] = 4 * np.power(eta_t * (self.L[k] - self.x), -2)
                sum0 += self.W[k]
                sum1 += np.power(self.W[k], 3 / 2)
            self.x = self.x - (sum0 - 1) / (eta_t * sum1)
            if abs(x_pre - self.x) <= 1:
                break
        self.W /= self.W.sum()

    def next(self):
        return self.sample_action(self.W)


class IW_FTRL(OSMD):
    def __init(self, K):
        super().__init__(K)

    def loss(self, eta_t, action, loss):
        self.L[action] += loss / self.W[action]


MA_m = 1
MA_delta_m = ((MA_m + 4) * 2 ** (MA_m + 4)) * np.log(K)
MA_lambda_m = np.log(4 * K * MA_delta_m) / V
MA_Delta_m = np.ones(K)
MA_communication = []
MA_N_m = 0
MA_N_m += int(K * MA_lambda_m * np.power(2, 2 * MA_m - 2) + 1)
MA_communication.append(MA_N_m)
MA_n_m = np.zeros(K)
for k in range(K):
    MA_n_m[k] = MA_lambda_m * np.power(MA_Delta_m[k], -2)
MA_award_m = np.zeros([V, K])
MA_k_m = 0
MA_Agents = []
for v in range(V):
    MA_Agents.append(MABARBAT(K))
    MA_Agents[v].generate_probability(MA_N_m, MA_n_m, np.ones(K))
MA_n_m = MA_Agents[0].generate_probability(MA_N_m, MA_n_m, np.ones(K))
MA_Regret = np.zeros([V, T])

DR_m = 1
DR_Delta_m = np.ones(K)
DR_lambda = 16 * np.log(4 / confidence_level) / V
DR_communication = []
DR_N_m = 0
DR_N_m += int(K * DR_lambda * np.power(2, 2 * DR_m - 2))
DR_communication.append(DR_N_m)
DR_award_m = np.zeros([V, K])
DR_r_max_m = 1
DR_active_arm = np.ones(K)
DR_p_m = np.zeros(K)

DR_Agents = []
for v in range(V):
    DR_Agents.append(DRAA(K))
    DR_Agents[v].generate_probability(DR_Delta_m, DR_active_arm, DR_m)
DR_p_m = DR_Agents[0].generate_probability(DR_Delta_m, DR_active_arm, DR_m)
DR_Regret = np.zeros([V,T])

IN_Delta_m = np.ones([V, K])
IN_lambda = np.log(4 / confidence_level)
IN_communication = []
for v in range(V):
    IN_communication.append([])
IN_m = np.zeros(V)
IN_N_m = np.zeros(V)
IN_n_m = np.zeros([V, K])
IN_award_m = np.zeros([V, K])

IN_Agents = []
for v in range(V):
    for k in range(K):
        IN_n_m[v][k] = IN_lambda * np.power(IN_Delta_m[v][k], -2)
    IN_N_m[v] = int(np.sum(IN_n_m[v]))
    IN_communication[v].append(IN_N_m[v])
    IN_m[v] += 1
    IN_Agents.append(IND_BARBAR(K))
    IN_Agents[v].generate_probability(IN_n_m[v], IN_N_m[v])
IN_Regret = np.zeros([V,T])

IW_Agents = []
for v in range(V):
    IW_Agents.append(IW_FTRL(K))
    IW_Agents[v].generate_probability(2)
IW_Regret = np.zeros([V,T])

bandit = Bandit(K)

for t in (tqdm(range(T))):
    for v in range(V):
        if t > 0:
            MA_Regret[v][t] = MA_Regret[v][t - 1]
            DR_Regret[v][t] = DR_Regret[v][t - 1]
            IN_Regret[v][t] = IN_Regret[v][t - 1]
            IW_Regret[v][t] = IW_Regret[v][t - 1]
        concurrent_award = bandit.generate_award()

        action = MA_Agents[v].next()
        MA_Regret[v][t] += bandit.generate_regret(action)

        if C[0] > 0:
            if action < target:
                MA_award_m[v][action] += 0
            else:
                MA_award_m[v][action] += 1
            C[0] -= 1
        else:
            MA_award_m[v][action] += concurrent_award[action]

        action = DR_Agents[v].next()
        DR_Regret[v][t] += bandit.generate_regret(action)

        if C[1] > 0:
            if action < target:
                DR_award_m[v][action] += 0
            else:
                DR_award_m[v][action] += 1
            C[1] -= 1
        else:
            DR_award_m[v][action] += concurrent_award[action]

        action = IN_Agents[v].next()
        IN_Regret[v][t] += bandit.generate_regret(action)

        if C[2] > 0:
            if action < target:
                IN_award_m[v][action] += 0
            else:
                IN_award_m[v][action] += 1
            C[2] -= 1
        else:
            IN_award_m[v][action] += concurrent_award[action]

        eta_t = 2 * np.sqrt(1 / (t + 1))
        action = IW_Agents[v].next()
        IW_Agents[v].generate_probability(eta_t)
        IW_Regret[v][t] += bandit.generate_regret(action)

        if C[3] > 0:
            if action < target:
                IW_Agents[v].loss(eta_t, action, 1)
            else:
                IW_Agents[v].loss(eta_t, action, 0)
            C[3] -= 1
        else:
            IW_Agents[v].loss(eta_t, action, 1 - concurrent_award[action])

    if t == MA_communication[-1]:
        MA_award_sum = np.zeros(K)
        for v in range(V):
            MA_award_sum += MA_award_m[v]
        MA_r_k = np.zeros(K)
        MA_r_star = 0
        for k in range(K):
            MA_r_k[k] = min(MA_award_sum[k] / (V * MA_n_m[k]), 1)
        MA_r_loop = MA_r_k - np.sqrt(MA_lambda_m * np.power(MA_Delta_m, -2) / (V * MA_n_m)) / 8
        MA_r_loop.sort()
        MA_r_star = MA_r_loop[K - 1]
        for k in range(K):
            MA_Delta_m[k] = max(MA_r_star - MA_r_k[k], 2 ** (0 - MA_m))
        MA_m += 1
        MA_delta_m = K * ((MA_m + 4) * 2 ** (MA_m + 4)) * np.log(K)
        MA_lambda_m = np.log(4 * K * MA_delta_m) / V
        MA_N_m = int(K * MA_lambda_m * np.power(2, 2 * MA_m - 2) + 1)
        MA_communication.append(MA_N_m + t)
        MA_award_m = np.zeros([V, K])
        for k in range(K):
            MA_n_m[k] = MA_lambda_m * np.power(MA_Delta_m[k], -2)

        for v in range(V):
            MA_Agents[v].generate_probability(MA_N_m, MA_n_m, MA_r_k)
        MA_n_m = MA_Agents[0].generate_probability(MA_N_m, MA_n_m, MA_r_k)

    if t == DR_communication[-1]:
        DR_award_sum = np.zeros(K)
        DR_p_sum = DR_p_m * V
        for k in range(K):
            for i in range(V):
                DR_award_sum[k] += DR_award_m[i][k]
        DR_r_k = np.zeros(K)
        for k in range(K):
            DR_r_k[k] = DR_award_sum[k] / (DR_p_sum[k] * DR_N_m)
        DR_r_max_m = np.max(DR_r_k - DR_Delta_m / 16)
        DR_active_arm = np.ones(K)
        for k in range(K):
            DR_Delta_m[k] = max(0.125, DR_r_max_m - DR_r_k[k] + 3 * np.power(float(2), -7))
            if DR_r_max_m - DR_r_k[k] >= np.power(float(2), 0 - DR_m) - 3 * np.power(float(2), 0 - 7):
                DR_active_arm[k] = 0
        for v in range(V):
            DR_Agents[v].generate_probability(DR_Delta_m, DR_active_arm, DR_m)
        DR_p_m = DR_Agents[0].generate_probability(DR_Delta_m, DR_active_arm, DR_m)
        DR_N_m = int(K * DR_lambda * np.power(2, 2 * DR_m - 2))
        DR_communication.append(DR_N_m + t)
        DR_m += 1
        DR_award_m = np.zeros([V, K])

    for v in range(V):
        if t == IN_communication[v][-1]:
            r_i = IN_award_m[v] / IN_n_m[v]
            r_i_copy = np.copy(r_i)
            r_star = np.max(r_i_copy - IN_Delta_m[v] / 16)
            for k in range(K):
                IN_Delta_m[v][k] = max(np.power(float(2), 0 - IN_m[v]), r_star - r_i[k])
                IN_n_m[v][k] = IN_lambda * np.power(IN_Delta_m[v][k], -2)
            IN_N_m[v] = int(np.sum(IN_n_m[v]))
            IN_communication[v].append(t + IN_N_m[v])
            IN_m[v] += 1
            IN_Agents[v].generate_probability(IN_N_m[v], IN_n_m[v])
            IN_award_m[v] = np.zeros(K)

num = 0
MA_IN_Regret = np.zeros(T)
DR_IN_Regret = np.zeros(T)
IN_IN_Regret = np.zeros(T)
IW_IN_Regret = np.zeros(T)
for v in range(V):
    MA_IN_Regret += MA_Regret[v] / V
    DR_IN_Regret += DR_Regret[v] / V
    IN_IN_Regret += IN_Regret[v] / V
    IW_IN_Regret += IW_Regret[v] / V

X = np.arange(1, T + 1)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.rcParams['font.size'] = 16

plt.plot(X, MA_IN_Regret, 'r-', label='MA-BARBAT')
plt.plot(X, DR_IN_Regret, 'm-', label='DRAA')
plt.plot(X, IN_IN_Regret, 'y-', label='IND-BARBAR')
plt.plot(X, IW_IN_Regret, 'g-', label='FTRL')

plt.xlabel('Rounds')
plt.ylabel('Regret' + str(num))
plt.grid(True)
plt.legend()
plt.ticklabel_format(style='sci', scilimits=(0, 0), axis='x')
plt.ticklabel_format(style='sci', scilimits=(0, 0), axis='y')
plt.show()

