import numpy as np
import math
from math import pi
import matplotlib.pyplot as plt


def K_best(m):
    Lis = np.argsort(m)[-K:]
    Lis = np.flipud(Lis)
    reward = 1 - np.prod(1 - m[Lis])
    return Lis, reward


def alpha(arm):
    ret = 0
    for j in List:
        if Ti[j] == 1:
            Ti[j] += 1
        tem = Mu_hat_0[j] - delta0 - 2 * math.sqrt(math.log(pi * pi * L * (Ti[j] - 1) * (Ti[j] - 1) / (3 * delta)) / (2 * (Ti[j] - 1)))
        tem = max(tem, 0)
        rtemp = Mu_hat_0[arm] * (Ti[arm] - 1) + 1 - Attack[arm] - tem * Ti[arm]
        if rtemp <= 1:
            Tstars[arm, j] = tem
        else:
            rtemp = Mu_hat_0[arm] * (Ti[arm] - 1) + 1 - Attack[arm] - Tstars[arm, j] * Ti[arm]
        if rtemp > ret:
            ret = rtemp
    return ret


L, K, T = 16, 8, 100000
delta0, delta = 0.1, 0.1
Repeat = 20
M = 9
List = [6, 7, 8, 9, 10, 11, 12, 13]
Times1, Times2 = np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T)


FLAG = False
Mu = np.zeros(L)

optK, optR = K_best(Mu)
print("Optimal Arms:\t", end='')
for item in optK:
    print(Mu[item], end=" ")
print()
for item in Mu:
    print(item)


A = np.zeros(T)
Target = np.zeros(T)


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    Ti = np.ones(L)
    Mu_hat = (np.random.rand(L) < Mu).astype(float)
    Mu_hat_0 = np.zeros(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))


    for t in range(T):
        Radius = np.sqrt(3 * np.log(t + 1) / (2 * Ti))
        UCB = np.clip(Mu_hat + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        flag = False
        for item in Rec:
            Ti[item] += 1
            temp = (np.random.rand() < Mu[item]).astype(int)
            if flag:
                if item not in List:
                    temp = 0
                else:
                    temp = 1
                    Attack[item] += 1
            if item not in List:
                if temp == 1:
                    Mu_hat_0[item] += (1 - Mu_hat_0[item]) / Ti[item]
                    Alpha = alpha(item)
                    if Alpha > 0:
                        temp = 0
                        Attack[item] += 1
                        flag = True
                else:
                    Mu_hat_0[item] += (0 - Mu_hat_0[item]) / Ti[item]
            else:
                Mu_hat_0[item] += (temp - Mu_hat_0[item]) / Ti[item]
            Mu_hat[item] += (temp - Mu_hat[item]) / Ti[item]
            if temp == 1:
                break

        Target[t] = (Target[t] * counter + Ntimes[M]) / (counter + 1)
        A[t] = (A[t] * counter + sum(Attack)) / (counter + 1)
        Times1[t][counter] = Ntimes[M]


Ratio = np.zeros(T)
for t in range(T):
    Ratio[t] = Target[t] / (t + 1)


AT = np.zeros(T)
TargetT = np.zeros(T)


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    Ti = np.ones(L)
    Mu_hat = (np.random.rand(L) < Mu).astype(float)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))

    for t in range(T):
        Radius = np.sqrt(3 * np.log(t + 1) / (2 * Ti))
        UCB = np.clip(Mu_hat + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        for item in Rec:
            Ti[item] += 1
            temp = (np.random.rand() < Mu[item]).astype(int)
            Mu_hat[item] += (temp - Mu_hat[item]) / Ti[item]
            if temp == 1:
                break

        TargetT[t] = (TargetT[t] * counter + Ntimes[M]) / (counter + 1)
        AT[t] = (AT[t] * counter + sum(Attack)) / (counter + 1)
        Times2[t][counter] = Ntimes[M]


Var1, Var2 = np.zeros(T), np.zeros(T)
for index in range(T):
    Var1[index] = np.std(Times1[index])
    Var2[index] = np.std(Times2[index])


RatioT = np.zeros(T)
for t in range(T):
    RatioT[t] = TargetT[t] / (t + 1)


plt.figure(figsize=(8, 6), dpi=600)
plt.grid(True)
x = np.linspace(1, T, T)
plt.plot(x, Target, label="Attacked", color='red', lw=2.4)
plt.plot(x, Target + Var1, color='pink', lw=0.8)
plt.plot(x, Target + Var1, color='pink', lw=0.8)
plt.fill_between(x, Target + Var1, Target - Var1, alpha=0.25, color='pink')

plt.plot(x, TargetT, label="Without attack", color='blue', lw=2.4)
plt.plot(x, TargetT + Var2, color='skyblue', lw=0.8)
plt.plot(x, TargetT - Var2, color='skyblue', lw=0.8)
plt.fill_between(x, TargetT + Var2, TargetT - Var2, alpha=0.25, color='skyblue')

plt.xlabel("t", fontsize=28)
plt.ylabel("Chosen times", fontsize=28)
plt.legend(fontsize=24)
plt.tick_params(labelsize=24)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))

plt.savefig("WithWithoutAttack.png", bbox_inches='tight')
# plt.show()
