{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d4bc4192-f6a5-407d-a9c5-4109b310964f",
   "metadata": {},
   "source": [
    "# helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6894ac8e-634a-4b8e-97fc-0c62994d3956",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# ============================================================\n",
    "# Minimal soft Bellman + policy\n",
    "# ============================================================\n",
    "\n",
    "# --- replace your compute_v_star_from_r with this formatted version (same math) ---\n",
    "def compute_v_star_from_r(P, r, gamma=0.99, tol=1e-12, max_iter=50_000, eps=1e-12):\n",
    "    \"\"\"\n",
    "    Solve V* for: V(s) = logsum_a exp{ r(s,a) + γ * E[V(S')|s,a] }.\n",
    "    Returns (V*, v*), where v*(s,a) = E[V*(S')|s,a].\n",
    "    \"\"\"\n",
    "    P = np.asarray(P, float); r = np.asarray(r, float)\n",
    "    S, A, Sp = P.shape\n",
    "    assert Sp == S and r.shape == (S, A)\n",
    "\n",
    "    # ensure row-stochastic per (s,a)\n",
    "    P = P / np.maximum(P.sum(axis=2, keepdims=True), eps)\n",
    "\n",
    "    # warm start with logsumexp over immediate rewards\n",
    "    m0 = r.max(axis=1, keepdims=True)\n",
    "    V = (m0 + np.log(np.exp(r - m0).sum(axis=1, keepdims=True) + eps)).ravel()\n",
    "\n",
    "    for _ in range(max_iter):\n",
    "        v = np.einsum('saq,q->sa', P, V)   # E[V' | s,a]\n",
    "        U = r + gamma * v\n",
    "        m = U.max(axis=1, keepdims=True)\n",
    "        V_new = (m + np.log(np.exp(U - m).sum(axis=1, keepdims=True) + eps)).ravel()\n",
    "        if np.max(np.abs(V_new - V)) < tol:\n",
    "            V = V_new\n",
    "            break\n",
    "        V = V_new\n",
    "    else:\n",
    "        raise RuntimeError(\"compute_v_star_from_r: did not converge.\")\n",
    "\n",
    "    v_star = np.einsum('saq,q->sa', P, V)   # same correction here\n",
    "    return V, v_star\n",
    "\n",
    "\n",
    "# --- redefine policy_from_reward to use compute_v_star_from_r (stop using VI) ---\n",
    "def policy_from_reward(P, r, gamma=0.97, eps=1e-12):\n",
    "    _, v = compute_v_star_from_r(P, r, gamma=gamma, eps=eps)\n",
    "    U = r + gamma * v\n",
    "    m = U.max(axis=1, keepdims=True)\n",
    "    E = np.exp(U - m)\n",
    "    Z = np.maximum(E.sum(axis=1, keepdims=True), eps)\n",
    "    return E / Z\n",
    "\n",
    "\n",
    " \n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cb968670-9ab0-4e12-9d4e-c181340379c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "\n",
    "# ============================================================\n",
    "# Richer features (state + action, with landmarks & distances)\n",
    "# ============================================================\n",
    "\n",
    "def make_phi_hard(nrow, ncol, A=5, masks=None, n_rbf=6, rbf_scale=0.7):\n",
    "    \"\"\"\n",
    "    Returns phi_sa(s,a) with:\n",
    "      - bias, x,y, x*y, x^2,y^2\n",
    "      - distances to K landmarks (RBF)\n",
    "      - indicators: ice, pit, portal (if available), proximity-to-wall (distance 1)\n",
    "      - action one-hot and (x,y)⊗action interactions\n",
    "    \"\"\"\n",
    "    # landmarks spaced on a coarse grid\n",
    "    ys = np.linspace(0, nrow-1, int(np.ceil(np.sqrt(n_rbf))))\n",
    "    xs = np.linspace(0, ncol-1, int(np.ceil(np.sqrt(n_rbf))))\n",
    "    L = []\n",
    "    for yy in ys:\n",
    "        for xx in xs:\n",
    "            L.append((yy, xx))\n",
    "            if len(L) >= n_rbf:\n",
    "                break\n",
    "        if len(L) >= n_rbf:\n",
    "            break\n",
    "    L = np.array(L, float)\n",
    "    inv_two_sigma2 = 1.0 / (2.0 * (rbf_scale**2))\n",
    "\n",
    "    wall = masks.get(\"wall\") if masks else np.zeros((nrow, ncol), bool)\n",
    "    ice  = masks.get(\"ice\")  if masks else np.zeros((nrow, ncol), bool)\n",
    "    pit  = masks.get(\"pit\")  if masks else np.zeros((nrow, ncol), bool)\n",
    "    portal_pairs = masks.get(\"portal_pairs\") if masks else []\n",
    "    portal_mask = np.zeros((nrow, ncol), bool)\n",
    "    for (y1, x1), (y2, x2) in portal_pairs:\n",
    "        portal_mask[int(y1), int(x1)] = True\n",
    "        portal_mask[int(y2), int(x2)] = True\n",
    "\n",
    "    I = np.eye(A)\n",
    "\n",
    "    def near_wall(y, x):\n",
    "        # any of 4-neighbors a wall?\n",
    "        for dy, dx in [(-1,0),(1,0),(0,-1),(0,1)]:\n",
    "            yy, xx = y+dy, x+dx\n",
    "            if 0 <= yy < nrow and 0 <= xx < ncol and wall[yy, xx]:\n",
    "                return 1.0\n",
    "        return 0.0\n",
    "\n",
    "    def phi_sa(s, a):\n",
    "        y, x = divmod(s, ncol)\n",
    "        yf, xf = float(y), float(x)\n",
    "        gy = yf / max(1, nrow-1)\n",
    "        gx = xf / max(1, ncol-1)\n",
    "\n",
    "        base = [\n",
    "            1.0, gx, gy, gx*gy, gx*gx, gy*gy\n",
    "        ]\n",
    "\n",
    "        # RBFs\n",
    "        rbf = []\n",
    "        for (ly, lx) in L:\n",
    "            d2 = (yf - ly)**2 + (xf - lx)**2\n",
    "            rbf.append(np.exp(-d2 * inv_two_sigma2))\n",
    "\n",
    "        ind = [\n",
    "            float(ice[y, x]),\n",
    "            float(pit[y, x]),\n",
    "            float(portal_mask[y, x]),\n",
    "            near_wall(y, x),\n",
    "        ]\n",
    "\n",
    "        a_onehot = I[a]\n",
    "\n",
    "        # simple interactions (helps disambiguate action-dependent rewards)\n",
    "        ax = gx * a_onehot\n",
    "        ay = gy * a_onehot\n",
    "\n",
    "        return np.concatenate([base, rbf, ind, a_onehot, ax, ay]).astype(np.float64)\n",
    "\n",
    "    # report dimension for convenience\n",
    "    d_base = 6\n",
    "    d_rbf  = len(L)\n",
    "    d_ind  = 4\n",
    "    d_act  = A\n",
    "    d_axay = 2*A\n",
    "    phi_dim = d_base + d_rbf + d_ind + d_act + d_axay\n",
    "    phi_sa.dim = phi_dim\n",
    "    return phi_sa\n",
    "\n",
    "\n",
    "# ============================================================\n",
    "# Hard reward (mixed state + action structure, sparse hotspots)\n",
    "# ============================================================\n",
    "\n",
    "def build_true_reward_hard(P, shape, masks, w_scale=1.0, goal=(0, -1), A=5, action_bias=(0.05, -0.03, 0.0, 0.02, 0.0)):\n",
    "    \"\"\"\n",
    "    - Attract to a goal near top-right (by default).\n",
    "    - Penalize pits strongly and wall-adjacent zones mildly.\n",
    "    - Reward portals slightly (but behavior depends on dynamics).\n",
    "    - Mild preference for certain actions globally.\n",
    "    \"\"\"\n",
    "    S, A_check, _ = P.shape\n",
    "    assert A_check == A\n",
    "    nrow, ncol = shape\n",
    "\n",
    "    phi = make_phi_hard(nrow, ncol, A=A, masks=masks, n_rbf=6, rbf_scale=0.9)\n",
    "    d = phi.dim\n",
    "\n",
    "    # weights\n",
    "    w = np.zeros(d, float)\n",
    "\n",
    "    # decode feature layout to set interpretable weights:\n",
    "    # [base(6) | rbf(6) | ind(4) | act(5) | ax(5) | ay(5)]   for default n_rbf=6, A=5\n",
    "\n",
    "    # base terms (encourage moving to smaller y, larger x):\n",
    "    w[0] = 0.0                     # bias\n",
    "    w[1] = +0.8 * w_scale          # +gx (rightwards)\n",
    "    w[2] = -0.9 * w_scale          # -gy (upwards)\n",
    "    w[3] = +0.2 * w_scale          # x*y\n",
    "    w[4] = -0.1 * w_scale          # x^2 slight penalty\n",
    "    w[5] = -0.1 * w_scale          # y^2 slight penalty\n",
    "\n",
    "    # RBF peaks around a \"goal\" region near (goal_y, goal_x)\n",
    "    gy = 0 if goal[0] is None else goal[0]\n",
    "    gx = (ncol-1) if goal[1] == -1 else (goal[1] or ncol-1)\n",
    "    # make the closest RBF(s) heavier\n",
    "    # (RBFs are in indices 6..11)\n",
    "    rbf_idx = slice(6, 12)\n",
    "    w[rbf_idx] = 0.3 * w_scale\n",
    "    # amplify the closest center:\n",
    "    # (not bothering to compute exact nearest; just bump all, dynamics will sort it)\n",
    "\n",
    "    # indicators: [ice, pit, portal, near_wall] at 12..15\n",
    "    w[12] = -0.10 * w_scale         # ice mildly bad\n",
    "    w[13] = -2.50 * w_scale         # pits very bad\n",
    "    w[14] = +0.10 * w_scale         # portals slightly good (but stochastic)\n",
    "    w[15] = -0.20 * w_scale         # near wall mildly bad (discourages hugging)\n",
    "\n",
    "    # global action bias at 16..20\n",
    "    w[16:16+A] = np.array(action_bias) * w_scale\n",
    "\n",
    "    # interactions x⊗a at 21..25 and y⊗a at 26..30\n",
    "    # encourage moving right (a=1) more when x is small; discourage left (a=3) when x large, etc.\n",
    "    ax = slice(16+A, 16+A+A)\n",
    "    ay = slice(16+A+A, 16+A+2*A)\n",
    "    w[ax] = np.array([+0.00, +0.15, +0.00, -0.10, +0.00]) * w_scale\n",
    "    w[ay] = np.array([-0.10, -0.10, +0.00, +0.00, +0.00]) * w_scale\n",
    "\n",
    "    # build R(s,a)\n",
    "    R = np.zeros((S, A), float)\n",
    "    for s in range(S):\n",
    "        for a in range(A):\n",
    "            R[s, a] = phi(s, a) @ w\n",
    "            # extra giant penalty if landing in pit (not just standing on one)\n",
    "            # (keeps objective sharp under dynamics)\n",
    "    return R\n",
    "\n",
    " \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9f58f67-ceb8-4135-99e0-129a0baab5e9",
   "metadata": {},
   "source": [
    "# linear maxent IRL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9b1b0d22-968d-4f2f-922c-b8b25ed4fde6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch as th\n",
    "\n",
    "# =========================\n",
    "# Soft VI (temperature = 1)\n",
    "# =========================\n",
    " \n",
    "import numpy as np\n",
    "import torch as th\n",
    "\n",
    "class DiffSoftVI:\n",
    "    def __init__(self, P_np, gamma=0.99, iters=50_000, tol=1e-12,\n",
    "                 dtype=th.float64, device=\"cpu\", eps=1e-12, renorm_P=True):\n",
    "        P = th.from_numpy(P_np).to(device=device, dtype=dtype).contiguous()\n",
    "        S, A, Sp = P.shape\n",
    "        assert Sp == S\n",
    "        if renorm_P:\n",
    "            row_sums = P.sum(dim=2, keepdim=True).clamp_min(eps)\n",
    "            P = P / row_sums\n",
    "        self.P_flat = P.view(S*A, S)\n",
    "        self.S, self.A = S, A\n",
    "        self.gamma, self.iters, self.tol = gamma, iters, tol\n",
    "        self.device, self.dtype, self.eps = device, dtype, eps\n",
    "\n",
    "    def __call__(self, r):\n",
    "        S, A, g, eps = self.S, self.A, self.gamma, self.eps\n",
    "        assert r.shape == (S, A)\n",
    "\n",
    "        # warm start\n",
    "        m0 = r.max(dim=1, keepdim=True).values\n",
    "        V  = (m0 + th.log(th.exp(r - m0).sum(dim=1, keepdim=True) + eps)).squeeze(1)\n",
    "\n",
    "        vi_resid = float(\"inf\")\n",
    "        vi_used  = 0\n",
    "        for k in range(self.iters):\n",
    "            EV = (self.P_flat @ V).view(S, A)\n",
    "            U  = r + g * EV\n",
    "            m  = U.max(dim=1, keepdim=True).values\n",
    "            V_new = (m + th.log(th.exp(U - m).sum(dim=1, keepdim=True) + eps)).squeeze(1)\n",
    "\n",
    "            vi_resid = th.max(th.abs(V_new - V)).item()\n",
    "            vi_used  = k + 1\n",
    "            V = V_new\n",
    "            if vi_resid < self.tol:\n",
    "                break\n",
    "        else:\n",
    "            raise RuntimeError(\"DiffSoftVI did not converge.\")\n",
    "\n",
    "        # final Q, V, pi\n",
    "        EV = (self.P_flat @ V).view(S, A)\n",
    "        Q  = r + g * EV\n",
    "        m  = Q.max(dim=1, keepdim=True).values\n",
    "        E  = th.exp(Q - m)\n",
    "        Z  = (E.sum(dim=1, keepdim=True)).clamp_min(eps)\n",
    "        pi = E / Z\n",
    "        V  = (m + th.log(Z)).squeeze(1)\n",
    "        return Q, V, pi, vi_resid, vi_used\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def fit_maxent_irl(\n",
    "    P, S_iid, A_iid, feature_fn, nrow, ncol, A,\n",
    "    gamma=0.97, lr=0.03, steps=600, iters=1000,\n",
    "    dtype=th.float64, device=\"cpu\",\n",
    "    weight_decay=0.0,          # keep off for calibrated logits\n",
    "    check_every=25,\n",
    "    tol_vi=1e-10,              # VI residual threshold\n",
    "    tol_mu_L2=5e-5,            # moment-matching threshold (L2)\n",
    "    plateau_patience=6,        # early stop if no LL improvement\n",
    "):\n",
    "    S = P.shape[0]\n",
    "\n",
    "    # Accept either a factory or a callable\n",
    "    try:\n",
    "        _ = feature_fn(0, 0)\n",
    "    except TypeError:\n",
    "        feature_fn = feature_fn(nrow, ncol, A)\n",
    "\n",
    "    D = feature_fn(0, 0).shape[0]\n",
    "    phi_np = np.zeros((S, A, D), dtype=np.float64)\n",
    "    for s in range(S):\n",
    "        for a in range(A):\n",
    "            phi_np[s, a] = feature_fn(s, a)\n",
    "\n",
    "    phi = th.from_numpy(phi_np).to(device=device, dtype=dtype)\n",
    "\n",
    "    S_iid_t = th.from_numpy(S_iid).to(device=device, dtype=th.long)\n",
    "    A_iid_t = th.from_numpy(A_iid).to(device=device, dtype=th.long)\n",
    "\n",
    "    # empirical (discounted-iid) feature mean from demos\n",
    "    mu_emp = phi_np[S_iid, A_iid].mean(axis=0)  # [D]\n",
    "\n",
    "    # empirical state weights (data distribution over states)\n",
    "    state_counts = np.bincount(S_iid, minlength=S).astype(np.float64)\n",
    "    w_emp = state_counts / max(1, state_counts.sum())      # [S]\n",
    "\n",
    "    vi = DiffSoftVI(P, gamma=gamma, iters=iters, tol=tol_vi, dtype=dtype, device=device)\n",
    "\n",
    "    theta = th.zeros(D, device=device, dtype=dtype, requires_grad=True)\n",
    "    opt   = th.optim.Adam([{\"params\": [theta], \"weight_decay\": weight_decay}], lr=lr, amsgrad=True)\n",
    "    sched = th.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=2)\n",
    "\n",
    "    history, best_ll, no_improve = [], -1e9, 0\n",
    "\n",
    "    for t in range(steps):\n",
    "        opt.zero_grad()\n",
    "\n",
    "        # reward model: r(s,a) = φ(s,a)·θ\n",
    "        r = (phi @ theta)  # [S,A]\n",
    "\n",
    "        Q, V, pi, vi_resid, vi_used = vi(r)  # soft VI\n",
    "\n",
    "        # NLL over demos\n",
    "        logp = th.log(th.clamp(pi, min=1e-12))\n",
    "        ll = logp[S_iid_t, A_iid_t].mean()\n",
    "        loss = -ll\n",
    "\n",
    "        loss.backward()\n",
    "        th.nn.utils.clip_grad_norm_([theta], max_norm=5.0)\n",
    "        opt.step()\n",
    "        history.append(ll.item())\n",
    "        sched.step(-ll.item())\n",
    "\n",
    "        if (t % check_every == 0) or (t == steps-1):\n",
    "            with th.no_grad():\n",
    "                pi_np  = pi.detach().cpu().numpy()\n",
    "                mu_hat = (w_emp[:,None,None] * pi_np[:,:,None] * phi_np).sum(axis=(0,1))  # [D]\n",
    "\n",
    "                mu_gap = mu_emp - mu_hat\n",
    "                mu_L2  = float(np.linalg.norm(mu_gap))\n",
    "                mu_Linf= float(np.max(np.abs(mu_gap)))\n",
    "                med_H  = float(th.median((-(pi * logp).sum(dim=1))).item())\n",
    "\n",
    "                print(f\"step {t:4d}  ll: {ll.item():.6f}  H(pi): {med_H:.3f}  \"\n",
    "                      f\"VIres:{vi_resid:.2e}({vi_used})  ||μ_emp-μ̂||_2:{mu_L2:.4f}  _inf:{mu_Linf:.4f}\")\n",
    "\n",
    "                if ll.item() > best_ll + 1e-4:\n",
    "                    best_ll, no_improve = ll.item(), 0\n",
    "                else:\n",
    "                    no_improve += 1\n",
    "\n",
    "                if (vi_resid < tol_vi) and (mu_L2 < tol_mu_L2):\n",
    "                    print(f\"[EARLY STOP] VI residual < {tol_vi} and moment gap < {tol_mu_L2}\")\n",
    "                    break\n",
    "                if no_improve >= plateau_patience:\n",
    "                    print(f\"[EARLY STOP] No LL improvement for {plateau_patience} checks.\")\n",
    "                    break\n",
    "\n",
    "    return theta.detach().cpu().numpy(), np.array(history), phi_np\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee53d843-545f-4a63-92ae-8a8259bee391",
   "metadata": {},
   "source": [
    "# our method: classify and solve fixed point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2da3334e-d807-48ca-bcfb-ed526ec26760",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LogisticRegressionCV\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "\n",
    "# --- tiny helper so we don't need a gym env just for n_states/n_actions\n",
    "class _Space:\n",
    "    def __init__(self, n): self.n = int(n)\n",
    "\n",
    "class Demos:\n",
    "    def __init__(self, obs, acts):\n",
    "        self.obs = np.asarray(obs, dtype=int)\n",
    "        self.acts = np.asarray(acts, dtype=int)\n",
    "\n",
    "def train_tabular_logreg_with_logreward(\n",
    "    demos,\n",
    "    observation_space,\n",
    "    action_space,\n",
    "    *,\n",
    "    temperature: float = 1.0,\n",
    "    C_grid=None,\n",
    "    max_iter: int = 5000,\n",
    "    solver: str = \"lbfgs\",\n",
    "    random_state: int = 0,\n",
    "    cv_splits: int = 5,\n",
    "):\n",
    "    n_states = observation_space.n\n",
    "    n_actions = action_space.n\n",
    "    inv_tau = 1.0 / float(temperature)\n",
    "\n",
    "    X = np.eye(n_states, dtype=np.float32)[demos.obs]  # [N, S]\n",
    "    y = np.asarray(demos.acts, dtype=int)\n",
    "\n",
    "    if C_grid is None:\n",
    "        C_grid = np.logspace(-6, 6, 13)\n",
    "\n",
    "    clf = LogisticRegressionCV(\n",
    "        Cs=list(C_grid),\n",
    "        cv=StratifiedKFold(n_splits=cv_splits, shuffle=True, random_state=random_state),\n",
    "        penalty=\"l2\",\n",
    "        solver=solver,\n",
    "        # multi_class deprecated in sklearn>=1.5; omit to silence warning\n",
    "        fit_intercept=True,\n",
    "        max_iter=max_iter,\n",
    "        scoring=\"neg_log_loss\",\n",
    "        refit=True,\n",
    "        n_jobs=None,\n",
    "        random_state=random_state,\n",
    "    )\n",
    "    clf.fit(X, y)\n",
    "\n",
    "    eye_S = np.eye(n_states, dtype=np.float32)\n",
    "\n",
    "    def reward_fn(s, a):\n",
    "        s = np.atleast_1d(s).astype(int)\n",
    "        a = np.atleast_1d(a).astype(int)\n",
    "        log_probs = clf.predict_log_proba(eye_S[s])\n",
    "        return log_probs[np.arange(len(s)), a] * inv_tau\n",
    "\n",
    "    def act_fn(s):\n",
    "        s = np.atleast_1d(s).astype(int)\n",
    "        return clf.predict(eye_S[s])\n",
    "\n",
    "    return reward_fn, act_fn, clf\n",
    "\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1c1a0c4e-21e5-485a-9b97-05d0f3cdcda5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# ---------- helpers ----------\n",
    "def _row_normalize_probs(pi, eps=1e-12):\n",
    "    pi = np.clip(pi, 0.0, 1.0)\n",
    "    rs = pi.sum(axis=1, keepdims=True)\n",
    "    rs[rs <= eps] = 1.0\n",
    "    return pi / rs\n",
    "\n",
    "def policy_metrics(pi_true, pi_hat, eps=1e-12):\n",
    "    pi_true = _row_normalize_probs(np.asarray(pi_true, float), eps)\n",
    "    pi_hat  = _row_normalize_probs(np.asarray(pi_hat,  float), eps)\n",
    "    kl  = (pi_true * (np.log(pi_true + eps) - np.log(pi_hat + eps))).sum(axis=1).mean()\n",
    "    tv  = 0.5 * np.abs(pi_true - pi_hat).sum(axis=1).mean()\n",
    "    top = (np.argmax(pi_true, axis=1) == np.argmax(pi_hat, axis=1)).mean()\n",
    "    return float(kl), float(tv), float(top)\n",
    "\n",
    "def _policy_softmax_from_rvc(r_star, v_star, gamma, eps=1e-12):\n",
    "    \"\"\"\n",
    "    Reconstruct π via u(s,a)=r*(s,a)+γ v*(s,a),  π(a|s) ∝ exp(u).\n",
    "    \"\"\"\n",
    "    u = r_star + gamma * v_star\n",
    "    u = u - u.max(axis=1, keepdims=True)   # stabilize\n",
    "    expu = np.exp(u)\n",
    "    z = expu.sum(axis=1, keepdims=True)\n",
    "    z[z <= eps] = 1.0\n",
    "    return expu / z\n",
    "\n",
    "# ---------- normalized maxent “gauge” solve ----------\n",
    "def solve_normalized_maxent(P: np.ndarray, pi: np.ndarray, gamma: float, eps: float = 1e-12):\n",
    "    S, A, Sp = P.shape\n",
    "    assert Sp == S and pi.shape == (S, A)\n",
    "    assert 0.0 <= gamma < 1.0\n",
    "\n",
    "    # normalize inputs\n",
    "    P = np.clip(P, 0.0, 1.0)\n",
    "    P = P / np.maximum(P.sum(axis=2, keepdims=True), eps)\n",
    "    pi = _row_normalize_probs(pi, eps)\n",
    "\n",
    "    # u(s,a) = log pi(a|s), \\bar u_mu(s) = E_{a~pi}[u(s,a)]\n",
    "    u = np.log(np.clip(pi, eps, 1.0))           # [S,A]\n",
    "    bar_u_mu = (pi * u).sum(axis=1)             # [S]\n",
    "\n",
    "    # P_mu[s,s'] = Σ_a pi(a|s) P[s,a,s']\n",
    "    P_mu = (pi[:, :, None] * P).sum(axis=1)     # [S,S]\n",
    "\n",
    "    # (I - γ P_μ) c = -\\bar u_μ\n",
    "    A_mat = np.eye(S) - gamma * P_mu\n",
    "    b_vec = -bar_u_mu\n",
    "    try:\n",
    "        c_star = np.linalg.solve(A_mat, b_vec)\n",
    "    except np.linalg.LinAlgError:\n",
    "        c_star, *_ = np.linalg.lstsq(A_mat, b_vec, rcond=None)\n",
    "\n",
    "    # v*(s,a) = E[c*(S') | s,a], r*(s,a) = u + c(s) - γ v*(s,a)\n",
    "    v_star = np.einsum('saq,q->sa', P, c_star)  # E[c*(S') | s,a]\n",
    "\n",
    "    r_star = u + c_star[:, None] - gamma * v_star\n",
    "\n",
    "    resid = np.linalg.norm(c_star - gamma * (P_mu @ c_star) + bar_u_mu, ord=np.inf)\n",
    "    return c_star, r_star, v_star, {\"norm_equation_inf_residual\": resid}\n",
    "\n",
    " \n",
    "\n",
    "\n",
    "  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c877e0d-f8dc-4059-9d1d-b1f0d5779fca",
   "metadata": {},
   "source": [
    "# code for evaluating methods in sims"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "82fef4bf-426c-402a-84bf-3d2ad6e662c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# ================= minimal helpers =================\n",
    "\n",
    "def _mse(a, b):\n",
    "    a = np.asarray(a); b = np.asarray(b)\n",
    "    m = np.isfinite(a) & np.isfinite(b)\n",
    "    if not np.any(m): return np.nan\n",
    "    d = (a - b)[m]\n",
    "    return float(np.mean(d**2))\n",
    "\n",
    "def _rmse_norm(y_true, y_pred, denom=\"std\", eps=1e-12):\n",
    "    \"\"\"\n",
    "    Normalized RMSE: RMSE(y, yhat) / scale(y). scale = std or rms.\n",
    "    \"\"\"\n",
    "    y = np.asarray(y_true).ravel()\n",
    "    yhat = np.asarray(y_pred).ravel()\n",
    "    m = np.isfinite(y) & np.isfinite(yhat)\n",
    "    if not np.any(m): return np.nan\n",
    "    y, yhat = y[m], yhat[m]\n",
    "    rmse = np.sqrt(np.mean((y - yhat)**2))\n",
    "    if denom == \"std\":\n",
    "        scale = np.std(y) + eps\n",
    "    else:  # \"rms\"\n",
    "        scale = np.sqrt(np.mean(y**2)) + eps\n",
    "    return float(rmse / scale)\n",
    "\n",
    "def _safe_corr(a, b):\n",
    "    a = np.asarray(a).ravel(); b = np.asarray(b).ravel()\n",
    "    m = np.isfinite(a) & np.isfinite(b)\n",
    "    if m.sum() < 2: return np.nan\n",
    "    a, b = a[m], b[m]\n",
    "    if a.var() <= 0 or b.var() <= 0: return np.nan\n",
    "    return float(np.corrcoef(a, b)[0, 1])\n",
    "\n",
    "def _weighted_mse_vec(x, y, w, eps=1e-12):\n",
    "    x = np.asarray(x, float).ravel(); y = np.asarray(y, float).ravel()\n",
    "    w = np.asarray(w, float).ravel()\n",
    "    w = np.clip(w, 0, None); ws = w.sum()\n",
    "    if ws <= eps: return np.nan\n",
    "    w = w / ws\n",
    "    d = x - y\n",
    "    return float((w * d * d).sum())\n",
    "\n",
    "def _weighted_rmse_norm(y_true, y_pred, w, denom=\"std\", eps=1e-12):\n",
    "    \"\"\"\n",
    "    Weighted normalized RMSE using weights w on y (and errors).\n",
    "    denom 'std': uses weighted std(y); 'rms': uses weighted RMS(y).\n",
    "    \"\"\"\n",
    "    y = np.asarray(y_true, float).ravel()\n",
    "    yhat = np.asarray(y_pred, float).ravel()\n",
    "    w = np.asarray(w, float).ravel()\n",
    "    w = np.clip(w, 0, None); ws = w.sum()\n",
    "    if ws <= eps: return np.nan\n",
    "    w = w / ws\n",
    "    rmse = np.sqrt((w * (y - yhat)**2).sum())\n",
    "    if denom == \"std\":\n",
    "        ym = (w * y).sum()\n",
    "        scale = np.sqrt(max((w * (y - ym)**2).sum(), eps))\n",
    "    else:  # \"rms\"\n",
    "        scale = np.sqrt(max((w * (y**2)).sum(), eps))\n",
    "    return float(rmse / (scale + eps))\n",
    "\n",
    "def _weighted_corr_vec(x, y, w, eps=1e-12):\n",
    "    x = np.asarray(x, float).ravel(); y = np.asarray(y, float).ravel()\n",
    "    w = np.asarray(w, float).ravel()\n",
    "    w = np.clip(w, 0, None); ws = w.sum()\n",
    "    if ws <= eps: return np.nan\n",
    "    w = w / ws\n",
    "    xm = (w * x).sum(); ym = (w * y).sum()\n",
    "    xc = x - xm; yc = y - ym\n",
    "    cov = (w * xc * yc).sum()\n",
    "    sx  = np.sqrt(max((w * xc * xc).sum(), eps))\n",
    "    sy  = np.sqrt(max((w * yc * yc).sum(), eps))\n",
    "    return float(cov / (sx * sy))\n",
    "\n",
    "def _fit_affine_1d(x, y, w=None, eps=1e-12):\n",
    "    \"\"\"\n",
    "    Fit alpha,beta minimizing ||alpha*x + beta - y||_w^2. Returns alpha, beta, y_hat.\n",
    "    \"\"\"\n",
    "    x = np.asarray(x, float).ravel(); y = np.asarray(y, float).ravel()\n",
    "    if w is None:\n",
    "        X = np.stack([x, np.ones_like(x)], axis=1)\n",
    "        theta, *_ = np.linalg.lstsq(X, y, rcond=None)\n",
    "    else:\n",
    "        w = np.asarray(w, float).ravel()\n",
    "        w = np.clip(w, 0, None)\n",
    "        ws = w.sum()\n",
    "        if ws <= eps:\n",
    "            return 1.0, 0.0, x  # degenerate\n",
    "        Wsqrt = np.sqrt(w / ws)[:, None]\n",
    "        X = np.stack([x, np.ones_like(x)], axis=1)\n",
    "        Xw = X * Wsqrt\n",
    "        yw = y * Wsqrt[:, 0]\n",
    "        theta, *_ = np.linalg.lstsq(Xw, yw, rcond=None)\n",
    "    alpha, beta = float(theta[0]), float(theta[1])\n",
    "    yhat = alpha * x + beta\n",
    "    return alpha, beta, yhat\n",
    "\n",
    " \n",
    "\n",
    "\n",
    "# Gauge-invariant Q-differences: Q(s,a)-Q(s,a0)\n",
    "def _qdiff_from_r(P, r, gamma=0.99, a0=0):\n",
    "    _, v_star = compute_v_star_from_r(P, r, gamma=gamma)\n",
    "    Q = r + gamma * v_star\n",
    "    return Q - Q[:, [a0]]  # subtract reference action per state\n",
    "\n",
    "def _expand_state_weights_to_sa(w_states, A, a0):\n",
    "    \"\"\"Expand state weights to (s,a) vector excluding the a0 column.\"\"\"\n",
    "    w_states = np.asarray(w_states, float).ravel()\n",
    "    S = w_states.size\n",
    "    W = np.repeat(w_states[:, None], A-1, axis=1)   # (S, A-1)\n",
    "    return W.ravel()\n",
    "\n",
    "# ================= main report =================\n",
    "\n",
    "def _row_normalize_probs(pi, eps=1e-12):\n",
    "    pi = np.asarray(pi, float)\n",
    "    pi = np.clip(pi, 0.0, 1.0)\n",
    "    row_sum = pi.sum(axis=1, keepdims=True)\n",
    "    row_sum = np.maximum(row_sum, eps)\n",
    "    return pi / row_sum\n",
    "\n",
    "def weighted_policy_metrics(pi_true, pi_hat, w, eps=1e-12):\n",
    "    \"\"\"\n",
    "    Weighted policy comparison with state weights w (>=0).\n",
    "    Returns (KL(true||hat), TV, top1) averaged w.r.t. w over states.\n",
    "      KL(true||hat) = sum_s w[s] * sum_a pi_true(s,a) log(pi_true/pi_hat)\n",
    "      TV            = sum_s w[s] * 0.5 * ||pi_true(s,·) - pi_hat(s,·)||_1\n",
    "      top1          = sum_s w[s] * 1{argmax_a pi_true(s,·) = argmax_a pi_hat(s,·)}\n",
    "    \"\"\"\n",
    "    pi_true = _row_normalize_probs(pi_true, eps)\n",
    "    pi_hat  = _row_normalize_probs(pi_hat,  eps)\n",
    "\n",
    "    w = np.asarray(w, float).ravel()\n",
    "    w = np.clip(w, 0.0, None)\n",
    "    ws = w.sum()\n",
    "    if ws <= eps:\n",
    "        return np.nan, np.nan, np.nan\n",
    "    w = w / ws  # normalize weights\n",
    "\n",
    "    # Per-state KL (true || hat)\n",
    "    kl_s = (pi_true * (np.log(pi_true + eps) - np.log(pi_hat + eps))).sum(axis=1)\n",
    "    KL = float((w * kl_s).sum())\n",
    "\n",
    "    # Per-state total variation\n",
    "    tv_s = 0.5 * np.abs(pi_true - pi_hat).sum(axis=1)\n",
    "    TV = float((w * tv_s).sum())\n",
    "\n",
    "    # Weighted top-1 accuracy\n",
    "    top_true = np.argmax(pi_true, axis=1)\n",
    "    top_hat  = np.argmax(pi_hat,  axis=1)\n",
    "    top1_s = (top_true == top_hat).astype(float)\n",
    "    top1 = float((w * top1_s).sum())\n",
    "\n",
    "    return KL, TV, top1\n",
    "\n",
    "\n",
    "def run_comparison_weighted(\n",
    "    P, pi_true, pi_hat,\n",
    "    R, r_hat, r_star=None, r_star_oracle=None,\n",
    "    gamma=0.99, w=None, a0=0, use_weighted_q=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Gauge-invariant Q-diff check:\n",
    "      For each reward r, compute D_r(s,a) = Q_r(s,a) - Q_r(s,a0),\n",
    "      then report NRMSE (RMSE/std of true Q-diffs) and Corr on the SAME arrays.\n",
    "    Adds an affine-calibracted block that fits alpha,beta on Q-diffs (same arrays).\n",
    "    If use_weighted_q=True, weight by state weights w expanded across actions != a0.\n",
    "    \"\"\"\n",
    "    S, A, Sp = P.shape\n",
    "    assert Sp == S\n",
    "    if w is None:\n",
    "        w = np.full(S, 1.0 / S)\n",
    "\n",
    "    results = {\"qdiff\": {}, \"policy\": {}}\n",
    "\n",
    "    # ----- Q-diff metrics -----\n",
    "    print(\"\\n# (B) Gauge-invariant Q-diff metrics (Q(s,a)-Q(s,a0))\")\n",
    "    D_true = _qdiff_from_r(P, R, gamma=gamma, a0=a0)\n",
    "    mask = np.ones_like(D_true, dtype=bool); mask[:, a0] = False  # exclude ref col\n",
    "\n",
    "    def _report_qdiff(tag, r_cmp):\n",
    "        out = {}\n",
    "        D_cmp = _qdiff_from_r(P, r_cmp, gamma=gamma, a0=a0)\n",
    "        x = D_cmp[mask].ravel(); y = D_true[mask].ravel()\n",
    "\n",
    "        if not use_weighted_q:\n",
    "            nrmse = _rmse_norm(y, x, denom=\"std\")\n",
    "            corr  = _safe_corr(x, y)\n",
    "            print(f\"  {tag:14s}  Qdiff-NRMSE={nrmse:.6f}  Qdiff-Corr={corr:.6f}\")\n",
    "            out.update({\"Qdiff-NRMSE\": nrmse, \"Qdiff-Corr\": corr})\n",
    "        else:\n",
    "            w_vec = _expand_state_weights_to_sa(w, A, a0)\n",
    "            nrmse_w = _weighted_rmse_norm(y, x, w_vec, denom=\"std\")\n",
    "            corr_w  = _weighted_corr_vec(x, y, w_vec)\n",
    "            print(f\"  {tag:14s}  Qdiff-NRMSE(μ)={nrmse_w:.6f}  Qdiff-Corr(μ)={corr_w:.6f}\")\n",
    "            out.update({\"Qdiff-NRMSE(μ)\": nrmse_w, \"Qdiff-Corr(μ)\": corr_w})\n",
    "\n",
    "        # ---- affine calibration on the SAME arrays ----\n",
    "        print(\" \" * 17 + \"→ After affine calibration on Q-diffs:\")\n",
    "        if not use_weighted_q:\n",
    "            alpha, beta, yhat = _fit_affine_1d(x, y, w=None)\n",
    "            nrmse_aff = _rmse_norm(y, yhat, denom=\"std\")\n",
    "            corr_aff  = _safe_corr(yhat, y)\n",
    "            print(f\"  {tag:14s}  Qdiff-NRMSE[aff]={nrmse_aff:.6f}  Qdiff-Corr[aff]={corr_aff:.6f}  (alpha={alpha:.6f}, beta={beta:.6f})\")\n",
    "            out.update({\"Qdiff-NRMSE[aff]\": nrmse_aff, \"Qdiff-Corr[aff]\": corr_aff,\n",
    "                        \"alpha\": alpha, \"beta\": beta})\n",
    "        else:\n",
    "            w_vec = _expand_state_weights_to_sa(w, A, a0)\n",
    "            alpha, beta, yhat = _fit_affine_1d(x, y, w=w_vec)\n",
    "            nrmse_aff = _weighted_rmse_norm(y, yhat, w_vec, denom=\"std\")\n",
    "            corr_aff  = _weighted_corr_vec(yhat, y, w_vec)\n",
    "            print(f\"  {tag:14s}  Qdiff-NRMSE(μ)[aff]={nrmse_aff:.6f}  Qdiff-Corr(μ)[aff]={corr_aff:.6f}  (alpha={alpha:.6f}, beta={beta:.6f})\")\n",
    "            out.update({\"Qdiff-NRMSE(μ)[aff]\": nrmse_aff, \"Qdiff-Corr(μ)[aff]\": corr_aff,\n",
    "                        \"alpha\": alpha, \"beta\": beta})\n",
    "\n",
    "        results[\"qdiff\"][tag] = out\n",
    "\n",
    "    if r_hat is not None:        _report_qdiff(\"MaxEnt\",        r_hat)\n",
    "    if r_star is not None:       _report_qdiff(\"r_star\",        r_star)\n",
    "    if r_star_oracle is not None:_report_qdiff(\"r_star_oracle\", r_star_oracle)\n",
    "\n",
    "    # ---------- (C) Policy comparisons (weighted vs π_true) ----------\n",
    "    print(\"\\n# (C) Policy comparisons (weighted vs π_true)\")\n",
    "    pi_from_Rtrue = policy_from_reward(P, R, gamma=gamma)\n",
    "    KLw0, TVw0, topw0 = weighted_policy_metrics(pi_true, pi_from_Rtrue, w)\n",
    "    print(f\"  policy(R_true):       KL={KLw0:.6f}  TV={TVw0:.6f}  top1={topw0:.3f}\")\n",
    "    results[\"policy\"][\"R_true\"] = {\"KL\": KLw0, \"TV\": TVw0, \"top1\": topw0}\n",
    "\n",
    "    KLb, TVb, topb = weighted_policy_metrics(pi_true, pi_hat, w)\n",
    "    print(f\"  baseline pi_hat:      KL={KLb:.6f}  TV={TVb:.6f}  top1={topb:.3f}\")\n",
    "    results[\"policy\"][\"pi_hat\"] = {\"KL\": KLb, \"TV\": TVb, \"top1\": topb}\n",
    "\n",
    "    if r_hat is not None:\n",
    "        pi_from_r_hat = policy_from_reward(P, r_hat, gamma=gamma)\n",
    "        KLh, TVh, toph = weighted_policy_metrics(pi_true, pi_from_r_hat, w)\n",
    "        print(f\"  policy(r_hat):        KL={KLh:.6f}  TV={TVh:.6f}  top1={toph:.3f}\")\n",
    "        results[\"policy\"][\"r_hat\"] = {\"KL\": KLh, \"TV\": TVh, \"top1\": toph}\n",
    "\n",
    "    if r_star is not None:\n",
    "        pi_rstar = policy_from_reward(P, r_star, gamma=gamma)\n",
    "        KLs, TVs, tops = weighted_policy_metrics(pi_true, pi_rstar, w)\n",
    "        print(f\"  policy(r_star):       KL={KLs:.6f}  TV={TVs:.6f}  top1={tops:.3f}\")\n",
    "        results[\"policy\"][\"r_star\"] = {\"KL\": KLs, \"TV\": TVs, \"top1\": tops}\n",
    "\n",
    "    if r_star_oracle is not None:\n",
    "        pi_roracle = policy_from_reward(P, r_star_oracle, gamma=gamma)\n",
    "        KLo, TVo, topo = weighted_policy_metrics(pi_true, pi_roracle, w)\n",
    "        print(f\"  policy(r_star_orc):   KL={KLo:.6f}  TV={TVo:.6f}  top1={topo:.3f}\")\n",
    "        results[\"policy\"][\"r_star_oracle\"] = {\"KL\": KLo, \"TV\": TVo, \"top1\": topo}\n",
    "\n",
    "    return results\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "61fa6138-8a41-4e2c-a8df-846ec9df3ba6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =========================\n",
    "# NN policy head using the SAME ψ(s,a)\n",
    "# =========================\n",
    "import numpy as np\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import math\n",
    "class _PolicyMLP(nn.Module):\n",
    "    # Scalar logit per (s,a) feature; softmax across actions per state.\n",
    "    def __init__(self, D, hidden=(128,64)):\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        in_d = D\n",
    "        for h in hidden:\n",
    "            layers += [nn.Linear(in_d, h), nn.ReLU()]\n",
    "            in_d = h\n",
    "        layers += [nn.Linear(in_d, 1)]\n",
    "        self.net = nn.Sequential(*layers)\n",
    "    def forward(self, Phi):  # Phi: [S,A,D] -> logits [S,A]\n",
    "        S, A, D = Phi.shape\n",
    "        z = self.net(Phi.view(S*A, D)).view(S, A)\n",
    "        return z\n",
    "\n",
    "def train_policy_mlp_with_phi(\n",
    "    Phi, S_iid, A_iid, *,\n",
    "    epochs=20, lr=5e-3, batch_size=4096,\n",
    "    device=\"cpu\", dtype=th.float64, verbose=True, seed=0,\n",
    "    hidden=(128,64), weight_decay=0.0\n",
    "):\n",
    "    th.manual_seed(seed); np.random.seed(seed)\n",
    "    S, A, D = Phi.shape\n",
    "    X = th.from_numpy(Phi).to(device=device, dtype=dtype)  # [S,A,D]\n",
    "\n",
    "    model = _PolicyMLP(D, hidden=hidden).to(device=device, dtype=dtype)\n",
    "    opt = th.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "\n",
    "    idx_S = th.from_numpy(np.asarray(S_iid)).to(device)\n",
    "    idx_A = th.from_numpy(np.asarray(A_iid)).to(device)\n",
    "\n",
    "    N = idx_S.numel()\n",
    "    steps_per_epoch = math.ceil(N / batch_size)\n",
    "\n",
    "    for ep in range(epochs):\n",
    "        perm = th.randperm(N, device=device)\n",
    "        total = 0.0\n",
    "        for k in range(steps_per_epoch):\n",
    "            sl = perm[k*batch_size : (k+1)*batch_size]\n",
    "            s_b = idx_S[sl]; a_b = idx_A[sl]\n",
    "\n",
    "            logits = model(X)  # [S,A]\n",
    "            m = logits.max(dim=1, keepdim=True).values\n",
    "            pi = th.exp(logits - m) / th.clamp(th.exp(logits - m).sum(dim=1, keepdim=True), min=1e-12)\n",
    "\n",
    "            logp = th.log(pi[s_b, a_b] + 1e-12)\n",
    "            loss = -logp.mean()\n",
    "\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            total += loss.item()\n",
    "        if verbose:\n",
    "            print(f\"[policy-mlp] epoch {ep+1:02d}/{epochs}  nll={total/steps_per_epoch:.4f}\")\n",
    "\n",
    "    with th.no_grad():\n",
    "        logits = model(X)\n",
    "        m = logits.max(dim=1, keepdim=True).values\n",
    "        pi = th.exp(logits - m) / th.clamp(th.exp(logits - m).sum(dim=1, keepdim=True), min=1e-12)\n",
    "    return pi.cpu().numpy()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7da29e9-841a-43e5-aa0b-fc3f33bf8e5d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Easy sims"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b4c6e8db-47d5-4464-9ffb-3dd8c77443a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "# Easy gridworld MDP (continuing; discounted)\n",
    "# ============================================================\n",
    "\n",
    "def build_torus_gridworld(nrow=4, ncol=4, slip=0.05, seed=0):\n",
    "    \"\"\"\n",
    "    Small torus grid:\n",
    "      - States are (y,x) on an nrow x ncol grid, wrapped at edges.\n",
    "      - Actions: 0=Up, 1=Right, 2=Down, 3=Left.\n",
    "      - With prob 1-2*slip: go intended; with prob slip each: turn left/right.\n",
    "    Returns P [S,A,S] and a uniform d0 over all states.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    S = nrow * ncol; A = 4\n",
    "    def enc(y, x): return (y % nrow) * ncol + (x % ncol)\n",
    "\n",
    "    # action -> (dy, dx)\n",
    "    moves = {0:(-1,0), 1:(0,1), 2:(1,0), 3:(0,-1)}\n",
    "    left  = {0:3, 1:0, 2:1, 3:2}\n",
    "    right = {0:1, 1:2, 2:3, 3:0}\n",
    "\n",
    "    P = np.zeros((S, A, S), dtype=np.float64)\n",
    "    for y in range(nrow):\n",
    "        for x in range(ncol):\n",
    "            s = enc(y, x)\n",
    "            for a in range(A):\n",
    "                probs = [(a, 1 - 2*slip), (left[a], slip), (right[a], slip)]\n",
    "                for a2, pa in probs:\n",
    "                    dy, dx = moves[a2]\n",
    "                    sp = enc(y + dy, x + dx)\n",
    "                    P[s, a, sp] += pa\n",
    "\n",
    "    # start/reset: uniform over all states\n",
    "    d0 = np.full(S, 1.0 / S)\n",
    "    return P, d0, (nrow, ncol)\n",
    "\n",
    "# ============================================================\n",
    "# Simple feature maps + true reward\n",
    "# ============================================================\n",
    "\n",
    "def make_phi_state(nrow, ncol):\n",
    "    def phi_s(s, a):\n",
    "        y, x = divmod(s, ncol)\n",
    "        gx = x / max(1, ncol-1)\n",
    "        gy = y / max(1, nrow-1)\n",
    "        return np.array([1.0, gx, gy, gx*gy], float)  # 4-dim (bias + coords)\n",
    "    return phi_s\n",
    "\n",
    "def make_phi_state_action(nrow, ncol, A=4):\n",
    "    phi_s = make_phi_state(nrow, ncol)\n",
    "    I = np.eye(A)\n",
    "    def phi_sa(s, a):\n",
    "        return np.concatenate([phi_s(s, a), I[a]], axis=0)  # 4 + A dims\n",
    "    return phi_sa\n",
    "\n",
    "def build_true_reward(P, shape, kind=\"sa\", w_scale=1.0, action_bias=0.2):\n",
    "    \"\"\"\n",
    "    kind:\n",
    "      - \"state\": r(s,a) = w·phi_s(s)  (same across actions; *easiest* to recover up to potential)\n",
    "      - \"sa\":     r(s,a) = w·phi_sa(s,a) (identifiable under chosen parameterization)\n",
    "    \"\"\"\n",
    "    S, A, _ = P.shape\n",
    "    nrow, ncol = shape\n",
    "    if kind == \"state\":\n",
    "        phi_s = make_phi_state(nrow, ncol)\n",
    "        d = len(phi_s(0))\n",
    "        w = np.array([0.0, 1.0, 0.7, 0.3], float) * w_scale\n",
    "        r_s = np.stack([phi_s(s) @ w for s in range(S)], axis=0)  # [S]\n",
    "        R = np.repeat(r_s[:, None], A, axis=1)                    # [S,A]\n",
    "    elif kind == \"sa\":\n",
    "        phi_sa = make_phi_state_action(nrow, ncol, A=A)\n",
    "        d = len(phi_sa(0, 0))\n",
    "        w = np.zeros(d, float)\n",
    "        # state part\n",
    "        w[:4] = np.array([0.0, 1.0, 0.7, 0.3]) * w_scale\n",
    "        # tiny action preference on \"Right\" (a=1)\n",
    "        w[4+1] = action_bias\n",
    "        R = np.zeros((S, A), float)\n",
    "        for s in range(S):\n",
    "            for a in range(A):\n",
    "                R[s, a] = phi_sa(s, a) @ w\n",
    "    else:\n",
    "        raise ValueError(\"kind must be 'state' or 'sa'\")\n",
    "    return R\n",
    "\n",
    "# ============================================================\n",
    "# Discounted i.i.d. sampler (geometric resets)\n",
    "# ============================================================\n",
    "\n",
    "def sample_discounted_iid(P, pi, gamma=0.97, n_steps=100_000, seed=1, start_dist=None):\n",
    "    \"\"\"\n",
    "    Simulates a continuing chain with geometric resets:\n",
    "      with prob (1-gamma) we reset s ~ start_dist; else s_{t+1} ~ P(s_t, a_t, ·), a_t ~ pi(s_t).\n",
    "    Returns arrays of states and actions of length n_steps.\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    S, A, _ = P.shape\n",
    "    if start_dist is None:\n",
    "        start_dist = np.full(S, 1.0 / S)\n",
    "    cum_start = np.cumsum(start_dist); cum_start[-1] = 1.0\n",
    "    cum_P  = np.cumsum(P, axis=2);    cum_P[:, :, -1] = 1.0\n",
    "    cum_pi = np.cumsum(pi, axis=1);   cum_pi[:,  -1] = 1.0\n",
    "\n",
    "    S_out, A_out = np.empty(n_steps, int), np.empty(n_steps, int)\n",
    "    # initial state\n",
    "    s = int(np.searchsorted(cum_start, rng.random()))\n",
    "    for t in range(n_steps):\n",
    "        if rng.random() < (1 - gamma):\n",
    "            s = int(np.searchsorted(cum_start, rng.random()))\n",
    "        a  = int(np.searchsorted(cum_pi[s], rng.random()))\n",
    "        sp = int(np.searchsorted(cum_P[s, a], rng.random()))\n",
    "        S_out[t], A_out[t] = s, a\n",
    "        s = sp\n",
    "    return S_out, A_out\n",
    "\n",
    "# ============================================================\n",
    "# One-call convenience builder\n",
    "# ============================================================\n",
    "\n",
    "def make_easy_maxent_irl_sim(\n",
    "    nrow=4, ncol=4, slip=0.05,\n",
    "    gamma=0.97,\n",
    "    reward_kind=\"sa\",     # \"state\" or \"sa\"\n",
    "    steps=200_000,\n",
    "    seed=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Builds a tiny MDP and softmax(temperature=1) expert, then generates i.i.d. discounted samples.\n",
    "    Use (P, R_true, d0, pi_true, S_iid, A_iid) for MaxEnt IRL.\n",
    "    \"\"\"\n",
    "    P, d0, shape = build_torus_gridworld(nrow=nrow, ncol=ncol, slip=slip, seed=seed)\n",
    "    R_true = build_true_reward(P, shape, kind=reward_kind)\n",
    "    pi_true = policy_from_reward(P, R_true, gamma=gamma)\n",
    "    S_iid, A_iid = sample_discounted_iid(P, pi_true, gamma=gamma, n_steps=steps, seed=seed+1, start_dist=d0)\n",
    "    return P, R_true, d0, pi_true, S_iid, A_iid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b49eac73-6e00-4afc-a59b-d9257b49c995",
   "metadata": {},
   "source": [
    "# Identified sims"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3e7c5806-f401-4620-90cb-20e8b81c4700",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# ============================================================\n",
    "# Identifiable MaxEnt IRL Simulator\n",
    "#   -> true reward lies in the MaxEnt features\n",
    "#   -> standardized policy_from_reward + compute_v_star_from_r\n",
    "# ============================================================\n",
    "\n",
    "def make_identifiable_maxent_irl_sim(\n",
    "    feature_builder,\n",
    "    nrow=8, ncol=8, A=5,\n",
    "    gamma=0.97,\n",
    "    slip=0.08,\n",
    "    steps=250_000,\n",
    "    seed=123,\n",
    "):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    S = nrow * ncol\n",
    "\n",
    "    # -------- world: walls + portal + slip --------\n",
    "    def enc(y, x): return (y % nrow) * ncol + (x % ncol)\n",
    "    moves4 = {0:(-1,0), 1:(0,1), 2:(1,0), 3:(0,-1)}\n",
    "    left, right = {0:3,1:0,2:1,3:2}, {0:1,1:2,2:3,3:0}\n",
    "\n",
    "    wall = np.zeros((nrow,ncol), bool)\n",
    "    for y,x in [(3,3),(3,4),(4,4),(5,4)]: wall[y,x] = True\n",
    "    portal_pairs = [((1,1),(6,6))]\n",
    "\n",
    "    P = np.zeros((S,A,S), float)\n",
    "    for y in range(nrow):\n",
    "        for x in range(ncol):\n",
    "            s = enc(y,x)\n",
    "            for a in range(A):\n",
    "                if a == 4:  # No-op\n",
    "                    probs = [(4, 1.0)]\n",
    "                else:\n",
    "                    probs = [(a, 1-2*slip), (left[a], slip), (right[a], slip)]\n",
    "                for a2, pa in probs:\n",
    "                    if a2 == 4:\n",
    "                        sp_y, sp_x = y, x\n",
    "                    else:\n",
    "                        dy,dx = moves4[a2]\n",
    "                        ny,nx = y+dy, x+dx\n",
    "                        if 0 <= ny < nrow and 0 <= nx < ncol and not wall[ny,nx]:\n",
    "                            sp_y, sp_x = ny, nx\n",
    "                        else:\n",
    "                            sp_y, sp_x = y, x\n",
    "                    sp = enc(sp_y, sp_x)\n",
    "                    P[s,a,sp] += pa\n",
    "\n",
    "                # portal teleport\n",
    "                for (y1,x1),(y2,x2) in portal_pairs:\n",
    "                    p1, p2 = enc(y1,x1), enc(y2,x2)\n",
    "                    if P[s,a,p1] > 0:\n",
    "                        m = P[s,a,p1]; P[s,a,p1]=0; P[s,a,p2]+=m\n",
    "                    if P[s,a,p2] > 0:\n",
    "                        m = P[s,a,p2]; P[s,a,p2]=0; P[s,a,p1]+=m\n",
    "\n",
    "            # normalize row\n",
    "            z = P[s].sum(axis=1, keepdims=True); z[z<=0]=1.0\n",
    "            P[s] /= z\n",
    "\n",
    "    d0 = np.full(S, 1.0/S)\n",
    "    masks = {\"wall\": wall, \"portal_pairs\": portal_pairs}\n",
    "\n",
    "    # -------- build features (same as MaxEnt training) --------\n",
    "    phi_fn = feature_builder(P, nrow, ncol, A=A, masks=masks)\n",
    "    D_probe = phi_fn(0,0).shape[0]\n",
    "    phi = np.zeros((S, A, D_probe), np.float64)\n",
    "    for s in range(S):\n",
    "        for a in range(A):\n",
    "            phi[s,a] = phi_fn(s,a)\n",
    "\n",
    "    # -------- θ_true and reward --------\n",
    "    rng = np.random.default_rng(seed+7)\n",
    "    theta_true = rng.normal(0, 0.5, size=D_probe).astype(np.float64)\n",
    "    if D_probe >= A:  # optional action bias\n",
    "        theta_true[-A:] += np.array([0.0, 0.25, 0.0, 0.0, -0.05])[:A]\n",
    "\n",
    "    R_true = np.einsum(\"sad,d->sa\", phi, theta_true)\n",
    "    R_true -= R_true.mean(axis=1, keepdims=True)  # gauge: per-state mean zero\n",
    "\n",
    "    # -------- expert policy via standardized solver --------\n",
    "    pi_true = policy_from_reward(P, R_true, gamma=gamma)\n",
    "\n",
    "    # -------- discounted i.i.d. demos --------\n",
    "    def sample_discounted_iid(P, pi, gamma, n_steps, seed, start_dist):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        S,A,_ = P.shape\n",
    "        cum_P  = np.cumsum(P, axis=2);  cum_P[:,:,-1] = 1.0\n",
    "        cum_pi = np.cumsum(pi, axis=1); cum_pi[:,-1]  = 1.0\n",
    "        cum_d0 = np.cumsum(start_dist.copy()); cum_d0[-1] = 1.0\n",
    "        S_out, A_out = np.empty(n_steps, int), np.empty(n_steps, int)\n",
    "        s = int(np.searchsorted(cum_d0, rng.random()))\n",
    "        for t in range(n_steps):\n",
    "            if rng.random() < (1-gamma):\n",
    "                s = int(np.searchsorted(cum_d0, rng.random()))\n",
    "            a  = int(np.searchsorted(cum_pi[s], rng.random()))\n",
    "            sp = int(np.searchsorted(cum_P[s,a], rng.random()))\n",
    "            S_out[t], A_out[t] = s, a\n",
    "            s = sp\n",
    "        return S_out, A_out\n",
    "\n",
    "    S_iid, A_iid = sample_discounted_iid(P, pi_true, gamma, steps, seed+1, d0)\n",
    "\n",
    "    return P, R_true, d0, pi_true, S_iid, A_iid, masks, phi, theta_true\n",
    "\n",
    "\n",
    "# 0) Setup\n",
    "GAMMA = 0.97\n",
    "nrow, ncol, A = 8, 8, 5\n",
    "\n",
    "def make_phi_tabular_linear_S_plus_A(nrow, ncol, A=5, *, include_bias=True, dtype=np.float64):\n",
    "    \"\"\"φ(s,a) = [ (optional bias) | 1_{state=s} | 1_{action=a} ], no interactions.\"\"\"\n",
    "    S = nrow * ncol\n",
    "    D_bias = 1 if include_bias else 0\n",
    "    D = D_bias + S + A\n",
    "    I_S = np.eye(S, dtype=dtype)\n",
    "    I_A = np.eye(A, dtype=dtype)\n",
    "\n",
    "    def phi_sa(s, a):\n",
    "        x = np.zeros(D, dtype=dtype)\n",
    "        off = 0\n",
    "        if include_bias:\n",
    "            x[0] = 1.0\n",
    "            off = 1\n",
    "        x[off + s] = 1.0           # state one-hot\n",
    "        x[off + S + a] = 1.0       # action one-hot\n",
    "        return x\n",
    "\n",
    "    phi_sa.dim = D\n",
    "    return phi_sa\n",
    "\n",
    "    \n",
    "\n",
    "# Choose the SAME feature builder for sim + training\n",
    "def feature_builder(P, nrow, ncol, A, masks):\n",
    "    return make_phi_hard_plus_sf(\n",
    "        P, nrow, ncol, A=A, masks=masks,\n",
    "        include_indicators=True,\n",
    "        include_deltas=True,\n",
    "        include_successor=True,\n",
    "        include_axay=True,\n",
    "        include_rbf_action=True,\n",
    "        n_rbf=20, n_rbf_action=16, rbf_scale=0.9\n",
    "    )\n",
    "\n",
    "P, R_true, d0, pi_true, S_iid, A_iid, masks, phi, theta_true = make_identifiable_maxent_irl_sim(\n",
    "    feature_builder, nrow=nrow, ncol=ncol, A=A, gamma=GAMMA, slip=0.08, steps=250_000, seed=123\n",
    ")\n",
    "\n",
    "# Build phi_fn(s,a) once:\n",
    "# TRUE features: linear in S and A, no interactions\n",
    "def feature_builder_true(P, nrow, ncol, A, masks):\n",
    "    return make_phi_tabular_linear_S_plus_A(nrow, ncol, A=A, include_bias=True)\n",
    " \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68c08b52-0104-462a-bf5c-0b3813c42a99",
   "metadata": {},
   "source": [
    "# identified"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "ab895c33-350e-4245-82f2-92d2908d342a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step    0  ll: -1.609438  H(pi): 1.609  VIres:9.87e-11(772)  ||μ_emp-μ̂||_2:0.1328  _inf:0.0885\n",
      "step   25  ll: -1.564740  H(pi): 1.549  VIres:9.85e-11(776)  ||μ_emp-μ̂||_2:0.0182  _inf:0.0156\n",
      "step   50  ll: -1.564421  H(pi): 1.550  VIres:9.83e-11(776)  ||μ_emp-μ̂||_2:0.0171  _inf:0.0146\n",
      "step   75  ll: -1.564418  H(pi): 1.550  VIres:9.83e-11(776)  ||μ_emp-μ̂||_2:0.0171  _inf:0.0146\n",
      "step  100  ll: -1.564418  H(pi): 1.550  VIres:9.83e-11(776)  ||μ_emp-μ̂||_2:0.0171  _inf:0.0146\n",
      "step  125  ll: -1.564418  H(pi): 1.550  VIres:9.83e-11(776)  ||μ_emp-μ̂||_2:0.0171  _inf:0.0146\n",
      "step  150  ll: -1.564418  H(pi): 1.550  VIres:9.83e-11(776)  ||μ_emp-μ̂||_2:0.0171  _inf:0.0146\n",
      "step  175  ll: -1.564418  H(pi): 1.550  VIres:9.83e-11(776)  ||μ_emp-μ̂||_2:0.0171  _inf:0.0146\n",
      "step  200  ll: -1.564418  H(pi): 1.550  VIres:9.83e-11(776)  ||μ_emp-μ̂||_2:0.0171  _inf:0.0146\n",
      "[EARLY STOP] No LL improvement for 6 checks.\n",
      "\n",
      "==== Rep 1/10  (sim_seed=1000, train_seed=2000) ====\n",
      "\n",
      "# (B) Gauge-invariant Q-diff metrics (Q(s,a)-Q(s,a0))\n",
      "  MaxEnt          Qdiff-NRMSE=0.289874  Qdiff-Corr=0.988469\n",
      "                 → After affine calibration on Q-diffs:\n",
      "  MaxEnt          Qdiff-NRMSE[aff]=0.151422  Qdiff-Corr[aff]=0.988469  (alpha=0.810780, beta=0.074999)\n",
      "  r_star          Qdiff-NRMSE=0.020281  Qdiff-Corr=0.999958\n",
      "                 → After affine calibration on Q-diffs:\n",
      "  r_star          Qdiff-NRMSE[aff]=0.009182  Qdiff-Corr[aff]=0.999958  (alpha=0.993942, beta=0.007473)\n",
      "  r_star_oracle   Qdiff-NRMSE=0.000000  Qdiff-Corr=1.000000\n",
      "                 → After affine calibration on Q-diffs:\n",
      "  r_star_oracle   Qdiff-NRMSE[aff]=0.000000  Qdiff-Corr[aff]=1.000000  (alpha=1.000000, beta=0.000000)\n",
      "\n",
      "# (C) Policy comparisons (weighted vs π_true)\n",
      "  policy(R_true):       KL=0.000000  TV=0.000000  top1=1.000\n",
      "  baseline pi_hat:      KL=0.000009  TV=0.001950  top1=1.000\n",
      "  policy(r_hat):        KL=0.001788  TV=0.020366  top1=0.977\n",
      "  policy(r_star):       KL=0.000009  TV=0.001950  top1=1.000\n",
      "  policy(r_star_orc):   KL=-0.000000  TV=0.000000  top1=1.000\n",
      "step    0  ll: -1.609438  H(pi): 1.609  VIres:9.87e-11(772)  ||μ_emp-μ̂||_2:0.2371  _inf:0.1362\n",
      "step   25  ll: -1.458452  H(pi): 1.419  VIres:9.81e-11(779)  ||μ_emp-μ̂||_2:0.0278  _inf:0.0187\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[23]\u001b[39m\u001b[32m, line 147\u001b[39m\n\u001b[32m    142\u001b[39m             \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m  \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m45s\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m  mean=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mm\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.6f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m  se=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00ms\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.6f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m    144\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[33m\"\u001b[39m\u001b[33mby_rep\u001b[39m\u001b[33m\"\u001b[39m: by_rep_flat, \u001b[33m\"\u001b[39m\u001b[33msummary\u001b[39m\u001b[33m\"\u001b[39m: {\u001b[33m\"\u001b[39m\u001b[33mmean\u001b[39m\u001b[33m\"\u001b[39m: mean_dict, \u001b[33m\"\u001b[39m\u001b[33mse\u001b[39m\u001b[33m\"\u001b[39m: se_dict}}\n\u001b[32m--> \u001b[39m\u001b[32m147\u001b[39m out = \u001b[43mrun_identifiable_experiment_reps\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    148\u001b[39m \u001b[43m    \u001b[49m\u001b[43mn_reps\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m    149\u001b[39m \u001b[43m    \u001b[49m\u001b[43mnrow\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mncol\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m8\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mA\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m    150\u001b[39m \u001b[43m    \u001b[49m\u001b[43mgamma\u001b[49m\u001b[43m=\u001b[49m\u001b[43mGAMMA\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    151\u001b[39m \u001b[43m    \u001b[49m\u001b[43msteps\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m250_000\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m    152\u001b[39m \u001b[43m    \u001b[49m\u001b[43mslip\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.08\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m    153\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcuda\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mhasattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mth\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcuda\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mand\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mth\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcuda\u001b[49m\u001b[43m.\u001b[49m\u001b[43mis_available\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcpu\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m    154\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.05\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps_train\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m700\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miters_vi\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1200\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m    155\u001b[39m \u001b[43m    \u001b[49m\u001b[43mverbose\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# set False to suppress per-rep prints\u001b[39;49;00m\n\u001b[32m    156\u001b[39m \u001b[43m)\u001b[49m\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[23]\u001b[39m\u001b[32m, line 93\u001b[39m, in \u001b[36mrun_identifiable_experiment_reps\u001b[39m\u001b[34m(n_reps, base_seed_sim, base_seed_train, nrow, ncol, A, gamma, steps, slip, device, lr, steps_train, iters_vi, verbose)\u001b[39m\n\u001b[32m     90\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[32m     91\u001b[39m     \u001b[38;5;28;01mpass\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m93\u001b[39m theta_hat, hist, phi_train = \u001b[43mfit_maxent_irl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m     94\u001b[39m \u001b[43m    \u001b[49m\u001b[43mP\u001b[49m\u001b[43m=\u001b[49m\u001b[43mP\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mS_iid\u001b[49m\u001b[43m=\u001b[49m\u001b[43mS_iid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mA_iid\u001b[49m\u001b[43m=\u001b[49m\u001b[43mA_iid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfeature_fn\u001b[49m\u001b[43m=\u001b[49m\u001b[43mphi_fn_train\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     95\u001b[39m \u001b[43m    \u001b[49m\u001b[43mnrow\u001b[49m\u001b[43m=\u001b[49m\u001b[43mnrow\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mncol\u001b[49m\u001b[43m=\u001b[49m\u001b[43mncol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mA\u001b[49m\u001b[43m=\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgamma\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgamma\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     96\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps\u001b[49m\u001b[43m=\u001b[49m\u001b[43msteps_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miters\u001b[49m\u001b[43m=\u001b[49m\u001b[43miters_vi\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     97\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m=\u001b[49m\u001b[43mth\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfloat64\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m     98\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    100\u001b[39m \u001b[38;5;66;03m# 2) Learned reward (same gauge)\u001b[39;00m\n\u001b[32m    101\u001b[39m r_hat = np.einsum(\u001b[33m\"\u001b[39m\u001b[33msad,d->sa\u001b[39m\u001b[33m\"\u001b[39m, phi_train, theta_hat)\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[17]\u001b[39m\u001b[32m, line 114\u001b[39m, in \u001b[36mfit_maxent_irl\u001b[39m\u001b[34m(P, S_iid, A_iid, feature_fn, nrow, ncol, A, gamma, lr, steps, iters, dtype, device, weight_decay, check_every, tol_vi, tol_mu_L2, plateau_patience)\u001b[39m\n\u001b[32m    111\u001b[39m \u001b[38;5;66;03m# reward model: r(s,a) = φ(s,a)·θ\u001b[39;00m\n\u001b[32m    112\u001b[39m r = (phi @ theta)  \u001b[38;5;66;03m# [S,A]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m Q, V, pi, vi_resid, vi_used = \u001b[43mvi\u001b[49m\u001b[43m(\u001b[49m\u001b[43mr\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# soft VI\u001b[39;00m\n\u001b[32m    116\u001b[39m \u001b[38;5;66;03m# NLL over demos\u001b[39;00m\n\u001b[32m    117\u001b[39m logp = th.log(th.clamp(pi, \u001b[38;5;28mmin\u001b[39m=\u001b[32m1e-12\u001b[39m))\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[17]\u001b[39m\u001b[32m, line 39\u001b[39m, in \u001b[36mDiffSoftVI.__call__\u001b[39m\u001b[34m(self, r)\u001b[39m\n\u001b[32m     37\u001b[39m U  = r + g * EV\n\u001b[32m     38\u001b[39m m  = U.max(dim=\u001b[32m1\u001b[39m, keepdim=\u001b[38;5;28;01mTrue\u001b[39;00m).values\n\u001b[32m---> \u001b[39m\u001b[32m39\u001b[39m V_new = (m + th.log(\u001b[43mth\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mU\u001b[49m\u001b[43m \u001b[49m\u001b[43m-\u001b[49m\u001b[43m \u001b[49m\u001b[43mm\u001b[49m\u001b[43m)\u001b[49m.sum(dim=\u001b[32m1\u001b[39m, keepdim=\u001b[38;5;28;01mTrue\u001b[39;00m) + eps)).squeeze(\u001b[32m1\u001b[39m)\n\u001b[32m     41\u001b[39m vi_resid = th.max(th.abs(V_new - V)).item()\n\u001b[32m     42\u001b[39m vi_used  = k + \u001b[32m1\u001b[39m\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch as th\n",
    "from math import sqrt\n",
    "\n",
    "# -------- utilities to aggregate over repetitions --------\n",
    "def _flatten_metrics(d, prefix=\"\"):\n",
    "    \"\"\"Flatten nested dict of scalars into { 'a/b/c': value }.\"\"\"\n",
    "    flat = {}\n",
    "    for k, v in d.items():\n",
    "        key = f\"{prefix}/{k}\" if prefix else str(k)\n",
    "        if isinstance(v, dict):\n",
    "            flat.update(_flatten_metrics(v, key))\n",
    "        else:\n",
    "            # only keep numeric scalars\n",
    "            if isinstance(v, (int, float, np.floating)):\n",
    "                flat[key] = float(v)\n",
    "    return flat\n",
    "\n",
    "def _agg_mean_se(list_of_dicts):\n",
    "    \"\"\"Compute mean and SE over a list of flat dicts with identical keys.\"\"\"\n",
    "    if not list_of_dicts:\n",
    "        return {}, {}\n",
    "    keys = sorted(list_of_dicts[0].keys())\n",
    "    arr = np.array([[d[k] for k in keys] for d in list_of_dicts], dtype=float)\n",
    "    mean = arr.mean(axis=0)\n",
    "    # use sample std with ddof=1 when n>1; otherwise se=nan\n",
    "    if arr.shape[0] > 1:\n",
    "        se = arr.std(axis=0, ddof=1) / sqrt(arr.shape[0])\n",
    "    else:\n",
    "        se = np.full_like(mean, np.nan)\n",
    "    mean_dict = {k: m for k, m in zip(keys, mean)}\n",
    "    se_dict   = {k: s for k, s in zip(keys, se)}\n",
    "    return mean_dict, se_dict\n",
    "\n",
    "# -------- main multi-rep runner --------\n",
    "def run_identifiable_experiment_reps(\n",
    "    n_reps=10,\n",
    "    base_seed_sim=1000,     # for make_identifiable_maxent_irl_sim TRUE-features\n",
    "    base_seed_train=2000,   # if you want to vary any training randomness\n",
    "    nrow=8, ncol=8, A=5,\n",
    "    gamma=0.97,\n",
    "    steps=250_000,\n",
    "    slip=0.08,\n",
    "    device=None,            # 'cuda' or 'cpu'; default picks automatically\n",
    "    lr=0.05, steps_train=700, iters_vi=1200,\n",
    "    verbose=True\n",
    "):\n",
    "    \"\"\"\n",
    "    Runs the 'identifiable' setting n_reps times:\n",
    "      - sim built with TRUE features (tabular state + action),\n",
    "      - training uses the same TRUE feature map (identifiable),\n",
    "      - evaluates with run_comparison_weighted (which must return a dict).\n",
    "    Returns:\n",
    "      {\n",
    "        'by_rep': [flattened_metrics_per_rep, ...],\n",
    "        'summary': {'mean': {...}, 'se': {...}}\n",
    "      }\n",
    "    \"\"\"\n",
    "    if device is None:\n",
    "        device = \"cuda\" if (hasattr(th, \"cuda\") and th.cuda.is_available()) else \"cpu\"\n",
    "\n",
    "    # TRUE features for both sim and training (identifiable)\n",
    "    def feature_builder_true(P_local, nrow_local, ncol_local, A=None, masks=None, **_):\n",
    "        # same TRUE feature map used for both sim and training (identifiable)\n",
    "        return make_phi_tabular_linear_S_plus_A(\n",
    "            nrow_local, ncol_local, A=A if A is not None else 5, include_bias=True\n",
    "        )\n",
    "\n",
    "\n",
    "    by_rep_flat = []\n",
    "\n",
    "    for rep in range(n_reps):\n",
    "        sim_seed  = base_seed_sim  + rep\n",
    "        train_seed = base_seed_train + rep  # available if your training routine uses PRNG\n",
    "\n",
    "        # --- Build sim with TRUE features (identifiable setting)\n",
    "        P, R_true, d0, pi_true, S_iid, A_iid, masks, phi_true, theta_true = make_identifiable_maxent_irl_sim(\n",
    "            feature_builder_true,\n",
    "            nrow=nrow, ncol=ncol, A=A,\n",
    "            gamma=gamma, slip=slip, steps=steps, seed=sim_seed\n",
    "        )\n",
    "\n",
    "        # --- Training map: SAME TRUE features (identifiable)\n",
    "        phi_fn_train = feature_builder_true(P, nrow, ncol, A, masks)\n",
    "\n",
    "        # If your fit uses randomness, make it reproducible per-rep\n",
    "        try:\n",
    "            th.manual_seed(train_seed)\n",
    "            np.random.seed(train_seed % (2**32 - 1))\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "        theta_hat, hist, phi_train = fit_maxent_irl(\n",
    "            P=P, S_iid=S_iid, A_iid=A_iid, feature_fn=phi_fn_train,\n",
    "            nrow=nrow, ncol=ncol, A=A, gamma=gamma,\n",
    "            lr=lr, steps=steps_train, iters=iters_vi,\n",
    "            dtype=th.float64, device=device,\n",
    "        )\n",
    "\n",
    "        # 2) Learned reward (same gauge)\n",
    "        r_hat = np.einsum(\"sad,d->sa\", phi_train, theta_hat)\n",
    "\n",
    "        # 3) Baseline + metrics inputs\n",
    "        S_local = P.shape[0]\n",
    "        obs_space, act_space = _Space(S_local), _Space(A)\n",
    "        demos = Demos(S_iid, A_iid)\n",
    "        _, _, clf = train_tabular_logreg_with_logreward(demos, obs_space, act_space, temperature=1.0)\n",
    "        pi_hat = clf.predict_proba(np.eye(S_local, dtype=np.float32))\n",
    "        w_emp = np.bincount(S_iid, minlength=S_local).astype(float)\n",
    "        w_emp /= max(1, w_emp.sum())\n",
    "\n",
    "        # Solve normalized MaxEnt fixed point (baseline + oracle)\n",
    "        c_star, r_star, v_star, info = solve_normalized_maxent(P, pi_hat, gamma=gamma)\n",
    "        _, r_star_oracle, _, _ = solve_normalized_maxent(P, pi_true, gamma=gamma)\n",
    "\n",
    "        # Evaluate and collect metrics; keep printing per-rep if verbose\n",
    "        if verbose:\n",
    "            print(f\"\\n==== Rep {rep+1}/{n_reps}  (sim_seed={sim_seed}, train_seed={train_seed}) ====\")\n",
    "        res = run_comparison_weighted(\n",
    "            P=P, pi_true=pi_true, pi_hat=pi_hat,\n",
    "            R=R_true, r_hat=r_hat, r_star=r_star, r_star_oracle=r_star_oracle,\n",
    "            gamma=gamma, w=w_emp\n",
    "        )\n",
    "\n",
    "        # Flatten and store\n",
    "        by_rep_flat.append(_flatten_metrics(res))\n",
    "\n",
    "    # Aggregate: mean + SE across reps\n",
    "    mean_dict, se_dict = _agg_mean_se(by_rep_flat)\n",
    "\n",
    "    # Pretty print a compact summary grouped by sections\n",
    "    print(\"\\n================ AGGREGATE OVER REPS ================\")\n",
    "    # group keys by first path segment\n",
    "    grouped = {}\n",
    "    for k in mean_dict.keys():\n",
    "        head = k.split(\"/\")[0]\n",
    "        grouped.setdefault(head, []).append(k)\n",
    "    for head, keys in grouped.items():\n",
    "        print(f\"\\n# {head.upper()}\")\n",
    "        for k in sorted(keys):\n",
    "            m = mean_dict[k]; s = se_dict[k]\n",
    "            print(f\"  {k:45s}  mean={m:.6f}  se={s:.6f}\")\n",
    "\n",
    "    return {\"by_rep\": by_rep_flat, \"summary\": {\"mean\": mean_dict, \"se\": se_dict}}\n",
    "\n",
    "\n",
    "out = run_identifiable_experiment_reps(\n",
    "    n_reps=10,\n",
    "    nrow=8, ncol=8, A=5,\n",
    "    gamma=GAMMA,\n",
    "    steps=250_000,\n",
    "    slip=0.08,\n",
    "    device=\"cuda\" if (hasattr(th, \"cuda\") and th.cuda.is_available()) else \"cpu\",\n",
    "    lr=0.05, steps_train=700, iters_vi=1200,\n",
    "    verbose=True  # set False to suppress per-rep prints\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d08af2a4-3b64-4d1b-8fa3-8ade2e746be4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "632ec297-25cb-4cae-8415-6d498e93469f",
   "metadata": {},
   "source": [
    "# hard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "460a0619-858c-4a8e-ba93-fa1111260874",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# ============================================================\n",
    "# Identifiable MaxEnt IRL Simulator\n",
    "#   -> true reward lies in the MaxEnt features\n",
    "#   -> standardized policy_from_reward + compute_v_star_from_r\n",
    "# ============================================================\n",
    "\n",
    "def _build_Phi(S, A, psi_fn=None):\n",
    "    \"\"\"\n",
    "    Build feature tensor Phi with shape [S, A, D] where Phi[s, a] = psi_fn(s, a).\n",
    "\n",
    "    If psi_fn is None, default to concatenated one-hot features:\n",
    "      psi(s,a) = [ one_hot_state(s) | one_hot_action(a) ]  (length D = S + A)\n",
    "    \"\"\"\n",
    "    if psi_fn is None:\n",
    "        D = S + A\n",
    "        Phi = np.zeros((S, A, D), dtype=np.float32)\n",
    "        for s in range(S):\n",
    "            for a in range(A):\n",
    "                Phi[s, a, s] = 1.0          # state one-hot\n",
    "                Phi[s, a, S + a] = 1.0      # action one-hot\n",
    "        return Phi\n",
    "    else:\n",
    "        sample = np.asarray(psi_fn(0, 0))\n",
    "        if sample.ndim != 1:\n",
    "            raise ValueError(\"psi_fn(s,a) must return a 1D feature vector.\")\n",
    "        D = sample.shape[0]\n",
    "        Phi = np.zeros((S, A, D), dtype=np.float32)\n",
    "        for s in range(S):\n",
    "            for a in range(A):\n",
    "                Phi[s, a] = psi_fn(s, a)\n",
    "        return Phi\n",
    "\n",
    "\n",
    "\n",
    "def make_identifiable_maxent_irl_sim(\n",
    "    feature_builder,\n",
    "    nrow=8, ncol=8, A=5,\n",
    "    gamma=0.97,\n",
    "    slip=0.08,\n",
    "    steps=250_000,\n",
    "    seed=123,\n",
    "    nonlinear_reward=False\n",
    "):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    S = nrow * ncol\n",
    "\n",
    "    # -------- world: walls + portal + slip --------\n",
    "    def enc(y, x): return (y % nrow) * ncol + (x % ncol)\n",
    "    moves4 = {0:(-1,0), 1:(0,1), 2:(1,0), 3:(0,-1)}\n",
    "    left, right = {0:3,1:0,2:1,3:2}, {0:1,1:2,2:3,3:0}\n",
    "\n",
    "    wall = np.zeros((nrow,ncol), bool)\n",
    "    for y,x in [(3,3),(3,4),(4,4),(5,4)]: wall[y,x] = True\n",
    "    portal_pairs = [((1,1),(6,6))]\n",
    "\n",
    "    P = np.zeros((S,A,S), float)\n",
    "    for y in range(nrow):\n",
    "        for x in range(ncol):\n",
    "            s = enc(y,x)\n",
    "            for a in range(A):\n",
    "                if a == 4:  # No-op\n",
    "                    probs = [(4, 1.0)]\n",
    "                else:\n",
    "                    probs = [(a, 1-2*slip), (left[a], slip), (right[a], slip)]\n",
    "                for a2, pa in probs:\n",
    "                    if a2 == 4:\n",
    "                        sp_y, sp_x = y, x\n",
    "                    else:\n",
    "                        dy,dx = moves4[a2]\n",
    "                        ny,nx = y+dy, x+dx\n",
    "                        if 0 <= ny < nrow and 0 <= nx < ncol and not wall[ny,nx]:\n",
    "                            sp_y, sp_x = ny, nx\n",
    "                        else:\n",
    "                            sp_y, sp_x = y, x\n",
    "                    sp = enc(sp_y, sp_x)\n",
    "                    P[s,a,sp] += pa\n",
    "\n",
    "                # portal teleport\n",
    "                for (y1,x1),(y2,x2) in portal_pairs:\n",
    "                    p1, p2 = enc(y1,x1), enc(y2,x2)\n",
    "                    if P[s,a,p1] > 0:\n",
    "                        m = P[s,a,p1]; P[s,a,p1]=0; P[s,a,p2]+=m\n",
    "                    if P[s,a,p2] > 0:\n",
    "                        m = P[s,a,p2]; P[s,a,p2]=0; P[s,a,p1]+=m\n",
    "\n",
    "            # normalize row\n",
    "            z = P[s].sum(axis=1, keepdims=True); z[z<=0]=1.0\n",
    "            P[s] /= z\n",
    "\n",
    "    d0 = np.full(S, 1.0/S)\n",
    "    masks = {\"wall\": wall, \"portal_pairs\": portal_pairs}\n",
    "\n",
    "    # -------- build features (same as MaxEnt training) --------\n",
    "    phi_fn = feature_builder(P, nrow, ncol, A=A, masks=masks)\n",
    "    D_probe = phi_fn(0,0).shape[0]\n",
    "    phi = np.zeros((S, A, D_probe), np.float64)\n",
    "    for s in range(S):\n",
    "        for a in range(A):\n",
    "            phi[s,a] = phi_fn(s,a)\n",
    "\n",
    "    # -------- TRUE reward (linear or nonlinear) --------\n",
    "    rng = np.random.default_rng(seed + 7)\n",
    "    theta_true = None  # ensure defined for nonlinear branch\n",
    "    if not nonlinear_reward:\n",
    "        # original identifiable: linear in features\n",
    "        theta_true = rng.normal(0, 0.5, size=D_probe).astype(np.float64)\n",
    "        if D_probe >= A:  # optional action bias\n",
    "            theta_true[-A:] += np.array([0.0, 0.25, 0.0, 0.0, -0.05])[:A]\n",
    "        R_true = np.einsum(\"sad,d->sa\", phi, theta_true)\n",
    "    else:\n",
    "        # small MLP to induce nonlinear SA interactions\n",
    "        H1 = max(32, min(128, (D_probe // 2) or 32))\n",
    "        H2 = 32\n",
    "        W1 = rng.normal(0, 1.0/np.sqrt(D_probe), size=(D_probe, H1))\n",
    "        b1 = rng.normal(0, 0.05, size=(H1,))\n",
    "        W2 = rng.normal(0, 1.0/np.sqrt(H1), size=(H1, H2))\n",
    "        b2 = rng.normal(0, 0.05, size=(H2,))\n",
    "        w3 = rng.normal(0, 1.0/np.sqrt(H2), size=(H2,))\n",
    "        b3 = float(rng.normal(0, 0.05))\n",
    "        def _mlp_forward(x):\n",
    "            h1 = np.tanh(x @ W1 + b1)\n",
    "            h2 = np.maximum(0.0, h1 @ W2 + b2)\n",
    "            return float(h2 @ w3 + b3)\n",
    "        R_true = np.empty((S, A), dtype=np.float64)\n",
    "        for s in range(S):\n",
    "            for a in range(A):\n",
    "                R_true[s, a] = _mlp_forward(phi[s, a])\n",
    "        if A <= 8:\n",
    "            R_true[:, 1] += 0.10\n",
    "            R_true[:, 4 % A] -= 0.03\n",
    "\n",
    "    # gauge: per-state mean zero\n",
    "    R_true -= R_true.mean(axis=1, keepdims=True)\n",
    "\n",
    "    # -------- expert policy via standardized solver --------\n",
    "    pi_true = policy_from_reward(P, R_true, gamma=gamma)\n",
    "\n",
    "    # -------- discounted i.i.d. demos --------\n",
    "    def sample_discounted_iid(P, pi, gamma, n_steps, seed, start_dist):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        S,A,_ = P.shape\n",
    "        cum_P  = np.cumsum(P, axis=2);  cum_P[:,:,-1] = 1.0\n",
    "        cum_pi = np.cumsum(pi, axis=1); cum_pi[:,-1]  = 1.0\n",
    "        cum_d0 = np.cumsum(start_dist.copy()); cum_d0[-1] = 1.0\n",
    "        S_out, A_out = np.empty(n_steps, int), np.empty(n_steps, int)\n",
    "        s = int(np.searchsorted(cum_d0, rng.random()))\n",
    "        for t in range(n_steps):\n",
    "            if rng.random() < (1-gamma):\n",
    "                s = int(np.searchsorted(cum_d0, rng.random()))\n",
    "            a  = int(np.searchsorted(cum_pi[s], rng.random()))\n",
    "            sp = int(np.searchsorted(cum_P[s,a], rng.random()))\n",
    "            S_out[t], A_out[t] = s, a\n",
    "            s = sp\n",
    "        return S_out, A_out\n",
    "\n",
    "    S_iid, A_iid = sample_discounted_iid(P, pi_true, gamma, steps, seed+1, d0)\n",
    "\n",
    "    return P, R_true, d0, pi_true, S_iid, A_iid, masks, phi, theta_true\n",
    "\n",
    "# 0) Setup\n",
    "GAMMA = 0.97\n",
    "nrow, ncol, A = 8, 8, 5\n",
    "\n",
    "def make_phi_tabular_linear_S_plus_A(nrow, ncol, A=5, *, include_bias=True, dtype=np.float64):\n",
    "    \"\"\"φ(s,a) = [ (optional bias) | 1_{state=s} | 1_{action=a} ], no interactions.\"\"\"\n",
    "    S = nrow * ncol\n",
    "    D_bias = 1 if include_bias else 0\n",
    "    D = D_bias + S + A\n",
    "    I_S = np.eye(S, dtype=dtype)\n",
    "    I_A = np.eye(A, dtype=dtype)\n",
    "\n",
    "    def phi_sa(s, a):\n",
    "        x = np.zeros(D, dtype=dtype)\n",
    "        off = 0\n",
    "        if include_bias:\n",
    "            x[0] = 1.0\n",
    "            off = 1\n",
    "        x[off + s] = 1.0           # state one-hot\n",
    "        x[off + S + a] = 1.0       # action one-hot\n",
    "        return x\n",
    "\n",
    "    phi_sa.dim = D\n",
    "    return phi_sa\n",
    "\n",
    "    \n",
    "\n",
    "# Choose the SAME feature builder for sim + training\n",
    "def feature_builder(P, nrow, ncol, A, masks):\n",
    "    return make_phi_hard_plus_sf(\n",
    "        P, nrow, ncol, A=A, masks=masks,\n",
    "        include_indicators=True,\n",
    "        include_deltas=True,\n",
    "        include_successor=True,\n",
    "        include_axay=True,\n",
    "        include_rbf_action=True,\n",
    "        n_rbf=20, n_rbf_action=16, rbf_scale=0.9\n",
    "    )\n",
    " \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c7d188a-d111-4954-9b50-c4c0073677b5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step    0  ll: -1.609438  H(pi): 1.609  VIres:9.87e-11(772)  ||μ_emp-μ̂||_2:0.0757  _inf:0.0568\n",
      "step   25  ll: -1.583791  H(pi): 1.586  VIres:9.72e-11(774)  ||μ_emp-μ̂||_2:0.0056  _inf:0.0045\n",
      "step   50  ll: -1.583459  H(pi): 1.588  VIres:9.73e-11(774)  ||μ_emp-μ̂||_2:0.0016  _inf:0.0010\n",
      "step   75  ll: -1.583453  H(pi): 1.588  VIres:9.73e-11(774)  ||μ_emp-μ̂||_2:0.0015  _inf:0.0009\n",
      "step  100  ll: -1.583453  H(pi): 1.588  VIres:9.73e-11(774)  ||μ_emp-μ̂||_2:0.0015  _inf:0.0009\n",
      "step  125  ll: -1.583453  H(pi): 1.588  VIres:9.73e-11(774)  ||μ_emp-μ̂||_2:0.0015  _inf:0.0009\n",
      "step  150  ll: -1.583453  H(pi): 1.588  VIres:9.73e-11(774)  ||μ_emp-μ̂||_2:0.0015  _inf:0.0009\n",
      "step  175  ll: -1.583453  H(pi): 1.588  VIres:9.73e-11(774)  ||μ_emp-μ̂||_2:0.0015  _inf:0.0009\n",
      "step  200  ll: -1.583453  H(pi): 1.588  VIres:9.73e-11(774)  ||μ_emp-μ̂||_2:0.0015  _inf:0.0009\n",
      "[EARLY STOP] No LL improvement for 6 checks.\n",
      "\n",
      "==== Rep 1/10  (sim_seed=42, train_seed=0) ====\n",
      "\n",
      "# (B) Gauge-invariant Q-diff metrics (Q(s,a)-Q(s,a0))\n",
      "  MaxEnt          Qdiff-NRMSE=0.818381  Qdiff-Corr=0.588154\n",
      "                 → After affine calibration on Q-diffs:\n",
      "  MaxEnt          Qdiff-NRMSE[aff]=0.808749  Qdiff-Corr[aff]=0.588154  (alpha=0.900942, beta=-0.065905)\n",
      "  r_star          Qdiff-NRMSE=0.250979  Qdiff-Corr=0.977407\n",
      "                 → After affine calibration on Q-diffs:\n",
      "  r_star          Qdiff-NRMSE[aff]=0.211368  Qdiff-Corr[aff]=0.977407  (alpha=1.039241, beta=0.070091)\n",
      "  r_star_oracle   Qdiff-NRMSE=0.000000  Qdiff-Corr=1.000000\n",
      "                 → After affine calibration on Q-diffs:\n",
      "  r_star_oracle   Qdiff-NRMSE[aff]=0.000000  Qdiff-Corr[aff]=1.000000  (alpha=1.000000, beta=0.000000)\n",
      "\n",
      "# (C) Policy comparisons (weighted vs π_true)\n",
      "  policy(R_true):       KL=0.000000  TV=0.000000  top1=1.000\n",
      "  baseline pi_hat:      KL=0.001488  TV=0.022062  top1=0.818\n",
      "  policy(r_hat):        KL=0.029661  TV=0.094488  top1=0.532\n",
      "  policy(r_star):       KL=0.001488  TV=0.022062  top1=0.818\n",
      "  policy(r_star_orc):   KL=-0.000000  TV=0.000000  top1=1.000\n",
      "step    0  ll: -1.609438  H(pi): 1.609  VIres:9.87e-11(772)  ||μ_emp-μ̂||_2:0.0391  _inf:0.0259\n",
      "step   25  ll: -1.598905  H(pi): 1.600  VIres:9.83e-11(772)  ||μ_emp-μ̂||_2:0.0040  _inf:0.0031\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from math import sqrt\n",
    "S = P.shape[0]\n",
    "\n",
    "# Make sure DEVICE / DTYPE exist\n",
    "DEVICE = \"cuda\" if (hasattr(th, \"cuda\") and th.cuda.is_available()) else \"cpu\"\n",
    "DTYPE  = th.float32\n",
    "\n",
    "# ---- tiny helpers ----\n",
    "def _flatten_metrics(d, prefix=\"\"):\n",
    "    flat = {}\n",
    "    for k, v in d.items():\n",
    "        key = f\"{prefix}/{k}\" if prefix else str(k)\n",
    "        if isinstance(v, dict):\n",
    "            flat.update(_flatten_metrics(v, key))\n",
    "        else:\n",
    "            if isinstance(v, (int, float, np.floating)):\n",
    "                flat[key] = float(v)\n",
    "    return flat\n",
    "\n",
    "def _agg_mean_se(list_of_dicts):\n",
    "    if not list_of_dicts:\n",
    "        return {}, {}\n",
    "    keys = sorted(list_of_dicts[0].keys())\n",
    "    arr = np.array([[d[k] for k in keys] for d in list_of_dicts], dtype=float)\n",
    "    mean = arr.mean(axis=0)\n",
    "    se = arr.std(axis=0, ddof=1) / sqrt(arr.shape[0]) if arr.shape[0] > 1 else np.full_like(mean, np.nan)\n",
    "    return {k: m for k, m in zip(keys, mean)}, {k: s for k, s in zip(keys, se)}\n",
    "\n",
    "# ---- minimal multi-rep wrapper (uses your code as-is inside the loop) ----\n",
    "def run_reps_mlp_misspec(\n",
    "    n_reps=10,\n",
    "    base_seed_sim=42,      # simulator seed (nonlinear truth)\n",
    "    base_seed_train=0,     # NN seed\n",
    "    verbose=True\n",
    "):\n",
    "    by_rep = []\n",
    "\n",
    "    for rep in range(n_reps):\n",
    "        sim_seed   = base_seed_sim   + rep\n",
    "        train_seed = base_seed_train + rep\n",
    "\n",
    "        # --- build simulator with nonlinear truth (your function, unchanged) ---\n",
    "        P, R_true, d0, pi_true, S_iid, A_iid, masks, phi_sim, theta_true = make_identifiable_maxent_irl_sim(\n",
    "            feature_builder, nrow=nrow, ncol=ncol, A=A, gamma=GAMMA,\n",
    "            slip=0.08, steps=250_000, seed=sim_seed, nonlinear_reward=True\n",
    "        )\n",
    "\n",
    "        # --- training features: linear (misspecified), unchanged ---\n",
    "        phi_fn_train = feature_builder_true(P, nrow, ncol, A, masks)\n",
    "        theta_hat, hist, phi_train = fit_maxent_irl(\n",
    "            P=P, S_iid=S_iid, A_iid=A_iid, feature_fn=phi_fn_train,\n",
    "            nrow=nrow, ncol=ncol, A=A, gamma=GAMMA,\n",
    "            lr=0.05, steps=700, iters=1200,\n",
    "            dtype=th.float64, device=\"cuda\" if th.cuda.is_available() else \"cpu\",\n",
    "        )\n",
    "        r_hat = np.einsum(\"sad,d->sa\", phi_train, theta_hat)\n",
    "\n",
    "        # --- policy NN (uses your exact calls; only seed is varied) ---\n",
    "        S = P.shape[0]\n",
    "        obs_space, act_space = _Space(S), _Space(A)\n",
    "        demos = Demos(S_iid, A_iid)\n",
    "\n",
    "        Phi = _build_Phi(S, A, phi_fn_train)  # assumes these are defined in your env\n",
    "        pi_hat = train_policy_mlp_with_phi(\n",
    "            Phi, S_iid, A_iid,\n",
    "            epochs=20, lr=5e-3, batch_size=4096,\n",
    "            device=DEVICE, dtype=DTYPE, verbose=False, seed=train_seed,\n",
    "            hidden=(256,128), weight_decay=1e-5\n",
    "        )\n",
    "\n",
    "        # weights + normalized MaxEnt solves (unchanged)\n",
    "        w_emp = np.bincount(S_iid, minlength=S).astype(float); w_emp /= max(1, w_emp.sum())\n",
    "        c_star, r_star, v_star, info = solve_normalized_maxent(P, pi_hat, gamma=GAMMA)\n",
    "        _, r_star_oracle, _, _ = solve_normalized_maxent(P, pi_true, gamma=GAMMA)\n",
    "\n",
    "        # collect metrics (run_comparison_weighted prints & returns a dict)\n",
    "        if verbose:\n",
    "            print(f\"\\n==== Rep {rep+1}/{n_reps}  (sim_seed={sim_seed}, train_seed={train_seed}) ====\")\n",
    "        res = run_comparison_weighted(\n",
    "            P=P, pi_true=pi_true, pi_hat=pi_hat,\n",
    "            R=R_true, r_hat=r_hat, r_star=r_star, r_star_oracle=r_star_oracle,\n",
    "            gamma=GAMMA, w=w_emp\n",
    "        )\n",
    "        by_rep.append(_flatten_metrics(res))\n",
    "\n",
    "    mean, se = _agg_mean_se(by_rep)\n",
    "\n",
    "    # optional compact printout\n",
    "    if verbose:\n",
    "        print(\"\\n================ MEAN ± SE over reps ================\")\n",
    "        for k in sorted(mean.keys()):\n",
    "            print(f\"{k:60s}  mean={mean[k]:.6f}  se={se[k]:.6f}\")\n",
    "\n",
    "    return {\"by_rep\": by_rep, \"summary\": {\"mean\": mean, \"se\": se}}\n",
    "\n",
    "# ---- run it ----\n",
    "out = run_reps_mlp_misspec(n_reps=10, base_seed_sim=42, base_seed_train=0, verbose=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5237c2bb-0e70-4996-a314-b3598c916fb0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1737f22e-1a2b-4300-bea2-8974d918e848",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "134f5137-6757-4ea7-857c-b8c12fd5e519",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36821a23-e7a0-427c-9534-56e83405876a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (myenv3)",
   "language": "python",
   "name": "myenv3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
