import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# ══════════════════════════════════════════════════════════
# Font & style setup — matches NeurIPS 2026 style file
# figsize width = TEXTWIDTH = 5.5in, so fonts are 1:1 on page
# ══════════════════════════════════════════════════════════
matplotlib.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "text.latex.preamble": (
        r"\renewcommand{\rmdefault}{ptm}"
        r"\usepackage{amsmath}"
        r"\usepackage{amsfonts}"
        r"\usepackage{amssymb}"
    ),

    "font.size": 9,
    "axes.labelsize": 10,
    "axes.titlesize": 9,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.fontsize": 8,

    "axes.linewidth": 0.5,
    "xtick.major.width": 0.5,
    "ytick.major.width": 0.5,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "figure.dpi": 200,
})

C1, C2 = '#3266ad', '#c0392b'
nS, nA, GAMMA = 3, 2, 0.9

TEXTWIDTH = 5.5    # NeurIPS textwidth in inches — no scaling
LABELPAD = -6      # theta to x-axis distance
DOT_LW = 1.2       # dotted guide line width
PI_FONTSIZE = 10   # pi, pi' labels match body text

# ══════════════════════════════════════════════════════════
# MDP setup
# ══════════════════════════════════════════════════════════

def val3(T, R, theta):
    pi = [theta, 1-theta]
    Rpi = np.array([sum(pi[a]*R[s][a] for a in range(nA)) for s in range(nS)])
    Ppi = np.zeros((nS, nS))
    for s in range(nS):
        for a in range(nA):
            for sp in range(nS):
                Ppi[s, sp] += pi[a] * T[s][a][sp]
    V = np.linalg.solve(np.eye(nS) - GAMMA * Ppi, Rpi)
    return float(np.ones(nS)/nS @ V)

def sweep(T, R, ps):
    return np.array([val3(T, R, p) for p in ps])

R = [[1,1],[0,0],[0,0]]
T1 = [[[0.7, 0.2, 0.1], [0.1, 0.3, 0.6]],
      [[0.5, 0.3, 0.2], [0.1, 0.2, 0.7]],
      [[0.4, 0.3, 0.3], [0.1, 0.1, 0.8]]]

ps = np.linspace(0.01, 0.99, 300)
J1 = sweep(T1, R, ps)
j1_min, j1_max = J1.min(), J1.max()
j1_spread = j1_max - j1_min

# ── Panel (a): Trivial ──
T2_triv = [[[0.4, 0.3, 0.3], [0.4, 0.3, 0.3]],
            [[0.2, 0.5, 0.3], [0.2, 0.5, 0.3]],
            [[0.3, 0.3, 0.4], [0.3, 0.3, 0.4]]]

# ── Panel (b): Equivalent ──
alpha_eq = 0.3
T2_equiv = []
for s in range(nS):
    row = []
    for a in range(nA):
        orig = np.array(T1[s][a])
        shifted = (1-alpha_eq)*orig + alpha_eq*np.array([1/3, 1/3, 1/3])
        row.append(shifted.tolist())
    T2_equiv.append(row)

# ── Panel (c): U-shape ──
T2_ushape = None
np.random.seed(7)
best_dip = 0
for _ in range(3000):
    T2 = [[[*np.random.dirichlet([0.3]*3)] for a in range(nA)] for s in range(nS)]
    J2 = sweep(T2, R, ps)
    mi = np.argmin(J2)
    if 40 < mi < 260:
        dip = min(J2[0], J2[-1]) - J2[mi]
        if (dip > 0.4 and J2.min() > j1_min - 0.3*j1_spread
            and J2.max() < j1_max + 0.5*j1_spread
            and np.all(J2 > J1 + 0.05)):
            if dip > best_dip:
                best_dip = dip
                T2_ushape = [[[float(x) for x in T2[s][a]] for a in range(nA)] for s in range(nS)]

if T2_ushape:
    print(f"U-shape dip: {best_dip:.3f}")
else:
    T2_ushape = [[[0.430, 0.0, 0.570], [0.979, 0.0, 0.021]],
                 [[0.010, 0.309, 0.680], [0.085, 0.914, 0.000]],
                 [[0.810, 0.000, 0.190], [0.031, 0.707, 0.261]]]
    print("Using fallback U-shape")

# ── Panel (d): X-crossing ──
T2_cross = None
np.random.seed(42)
best_score = 0
for _ in range(500):
    T2 = [[[*np.random.dirichlet([0.5]*3)] for a in range(nA)] for s in range(nS)]
    J2 = sweep(T2, R, ps)
    if J2[0] > J2[-1] and (J2[0] - J2[-1]) > 1.5:
        diff = J1 - J2
        crossings = np.where(np.diff(np.sign(diff)))[0]
        if len(crossings) == 1 and 80 < crossings[0] < 220:
            range_penalty = abs(J2.max() - j1_max) + abs(J2.min() - j1_min)
            score = (J2[0] - J2[-1]) / (1 + 0.3*range_penalty)
            if score > best_score:
                best_score = score
                T2_cross = [[[float(x) for x in T2[s][a]] for a in range(nA)] for s in range(nS)]

# Compute all curves
J2a = sweep(T2_triv, R, ps)
J2b = sweep(T2_equiv, R, ps)
J2c = sweep(T2_ushape, R, ps)
J2d = sweep(T2_cross, R, ps)

all_vals = np.concatenate([J1, J2a, J2b, J2c, J2d])
ymin, ymax = all_vals.min(), all_vals.max()
pad = (ymax - ymin) * 0.07
YLIM = (ymin - pad, ymax + pad)

T2_all = [J2a, J2b, J2c, J2d]

for name, J2 in [('triv', J2a), ('equiv', J2b), ('ushape', J2c), ('cross', J2d)]:
    print(f"{name}: [{J2.min():.2f}, {J2.max():.2f}]")

# ══════════════════════════════════════════════════════════
# Shared helpers
# ══════════════════════════════════════════════════════════

def find_exploit(J1, J2, ps, lo_range=(0.15, 0.40), hi_range=(0.60, 0.85), min_theta=0.10):
    best_gap, best_ij = 0, None
    lo = np.where((ps > lo_range[0]) & (ps < lo_range[1]))[0]
    hi = np.where((ps > hi_range[0]) & (ps < hi_range[1]))[0]
    for i in lo[::2]:
        for j in hi[::2]:
            for a, b in [(i,j), (j,i)]:
                if J1[a] > J1[b] and J2[b] > J2[a]:
                    g = min(J1[a]-J1[b], J2[b]-J2[a])
                    if g > best_gap:
                        best_gap, best_ij = g, (a, b)
    if best_ij is None:
        valid = np.where((ps > min_theta) & (ps < 1 - min_theta))[0]
        for i in valid[::3]:
            for j in valid[::3]:
                if J1[i] > J1[j] and J2[j] > J2[i]:
                    g = min(J1[i]-J1[j], J2[j]-J2[i])
                    if g > best_gap:
                        best_gap, best_ij = g, (i, j)
    return best_ij

def colored_title(ax, segments, y=1.18, fontsize=9):
    from matplotlib.offsetbox import AnchoredOffsetbox, HPacker, TextArea
    texts = []
    for txt, col in segments:
        ta = TextArea(txt, textprops=dict(fontsize=fontsize, color=col,
                      usetex=True, fontfamily='serif'))
        texts.append(ta)
    box = HPacker(children=texts, align='baseline', pad=0, sep=2)
    ab = AnchoredOffsetbox(9, child=box,
                           bbox_to_anchor=(0.5, y), bbox_transform=ax.transAxes,
                           frameon=False, pad=0)
    ax.add_artist(ab)

title_segs = [
    [('(a) ', 'black'), (r'$\mathcal{T}_2$', C2), (' is trivial', 'black')],
    [('(b) ', 'black'), (r'$\mathcal{T}_1$', C1), (', ', 'black'),
     (r'$\mathcal{T}_2$', C2), (' equivalent', 'black')],
    [('(c) ', 'black'), (r'$\mathcal{T}_1$', C1), (', ', 'black'),
     (r'$\mathcal{T}_2$', C2), (' exploitable', 'black')],
    [('(d) ', 'black'), (r'$\mathcal{T}_1$', C1), (', ', 'black'),
     (r'$\mathcal{T}_2$', C2), (' exploitable', 'black')],
]

exploit_ranges = {
    2: dict(lo_range=(0.10, 0.20), hi_range=(0.25, 0.45)),
    3: dict(lo_range=(0.20, 0.35), hi_range=(0.55, 0.75)),
}

def draw_exploit_annotations(ax, J1, J2, ps, idx, ylim):
    if idx not in exploit_ranges:
        return
    bi = find_exploit(J1, J2, ps, **exploit_ranges[idx])
    if not bi:
        return
    i, j = bi
    if J1[i] > J1[j]:
        pi_idx, pip_idx = i, j
    else:
        pi_idx, pip_idx = j, i
    for k in [i, j]:
        ax.plot(ps[k], J1[k], 'o', color=C1, ms=4, zorder=5)
        ax.plot(ps[k], J2[k], 'o', color=C2, ms=4, zorder=5)
        top_y = max(J1[k], J2[k])
        ax.plot([ps[k], ps[k]], [ylim[0], top_y],
                color='#555', ls=':', alpha=0.6, lw=DOT_LW, zorder=1)
        ax.plot([0, ps[k]], [J1[k], J1[k]],
                color=C1, ls=':', alpha=0.5, lw=DOT_LW, zorder=1)
        ax.plot([0, ps[k]], [J2[k], J2[k]],
                color=C2, ls=':', alpha=0.5, lw=DOT_LW, zorder=1)
    ax.text(ps[pi_idx], ylim[0] - (ylim[1]-ylim[0])*0.07,
            r"$\pi'$", ha='center', va='top', fontsize=PI_FONTSIZE, color='#333')
    ax.text(ps[pip_idx], ylim[0] - (ylim[1]-ylim[0])*0.07,
            r"$\pi$", ha='center', va='top', fontsize=PI_FONTSIZE, color='#333')
    print(f"Panel {idx}: exploit at theta={ps[i]:.2f}, {ps[j]:.2f}")

def draw_deriv_panel(ax, dJ2, idx, r=None, c=None, is_1x4=False):
    ax.axhline(0, color='#888', lw=0.8, alpha=0.7, zorder=0)
    ax.plot(ps, dJ1, color=C1, lw=1.5, label=r"$\nabla J_1$")
    ax.plot(ps, dJ2, color=C2, lw=1.5, label=r"$\nabla J_2$")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim(0, 1); ax.set_xticks([0, 1]); ax.set_yticks([])
    if is_1x4:
        if idx == 0: ax.set_ylabel(r"$\nabla J$")
        ax.set_xlabel(r'$\theta$', labelpad=LABELPAD)
        if idx >= 2: ax.set_ylim(BOT_YLIM)
    else:
        if c == 0: ax.set_ylabel(r"$\nabla J$")
        if r == 1: ax.set_xlabel(r'$\theta$', labelpad=LABELPAD)
        if r == 1: ax.set_ylim(BOT_YLIM)
    ylim = ax.get_ylim()
    if ylim[0] < 0 < ylim[1]:
        ax.text(1.02, 0, '0', transform=ax.get_yaxis_transform(),
                fontsize=7, va='center', ha='left', color='#888')
    colored_title(ax, title_segs[idx])

# ══════════════════════════════════════════════════════════
# VALUE FIGURE — 2x2
# ══════════════════════════════════════════════════════════
fig, axes = plt.subplots(2, 2, figsize=(TEXTWIDTH, 3.8))
fig.subplots_adjust(hspace=0.45, wspace=0.12, left=0.07, right=0.97, top=0.90, bottom=0.10)

for idx, (J2, ax) in enumerate(zip(T2_all, axes.flatten())):
    r, c = idx // 2, idx % 2
    ax.plot(ps, J1, color=C1, lw=1.5, label=r'$J_1$')
    ax.plot(ps, J2, color=C2, lw=1.5, label=r'$J_2$')
    draw_exploit_annotations(ax, J1, J2, ps, idx, YLIM)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim(0, 1); ax.set_xticks([0, 1]); ax.set_yticks([])
    ax.set_ylim(YLIM)
    if c == 0: ax.set_ylabel(r'$J(\pi)$')
    if r == 1: ax.set_xlabel(r'$\theta$', labelpad=LABELPAD)
    colored_title(ax, title_segs[idx])

axes[0,0].legend(frameon=False, loc='upper left',
                 handlelength=1.5, borderpad=0.2, labelspacing=0.25)
fig.text(0.015, 0.74, r'\textit{Unexploitable}', rotation=90, va='center',
         ha='center', color='#555')
fig.text(0.015, 0.30, r'\textit{Exploitable}', rotation=90, va='center',
         ha='center', color='#555')

plt.savefig('figures/fig1_v3.png', bbox_inches='tight', dpi=300)
plt.savefig('figures/fig1_v3.pdf', bbox_inches='tight')
print("Saved fig1_v3 (2x2)")

# ══════════════════════════════════════════════════════════
# VALUE FIGURE — 1x4
# ══════════════════════════════════════════════════════════
fig1r, axes1r = plt.subplots(1, 4, figsize=(TEXTWIDTH, 1.8))
fig1r.subplots_adjust(wspace=0.15, left=0.06, right=0.98, top=0.75, bottom=0.18)

for idx, (J2, ax) in enumerate(zip(T2_all, axes1r.flatten())):
    ax.plot(ps, J1, color=C1, lw=1.5, label=r'$J_1$')
    ax.plot(ps, J2, color=C2, lw=1.5, label=r'$J_2$')
    draw_exploit_annotations(ax, J1, J2, ps, idx, YLIM)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim(0, 1); ax.set_xticks([0, 1]); ax.set_yticks([])
    ax.set_ylim(YLIM)
    if idx == 0: ax.set_ylabel(r'$J(\pi)$')
    ax.set_xlabel(r'$\theta$', labelpad=LABELPAD)
    colored_title(ax, title_segs[idx])

axes1r[0].legend(frameon=False, loc='upper left',
                 handlelength=1.5, borderpad=0.2, labelspacing=0.25)

plt.savefig('figures/fig1_v3_1x4.png', bbox_inches='tight', dpi=300)
plt.savefig('figures/fig1_v3_1x4.pdf', bbox_inches='tight')
print("Saved fig1_v3_1x4")

# ══════════════════════════════════════════════════════════
# DERIVATIVE — shared data
# ══════════════════════════════════════════════════════════
dJ1 = np.gradient(J1, ps)
dJ2_all = [np.gradient(J2, ps) for J2 in T2_all]

bot_vals = np.concatenate([dJ1, dJ2_all[2], dJ2_all[3]])
bot_pad = (bot_vals.max() - bot_vals.min()) * 0.08
BOT_YLIM = (bot_vals.min() - bot_pad, bot_vals.max() + bot_pad)

# ══════════════════════════════════════════════════════════
# DERIVATIVE FIGURE — 2x2
# ══════════════════════════════════════════════════════════
fig2, axes2 = plt.subplots(2, 2, figsize=(TEXTWIDTH, 3.8))
fig2.subplots_adjust(hspace=0.45, wspace=0.12, left=0.07, right=0.97, top=0.90, bottom=0.10)

for idx, (dJ2, ax) in enumerate(zip(dJ2_all, axes2.flatten())):
    r, c = idx // 2, idx % 2
    draw_deriv_panel(ax, dJ2, idx, r=r, c=c)

axes2[0,0].legend(frameon=False, loc='best',
                  handlelength=1.5, borderpad=0.2, labelspacing=0.25)
fig2.text(0.015, 0.74, r'\textit{Unexploitable}', rotation=90, va='center',
          ha='center', color='#555')
fig2.text(0.015, 0.30, r'\textit{Exploitable}', rotation=90, va='center',
          ha='center', color='#555')

plt.savefig('figures/deriv_v3.png', bbox_inches='tight', dpi=300)
plt.savefig('figures/deriv_v3.pdf', bbox_inches='tight')
print("Saved deriv_v3 (2x2)")

# ══════════════════════════════════════════════════════════
# DERIVATIVE FIGURE — 1x4
# ══════════════════════════════════════════════════════════
fig2r, axes2r = plt.subplots(1, 4, figsize=(TEXTWIDTH, 1.8))
fig2r.subplots_adjust(wspace=0.15, left=0.06, right=0.98, top=0.75, bottom=0.18)

for idx, (dJ2, ax) in enumerate(zip(dJ2_all, axes2r.flatten())):
    draw_deriv_panel(ax, dJ2, idx, is_1x4=True)

axes2r[0].legend(frameon=False, loc='best',
                 handlelength=1.5, borderpad=0.2, labelspacing=0.25)

plt.savefig('figures/deriv_v3_1x4.png', bbox_inches='tight', dpi=300)
plt.savefig('figures/deriv_v3_1x4.pdf', bbox_inches='tight')
print("Saved deriv_v3_1x4")