from Environment import *
from Algorithms import *
import numpy as np
import matplotlib.pyplot as plt
import pickle
from pathlib import Path
import matplotlib.gridspec as gridspec
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import FixedLocator, FuncFormatter

def plot(n, repeat, d,noise_std=0.1):
    algorithms = ['ETD-LCBT(iid)','ε-Greedy-LCBT','Gusein-Zade']
    # algorithms = ['ETD-LCBT']

    exp_reward_sum = {alg: 0.0 for alg in algorithms}
    oracle_reward_sum = {alg: 0.0 for alg in algorithms}
    ratio = {alg: 0.0 for alg in algorithms}

    for alg in algorithms:
        for i in range(repeat):
            alg_file = f'./result/{alg}n{n}d{d}repeat{i}noise_std{noise_std}alg.txt'
            oracle_file = f'./result/{alg}n{n}d{d}repeat{i}noise_std{noise_std}oracle.txt'
            with open(alg_file, "rb") as f:
                alg_reward = pickle.load(f)[0]
            with open(oracle_file, "rb") as f:
                oracle_reward = pickle.load(f)[0]
            exp_reward_sum[alg] += alg_reward
            oracle_reward_sum[alg] += oracle_reward
        ratio[alg] = exp_reward_sum[alg] / oracle_reward_sum[alg]

    # Plot all algorithms' ratios in one plot
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = ['gray', 'royalblue', 'gold', 'green']
    markers = ['P', '<', 'v', 'o']
    x = np.array([-0.15,0,0.15])
    y = [ratio[alg] for alg in algorithms]

    for idx, alg in enumerate(algorithms):
        ax.plot(
            x[idx], y[idx],
            marker=markers[idx], color=colors[idx],
            label=alg, markersize=14, linestyle='None', zorder=10
        )

    ax.set_xticks(x)
    ax.set_xticklabels(algorithms, fontsize=18)

    ax.set_xlim(-0.2, 0.2)     # 필요하면 더 좁혀도 됨
    ax.margins(x=0)            # 자동 마진 제거
    ax.axhline(
    y=1 - 1/math.e,        # 수평선 위치
    color='orange',         # 색상
    linestyle='--',        # 점선
    linewidth=1.5,         # 두께
    label=r'$1 - 1/e$'     # 범례에 표시할 라벨
    )

    ref_value = 1 - 1/math.e
    ref_label = r'$1-1/e$'   # 수식처럼 보이게

    ax.set_ylim(0, 2)

    # 주황색 점선 (원하면 label=ref_label 추가)
    ax.axhline(ref_value, color='orange', linestyle='--', linewidth=1.5)

    # 기존 yticks에 ref_value 없으면 추가
    yticks = ax.get_yticks()
    if not np.any(np.isclose(yticks, ref_value)):
        yticks = np.sort(np.append(yticks, ref_value))

    # 고정 위치 + 커스텀 포매터: ref_value만 "1-1/e"로 표시
    ax.yaxis.set_major_locator(FixedLocator(yticks))
    def custom_formatter(v, pos, rv=ref_value, rl=ref_label):
        return rl if np.isclose(v, rv) else f"{v:.2f}"
    ax.yaxis.set_major_formatter(FuncFormatter(custom_formatter))

    # 라벨 생성 후 ref_value 라벨만 주황/볼드
    fig.canvas.draw()
    for lbl, val in zip(ax.get_yticklabels(), yticks):
        if np.isclose(val, ref_value):
            lbl.set_color('orange')
            lbl.set_fontweight('bold')

    ax.set_ylabel(r'$\mathrm{Ratio}$', fontsize=18)
    ax.set_title(f'Competitive Ratio of Algorithms (n={n}, d={d})', fontsize=20)
    ax.legend(fontsize=18) 
    plt.tight_layout()
    plt.savefig(f'./plot/n{n}d{d}repeat{repeat}noise_std{noise_std}_iid.pdf', bbox_inches="tight")



if __name__ == '__main__':
    Path("./plot").mkdir(parents=True, exist_ok=True)
    d = 2
    repeat = 10
    n = 100000
    noise_std=0.8
    plot(n, repeat, d, noise_std)
