import numpy as np
import math
import matplotlib.pyplot as plt
horizon = 500000
N = 10
L_max = 30
W = 5
UCB = np.ones(N) * 10000
LCB = np.ones(N) * 10000
mu = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

b = np.zeros(horizon)
ell = np.zeros(horizon)
typen = np.zeros(horizon)

for i in range(horizon):
    b[i] = i
    ell[i] = np.random.randint(1, L_max+1)
    typen[i] = np.random.randint(N)

def offline_optimal(horizon, b, ell, typen):
    opt = np.zeros(horizon)
    for t in range(horizon-1, -1, -1):
        if t < horizon-1:
            opt[t] = opt[t+1]
        if b[t] + ell[t] < horizon:
            opt[t] = max(opt[t], mu[int(typen[t])] + opt[int(b[t]+ell[t])])
    opt_reward = opt[0]
    for t in range(horizon):
        opt[t] = opt_reward - opt[t]
    return opt

opt = offline_optimal(horizon, b, ell, typen)
print(opt)

def MRDF(horizon, b, ell, typen):
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                M = -1
        if M != -1:
            if (mu[int(typen[t])] / ell[t]) > mu[int(typen[M])] / (ell[M]-t+b[M]):
                M = int(t)
        else:
            M = int(t)
    return cumu_r

sum_r = MRDF(horizon, b, ell, typen)
print(sum_r)

def max_reward(horizon, b, ell, typen):
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                M = -1
        if M != -1:
            if (mu[int(typen[t])] ) > mu[int(typen[M])] :
                M = int(t)
        else:
            M = int(t)
    return cumu_r

sum_r = max_reward(horizon, b, ell, typen)
print(sum_r)

def min_processing_time(horizon, b, ell, typen):
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                M = -1
        if M != -1:
            if (b[t]+ell[t]) < (b[M]+ell[M]) :
                M = int(t)
        else:
            M = int(t)
    return cumu_r

sum_r = min_processing_time(horizon, b, ell, typen)
print(sum_r)


def explore_then_schedule(horizon, b, ell, typen):
    N = 10
    UCB = np.ones(N) * 10000
    LCB = np.ones(N) * 10000
    T = np.zeros(N)
    hat_mu = np.zeros(N)
    Flag = True
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        # print(Flag)
        #print(M)
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                r = np.random.normal(mu[int(typen[int(b[M])])], 2)

                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                T[int(typen[int(b[M])])] += 1
                hat_mu[int(typen[int(b[M])])] = (hat_mu[int(typen[int(b[M])])] * (T[int(typen[int(b[M])])]-1) + r) / T[int(typen[int(b[M])])]
                UCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] + math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                LCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] - math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                M = -1
                Flag = True
        if Flag == False:
            continue
        if M != -1:
            if T[int(typen[int(b[M])])] > 100 and T[int(typen[t])] > 100:
                Flag = True
                if (mu[int(typen[t])] / ell[t]) > mu[int(typen[M])] / (ell[M]-t+b[M]):
                    M = int(t)
                continue
            u_1 = UCB[int(typen[int(b[M])])] / (ell[M] - t + b[M])
            l_1 = LCB[int(typen[int(b[M])])] / (ell[M] - t + b[M])
            u_2 = UCB[int(typen[t])] / ell[t]
            l_2 = LCB[int(typen[t])] / ell[t]
            if l_1 > u_2:
                continue
            elif l_2 > u_1:
                M = t
            else:
                if T[int(typen[int(b[M])])] > T[int(typen[t])]:
                    M = t
                    Flag = False
        else:
            M = int(t)
    return cumu_r

def SUCB(horizon, b, ell, typen):
    N = 10
    UCB = np.ones(N) * 10000
    LCB = np.ones(N) * 10000
    T = np.zeros(N)
    hat_mu = np.zeros(N)
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        # print(Flag)
        #print(M)
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                r = np.random.normal(mu[int(typen[int(b[M])])], 2)

                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                T[int(typen[int(b[M])])] += 1
                hat_mu[int(typen[int(b[M])])] = (hat_mu[int(typen[int(b[M])])] * (T[int(typen[int(b[M])])]-1) + r) / T[int(typen[int(b[M])])]
                UCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] + math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                LCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] - math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                M = -1
        if M != -1:
            if T[int(typen[int(b[M])])] > 1000 and T[int(typen[t])] > 1000:
                if (mu[int(typen[t])] / ell[t]) > mu[int(typen[M])] / (ell[M]-t+b[M]):
                    M = int(t)
                continue
            u_1 = UCB[int(typen[int(b[M])])] / (ell[M] - t + b[M])
            l_1 = LCB[int(typen[int(b[M])])] / (ell[M] - t + b[M])
            u_2 = UCB[int(typen[t])] / ell[t]
            l_2 = LCB[int(typen[t])] / ell[t]
            if u_1 > u_2:
                continue
            else:
                M = t
        else:
            M = int(t)
    return cumu_r

sum_r = SUCB(horizon, b, ell, typen)
print(sum_r)

def MRDF_est(horizon, b, ell, typen):
    N = 10
    UCB = np.ones(N) * 10000
    LCB = np.ones(N) * 10000
    T = np.zeros(N)
    hat_mu = np.zeros(N)
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        # print(Flag)
        #print(M)
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                r = np.random.normal(mu[int(typen[int(b[M])])], 2)

                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                T[int(typen[int(b[M])])] += 1
                hat_mu[int(typen[int(b[M])])] = (hat_mu[int(typen[int(b[M])])] * (T[int(typen[int(b[M])])]-1) + r) / T[int(typen[int(b[M])])]
                UCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] + math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                LCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] - math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                M = -1
        if M != -1:
            if T[int(typen[int(b[M])])] > 1000 and T[int(typen[t])] > 1000:
                if (mu[int(typen[t])] / ell[t]) > mu[int(typen[M])] / (ell[M]-t+b[M]):
                    M = int(t)
                continue
            u_1 = hat_mu[int(typen[int(b[M])])] / (ell[M] - t + b[M])
            u_2 = hat_mu[int(typen[t])] / ell[t]

            if u_1 > u_2:
                continue
            else:
                M = t
        else:
            M = int(t)
    return cumu_r

sum_t = MRDF_est(horizon, b, ell, typen)
print(sum_t)

epochs = 20
co, cm, ce, cr, cs, ct = [], [], [], [], [], []
b = np.zeros(horizon)
ell = np.zeros(horizon)
typen = np.zeros(horizon)

for i in range(horizon):
    b[i] = i
    ell[i] = np.random.randint(1, L_max+1)
    typen[i] = np.random.randint(N)
for _ in range(epochs):
    print(_)
    horizon = 500000
    N = 10
    L_max = 20
    UCB = np.ones(N) * 10000
    LCB = np.ones(N) * 10000
    mu = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    
    opt = offline_optimal(horizon, b, ell, typen)
    co.append(opt)
    # print(opt[horizon-1])
    sum_m = MRDF(horizon, b, ell, typen)
    cm.append(sum_m)
    sum_e = SUCB(horizon, b, ell, typen)
    ce.append(sum_e)
    sum_r = sum_m - sum_e
    cr.append(sum_r)
    sum_s = explore_then_schedule(horizon, b, ell, typen)
    sum_s = sum_m - sum_s
    cs.append(sum_s)
    sum_t = MRDF_est(horizon, b, ell, typen)
    sum_t = sum_m - sum_t
    ct.append(sum_t)

cm = np.array(cm)
co = np.array(co)
cr = np.array(cr)
ce = np.array(ce)
cs = np.array(cs)
ct = np.array(ct)
o_variance = np.std(co, axis=0)
m_variance = np.std(cm, axis=0)
r_variance = np.std(cr, axis=0)
e_variance = np.std(ce, axis=0)
s_variance = np.std(cs, axis=0)
t_variance = np.std(ct, axis=0)

cO, cM, cR, cE, cS, cT = co[0], cm[0], cr[0], ce[0], cs[0], ct[0]
for i in range(1, epochs):
    cO = cO + co[i]
    cM = cM + cm[i]
    cR = cR + cr[i]
    cE = cE + ce[i]
    cS = cS + cs[i]
    cT = cT + ct[i]
cO = cO/epochs
cM = cM/epochs
cR = cR/epochs
cE = cE/epochs
cS = cS/epochs
cT = cT/epochs

y0, yLim = 0, cT[-1]+200

plt.figure(dpi=600)
x = np.linspace(1, horizon, horizon)
plt.grid(True)
plt.xlabel("time")
plt.ylabel("regret")
plt.ylim((y0, yLim))
plt.xlim((-1, horizon+1))
# plt.plot(x, cO, label = 'OPT', color="red", linewidth="2")
# plt.plot(x, cO - o_variance, color='pink', lw=0.8)
# plt.plot(x, cO + o_variance, color='pink', lw=0.8)
# plt.fill_between(x, cO - o_variance, cO + o_variance, alpha=0.25, color='pink')

# plt.plot(x, cM, label = 'MRDF', color="blue", linewidth="2")
# plt.plot(x, cM - m_variance, color='skyblue', lw=0.8)
# plt.plot(x, cM + m_variance, color='skyblue', lw=0.8)
# plt.fill_between(x, cM - m_variance, cM + m_variance, alpha=0.25, color='skyblue')

plt.plot(x, cR, label = 'S-UCB', color="green", linewidth="2")
plt.plot(x, cR + r_variance, color='lightgreen', lw=0.8)
plt.plot(x, cR - r_variance, color='lightgreen', lw=0.8)
plt.fill_between(x, cR - r_variance, cR + r_variance, alpha=0.25, color='lightgreen')

plt.plot(x, cS, label = 'Explore then Scheduling', color="yellow", linewidth="2")
plt.plot(x, cS + s_variance, color='lightyellow', lw=0.8)
plt.plot(x, cS - s_variance, color='lightyellow', lw=0.8)
plt.fill_between(x, cS - s_variance, cS + s_variance, alpha=0.25, color='lightyellow')

plt.plot(x, cT, label = 'MRDF with estimation', color="blue", linewidth="2")
plt.plot(x, cT + t_variance, color='lightblue', lw=0.8)
plt.plot(x, cT - t_variance, color='lightblue', lw=0.8)
plt.fill_between(x, cT - t_variance, cT + t_variance, alpha=0.25, color='lightblue')

# plt.plot(x, cP, label = 'SRPT', color="darkorange", linewidth="2")
# plt.plot(x, cP + p_variance, color='orange', lw=0.8)
# plt.plot(x, cP - p_variance, color='orange', lw=0.8)
# plt.fill_between(x, cP - p_variance, cP + p_variance, alpha=0.25, color='orange')

plt.xlabel("Time", fontsize=20)
plt.ylabel("Regret", fontsize=20)
plt.ticklabel_format(style='sci', axis='both', scilimits=(0, 0))
plt.legend(fontsize=12)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.savefig("regret_bar.png", dpi=600, bbox_inches='tight')
plt.show()