import numpy as np
import math
from math import pi
import matplotlib.pyplot as plt


def K_best(m):
    Lis = np.argsort(m)[-K:]
    Lis = np.flipud(Lis)
    reward = 1 - np.prod(1 - m[Lis])
    return Lis, reward


def alpha(arm):
    ret = 0
    for j in List:
        if TiP[j] == 1:
            TiP[j] += 1
        tem = S0[j] / TiP[j] - delta0 - 2 * math.sqrt(math.log(pi * pi * L * (TiP[j] - 1) * (TiP[j] - 1) / (3 * delta)) / (2 * (TiP[j] - 1)))
        tem = max(tem, 0)
        rtemp = S0[arm] / TiP[arm] * (TiP[arm] - 1) + 1 - Attack[arm] - tem * TiP[arm]
        if rtemp <= 1:
            Tstars[arm, j] = tem
        else:
            rtemp = S0[arm] / TiP[arm] * (TiP[arm] - 1) + 1 - Attack[arm] - Tstars[arm, j] * TiP[arm]
        if rtemp > ret:
            ret = rtemp
    return ret


L, K, T = 16, 8, 100000
delta0, delta = 0.1, 1.1
Repeat = 20
M = 9
List = [6, 7, 8, 9, 10, 11, 12, 13]
Regret = np.zeros((Repeat, T))

Times1, Times2 = np.array([np.zeros(Repeat)] * T), np.array([np.zeros(Repeat)] * T)

A = np.zeros(T)
Target = np.zeros(T)


Position = np.array([0.8, 0.8, 0.8, 0.4, 0.4, 0.4, 0.2, 0.2])


FLAG = False
Mu = np.zeros(L)

optK, optR = K_best(Mu)
print("Optimal Arms:\t", end='')
for item in optK:
    print(Mu[item], end=" ")
print()
for item in Mu:
    print(item)


file = open('wwout_Mu.txt', 'w')
file.write(str(Mu))
file.write('\n')
file.write(str(Position))
file.close()


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    S = np.zeros(L)
    S0 = np.zeros(L)
    Ti0 = np.ones(L)
    TiP = np.ones(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))

    for t in range(T):
        Radius = np.sqrt(delta * Ti0 * np.log(t + 1) / (2 * TiP * TiP))
        UCB = np.clip(S / TiP + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        clock = 0
        for item in Rec:
            Ti0[item] += 1
            TiP[item] += Position[clock]
            temp = ((np.random.rand() < Mu[item]) * (np.random.rand() < Position[clock])).astype(int)
            S0[item] += temp
            if item not in List:
                if temp == 1:
                    Alpha = alpha(item)
                    if Alpha > 0:
                        temp = 0
                        Attack[item] += 1
            S[item] += temp
            clock += 1

        Target[t] = (Target[t] * counter + Ntimes[M]) / (counter + 1)
        A[t] = (A[t] * counter + sum(Attack)) / (counter + 1)
        Times1[t][counter] = Ntimes[M]


AT = np.zeros(T)
TargetT = np.zeros(T)


for counter in range(Repeat):
    print()
    print("Counter: " + str(counter))


    Radius = np.zeros(L)
    S = np.zeros(L)
    S0 = np.zeros(L)
    Ti0 = np.ones(L)
    TiP = np.ones(L)
    Attack = np.zeros(L)
    Ntimes = np.zeros(L)
    Tstars = np.zeros((L, L))

    for t in range(T):
        Radius = np.sqrt(delta * Ti0 * np.log(t + 1) / (2 * TiP * TiP))
        UCB = np.clip(S / TiP + Radius, a_min=0, a_max=1)
        Rec, _ = K_best(UCB)

        for item in Rec:
            Ntimes[item] += 1

        clock = 0
        for item in Rec:
            Ti0[item] += 1
            TiP[item] += Position[clock]
            temp = ((np.random.rand() < Mu[item]) * (np.random.rand() < Position[clock])).astype(int)
            S0[item] += temp
            if (temp == 1) and (item is not M):
                temp = 0
                Attack[item] += 1
            S[item] += temp
            clock += 1

        TargetT[t] = (TargetT[t] * counter + Ntimes[M]) / (counter + 1)
        AT[t] = (AT[t] * counter + sum(Attack)) / (counter + 1)
        Times2[t][counter] = Ntimes[M]


Var1, Var2 = np.zeros(T), np.zeros(T)
for index in range(T):
    Var1[index] = np.std(Times1[index])
    Var2[index] = np.std(Times2[index])


plt.figure(figsize=(8, 6), dpi=600)
plt.grid(True)
x = np.linspace(1, T, T)
plt.plot(x, Target, label="Attacked", color='red', lw=2.4)
plt.plot(x, Target + Var1, color='pink', lw=0.8)
plt.plot(x, Target + Var1, color='pink', lw=0.8)
plt.fill_between(x, Target + Var1, Target - Var1, alpha=0.25, color='pink')

plt.plot(x, TargetT, label="Without attack", color='blue', lw=2.4)
plt.plot(x, TargetT + Var2, color='skyblue', lw=0.8)
plt.plot(x, TargetT - Var2, color='skyblue', lw=0.8)
plt.fill_between(x, TargetT + Var2, TargetT - Var2, alpha=0.25, color='skyblue')

plt.xlabel("t", fontsize=28)
plt.ylabel("Chosen times", fontsize=28)
plt.legend(fontsize=24)
plt.tick_params(labelsize=24)
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))

plt.savefig("WithWithoutAttack.png", bbox_inches='tight')
# plt.show()
