{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a292831f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as th\n",
    "import transformers\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "import lqr_utils_seq as lqr\n",
    "from functools import partial\n",
    "from datasets import load_dataset\n",
    "import random\n",
    "import pickle\n",
    "import time\n",
    "from steering import LQRSteering\n",
    "from data_scripts_and_utils.data_handling import ContrastiveBuilder\n",
    "import yaml\n",
    "import random\n",
    "import json\n",
    "from linearization import compute_lin_err as OLcompute_lin_err\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "with open('config/config.yaml', 'r') as f:\n",
    "    config_data = yaml.safe_load(f)\n",
    "PICKLE_JAR = config_data[\"environment\"][\"pickle_jar\"]\n",
    "\n",
    "device = th.device(\"cuda\" if th.cuda.is_available() else \"cpu\")\n",
    "print(f\"device: {device}\")\n",
    "\n",
    "\n",
    "def load_model(model_name, quant=False):\n",
    "\n",
    "    if quant:\n",
    "        quant_config = BitsAndBytesConfig(\n",
    "            load_in_4bit=True,          # or load_in_8bit=True\n",
    "            # load_in_8bit=True,\n",
    "            bnb_4bit_compute_dtype=th.float16,\n",
    "            bnb_4bit_quant_type=\"nf4\",  # best for LLMs\n",
    "            bnb_4bit_use_double_quant=True,\n",
    "        )\n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            model_name, quantization_config=quant_config, dtype=th.float32, device_map=\"auto\")\n",
    "        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\")\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "        tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "    else: \n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            model_name).to(device)\n",
    "        tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "    return model, tokenizer\n",
    "        \n",
    "\n",
    "def get_safe_prompts():\n",
    "    dataset = load_dataset(\"tatsu-lab/alpaca\")\n",
    "    return dataset['train'][:][\"instruction\"]\n",
    "\n",
    "def load_file(filename):\n",
    "    try:\n",
    "        with open(PICKLE_JAR + filename + \".pkl\", \"rb\") as f:\n",
    "            return pickle.load(f)\n",
    "    except FileNotFoundError:\n",
    "        return None\n",
    "\n",
    "def compute_lin_err(nom_acts, jacs, acts, var, cov, P):\n",
    "    l1norm_byvar_list = []\n",
    "    avg_l1norm_byvar_list = []\n",
    "    l2norm_byvar_list = []\n",
    "    l2norm_list = []\n",
    "    linfty_by_var_list = []\n",
    "    cos_sim_list = []\n",
    "    err_list = []\n",
    "    Pnorm_list = []\n",
    "    for i in range(nom_acts.shape[0]-1):\n",
    "        delta_xt = nom_acts[i] - acts[i]\n",
    "        Adelta_xt = jacs[i] @ delta_xt\n",
    "        delta_xtp1 = nom_acts[i+1] - acts[i+1]\n",
    "\n",
    "        err = Adelta_xt - delta_xtp1\n",
    "\n",
    "        l1byvar = th.sum(th.abs(err / var[i]))\n",
    "\n",
    "        l1norm_byvar_list.append((l1byvar).item())\n",
    "        avg_l1norm_byvar_list.append((l1byvar / err.shape[-1]).item())\n",
    "        l2norm_byvar_list.append(th.norm(err / var[i]).item())\n",
    "        cos_sim_list.append((th.dot(Adelta_xt, delta_xtp1)/(th.norm(Adelta_xt)*th.norm(delta_xtp1))).item()) # cosine similarity\n",
    "        l2norm_list.append(th.norm(err).item())\n",
    "        linfty_by_var_list.append(th.max(th.abs(err)).item())\n",
    "        Pnorm_list.append(th.sqrt(err.T @ P[i+1] @ err).item())\n",
    "\n",
    "    data = {\n",
    "        \"l1_by_var\": l1norm_byvar_list,\n",
    "        \"avg_l1_by_var\": avg_l1norm_byvar_list,\n",
    "        \"l2_by_var\": l2norm_byvar_list,\n",
    "        \"l2\": l2norm_list,\n",
    "        \"linfty_by_var_list\": linfty_by_var_list,\n",
    "        \"cos_sim_list\": cos_sim_list,\n",
    "        \"Pnorm_list\": Pnorm_list\n",
    "        # \"err_list\": err_list\n",
    "            }\n",
    "    return data\n",
    "\n",
    "\n",
    "def helper(X1, X2):\n",
    "    print(X1.shape)\n",
    "    print(X2.shape)\n",
    "    # for i in range(X1.shape[0]):\n",
    "\n",
    "def lin_err_CL(nom_acts, jacs, acts, U, var, diffs, cov, P):\n",
    "\n",
    "    l1norm_byvar_list = []\n",
    "    avg_l1norm_bydiffs_list = []\n",
    "    avg_l1norm_byvar_list = []\n",
    "    l2norm_byvar_list = []\n",
    "    l2norm_list = []\n",
    "    linfty_by_var_list = []\n",
    "    cos_sim_list = []\n",
    "    # err_list = []\n",
    "    l1_list = []\n",
    "    ratio_list = []\n",
    "    err_list = []\n",
    "    \n",
    "    Pnorm_list = []\n",
    "\n",
    "    for i in range(nom_acts.shape[0]-1):\n",
    "        delta_xt_cl = acts[i] - nom_acts[i]\n",
    "        delta_xt_cl = jacs[i] @ delta_xt_cl + U[i]\n",
    "\n",
    "        true_delta = acts[i+1] - nom_acts[i+1]\n",
    "        err = delta_xt_cl - true_delta\n",
    "        \n",
    "        \n",
    "        std = th.sqrt(var[i])\n",
    "\n",
    "        ratio_list.append((th.norm(err) / th.norm(acts[i] - nom_acts[i])).item())\n",
    "        # ratio_list.append((th.norm(err) / th.norm(true_delta)).item())\n",
    "\n",
    "\n",
    "        l1 = th.sum(th.abs(err))\n",
    "        l1_list.append(l1)\n",
    "\n",
    "        l1byvar = th.sum(th.abs(err / std))\n",
    "        l1bydiff = th.sum(th.abs(err / diffs[i+1]))\n",
    "\n",
    "        l1norm_byvar_list.append((l1byvar).item())\n",
    "        avg_l1norm_byvar_list.append((l1byvar / err.shape[-1]).item())\n",
    "        avg_l1norm_bydiffs_list.append((l1bydiff / err.shape[-1]).item())\n",
    "\n",
    "        l2norm_byvar_list.append(th.norm(err / std).item())\n",
    "        cos_sim_list.append((th.dot(delta_xt_cl, true_delta)/(th.norm(delta_xt_cl)*th.norm(true_delta))).item()) # cosine similarity\n",
    "        l2norm_list.append(th.norm(err).item())\n",
    "        err_list.append(err.tolist())\n",
    "\n",
    "\n",
    "        errbydiff = err / diffs[i+1]\n",
    "        # errbyvar = err / var[i]\n",
    "        Pnorm_list.append(th.sqrt(err.T @ P[i+1] @ err).item())\n",
    "\n",
    "        linfty_by_var_list.append(th.max(th.abs(err)/ std).item())\n",
    "\n",
    "    data = {\n",
    "        \"l1\": l1_list,\n",
    "        \"l1_by_var\": l1norm_byvar_list,\n",
    "        \"avg_l1_by_var\": avg_l1norm_byvar_list,\n",
    "        \"avg_l1_by_diffs\": avg_l1norm_bydiffs_list,\n",
    "        \"ratio\": ratio_list,\n",
    "        \"l2_by_var\": l2norm_byvar_list,\n",
    "        \"l2\": l2norm_list,\n",
    "        \"linfty_by_var_list\": linfty_by_var_list,\n",
    "        \"cos_sim_list\": cos_sim_list,\n",
    "        \"err_list\": err_list,\n",
    "        \"Pnorm_list\": Pnorm_list\n",
    "            }\n",
    "    return data\n",
    "\n",
    "\n",
    "def solve_backwards_lyapunov(A):\n",
    "    N = A.shape[0]\n",
    "    n = A.shape[1]\n",
    "\n",
    "    Q = th.eye(n, device=device)\n",
    "    P_T = th.eye(n, device=device)\n",
    "    \n",
    "    P = th.zeros((N+1, n, n),device=device)\n",
    "    P[N] = P_T\n",
    "\n",
    "    A_gpu = A.to(device)\n",
    "    for k in reversed(range(N)):\n",
    "        P[k] = A_gpu[k].T @ P[k+1] @ A_gpu[k] + Q\n",
    "\n",
    "    return P.detach().cpu()\n",
    "\n",
    "def op_norm(A, Pi, Pip1):\n",
    "    L = th.linalg.cholesky(Pi)\n",
    "\n",
    "    M = A.T @ Pip1 @ A\n",
    "\n",
    "    X = th.cholesky_solve(M,L)\n",
    "    lambda_max = th.linalg.eigvalsh(X).max()\n",
    "\n",
    "    return th.sqrt(lambda_max).item()\n",
    "\n",
    "\n",
    "def is_contracting(A, Pi, Pip1, gamma=1.0):\n",
    "    lhs = A.T @ Pip1 @ A\n",
    "    rhs = gamma**2 * Pi\n",
    "    eigs = th.linalg.eigvalsh(lhs-rhs)\n",
    "    return eigs.max() <= 0\n",
    "\n",
    "def bound_at_k(k, delta_x0, errs, A_cl, P):\n",
    "    phi_k_0 = th.eye(A_cl.shape[-1], device=device)\n",
    "    # phi_bar = 1\n",
    "\n",
    "    sigma = 0\n",
    "\n",
    "\n",
    "    for i in range(k-1):\n",
    "        phi_k_0 = A_cl[i] @ phi_k_0\n",
    "        # phi_bar *= op_norm(A_cl[i], P[i], P[i+1])\n",
    "\n",
    "    # delta_xk = delta_x0\n",
    "\n",
    "    t1 = op_norm(phi_k_0, P[0], P[k]) * th.sqrt(delta_x0.T @ P[0] @ delta_x0)\n",
    "    # t1 = phi_bar\n",
    "\n",
    "\n",
    "    for i in range(k):\n",
    "        phi_k_i = th.eye(A_cl.shape[-1], device=device)\n",
    "        # sigma += op_norm(A_cl[i], P[i], P[i+1])*th.sqrt(errs[i].T @ P[i] @ errs[i])\n",
    "        for j in range(i+1, k):\n",
    "            phi_k_i = A_cl[j] @ phi_k_i\n",
    "        \n",
    "\n",
    "        sigma += op_norm(phi_k_i, P[i+1], P[k])*errs[i]\n",
    "    \n",
    "\n",
    "    bound = t1 + sigma\n",
    "\n",
    "    return bound\n",
    "\n",
    "def P_norm(x, P):\n",
    "    return th.sqrt(x.T @ P @ x)\n",
    "\n",
    "def compute_closed_loop_lin_err_P(\n",
    "    nom_states,\n",
    "    states,\n",
    "    A_list,\n",
    "    B_list,\n",
    "    K_list,\n",
    "    P_list,\n",
    "    eps=1e-8\n",
    "):\n",
    "    r_Pnorm_list = []\n",
    "    dz_Pnorm_list = []\n",
    "    L_hat_list = []\n",
    "\n",
    "    for k in range(nom_states.shape[0] - 1):\n",
    "        bar_z_k = nom_states[k]\n",
    "        bar_z_kp1 = nom_states[k + 1]\n",
    "\n",
    "        z_k = states[k]\n",
    "        z_kp1 = states[k + 1]\n",
    "\n",
    "        delta_z_k = z_k - bar_z_k\n",
    "        delta_z_kp1 = z_kp1 - bar_z_kp1\n",
    "\n",
    "        dz_norm = P_norm(delta_z_k, P_list[k])\n",
    "\n",
    "        if dz_norm < eps:\n",
    "            r_Pnorm_list.append(0.0)\n",
    "            dz_Pnorm_list.append(0.0)\n",
    "            L_hat_list.append(0.0)\n",
    "            continue\n",
    "\n",
    "        Ahat_k = A_list[k] - B_list[k] @ K_list[k]\n",
    "\n",
    "        r_k = delta_z_kp1 - Ahat_k @ delta_z_k\n",
    "        r_norm = P_norm(r_k, P_list[k])\n",
    "\n",
    "        L_hat = 2.0 * r_norm / (dz_norm ** 2)\n",
    "\n",
    "        r_Pnorm_list.append(r_norm.item())\n",
    "        dz_Pnorm_list.append(dz_norm.item())\n",
    "        L_hat_list.append(L_hat.item())\n",
    "\n",
    "    return {\n",
    "        \"r_Pnorm\": r_Pnorm_list,\n",
    "        \"dz_Pnorm\": dz_Pnorm_list,\n",
    "        \"L_hat\": L_hat_list,\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b63bc8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"google/gemma-2-2b\"\n",
    "# model_name = \"Qwen/Qwen2.5-3B-Instruct\"\n",
    "# model_name = \"Qwen/Qwen2.5-3B\"\n",
    "# model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
    "\n",
    "# model_name = \"google/gemma-2-9b-it\"\n",
    "# model_name = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "\n",
    "\n",
    "model, tokenizer = load_model(model_name, quant=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f86a49fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import random\n",
    "ds = load_dataset(\"llm-aes/writing-prompts\")[\"train\"]\n",
    "# prompts = get_safe_prompts()\n",
    "rand_prompts = [d['prompt'][-32:] for d in ds]\n",
    "random.shuffle(rand_prompts)\n",
    "print(rand_prompts[:10])\n",
    "rand_prompts = [\n",
    "    \"orbit\", \"cascade\", \"velvet\", \"quantum\", \"ember\",\n",
    "    \"lantern\", \"echo\", \"harvest\", \"rift\", \"mosaic\",\n",
    "    \"drift\", \"crimson\", \"signal\", \"hollow\", \"atlas\",\n",
    "    \"breeze\", \"glyph\", \"summit\", \"current\", \"nova\"\n",
    "]\n",
    "\n",
    "\n",
    "noms = []\n",
    "\n",
    "\n",
    "steer = LQRSteering(model, tokenizer, 10, 1, 10)\n",
    "\n",
    "p = 1\n",
    "i = 0\n",
    "while i < 5:    \n",
    "    ind = random.randint(0, len(rand_prompts)-1)\n",
    "    out = steer.track_tokens(rand_prompts[ind], rand_prompts[ind+1], 1)\n",
    "    X = steer.X[0][:,-1,:].detach().cpu()\n",
    "    X_cl = steer.X_cl[0][:,-1,:].detach().cpu()\n",
    "    if th.norm(X[0] - X_cl[0]).item() == 0:\n",
    "        print(\"yeehaw\")\n",
    "        # i = i\n",
    "        p = p+7\n",
    "        continue\n",
    "    else:\n",
    "        i = i+1\n",
    "    print(f\"diff: {th.norm(X[0] - X_cl[0]).item()}\")\n",
    "\n",
    "    A = steer.A.detach().cpu()\n",
    "    K = steer.K.detach().cpu()\n",
    "    U = steer.U.detach().cpu()\n",
    "    dat = {\n",
    "        \"X_nom\": X,\n",
    "        \"A_nom\": A,\n",
    "        \"X_cl\": X_cl,\n",
    "        \"K\": K,\n",
    "        \"U\": U,\n",
    "    }\n",
    "    noms.append(dat)\n",
    "    print(f\"Steered out: {out}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d343ad4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "goof = noms[0]\n",
    "P_cpu = solve_backwards_lyapunov(goof[\"A_nom\"]-goof[\"K\"])\n",
    "B = th.eye(goof[\"A_nom\"].shape[-1])\n",
    "B = B.repeat(goof[\"A_nom\"].shape[0], 1, 1)\n",
    "out = compute_closed_loop_lin_err_P(goof[\"X_nom\"], goof[\"X_cl\"], goof[\"A_nom\"], B, goof[\"K\"], P_cpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2607fcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.loglog(out[\"dz_Pnorm\"], out[\"r_Pnorm\"], '.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93bef12f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dataguy = ContrastiveBuilder(model, tokenizer)\n",
    "\n",
    "acts_new = dataguy.collect_activations(prompts=rand_prompts, num_samples=10)\n",
    "\n",
    "cov_mat = [th.cov(acts_new[:,i,:].T) for i in range(1, acts_new.shape[1])]\n",
    "variances = [th.diag(th.cov(acts_new[:,i,:].T)) for i in range(1, acts_new.shape[1])]\n",
    "\n",
    "diffs = th.zeros_like(acts_new[0])\n",
    "for i in range(acts_new.shape[1]):\n",
    "    for j in range(acts_new.shape[2]):\n",
    "        min = th.min(acts_new[:,i,j])\n",
    "        max = th.max(acts_new[:,i,j])\n",
    "        diffs[i,j] = max-min\n",
    "\n",
    "# single_step_OL = OLcompute_lin_err(steer.X[0][:,-1,:].detach().cpu(), steer.A.detach().cpu(), acts_new[0], variances, diffs)\n",
    "\n",
    "dat=noms[0]\n",
    "X_nom_cpu = dat[\"X_nom\"]\n",
    "\n",
    "print(\"shapes\")\n",
    "print(X_nom_cpu.shape)\n",
    "print(X_cl.shape)\n",
    "print(\"end shapes\")\n",
    "\n",
    "\n",
    "CL_mats = (dat[\"A_nom\"] - dat[\"K\"]).to(device)\n",
    "\n",
    "P_cpu = solve_backwards_lyapunov(CL_mats)\n",
    "\n",
    "\n",
    "# err_sum = th.zeros(acts_new.shape[1]-1)\n",
    "err_max = th.zeros(acts_new.shape[1]-1)\n",
    "\n",
    "print(acts_new.mean(0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62a1d9e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NEW\n",
    "act_mean = acts_new.mean(0).to(device)\n",
    "# x_mean: (27, 2304)\n",
    "# P:      (2304, 2304)\n",
    "P = P_cpu.to(device)\n",
    "x_col = act_mean.unsqueeze(1)\n",
    "\n",
    "# (27, 1, 2304)\n",
    "xP = th.bmm(x_col, P)\n",
    "\n",
    "# (27, 1, 1) → (27,)\n",
    "x_Pnorm = th.sqrt(\n",
    "    th.bmm(xP, x_col.transpose(1, 2)).squeeze()\n",
    ")\n",
    "\n",
    "print(act_mean.shape)\n",
    "\n",
    "# xP = act_mean @ P                    # (27, 2304)\n",
    "# x_Pnorm = th.sqrt(\n",
    "#     th.sum(xP * act_mean, dim=1)    # row-wise x^T P x\n",
    "# )\n",
    "print(x_Pnorm.shape)\n",
    "X_nom = dat[\"X_nom\"].to(device)\n",
    "\n",
    "# normalization (unchanged)\n",
    "Pvar = [\n",
    "    th.sqrt(th.trace(P[i] @ th.cov(acts_new[:, i, :].T).to(device))).item()\n",
    "    for i in range(acts_new.shape[1])\n",
    "]\n",
    "\n",
    "all_bounds = []\n",
    "all_errs = []\n",
    "\n",
    "for rollout_idx in range(len(noms)):\n",
    "    dat = noms[rollout_idx]\n",
    "    X_cl = dat[\"X_cl\"].to(device)\n",
    "\n",
    "    # === NEW: closed-loop remainder computation ===\n",
    "    rem_data = compute_closed_loop_lin_err_P(\n",
    "        nom_states=X_nom.detach().cpu(),\n",
    "        states=X_cl.detach().cpu(),\n",
    "        A_list=dat[\"A_nom\"],      # same A used to build CL_mats\n",
    "        B_list=B,\n",
    "        K_list=dat[\"K\"],\n",
    "        P_list=P_cpu,\n",
    "    )\n",
    "\n",
    "    # This is ||r_k||_P\n",
    "    r_Pnorm = th.tensor(rem_data[\"r_Pnorm\"], device=device)\n",
    "\n",
    "    # This is empirical L_k = 2||r_k|| / ||δz_k||^2\n",
    "    L_hat = th.tensor(rem_data[\"L_hat\"], device=device)\n",
    "    err0 = X_cl[0] - X_nom[0]\n",
    "    err0P = th.sqrt(err0.T @ P[0] @ err0).item()\n",
    "\n",
    "    # bounds = [err0P / Pvar[0]]\n",
    "    bounds = [(err0P / x_Pnorm[0]).detach().cpu().item()]\n",
    "    true_errs = [bounds[0]]\n",
    "\n",
    "    print(err0P)\n",
    "\n",
    "    for k in range(1, CL_mats.shape[0] + 1):\n",
    "        print(f\"layer {k}\")\n",
    "\n",
    "        # induced P-norm of closed-loop linearization\n",
    "        print(f\"op norm: {op_norm(CL_mats[k-1], P[k-1], P[k])}\")\n",
    "\n",
    "        # === NEW: quadratic remainder contribution ===\n",
    "        # theorem uses (L_i / 2) * ||δz_i||^2\n",
    "        quad_errors = 0.5 * L_hat[:k] * (\n",
    "            th.tensor(rem_data[\"dz_Pnorm\"][:k], device=device) ** 2\n",
    "        )\n",
    "\n",
    "        bound = bound_at_k(\n",
    "            k,\n",
    "            err0,\n",
    "            quad_errors,   # <-- THIS is the key change\n",
    "            CL_mats,\n",
    "            P\n",
    "        )\n",
    "\n",
    "        print(f\"bound: {bound}\")\n",
    "\n",
    "        # true error\n",
    "        err = X_cl[k] - X_nom[k]\n",
    "        errP = th.sqrt(err.T @ P[k] @ err).item()\n",
    "        print(f\"true error: {errP}\")\n",
    "\n",
    "        # bounds.append(bound.cpu().item() / Pvar[k])\n",
    "        # true_errs.append(errP / Pvar[k])\n",
    "\n",
    "        # print(f\"normalization factor: {Pvar[k]}\")\n",
    "        bounds.append((bound.cpu().item() / x_Pnorm[k]).cpu().item())\n",
    "        true_errs.append((errP / x_Pnorm[k]).cpu().item())\n",
    "\n",
    "        print(f\"normalization factor: {x_Pnorm[k]}\")\n",
    "    all_bounds.append(bounds)\n",
    "    all_errs.append(true_errs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1479301",
   "metadata": {},
   "outputs": [],
   "source": [
    "P = P_cpu.to(device)\n",
    "# all_variances = [th.diag(th.cov(acts_new[:,i,:].T)).to(device) for i in range(acts_new.shape[1])]\n",
    "# Pvar = [th.sqrt(v.T @ P[i] @ v).item() for i, v in enumerate(all_variances)]\n",
    "Pvar = [th.sqrt(th.trace(P[i] @ th.cov(acts_new[:,i,:].T).to(device))).item() for i in range(acts_new.shape[1])]\n",
    "\n",
    "all_bounds = []\n",
    "all_errs = []\n",
    "X_nom = dat[\"X_nom\"].to(device)\n",
    "\n",
    "for i in range(len(noms)):\n",
    "    dat=noms[i]\n",
    "    X_cl = dat[\"X_cl\"].to(device)\n",
    "\n",
    "    running_product = 1\n",
    "\n",
    "    err = (X_cl[0] - X_nom[0])\n",
    "    bounds = [th.sqrt(err.T @ P[0] @ err).item()/Pvar[0]]\n",
    "    true_errs = [bounds[0]]\n",
    "    print(bounds)\n",
    "    # for i in range(CL_mats.shape[0]):\n",
    "    for i in range(1,CL_mats.shape[0]+1):\n",
    "        print(f\"layer {i}\")\n",
    "        print(f\"op norm: {op_norm(CL_mats[i-1], P[i-1], P[i])}\")\n",
    "\n",
    "        # bound = bound_at_k(i, (acts_new[0][0]-X_nom[0]).to(device), errors, CL_mats.to(device), P.to(device))\n",
    "        bound = bound_at_k(i, (X_cl[0]-X_nom[0]), errors, CL_mats, P)\n",
    "        print(f\"bound: {bound}\")\n",
    "\n",
    "        err = (X_cl[i] - X_nom[i])\n",
    "        errP = th.sqrt(err.T @ P[i] @ err).item()\n",
    "\n",
    "        print(f\"true error: {errP}\")\n",
    "\n",
    "        bounds.append(bound.cpu() / Pvar[i])\n",
    "        true_errs.append(errP / Pvar[i])\n",
    "        print(f\"normalization factor: {Pvar[i]}\")\n",
    "        # bounds.append(bound.cpu() / errors[i-1])\n",
    "        # true_errs.append(errP / errors[i-1])\n",
    "        # print(f\"normalization factor: {errors[i-1]}\")\n",
    "    all_bounds.append(bounds)\n",
    "    all_errs.append(true_errs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62aca611",
   "metadata": {},
   "outputs": [],
   "source": [
    "bounds = np.array(all_bounds)\n",
    "errs = np.array(all_errs)\n",
    "\n",
    "\n",
    "np.savez(\n",
    "    \"gemma_lin_data.npz\",\n",
    "    all_bounds=bounds,\n",
    "    all_errs=errs,\n",
    ")\n",
    "\n",
    "# data = np.load(\"gemma_lin_data.npz\")\n",
    "\n",
    "# bounds = data[\"all_bounds\"]\n",
    "# errs = data[\"all_errs\"]\n",
    "\n",
    "print(\"bounds shape\", bounds.shape)\n",
    "print(\"errs shape\", errs.shape)\n",
    "\n",
    "# Time axis\n",
    "x = np.arange(bounds.shape[1])\n",
    "\n",
    "# --- Select the bound with the largest initial value ---\n",
    "# initial value = value at t = 0\n",
    "idx = np.argmax(bounds[:, 0])\n",
    "worst_bound = bounds[idx]\n",
    "# worst_bound = np.max(bounds, axis=0)\n",
    "\n",
    "# --- Compute min/max envelope of errors ---\n",
    "err_min = errs.min(axis=0)\n",
    "err_max = errs.max(axis=0)\n",
    "\n",
    "# --- Plot ---\n",
    "plt.figure()\n",
    "\n",
    "for i, b in enumerate(bounds):\n",
    "    plt.plot(\n",
    "        x,\n",
    "        b,\n",
    "        color=\"black\",\n",
    "        alpha=0.3,\n",
    "        linewidth=1,\n",
    "        label=\"Bounds\" if i == 0 else None  # avoid legend spam\n",
    "    )\n",
    "\n",
    "# Error envelope\n",
    "plt.fill_between(\n",
    "    x,\n",
    "    err_min,\n",
    "    err_max,\n",
    "    color=\"red\",\n",
    "    alpha=0.3,\n",
    "    label=\"Tracking Error\"\n",
    ")\n",
    "\n",
    "# Optional: outline the envelope\n",
    "plt.plot(x, err_min, color=\"red\", linewidth=1, label=\"lower bound\")\n",
    "plt.plot(x, err_max, color=\"blue\", linewidth=1, label=\"upper bound\")\n",
    "\n",
    "# Worst-case bound\n",
    "plt.plot(\n",
    "    x,\n",
    "    worst_bound,\n",
    "    color=\"black\",\n",
    "    linewidth=2,\n",
    "    label=\"Worst-case bound\"\n",
    ")\n",
    "\n",
    "plt.xlabel(\"Model Layer\")\n",
    "plt.ylabel(\"Normalized Error (Lyapunov Norm)\")\n",
    "plt.title(\"Worst-Case Bound vs Error Envelope\")\n",
    "\n",
    "# plt.savefig(\"gemma2-2b-err-bound_range.png\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
