# -*- coding: utf-8 -*-
"""CMDP_Bandit_Single_vs_MultiCritic.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1fb5usAAEO9xZhd8wJlKEMgz6ufosAVks

# CMDP Bandit: Single vs Multi-Critic (Lagrangian RL)

This notebook reproduces all figures for the minimal **bandit CMDP** with **two constraints**, comparing:
- **Single mixed critic** (trained on `r - λᵀc`), and
- **Multi-critic** (separate critics for reward and each cost, combined at readout).

We plot:
1. Expected reward  
2. Constraint violation  
3. **Unconditional** gradient alignment (moving Pearson corr between \(\hat g_t\) and \(g_t\))  
4. **Conditional** gradient alignment (only when \(|g_t|>10^{-3}\))  
5. Dual oscillation magnitude (moving std of \(|\lambda_2-\lambda_1|\) with boundary normalization)  
6–8. **Timescale ablation** heatmaps for the single-critic: Violation AUC, Late Alignment, Dual Oscillation

All plots use **boundary-normalized** smoothers to avoid edge artifacts.
"""

import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
import pickle, os

# Global matplotlib defaults (no fixed colors/styles beyond defaults)
plt.rcParams["figure.dpi"] = 150

@dataclass
class Config:
    steps: int = 5000
    seeds: int = 20
    alpha: float = 0.02      # actor lr
    beta: float = 0.02       # dual lr
    eta_single: float = 0.03 # critic lr (single)
    eta_multi: float = 0.03  # critic lr (each head)
    d1: float = 0.5
    d2: float = 0.5
    lam_max: float = 10.0
    eps_trueg: float = 1e-3  # threshold for conditional alignment
    ma_win: int = 201
    # Reward structure
    r_a1: float = 0.0
    r_a2: float = 1.0        # action a2 is high reward but collides with c1 cost

cfg = Config()

def sigmoid(x):
    return 1.0/(1.0+np.exp(-x))

def normalized_moving_average(x, w):
    kernel = np.ones(w)
    num = np.convolve(x, kernel, mode='same')
    den = np.convolve(np.ones_like(x), kernel, mode='same')
    return num/np.maximum(den, 1e-12)

def normalized_moving_average_2d(X, w):
    return np.vstack([normalized_moving_average(row, w) for row in X])

def moving_std_norm(X, w):
    out = []
    for row in X:
        kernel = np.ones(w)
        mu = np.convolve(row, kernel, mode='same') / np.convolve(np.ones_like(row), kernel, mode='same')
        sq = np.convolve(row**2, kernel, mode='same') / np.convolve(np.ones_like(row), kernel, mode='same')
        var = np.maximum(0, sq - mu**2)
        out.append(np.sqrt(var))
    return np.array(out)

def moving_corr(x, y, w):
    kernel = np.ones(w)
    Ex = np.convolve(x, kernel, mode='same')
    Ey = np.convolve(y, kernel, mode='same')
    Ex2 = np.convolve(x*x, kernel, mode='same')
    Ey2 = np.convolve(y*y, kernel, mode='same')
    Exy = np.convolve(x*y, kernel, mode='same')
    N = np.convolve(np.ones_like(x), kernel, mode='same')
    cx = Ex/N; cy = Ey/N
    vx = np.maximum(Ex2/N - cx*cx, 1e-12)
    vy = np.maximum(Ey2/N - cy*cy, 1e-12)
    cov = Exy/N - cx*cy
    return cov / np.sqrt(vx*vy)

def moving_corr_2d(X, Y, w):
    return np.vstack([moving_corr(x, y, w) for x,y in zip(X, Y)])

def moving_corr_cond(x, y, w, eps):
    # Mask on |y|>eps; boundary-normalized
    kernel = np.ones(w)
    m = (np.abs(y) > eps).astype(float)
    xm = x*m; ym = y*m
    N = np.convolve(m, kernel, mode='same')
    Ex = np.convolve(xm, kernel, mode='same');   Ey = np.convolve(ym, kernel, mode='same')
    Ex2= np.convolve(xm*x, kernel, mode='same'); Ey2= np.convolve(ym*y, kernel, mode='same')
    Exy= np.convolve(xm*y, kernel, mode='same')
    N_safe = np.maximum(N, 1e-9)
    cx = Ex/N_safe; cy = Ey/N_safe
    vx = np.maximum(Ex2/N_safe - cx*cx, 1e-12)
    vy = np.maximum(Ey2/N_safe - cy*cy, 1e-12)
    cov = Exy/N_safe - cx*cy
    corr = cov / np.sqrt(vx*vy)
    corr[N < 5] = np.nan
    return corr

def moving_corr_cond_2d(X, Y, w, eps):
    return np.vstack([moving_corr_cond(x, y, w, eps) for x,y in zip(X, Y)])

# Bandit environment
def sample_transition(rs, theta):
    pi1 = sigmoid(theta)           # P(a1)
    a = 0 if rs.random() < pi1 else 1
    r = cfg.r_a1 if a==0 else cfg.r_a2
    c1 = 0 if a==0 else 1
    c2 = 1 if a==0 else 0
    return a, r, c1, c2, pi1

def true_grad(theta, lam):
    pi1 = sigmoid(theta)
    f_a1 = cfg.r_a1 - lam[1]  # c2(a1)=1
    f_a2 = cfg.r_a2 - lam[0]  # c1(a2)=1
    return pi1*(1-pi1) * (f_a1 - f_a2)

# Single-critic runner
def run_single(seed):
    rs = np.random.default_rng(seed)
    theta = 0.0; lam = np.array([0.1, 0.1], dtype=float)
    Q = np.zeros(2, dtype=float)
    vio = np.zeros(cfg.steps); rew = np.zeros(cfg.steps)
    ghat = np.zeros(cfg.steps); gtrue = np.zeros(cfg.steps); ldiff = np.zeros(cfg.steps)
    for t in range(cfg.steps):
        a, r, c1, c2, pi1 = sample_transition(rs, theta)
        rlam = r - (lam[0]*c1 + lam[1]*c2)
        Q[a] += cfg.eta_single * (rlam - Q[a])
        score = (1-pi1) if a==0 else (-pi1)
        ghat[t] = score * Q[a]
        theta += cfg.alpha * ghat[t]
        lam[0] = min(cfg.lam_max, max(0.0, lam[0] + cfg.beta * (c1 - cfg.d1)))
        lam[1] = min(cfg.lam_max, max(0.0, lam[1] + cfg.beta * (c2 - cfg.d2)))
        Jc1 = 1 - pi1; Jc2 = pi1
        vio[t] = max(0, Jc1 - cfg.d1) + max(0, Jc2 - cfg.d2)
        rew[t] = cfg.r_a1*pi1 + cfg.r_a2*(1-pi1)
        gtrue[t] = true_grad(theta, lam)
        ldiff[t] = abs(lam[1]-lam[0])
    return vio, rew, ghat, gtrue, ldiff

# Multi-critic runner
def run_multi(seed):
    rs = np.random.default_rng(seed)
    theta = 0.0; lam = np.array([0.1, 0.1], dtype=float)
    Qr = np.zeros(2, dtype=float); Qc1 = np.zeros(2, dtype=float); Qc2 = np.zeros(2, dtype=float)
    vio = np.zeros(cfg.steps); rew = np.zeros(cfg.steps)
    ghat = np.zeros(cfg.steps); gtrue = np.zeros(cfg.steps); ldiff = np.zeros(cfg.steps)
    for t in range(cfg.steps):
        a, r, c1, c2, pi1 = sample_transition(rs, theta)
        Qr[a]  += cfg.eta_multi * (r  - Qr[a])
        Qc1[a] += cfg.eta_multi * (c1 - Qc1[a])
        Qc2[a] += cfg.eta_multi * (c2 - Qc2[a])
        Qmix_est = Qr[a] - (lam[0]*Qc1[a] + lam[1]*Qc2[a])
        score = (1-pi1) if a==0 else (-pi1)
        ghat[t] = score * Qmix_est
        theta += cfg.alpha * ghat[t]
        lam[0] = min(cfg.lam_max, max(0.0, lam[0] + cfg.beta * (c1 - cfg.d1)))
        lam[1] = min(cfg.lam_max, max(0.0, lam[1] + cfg.beta * (c2 - cfg.d2)))
        Jc1 = 1 - pi1; Jc2 = pi1
        vio[t] = max(0, Jc1 - cfg.d1) + max(0, Jc2 - cfg.d2)
        rew[t] = cfg.r_a1*pi1 + cfg.r_a2*(1-pi1)
        gtrue[t] = true_grad(theta, lam)
        ldiff[t] = abs(lam[1]-lam[0])
    return vio, rew, ghat, gtrue, ldiff

def run_suite(run_fn, seeds):
    V, R, GH, GT, LD = [], [], [], [], []
    for s in range(seeds):
        v, r, gh, gt, ld = run_fn(1000+s)
        V.append(v); R.append(r); GH.append(gh); GT.append(gt); LD.append(ld)
    return np.array(V), np.array(R), np.array(GH), np.array(GT), np.array(LD)

# Run experiments (this takes a few seconds)
sing_V, sing_R, sing_GH, sing_GT, sing_LD = run_suite(run_single, cfg.seeds)
mult_V, mult_R, mult_GH, mult_GT, mult_LD = run_suite(run_multi,  cfg.seeds)

t = np.arange(cfg.steps)
w = cfg.ma_win

def smooth_mean_std(X):
    Xs = normalized_moving_average_2d(X, w)
    return Xs.mean(0), Xs.std(0)

import matplotlib.pyplot as plt

def plot_with_band(
    x, m1, s1, m2, s2,
    ylabel,
    labels=("Mixed-critic","Dedicated-critic"),
    font_scale=1.4  # 1.0 = default; bump to make larger
):
    # Scale common font sizes
    base = 12 * font_scale
    plt.rcParams.update({
        "font.size": base,
        "axes.titlesize": base * 1.2,
        "axes.labelsize": base * 1.1,
        "legend.fontsize": base * 0.95,
        "xtick.labelsize": base,
        "ytick.labelsize": base,
    })

    plt.figure(figsize=(7, 4))
    plt.plot(x, m1, label=labels[0], linewidth=2)
    plt.plot(x, m2, label=labels[1], linewidth=2)
    plt.fill_between(x, m1 - s1, m1 + s1, alpha=0.2)
    plt.fill_between(x, m2 - s2, m2 + s2, alpha=0.2)

    plt.xlabel("Steps")
    plt.ylabel(ylabel)
    # plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

m1,s1 = smooth_mean_std(sing_R)
m2,s2 = smooth_mean_std(mult_R)
plot_with_band(t, m1,s1, m2,s2, ylabel="Reward",
)

m1,s1 = smooth_mean_std(sing_V)
m2,s2 = smooth_mean_std(mult_V)
plot_with_band(t, m1,s1, m2,s2, ylabel="Constraint violation (smoothed)")

sg_corr = moving_corr_2d(sing_GH, sing_GT, w)
mc_corr = moving_corr_2d(mult_GH, mult_GT, w)

m1, s1 = sg_corr.mean(0), sg_corr.std(0)
m2, s2 = mc_corr.mean(0), mc_corr.std(0)
plot_with_band(t, m1,s1, m2,s2, ylabel="Moving corr(ĝ, g)")

sg_corrc = moving_corr_cond_2d(sing_GH, sing_GT, w, cfg.eps_trueg)
mc_corrc = moving_corr_cond_2d(mult_GH, mult_GT, w, cfg.eps_trueg)

m1, s1 = np.nanmean(sg_corrc,0), np.nanstd(sg_corrc,0)
m2, s2 = np.nanmean(mc_corrc,0), np.nanstd(mc_corrc,0)
plot_with_band(t[:-100], m1[:-100],s1[:-100], m2[:-100],s2[:-100],
               ylabel=r"Moving corr$(\hat g, g)$")

sg_osc = moving_std_norm(sing_LD, w)
mc_osc = moving_std_norm(mult_LD, w)

m1, s1 = sg_osc.mean(0), sg_osc.std(0)
m2, s2 = mc_osc.mean(0), mc_osc.std(0)
plot_with_band(t[:-100], m1[:-100],s1[:-100], m2[:-100],s2[:-100], ylabel=r"Moving std of $|\lambda_2-\lambda_1|$")

eta_grid = [0.01, 0.03, 0.1]
beta_grid = [0.005, 0.02, 0.08]

def violation_auc(V): return V.mean(0).sum()/V.shape[1]

def cond_align_final(GH, GT):
    ma = moving_corr_cond_2d(GH, GT, w, cfg.eps_trueg)
    last = ma[:, -500:]
    return np.nanmean(last)

def osc_mag_last(LD):
    std_ma = moving_std_norm(LD, w)
    last = std_ma[:, -500:]
    return np.mean(last)

heat_single_auc = np.zeros((len(beta_grid), len(eta_grid)))
heat_single_align = np.zeros((len(beta_grid), len(eta_grid)))
heat_single_osc = np.zeros((len(beta_grid), len(eta_grid)))

for i,beta in enumerate(beta_grid):
    for j,eta in enumerate(eta_grid):
        V, GH, GT, LD = [], [], [], []
        for s in range(8):
            rs = np.random.default_rng(6000+i*10+j*3+s)
            # custom single-critic run
            theta = 0.0; lam = np.array([0.1,0.1], float); Q = np.zeros(2)
            vio = np.zeros(cfg.steps); ghat = np.zeros(cfg.steps); gtrue = np.zeros(cfg.steps); ldiff = np.zeros(cfg.steps)
            for tstep in range(cfg.steps):
                pi1 = sigmoid(theta)
                a = 0 if rs.random() < pi1 else 1
                r = cfg.r_a1 if a==0 else cfg.r_a2
                c1 = 0 if a==0 else 1
                c2 = 1 if a==0 else 0
                rlam = r - (lam[0]*c1 + lam[1]*c2)
                Q[a] += eta * (rlam - Q[a])
                score = (1-pi1) if a==0 else (-pi1)
                ghat[tstep] = score * Q[a]
                theta += cfg.alpha * ghat[tstep]
                lam[0] = min(cfg.lam_max, max(0.0, lam[0] + beta * (c1 - cfg.d1)))
                lam[1] = min(cfg.lam_max, max(0.0, lam[1] + beta * (c2 - cfg.d2)))
                Jc1 = 1 - pi1; Jc2 = pi1
                vio[tstep] = max(0, Jc1 - cfg.d1) + max(0, Jc2 - cfg.d2)
                # true gradient for this (theta, lam)
                gtrue[tstep] = true_grad(theta, lam)
                ldiff[tstep] = abs(lam[1]-lam[0])
            V.append(vio); GH.append(ghat); GT.append(gtrue); LD.append(ldiff)
        V = np.array(V); GH = np.array(GH); GT = np.array(GT); LD = np.array(LD)
        heat_single_auc[i,j] = violation_auc(V)
        heat_single_align[i,j] = cond_align_final(GH, GT)
        heat_single_osc[i,j] = osc_mag_last(LD)

def plot_heat(H, title, cbarlabel):
    plt.figure(figsize=(6,4.2))
    im = plt.imshow(H, origin='lower', aspect='auto')
    plt.xticks(range(len(eta_grid)), [str(x) for x in eta_grid])
    plt.yticks(range(len(beta_grid)), [str(x) for x in beta_grid])
    plt.xlabel("Critic lr η"); plt.ylabel("Dual lr β")
    plt.title(title)
    cbar = plt.colorbar(im); cbar.set_label(cbarlabel)
    # annotate
    for i in range(H.shape[0]):
        for j in range(H.shape[1]):
            plt.text(j, i, f"{H[i,j]:.2f}", ha='center', va='center', fontsize=9, color='white')
    plt.tight_layout(); plt.show()

plot_heat(heat_single_auc,   "Single-critic: Violation AUC vs. (η, β)", "lower = better")
plot_heat(heat_single_align, "Single-critic: Conditional Alignment (late) vs. (η, β)", "higher = better")
plot_heat(heat_single_osc,   "Single-critic: Dual Oscillation (late) vs. (η, β)", "moving std of |λ2-λ1|")

