import numpy as np
import math
import matplotlib.pyplot as plt
horizon = 500000
N = 10
L_max = 20
UCB = np.ones(N) * 10000
LCB = np.ones(N) * 10000
mu = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

b = np.zeros(horizon)
ell = np.zeros(horizon)
typen = np.zeros(horizon)

for i in range(horizon):
    b[i] = i
    ell[i] = np.random.randint(1, L_max+1)
    typen[i] = np.random.randint(N)

def offline_optimal(horizon, b, ell, typen):
    opt = np.zeros(horizon)
    for t in range(horizon-1, -1, -1):
        if t < horizon-1:
            opt[t] = opt[t+1]
        if b[t] + ell[t] < horizon:
            opt[t] = max(opt[t], mu[int(typen[t])] + opt[int(b[t]+ell[t])])
    opt_reward = opt[0]
    for t in range(horizon):
        opt[t] = opt_reward - opt[t]
    return opt

opt = offline_optimal(horizon, b, ell, typen)
print(opt)

def MRDF(horizon, b, ell, typen):
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                M = -1
        if M != -1:
            if (mu[int(typen[t])] / ell[t]) > mu[int(typen[M])] / (ell[M]-t+b[M]):
                M = int(t)
        else:
            M = int(t)
    return cumu_r

sum_r = MRDF(horizon, b, ell, typen)
print(sum_r)

def max_reward(horizon, b, ell, typen):
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                M = -1
        if M != -1:
            if (mu[int(typen[t])] ) > mu[int(typen[M])] :
                M = int(t)
        else:
            M = int(t)
    return cumu_r

sum_r = max_reward(horizon, b, ell, typen)
print(sum_r)

def min_processing_time(horizon, b, ell, typen):
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                sum_r += mu[int(typen[int(b[M])])]
                cumu_r[t] += mu[int(typen[int(b[M])])]
                M = -1
        if M != -1:
            if (b[t]+ell[t]) < (b[M]+ell[M]) :
                M = int(t)
        else:
            M = int(t)
    return cumu_r

sum_r = min_processing_time(horizon, b, ell, typen)
print(sum_r)


def explore_then_schedule(horizon, b, ell, typen):
    N = 10
    UCB = np.ones(N) * 10000
    LCB = np.ones(N) * 10000
    T = np.zeros(N)
    hat_mu = np.zeros(N)
    Flag = True
    M = -1
    sum_r = 0
    cumu_r = np.zeros(horizon)
    for t in range(horizon):
        # print(Flag)
        #print(M)
        cumu_r[t] = cumu_r[t-1]
        if M != -1:
            if b[M] + ell[M] == t:
                r = np.random.normal(mu[int(typen[int(b[M])])], 1)
                sum_r += r
                cumu_r[t] += r
                T[int(typen[int(b[M])])] += 1
                hat_mu[int(typen[int(b[M])])] = (hat_mu[int(typen[int(b[M])])] * (T[int(typen[int(b[M])])]-1) + r) / T[int(typen[int(b[M])])]
                UCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] + math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                LCB[int(typen[int(b[M])])] = hat_mu[int(typen[int(b[M])])] - math.sqrt(6*math.log(T[int(typen[int(b[M])])])/T[int(typen[int(b[M])])])
                M = -1
                Flag = True
        if Flag == False:
            continue
        if M != -1:
            u_1 = UCB[int(typen[int(b[M])])] / (ell[M] - t + b[M])
            l_1 = LCB[int(typen[int(b[M])])] / (ell[M] - t + b[M])
            u_2 = UCB[int(typen[t])] / ell[t]
            l_2 = LCB[int(typen[t])] / ell[t]
            if l_1 > u_2:
                continue
            elif l_2 > u_1:
                M = t
            else:
                if T[int(typen[int(b[M])])] > T[int(typen[t])]:
                    M = t
                    Flag = False
        else:
            M = int(t)
    return cumu_r

sum_r = explore_then_schedule(horizon, b, ell, typen)
print(sum_r)

epochs = 20
co, cm, cr, cp = [], [], [], []
for _ in range(epochs):
    print(_)
    horizon = 500000
    N = 10
    L_max = 20
    UCB = np.ones(N) * 10000
    LCB = np.ones(N) * 10000
    mu = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    b = np.zeros(horizon)
    ell = np.zeros(horizon)
    typen = np.zeros(horizon)

    for i in range(horizon):
        b[i] = i
        ell[i] = np.random.randint(1, L_max+1)
        typen[i] = np.random.randint(N)
    
    opt = offline_optimal(horizon, b, ell, typen)
    co.append(opt)
    # print(opt[horizon-1])
    sum_m = MRDF(horizon, b, ell, typen)
    cm.append(sum_m)
    # print(sum_r)
    sum_r = max_reward(horizon, b, ell, typen)
    cr.append(sum_r)
    # print(sum_r)
    sum_p = min_processing_time(horizon, b, ell, typen)
    cp.append(sum_p)
    # print(sum_r)

cm = np.array(cm)
co = np.array(co)
cr = np.array(cr)
cp = np.array(cp)
o_variance = np.std(co, axis=0)
m_variance = np.std(cm, axis=0)
r_variance = np.std(cr, axis=0)
p_variance = np.std(cp, axis=0)

cO, cM, cR, cP = co[0], cm[0], cr[0], cp[0]
for i in range(1, epochs):
    cO = cO + co[i]
    cM = cM + cm[i]
    cR = cR + cr[i]
    cP = cP + cp[i]
cO = cO/epochs
cM = cM/epochs
cR = cR/epochs
cP = cP/epochs

# y0, yLim = 0, cO[-1]+200

# plt.figure()
# x = np.linspace(1, horizon, horizon)
# plt.grid(True)
# plt.xlabel("time")
# plt.ylabel("reward")
# plt.ylim((y0, yLim))
# plt.xlim((-1, horizon+1))
# plt.plot(x, cO, label = 'OPT', color="red", linewidth="2")
# plt.plot(x, cO - o_variance, color='pink', lw=0.8)
# plt.plot(x, cO + o_variance, color='pink', lw=0.8)
# plt.fill_between(x, cO - o_variance, cO + o_variance, alpha=0.25, color='pink')

# plt.plot(x, cM, label = 'MRDF', color="blue", linewidth="2")
# plt.plot(x, cM - m_variance, color='skyblue', lw=0.8)
# plt.plot(x, cM + m_variance, color='skyblue', lw=0.8)
# plt.fill_between(x, cM - m_variance, cM + m_variance, alpha=0.25, color='skyblue')

# plt.plot(x, cR, label = 'MR', color="green", linewidth="2")
# plt.plot(x, cR + r_variance, color='lightgreen', lw=0.8)
# plt.plot(x, cR - r_variance, color='lightgreen', lw=0.8)
# plt.fill_between(x, cR - r_variance, cR + r_variance, alpha=0.25, color='lightgreen')

# plt.plot(x, cP, label = 'SRPT', color="darkorange", linewidth="2")
# plt.plot(x, cP + p_variance, color='orange', lw=0.8)
# plt.plot(x, cP - p_variance, color='orange', lw=0.8)
# plt.fill_between(x, cP - p_variance, cP + p_variance, alpha=0.25, color='orange')

# plt.xlabel("Time", fontsize=20)
# plt.ylabel("Reward", fontsize=20)
# plt.ticklabel_format(style='sci', axis='both', scilimits=(0, 0))
# plt.legend(fontsize=12)
# plt.xticks(fontsize=14)
# plt.yticks(fontsize=14)
# plt.savefig("fig_bar.pdf")
# plt.show()


horizon = len(cO)  # 时间范围
y0, yLim = 0, cO[-1]+200


# 1. 准备柱状图数据

methods = ['OPT', 'MRDF', 'MR', 'SRPT']
final_values = [cO[-1], cM[-1], cR[-1], cP[-1]]
variance_values = [o_variance[-1], m_variance[-1], r_variance[-1], p_variance[-1]]
colors = ['#FF6B6B', '#4D96FF', '#6BCB77', '#FFA34D']

# 设置全局字体大小
plt.rcParams.update({
    'font.size': 14,          # 全局字体大小
    'axes.titlesize': 16,     # 标题字体大小
    'axes.labelsize': 15,     # 坐标轴标签大小
    'xtick.labelsize': 14,    # X轴刻度大小
    'ytick.labelsize': 13     # Y轴刻度大小
})

plt.figure(figsize=(10, 6))

# 1. 绘制细柱状图 (宽度设为0.35)
bar_width = 0.35
x_pos = np.arange(len(methods))
bars = plt.bar(x_pos, final_values, width=bar_width, 
               color=colors, edgecolor='black', linewidth=1.2, zorder=3)

# 2. 添加误差条
plt.errorbar(x_pos, final_values, yerr=variance_values, 
             fmt='none', ecolor='black', elinewidth=1.5, 
             capsize=8, capthick=1.5, zorder=4)  # 调大capsize

# 3. 设置坐标轴标签（加粗）
plt.xticks(x_pos, methods, weight='bold')
plt.ylabel('Final Reward Value', weight='bold', fontsize=16)  # 调大标签字体
plt.title('Comparison of Final Reward Values', weight='bold', pad=20, fontsize=18)

# 4. 网格线设置
plt.grid(axis='y', linestyle=':', alpha=0.6, zorder=0)

# 5. 在柱子上方显示数值（放大字体）
for bar, var in zip(bars, variance_values):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., 
             height + var + 0.05*max(final_values),
             f'{height:.1f} ± {var:.1f}',
             ha='center', va='bottom', 
             fontsize=13, weight='bold')  # 调大数值字体

# 6. 自动调整y轴范围
plt.ylim(0, max(final_values) * 1.25)  # 增加顶部留白

# 7. 调整布局并保存
plt.tight_layout()
plt.savefig('final_values_comparison.png', dpi=300, bbox_inches='tight')
plt.show()