import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LogNorm
from matplotlib.offsetbox import AnchoredOffsetbox, HPacker, TextArea

# ══════════════════════════════════════════════════════════
# Font & style setup — matches NeurIPS 2026 style file
# ══════════════════════════════════════════════════════════
USE_LATEX = True

_usetex_opts = {}
if USE_LATEX:
    _usetex_opts = {
        "text.usetex": True,
        "text.latex.preamble": (
            r"\renewcommand{\rmdefault}{ptm}"
            r"\usepackage{amsmath}"
            r"\usepackage{amsfonts}"
            r"\usepackage{amssymb}"
            r"\usepackage[normalem]{ulem}"
        ),
    }

matplotlib.rcParams.update({
    "font.family": "serif",
    "mathtext.fontset": "cm",
    **_usetex_opts,

    "font.size": 9,
    "axes.labelsize": 10,
    "axes.titlesize": 9,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.fontsize": 8,

    "axes.linewidth": 0.5,
    "xtick.major.width": 0.5,
    "ytick.major.width": 0.5,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "figure.dpi": 200,
})

C1, C2 = '#3266ad', '#c0392b'
nS, nA = 3, 2
TEXTWIDTH = 5.5
LABELPAD = -6
DOT_LW = 1.2
PI_FONTSIZE = 10

# Panel width matching the 1x4 layout:
# 1x4 has figwidth=5.5, left=0.06, right=0.98, wspace=0.15
# → plot area = 0.92*5.5 = 5.06in, each panel w = 5.06/4.45 = 1.137in
# For 3 panels: 3.3*1.137 = 3.752in → figwidth = 3.752/0.92 = 4.08in
TRIPANEL_WIDTH = 4.08

# ══════════════════════════════════════════════════════════
# MDP setup
# ══════════════════════════════════════════════════════════

def val3(T, R, theta, gamma):
    pi = [theta, 1-theta]
    Rpi = np.array([sum(pi[a]*R[s][a] for a in range(nA)) for s in range(nS)])
    Ppi = np.zeros((nS, nS))
    for s in range(nS):
        for a in range(nA):
            for sp in range(nS):
                Ppi[s, sp] += pi[a] * T[s][a][sp]
    V = np.linalg.solve(np.eye(nS) - gamma * Ppi, Rpi)
    return float(np.ones(nS)/nS @ V)

def sweep(T, R, ps, gamma):
    return np.array([val3(T, R, p, gamma) for p in ps])

def H_safe(eps, delta):
    eps = np.maximum(eps, 1e-10)
    delta = np.maximum(delta, 1e-10)
    return ((1 + eps) + np.sqrt((1 - eps)**2 + 4*eps/delta)) / 2

def colored_title(ax, segments, y=1.25, fontsize=9):
    texts = []
    for txt, col in segments:
        props = dict(fontsize=fontsize, color=col, fontfamily='serif')
        if not USE_LATEX:
            props['math_fontfamily'] = 'cm'
        ta = TextArea(txt, textprops=props)
        texts.append(ta)
    box = HPacker(children=texts, align='baseline', pad=0, sep=2)
    ab = AnchoredOffsetbox(9, child=box,
                           bbox_to_anchor=(0.5, y), bbox_transform=ax.transAxes,
                           frameon=False, pad=0)
    ax.add_artist(ab)

def draw_annotations(ax, J1, J2, ps, ylim, pair):
    if pair is None:
        return
    i, j = pair
    if J1[i] > J1[j]:
        pip_idx, pi_idx = i, j
    else:
        pip_idx, pi_idx = j, i
    for k in [pi_idx, pip_idx]:
        ax.plot(ps[k], J1[k], 'o', color=C1, ms=3.5, zorder=5)
        ax.plot(ps[k], J2[k], 'o', color=C2, ms=3.5, zorder=5)
        top_y = max(J1[k], J2[k])
        ax.plot([ps[k], ps[k]], [ylim[0], top_y],
                color='#555', ls=':', alpha=0.6, lw=DOT_LW, zorder=1)
        ax.plot([0, ps[k]], [J1[k], J1[k]],
                color=C1, ls=':', alpha=0.5, lw=DOT_LW, zorder=1)
        ax.plot([0, ps[k]], [J2[k], J2[k]],
                color=C2, ls=':', alpha=0.5, lw=DOT_LW, zorder=1)
    ax.text(ps[pi_idx], ylim[0] - (ylim[1]-ylim[0])*0.07,
            r"$\pi$", ha='center', va='top', fontsize=PI_FONTSIZE, color='#333')
    ax.text(ps[pip_idx], ylim[0] - (ylim[1]-ylim[0])*0.07,
            r"$\pi'$", ha='center', va='top', fontsize=PI_FONTSIZE, color='#333')

def find_best_pair(J1, J2, ps, min_theta=0.01):
    best_gap, best_pair = 0, None
    valid = np.where((ps > min_theta) & (ps < 1 - min_theta))[0]
    for i in valid[::2]:
        for j in valid[::2]:
            if J1[i] > J1[j] and J2[j] > J2[i]:
                g = min(J1[i]-J1[j], J2[j]-J2[i])
                if g > best_gap:
                    best_gap, best_pair = g, (i, j)
    return best_gap, best_pair

# ══════════════════════════════════════════════════════════
# Transition models
# ══════════════════════════════════════════════════════════

R = [[1,1],[0,0],[0,0]]
T1 = [[[0.7, 0.2, 0.1], [0.1, 0.3, 0.6]],
      [[0.5, 0.3, 0.2], [0.1, 0.2, 0.7]],
      [[0.4, 0.3, 0.3], [0.1, 0.1, 0.8]]]

ps = np.linspace(0.01, 0.99, 300)
J1 = sweep(T1, R, ps, 0.9)

# U-shape (panel a)
np.random.seed(7)
T2_ushape = None
best_dip = 0
for _ in range(3000):
    T2 = [[[*np.random.dirichlet([0.3]*3)] for a in range(nA)] for s in range(nS)]
    J2 = sweep(T2, R, ps, 0.9)
    mi = np.argmin(J2)
    if 40 < mi < 260:
        dip = min(J2[0], J2[-1]) - J2[mi]
        j1s = J1.max() - J1.min()
        if (dip > 0.4 and J2.min() > J1.min() - 0.3*j1s
            and J2.max() < J1.max() + 0.5*j1s
            and np.all(J2 > J1 + 0.05)):
            if dip > best_dip:
                best_dip = dip
                T2_ushape = [[[float(x) for x in T2[s][a]] for a in range(nA)] for s in range(nS)]

# Crossing (panel b)
np.random.seed(42)
T2_cross = None
best_score = 0
for _ in range(500):
    T2 = [[[*np.random.dirichlet([0.5]*3)] for a in range(nA)] for s in range(nS)]
    J2t = sweep(T2, R, ps, 0.9)
    if J2t[0] > J2t[-1] and (J2t[0] - J2t[-1]) > 1.5:
        diff = J1 - J2t
        crossings = np.where(np.diff(np.sign(diff)))[0]
        if len(crossings) == 1 and 80 < crossings[0] < 220:
            rp = abs(J2t.max() - J1.max()) + abs(J2t.min() - J1.min())
            score = (J2t[0] - J2t[-1]) / (1 + 0.3*rp)
            if score > best_score:
                best_score = score
                T2_cross = [[[float(x) for x in T2[s][a]] for a in range(nA)] for s in range(nS)]

J2_ushape = sweep(T2_ushape, R, ps, 0.9)
J2_cross = sweep(T2_cross, R, ps, 0.9)

gap_a, pair_a = find_best_pair(J1, J2_ushape, ps, min_theta=0.10)
gap_b, pair_b = find_best_pair(J1, J2_cross, ps, min_theta=0.25)

print(f"Panel (a): gap={gap_a:.3f}, theta=({ps[pair_a[0]]:.2f}, {ps[pair_a[1]]:.2f})")
print(f"Panel (b): gap={gap_b:.3f}, theta=({ps[pair_b[0]]:.2f}, {ps[pair_b[1]]:.2f})")

all_ab = np.concatenate([J1, J2_ushape, J2_cross])
pad_v = (all_ab.max() - all_ab.min()) * 0.07
ylim_ab = (all_ab.min() - pad_v, all_ab.max() + pad_v)

# ══════════════════════════════════════════════════════════
# Figure: 3 panels, each matching 1x4 panel size
# ══════════════════════════════════════════════════════════
fig = plt.figure(figsize=(TRIPANEL_WIDTH, 1.8))

gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 1], wspace=0.25,
                       left=0.06, right=0.98, top=0.75, bottom=0.18)

# ── (a) U-shape, exploitable ──
ax_a = fig.add_subplot(gs[0, 0])
ax_a.plot(ps, J1, color=C1, lw=1.5, label=r'$J_1$')
ax_a.plot(ps, J2_ushape, color=C2, lw=1.5, label=r'$J_2$')
draw_annotations(ax_a, J1, J2_ushape, ps, ylim_ab, pair_a)
ax_a.spines['top'].set_visible(False)
ax_a.spines['right'].set_visible(False)
ax_a.set_xlim(0, 1); ax_a.set_xticks([0, 1]); ax_a.set_yticks([])
ax_a.set_ylim(ylim_ab)
ax_a.set_xlabel(r'$\theta$', labelpad=LABELPAD)
ax_a.set_ylabel(r'$J(\pi)$')
ax_a.legend(frameon=False, loc='lower right',
            handlelength=1.5, borderpad=0.2, labelspacing=0.25)
colored_title(ax_a, [('(a) ', 'black'), ('$0.3$-exploitable', 'black')])

# ── (b) Crossing, exploitable ──
ax_b = fig.add_subplot(gs[0, 1])
ax_b.plot(ps, J1, color=C1, lw=1.5)
ax_b.plot(ps, J2_cross, color=C2, lw=1.5)
draw_annotations(ax_b, J1, J2_cross, ps, ylim_ab, pair_b)
ax_b.spines['top'].set_visible(False)
ax_b.spines['right'].set_visible(False)
ax_b.set_xlim(0, 1); ax_b.set_xticks([0, 1]); ax_b.set_yticks([])
ax_b.set_ylim(ylim_ab)
ax_b.set_xlabel(r'$\theta$', labelpad=LABELPAD)
colored_title(ax_b, [('(b) ', 'black'), ('$2$-exploitable', 'black')])

# ── (c) Safe horizon H(eps, delta) — axes swapped: x=delta, y=epsilon ──
ax_c = fig.add_subplot(gs[0, 2])

n = 300
delta_grid = np.linspace(-0.02, 1.02, n)
eps_grid = np.linspace(-0.02, 10.02, n)
D, E = np.meshgrid(delta_grid, eps_grid)
Z = H_safe(E, D)

cf = ax_c.contourf(D, E, Z, levels=np.logspace(0, 3, 40),
                    norm=LogNorm(vmin=1, vmax=1000),
                    cmap='RdYlGn', alpha=0.75)

levels_h = [2, 5, 10, 20]
cs = ax_c.contour(D, E, Z, levels=levels_h,
                  colors='black', linewidths=0.7)

ax_c.clabel(cs, levels=[2, 5, 10], inline=True, fontsize=6.5,
            fmt={lev: f'$H={lev}$' for lev in levels_h},
            manual=[(0.55, 0.65), (0.60, 2.8), (0.60, 7.5)])

ax_c.set_xlabel(r'$\delta(\mathcal{T}_1, \mathcal{T}_2)$', labelpad=LABELPAD)
ax_c.set_ylabel(r'$\varepsilon$', labelpad=-16)

ax_c.spines['top'].set_visible(False)
ax_c.spines['right'].set_visible(False)

ax_c.set_xlim(0, 1); ax_c.set_ylim(0, 10)
ax_c.set_xticks([0, 1]); ax_c.set_yticks([0, 10])

colored_title(ax_c, [('(c) Safe horizon ', 'black'),
                      (r'$H(\varepsilon, \delta)$', 'black')])

plt.savefig('figures/fig3_eps.png', bbox_inches='tight', dpi=300)
plt.savefig('figures/fig3_eps.pdf', bbox_inches='tight')
print("Saved fig3_eps")