import numpy as np
import random
from random import choice
from scipy.stats import bernoulli
from scipy.stats import norm
import matplotlib.pyplot as plt
from tqdm import tqdm
import math

EPS = 10E-12

T = 50000
K = 16
V = 1
C = [2000] * 5
target = K - 2
num = 2

graph = [[0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0],
         [1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0],
         [1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1],
         [1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0],
         [1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1],
         [1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
         [1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0],
         [1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1],
         [1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1]]

# graph = [[0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0],
#         [1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0],
#         [1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0],
#         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
#         [1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],
#         [1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0],
#         [1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0],
#         [1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0],
#         [1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1],
#         [1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0],
#         [1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0],
#         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]]

class Bandit:
    def __init__(self, K):
        self.K = K
        self.U = np.zeros(K)
        self.U = np.linspace(0.02, 0.96, K)[: : -1]
        self.best_U = np.max(self.U)

    def generate_award(self):
        self.generate_U = np.zeros(self.K)
        for i in range(self.K):
            self.generate_U[i] = self.U[i] + np.random.normal(0, 0.1)
            self.generate_U[i] = min(self.generate_U[i], 1)
            self.generate_U[i] = max(self.generate_U[i], 0)
            # self.generate_U[i] = bernoulli.rvs(self.U[i])

        return self.generate_U

    def generate_regret(self, action):
        r = self.best_U - self.U[action]
        return r


class Meta_Algorithm:
    def __init__(self, K):
        self.K = K
        self.unconstrained = False

    def next(self):
        raise NotImplementedError

    def reset(self) -> object:
        raise NotImplementedError

    def update(self, action, feedback):
        raise NotImplementedError

    def sample_action(self, x):
        action_set = np.random.multinomial(1, x, size=1)
        for action in range(self.K):
            if action_set[0][action] == 1:
                return action


class SOGBARBAT(Meta_Algorithm):
    def __init__(self, K):
        super().__init__(K)
        self.W = np.zeros(self.K)

    def generate_probability(self, N_m, n_m, r_m):
        k_m = 0
        r_max = 0
        for k in range(self.K - 1):
            if r_max > r_m[k]:
                r_max = r_m[k]
                k_m = k
        Z_m = np.zeros(K)
        H_m = np.zeros(K)
        H_m[k_m] = N_m
        G_m = np.ones(K)
        G_m[k_m] = 0
        flag = 1
        while flag:
            a_m = np.zeros(K)
            b_m = np.zeros(K)
            c_m = np.copy(G_m)
            D_m = []
            for k in range(K):
                if c_m[k] == 0:
                    continue
                else:
                    for i in range(K):
                        if graph[k][i] == 1:
                            a_m[k] += 1
                            if k != i:
                                b_m[i] += 1
            arg_a_m = np.argsort(a_m)
            for k in range(K):
                if b_m[k] == 0 and c_m[k] == 1:
                    D_m.append(k)
                    c_m[k] = 0
                    for i in range(K):
                        if graph[k][i] == 1:
                            c_m[i] = 0
            for k in range(K):
                j = arg_a_m[k]
                if c_m[j] == 1:
                    D_m.append(j)
                    c_m[j] = 0
                    for i in range(K):
                        if graph[j][i] == 1:
                            c_m[i] = 0
                else:
                    continue
                flag1 = 0
                for k in range(K):
                    if c_m[k] == 1:
                        flag1 = 1
                        break
                if flag1 == 0:
                    break
            cH_m = 100000000
            for x in D_m:
                cH_m = min(cH_m, n_m[x] - H_m[x])
            for x in D_m:
                Z_m[x] += cH_m
                for k in range(K):
                    if graph[x][k] == 1:
                        H_m[k] += cH_m
                        if H_m[k] >= n_m[k]:
                            G_m[k] = 0
                if graph[x][x] == 0:
                    H_m[x] += cH_m
                    if H_m[x] > n_m[x]:
                        G_m[x] = 0
            flag = 0
            for k in range(K):
                if H_m[k] < n_m[k]:
                    flag = 1
                    break
        self.W = Z_m / N_m
        mean = 1 - (self.W.sum())
        self.W[k_m] += mean
        hat_n_m = np.zeros(K)
        for k in range(K):
            for i in range(K):
                if graph[k][i] == 1:
                    hat_n_m[i] += Z_m[k]
        return hat_n_m

    def next(self):
        return self.sample_action(self.W)

class OSMD(Meta_Algorithm):
    L = None
    x = None
    time_step = 0.0
    gamma = 1.0
    learning_rate = None
    bias = None

    def __init__(self, K):
        super().__init__(K)
        self.x = np.zeros(K)

    def next(self, ):
        self.time_step += 1.0
        self.learning_rate = self.get_learning_rate(self.time_step)
        self.solve_optimization()
        return self.sample_action(self.x)

    def reset(self):
        self.L = np.array([0.0] * self.K)
        self.x = [1 / self.K if not self.unconstrained else 0.5 for _ in range(self.K)]
        self.time_step = 0.0
        self.bias = 0

    def solve_optimization(self):
        if self.unconstrained:
            self.x = np.array([self.solve_unconstrained(l * self.learning_rate, x) for l, x in zip(self.L, self.x)])
        else:
            max_iter = 1000
            iteration = 0
            upper = None
            lower = None
            step_size = 1
            while True:
                iteration += 1
                self.index = np.arange(0, self.K, 1)
                self.x = np.array(
                    [self.solve_unconstrained((l + self.bias) * self.learning_rate, x, i) for l, x, i in zip(self.L, self.x, self.index)])
                f = self.x.sum() - 1
                df = self.hessian_inverse()
                next_bias = self.bias + f / df
                if f > 0:
                    lower = self.bias
                    self.bias = next_bias
                    if upper is None:
                        step_size *= 2
                        if next_bias > lower + step_size:
                            self.bias = lower + step_size
                    else:
                        if next_bias > upper:
                            self.bias = (lower + upper) / 2
                else:
                    upper = self.bias
                    self.bias = next_bias
                    if lower is None:
                        step_size *= 2
                        if next_bias < upper - step_size:
                            self.bias = upper - step_size
                    else:
                        if next_bias < lower:
                            self.bias = (lower + upper) / 2

                if iteration > max_iter or abs(f) < 100 * EPS:
                    break

            # assert iteration < max_iter

    def get_learning_rate(self, time):
        return 1.0 / np.sqrt(time)

    def solve_unconstrained(self, loss, warmstart):
        raise NotImplementedError

    def hessian_inverse(self):
        raise NotImplementedError

class Shannon(OSMD):
    def __init__(self, K):
        super().__init__(K)
        self.c = 1
        self.beta = self.c
        self.gamma = 0.5
        self.a_s = 0

    def solve_unconstrained(self, loss, warmstart, i):
        x_val, func_val, dif_func_val, dif_x = warmstart, 1.0, float('inf'), 1.0
        while True:
            func_val = loss + self.beta * (1 - np.log(1.0 / x_val))
            dif_func_val = self.beta / x_val
            dif_x = func_val / (dif_func_val)

            if dif_x > x_val:
                dif_x = x_val / 2
            elif dif_x < x_val - 1.0:
                dif_x = (x_val - 1.0) / 2
            if abs(dif_x) < EPS:
                break
            x_val -= dif_x

        return x_val

    def hessian_inverse(self):
        return self.x.sum() / self.beta

    def get_learning_rate(self, time):
        return 1.0

    def update(self, action, feedback):
        for i in range(len(action)):
            p = 0
            for k in range(K):
                if graph[k][action[i]] == 1:
                    p += self.x[k]
            self.L[action[i]] += feedback[i] / self.x[action[i]]
        if self.unconstrained:
            self.L -= 1
        else:
            self.L += self.bias
            self.bias = 0
        self.beta += self.c / np.sqrt(1 + self.a_s / np.log(K))
        self.gamma = 0.5 / self.beta

    def next(self, ):
        self.time_step += 1.0
        self.learning_rate = self.get_learning_rate(self.time_step)
        self.solve_optimization()
        for k in range(K):
            self.a_s += self.x[k] * np.log(1.0 / self.x[k])
        self.x = (1 - self.gamma) * self.x + self.gamma / K * np.ones(K)
        return self.sample_action(self.x)

class Tsallis(OSMD):
    def __init__(self, K):
        super().__init__(K)
        self.beta = 1 - 1.0 / np.log(K)
        self.aplha = 2
        self.w = 1

    def get_learning_rate(self, time):
        return np.sqrt((np.log(K) / (time * (self.aplha * np.log(K)) + 1)))

    def update(self, action, feedback):
        for i in range(len(action)):
            p = 0
            for k in range(K):
                if graph[k][action[i]] == 1:
                    p += self.x[k]
            self.L[action[i]] += feedback[i] / self.x[action[i]]
        if self.unconstrained:
            self.L -= 1
        else:
            self.L += 0
            self.bias = 0
    def next(self, ):
        self.time_step += 1.0
        self.learning_rate = self.get_learning_rate(self.time_step)
        while True:
            x_pre = self.w
            sum0 = 0
            sum1 = 0
            for k in range(self.K):
                self.x[k] = 4 * np.power(self.learning_rate * (self.L[k] - self.w), -2)
                sum0 += self.x[k]
                sum1 += np.power(self.x[k], 2 - self.beta)
            self.w = self.w - (sum0 - 1) / (self.learning_rate * sum1)
            if abs(x_pre - self.w) <= 1:
                break
        self.x /= sum(self.x)
        return self.sample_action(self.x)

SOG_m = 1
SOG_delta_m = ((SOG_m + 4) * 2 ** (SOG_m + 4)) * np.log(K)
SOG_lambda_m = num * np.log(4 * K * SOG_delta_m) / V
SOG_Delta_m = np.ones(K)
SOG_communication = []
SOG_N_m = 0
SOG_N_m += int(K * SOG_lambda_m * np.power(2, 2 * SOG_m - 2) + 1)
SOG_communication.append(SOG_N_m)
SOG_n_m = np.zeros(K)
for k in range(K):
    SOG_n_m[k] = SOG_lambda_m * np.power(SOG_Delta_m[k], -2)
SOG_award_m = np.zeros([V, K])
SOG_k_m = 0
SOG_Agents = []
for v in range(V):
    SOG_Agents.append(SOGBARBAT(K))
    SOG_Agents[v].generate_probability(SOG_N_m, SOG_n_m, np.ones(K))
SOG_n_m = SOG_Agents[0].generate_probability(SOG_N_m, SOG_n_m, np.ones(K))
SOG_Regret = np.zeros(T)

Shannon_Agents = []
for v in range(V):
    Shannon_Agents.append(Shannon(K))
    Shannon_Agents[v].reset()
Shannon_Regret = np.zeros(T)

Tsallis_Agents = []
for v in range(V):
    Tsallis_Agents.append(Tsallis(K))
    Tsallis_Agents[v].reset()
Tsallis_Regret = np.zeros(T)

bandit = Bandit(K)

for t in (tqdm(range(T))):
    if t > 0:
        SOG_Regret[t] = SOG_Regret[t - 1]
        Shannon_Regret[t] = Shannon_Regret[t - 1]
        Tsallis_Regret[t] = Tsallis_Regret[t - 1]
        for v in range(V):
            concurrent_award = bandit.generate_award()

            action = SOG_Agents[v].next()
            SOG_Regret[t] += bandit.generate_regret(action)
            if C[0] > 0:
                if action < target:
                    SOG_award_m[v][action] += 0
                else:
                    SOG_award_m[v][action] += 1
                C[0] -= 1
            else:
                SOG_award_m[v][action] += concurrent_award[action]

            action = Shannon_Agents[v].next()
            Shannon_Regret[t] += bandit.generate_regret(action)
            feedback = []
            action_set = [action]
            if C[1] > 0:
                if action < target:
                    feedback.append(1)
                else:
                    feedback.append(0)
                C[1] -= 1
            else:
                feedback.append(1 -concurrent_award[action])
            Shannon_Agents[v].update(action_set, feedback)

            action = Tsallis_Agents[v].next()
            Tsallis_Regret[t] += bandit.generate_regret(action)
            feedback = []
            action_set = [action]
            if C[2] > 0:
                if action < target:
                    feedback.append(1)
                else:
                    feedback.append(0)
                C[2] -= 1
            else:
                feedback.append(1 - concurrent_award[action])
            Tsallis_Agents[v].update(action_set, feedback)

        if t == SOG_communication[-1]:
            SOG_award_sum = np.zeros(K)
            for v in range(V):
                SOG_award_sum += SOG_award_m[v]
            SOG_r_k = np.zeros(K)
            SOG_r_star = 0
            for k in range(K):
                SOG_r_k[k] = min(SOG_award_sum[k] / (V * SOG_n_m[k]), 1)
            SOG_r_loop = SOG_r_k - np.sqrt(SOG_lambda_m * np.power(SOG_Delta_m, -2) / (V * SOG_n_m)) / 8
            SOG_r_loop.sort()
            SOG_r_star = SOG_r_loop[K - 1]
            for k in range(K):
                SOG_Delta_m[k] = max(SOG_r_star - SOG_r_k[k], 2 ** (0 - SOG_m))
            SOG_m += 1
            SOG_delta_m = K * ((SOG_m + 4) * 2 ** (SOG_m + 4)) * np.log(K)
            SOG_lambda_m = num * np.log(4 * K * SOG_delta_m) / V
            SOG_N_m = int(K * SOG_lambda_m * np.power(2, 2 * SOG_m - 2) + 1)
            SOG_communication.append(SOG_N_m + t)
            SOG_award_m = np.zeros([V, K])
            for k in range(K):
                SOG_n_m[k] = SOG_lambda_m * np.power(SOG_Delta_m[k], -2)
            for v in range(V):
                SOG_Agents[v].generate_probability(SOG_N_m, SOG_n_m, SOG_r_k)
            SOG_n_m = SOG_Agents[0].generate_probability(SOG_N_m, SOG_n_m, SOG_r_k)

X = np.arange(1, T + 1)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.rcParams['font.size'] = 16

plt.plot(X, SOG_Regret, 'r-', label='SOG-BARBAT')
plt.plot(X, Shannon_Regret, 'b-', label='Shannon-FTRL')
plt.plot(X, Tsallis_Regret, 'y-', label='Tsallis-FTRL')

plt.xlabel('Rounds')
plt.ylabel('Regret')
plt.grid(True)
plt.legend()
plt.ticklabel_format(style='sci', scilimits=(0, 0), axis='x')
plt.ticklabel_format(style='sci', scilimits=(0, 0), axis='y')
plt.show()