import numpy as np
import math
from math import pi
import matplotlib.pyplot as plt


def K_best(m):
    Lis = np.argsort(m)[-K:]
    Lis = np.flipud(Lis)
    reward = 1 - np.prod(1 - m[Lis])
    return Lis, reward


def alpha(arm):
    ret = 0
    for j in List:
        if Ti[j] == 1:
            Ti[j] += 1
        tem = Mu_hat_0[j] - delta0 - 2 * math.sqrt(math.log(pi * pi * L * (Ti[j] - 1) * (Ti[j] - 1) / (3 * delta)) / (2 * (Ti[j] - 1)))
        tem = max(tem, 0)
        rtemp = Mu_hat_0[arm] * (Ti[arm] - 1) + 1 - Attack[arm] - tem * Ti[arm]
        if rtemp <= 1:
            Tstars[arm, j] = tem
        else:
            rtemp = Mu_hat_0[arm] * (Ti[arm] - 1) + 1 - Attack[arm] - Tstars[arm, j] * Ti[arm]
        if rtemp > ret:
            ret = rtemp
    return ret


L, K, T = 16, 8, 100000
delta = 0.1
Repeat = 20
M = 9
List = [6, 7, 8, 9, 10, 11, 12, 13]
Cost1, Cost2, Cost5 = np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T)


FLAG = False
Mu = np.zeros(L)


optK, optR = K_best(Mu)
print("Optimal Arms:\t", end='')
for item in optK:
    print(Mu[item], end=" ")
print()
for item in Mu:
    print(item)


file = open('delta0.txt', 'w')
file.write(str(Mu))
file.close()


A = np.zeros(T)
Target = np.zeros(T)
delta0 = 0.1


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    Ti = np.ones(L)
    Mu_hat = (np.random.rand(L) < Mu).astype(float)
    Mu_hat_0 = np.zeros(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))


    for t in range(T):
        Radius = np.sqrt(3 * np.log(t + 1) / (2 * Ti))
        UCB = np.clip(Mu_hat + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        flag = False
        for item in Rec:
            Ti[item] += 1
            temp = (np.random.rand() < Mu[item]).astype(int)
            if flag:
                if item not in List:
                    temp = 0
                else:
                    temp = 1
                    Attack[item] += 1
            if item not in List:
                if temp == 1:
                    Mu_hat_0[item] += (1 - Mu_hat_0[item]) / Ti[item]
                    Alpha = alpha(item)
                    if Alpha > 0:
                        temp = 0
                        Attack[item] += 1
                        flag = True
                else:
                    Mu_hat_0[item] += (0 - Mu_hat_0[item]) / Ti[item]
            else:
                Mu_hat_0[item] += (temp - Mu_hat_0[item]) / Ti[item]
            Mu_hat[item] += (temp - Mu_hat[item]) / Ti[item]
            if temp == 1:
                break

        A[t] = (A[t] * counter + sum(Attack)) / (counter + 1)
        Cost1[t][counter] = sum(Attack)


A2 = np.zeros(T)
Target2 = np.zeros(T)
delta0 = 0.2


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    Ti = np.ones(L)
    Mu_hat = (np.random.rand(L) < Mu).astype(float)
    Mu_hat_0 = np.zeros(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))


    for t in range(T):
        Radius = np.sqrt(3 * np.log(t + 1) / (2 * Ti))
        UCB = np.clip(Mu_hat + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        flag = False
        for item in Rec:
            Ti[item] += 1
            temp = (np.random.rand() < Mu[item]).astype(int)
            if flag:
                if item not in List:
                    temp = 0
                else:
                    temp = 1
                    Attack[item] += 1
            if item not in List:
                if temp == 1:
                    Mu_hat_0[item] += (1 - Mu_hat_0[item]) / Ti[item]
                    Alpha = alpha(item)
                    if Alpha > 0:
                        temp = 0
                        Attack[item] += 1
                        flag = True
                else:
                    Mu_hat_0[item] += (0 - Mu_hat_0[item]) / Ti[item]
            else:
                Mu_hat_0[item] += (temp - Mu_hat_0[item]) / Ti[item]
            Mu_hat[item] += (temp - Mu_hat[item]) / Ti[item]
            if temp == 1:
                break

        A2[t] = (A2[t] * counter + sum(Attack)) / (counter + 1)
        Cost2[t][counter] = sum(Attack)


A5 = np.zeros(T)
Target5 = np.zeros(T)
delta0 = 0.5


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    Ti = np.ones(L)
    Mu_hat = (np.random.rand(L) < Mu).astype(float)
    Mu_hat_0 = np.zeros(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))


    for t in range(T):
        Radius = np.sqrt(3 * np.log(t + 1) / (2 * Ti))
        UCB = np.clip(Mu_hat + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        flag = False
        for item in Rec:
            Ti[item] += 1
            temp = (np.random.rand() < Mu[item]).astype(int)
            if flag:
                if item not in List:
                    temp = 0
                else:
                    temp = 1
                    Attack[item] += 1
            if item not in List:
                if temp == 1:
                    Mu_hat_0[item] += (1 - Mu_hat_0[item]) / Ti[item]
                    Alpha = alpha(item)
                    if Alpha > 0:
                        temp = 0
                        Attack[item] += 1
                        flag = True
                else:
                    Mu_hat_0[item] += (0 - Mu_hat_0[item]) / Ti[item]
            else:
                Mu_hat_0[item] += (temp - Mu_hat_0[item]) / Ti[item]
            Mu_hat[item] += (temp - Mu_hat[item]) / Ti[item]
            if temp == 1:
                break

        A5[t] = (A5[t] * counter + sum(Attack)) / (counter + 1)
        Cost5[t][counter] = sum(Attack)


Var1, Var2, Var5 = np.zeros(T), np.zeros(T), np.zeros(T)
for index in range(T):
    Var1[index] = np.std(Cost1[index])
    Var2[index] = np.std(Cost2[index])
    Var5[index] = np.std(Cost5[index])


plt.figure(figsize=(8, 6), dpi=600)
plt.grid(True)
x = np.linspace(1, T, T)
plt.plot(x, A, label="△₀=0.1", color='red', lw=2.4)
plt.plot(x, A + Var1, color='pink', lw=0.8)
plt.plot(x, A - Var1, color='pink', lw=0.8)
plt.fill_between(x, A + Var1, A - Var1, alpha=0.25, color='pink')

plt.plot(x, A2, label="△₀=0.2", color='blue', lw=2.4)
plt.plot(x, A2 + Var2, color='skyblue', lw=0.8)
plt.plot(x, A2 - Var2, color='skyblue', lw=0.8)
plt.fill_between(x, A2 + Var2, A2 - Var2, alpha=0.25, color='skyblue')

plt.plot(x, A5, label="△₀=0.5", color='green', lw=2.4)
plt.plot(x, A5 + Var5, color='lightgreen', lw=0.8)
plt.plot(x, A5 - Var5, color='lightgreen', lw=0.8)
plt.fill_between(x, A5 + Var5, A5 - Var5, alpha=0.25, color='lightgreen')

plt.xlabel("t", fontsize=28)
plt.ylabel("Cost", fontsize=28)
plt.legend(fontsize=24)
plt.tick_params(labelsize=24)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
plt.savefig("Delta0.png", bbox_inches='tight')
# plt.show()
