import numpy as np
import pandas as pd
import math
from scipy.stats import bernoulli
from math import pi
from matplotlib import pyplot as plt


K, T = 2, 10000
m1, mstar = 0.85, 0.1
delta0, delta =0.1, 0.1

LB1, UB1 = 0, 0
LBstar, UBstar = 0, 0
U1ave, Ustarave = 0, 0
Cost, AverageNstar = 0, 0

N1List = []
Cost1List = []


def add_list(l1, l2):
    return [x + y for x, y in zip(l1, l2)]


def UCB(rounds, mu, N):
    return mu + 3 * math.sqrt(math.log(rounds) / N) / 2


def alpha0():
    tem = Ustar - delta0 - 2 * math.sqrt(math.log(pi * pi * K * Nstar * Nstar / (3 * delta)) / (2 * Nstar))
    tem = max(tem, 0)
    ret = N1 * U1_0 - Attack1 - tem * N1
    return [ret, tem]


for counter in range(500):

    U1, U1_0, Ustar = 0, 0, 0                   # empirical means
    N1, Nstar = 0, 0                            # times played
    Attack1 = 0                                 # cumulative cost
    Alpha1 = 0                                  # need to attack
    star1 = 0

    t = 0

    """
    first K rounds
    """


    for i in range(1):
        r1 = bernoulli.rvs(m1)
        t += 1
        N1 += 1
        if r1 == 1:
            Attack1 += 1
        U1, U1_0 = 0, r1
        star1 = 0

        rstar = bernoulli.rvs(mstar)
        t += 1
        Nstar += 1
        Ustar = rstar


    """
    the next T - K rounds
    """


    for index in range(K, T):
        t += 1
        UCB1 = UCB(t, U1, N1)
        UCBstar = UCB(t, Ustar, Nstar)
        umax = max(UCB1, UCBstar)
        if UCBstar == umax:
            r = bernoulli.rvs(mstar)
            Nstar += 1
            Ustar = (Ustar * (Nstar - 1) + r) / Nstar

        else:
            r = bernoulli.rvs(m1)
            N1 += 1
            if r == 1:
                U1_0 = (U1_0 * (N1 - 1) + 1) / N1
                U1 = U1 * (N1 - 1) / N1
                Attack1 += 1
            else:
                U1_0 = U1_0 * (N1 - 1) / N1
                U1 = U1 * (N1 - 1) / N1


    Cost = (Cost * counter + Attack1) / (counter + 1)
    AverageNstar = (AverageNstar * counter + Nstar) / (counter + 1)

    U1ave = (U1ave * counter + U1_0) / (counter + 1)

    N1List.append(N1)
    Cost1List.append(Cost)

    print(counter)

Std1 = np.std(N1List)
StdCost = np.std(Cost1List)

"""
visualize
"""

print()

print("Cost = ", end="")
print(Cost)
print("Target: ", end="")
print(AverageNstar)

print("Sts_Times: ", end="")
print(Std1)
print("Sts_Cost: ", end="")
print(StdCost)

print("mean: ", end="")
print(U1ave)
