{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60f46d29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import transformers\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
    "import lqr_utils_seq as lqr\n",
    "from functools import partial\n",
    "from datasets import load_dataset\n",
    "import random\n",
    "import pickle\n",
    "import time\n",
    "from steering import LQRSteering\n",
    "from data_scripts_and_utils.data_handling import ContrastiveBuilder\n",
    "import yaml\n",
    "import random\n",
    "import json\n",
    "from linearization import compute_lin_err as OLcompute_lin_err\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import tox_data_script as utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba12a46b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('config/config.yaml', 'r') as f:\n",
    "    config_data = yaml.safe_load(f)\n",
    "PICKLE_JAR = config_data[\"environment\"][\"pickle_jar\"]\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"device: {device}\")\n",
    "\n",
    "\n",
    "def load_model(model_name, quant=False):\n",
    "\n",
    "    if quant:\n",
    "        quant_config = BitsAndBytesConfig(\n",
    "            load_in_4bit=True,          # or load_in_8bit=True\n",
    "            # load_in_8bit=True,\n",
    "            bnb_4bit_compute_dtype=torch.float16,\n",
    "            bnb_4bit_quant_type=\"nf4\",  # best for LLMs\n",
    "            bnb_4bit_use_double_quant=True,\n",
    "        )\n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            model_name, quantization_config=quant_config, dtype=torch.float32, device_map=\"auto\")\n",
    "        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\")\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "        tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "    else: \n",
    "        model = AutoModelForCausalLM.from_pretrained(\n",
    "            model_name).to(device)\n",
    "        tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "    return model, tokenizer\n",
    "        \n",
    "\n",
    "def get_safe_prompts():\n",
    "    dataset = load_dataset(\"tatsu-lab/alpaca\")\n",
    "    return dataset['train'][:][\"instruction\"]\n",
    "\n",
    "def load_file(filename):\n",
    "    try:\n",
    "        with open(PICKLE_JAR + filename + \".pkl\", \"rb\") as f:\n",
    "            return pickle.load(f)\n",
    "    except FileNotFoundError:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c084104",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"google/gemma-2-2b\"\n",
    "# model_name = \"Qwen/Qwen2.5-3B\"\n",
    "# model_name = \"meta-llama/Meta-Llama-3-8B\"\n",
    "\n",
    "# model_name = \"google/gemma-2-9b-it\"\n",
    "# model_name = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "\n",
    "\n",
    "model, tokenizer = load_model(model_name, quant=True)\n",
    "\n",
    "dataguy = ContrastiveBuilder(model, tokenizer)\n",
    "toxic_prompts = utils.get_tox_prompts(0.8, 0.9)\n",
    "\n",
    "acts, As = dataguy.collect_acts_and_jacs(prompts=toxic_prompts, num_samples=10, filename=\"qwen-acts-jacs\")\n",
    "\n",
    "print(As.shape)\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ba89736",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers = 20  # how many layers you want to plot\n",
    "\n",
    "fig, axes = plt.subplots(1, num_layers, figsize=(4*num_layers, 4))  # 1 row, num_layers columns\n",
    "from scipy.linalg import subspace_angles\n",
    "k=150\n",
    "for i in range(0, num_layers):\n",
    "    A_np = As[0][i].detach().cpu().numpy()\n",
    "    B_np = As[1][i].detach().cpu().numpy()\n",
    "    C_np = As[2][i].detach().cpu().numpy()\n",
    "\n",
    "    U0 = np.linalg.svd(A_np, full_matrices=False)[0][:, :k]\n",
    "    U1 = np.linalg.svd(B_np, full_matrices=False)[0][:, :k]\n",
    "    U2 = np.linalg.svd(C_np, full_matrices=False)[0][:, :k]\n",
    "\n",
    "    angles = subspace_angles(U0, U1)\n",
    "    angles1 = subspace_angles(U2, U1)\n",
    "    angles2 = subspace_angles(U0, U2)\n",
    "\n",
    "\n",
    "    # n=10\n",
    "    # print(f\"angles 0: {np.mean(np.cos(angles))}\")\n",
    "    # print(f\"angles 1: {np.mean(np.cos(angles1))}\")\n",
    "    # print(f\"angles 2: {np.mean(np.cos(angles2))}\")\n",
    "# \n",
    "    n=10\n",
    "    print(f\"layer: {i}\")\n",
    "    print(f\"angles 0: {np.mean(np.cos(np.partition(angles,n)[:n]))}\")\n",
    "    print(f\"angles 1: {np.mean(np.cos(np.partition(angles1,n)[:n]))}\")\n",
    "    print(f\"angles 2: {np.mean(np.cos(np.partition(angles2,n)[:n]))}\\n\")\n",
    "\n",
    "    # Compute singular values\n",
    "    sA = np.linalg.svd(A_np, compute_uv=False)\n",
    "    sB = np.linalg.svd(B_np, compute_uv=False)\n",
    "    sC = np.linalg.svd(C_np, compute_uv=False)\n",
    "    max_s = max(sA.max(), sB.max(), sC.max())\n",
    "\n",
    "    ax = axes[i-10]  # current subplot\n",
    "    ax.semilogy(sA / max_s, label=\"0\")\n",
    "    ax.semilogy(sB / max_s, label=\"1\")\n",
    "    ax.semilogy(sC / max_s, label=\"2\")\n",
    "    ax.set_title(f\"Layer {i}\")\n",
    "    ax.set_xlabel(\"Mode index\")\n",
    "    ax.set_ylabel(\"Normalized singular value\")\n",
    "    ax.legend()\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"qwen-layer-0-semilogy.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7827497",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import numpy as np\n",
    "# from scipy.linalg import subspace_angles\n",
    "\n",
    "# As: list of 3 lists of PyTorch tensors, one per trajectory/set\n",
    "# num_layers: number of layers\n",
    "# Define search ranges\n",
    "\n",
    "num_layers = 5\n",
    "\n",
    "k_range = range(5, 201, 5)  # singular vectors to consider\n",
    "n_range = range(10, 50, 10)     # number of smallest angles to average\n",
    "\n",
    "best_score = -np.inf\n",
    "best_params = None\n",
    "\n",
    "# Iterate over layers\n",
    "for i in range(num_layers):\n",
    "    print(f\"Layer {i}\")\n",
    "    \n",
    "    A_np = As[0][i].detach().cpu().numpy()\n",
    "    B_np = As[1][i].detach().cpu().numpy()\n",
    "    C_np = As[2][i].detach().cpu().numpy()\n",
    "    \n",
    "    # Precompute full SVD for speed\n",
    "    U0_full = np.linalg.svd(A_np, full_matrices=False)[0]\n",
    "    U1_full = np.linalg.svd(B_np, full_matrices=False)[0]\n",
    "    U2_full = np.linalg.svd(C_np, full_matrices=False)[0]\n",
    "\n",
    "    layer_best_score = -np.inf\n",
    "    layer_best_params = None\n",
    "\n",
    "    # Sweep over k and n\n",
    "    for k in k_range:\n",
    "        if k > U0_full.shape[1] or k > U1_full.shape[1] or k > U2_full.shape[1]:\n",
    "            continue  # skip if k too large for this layer\n",
    "\n",
    "        U0 = U0_full[:, :k]\n",
    "        U1 = U1_full[:, :k]\n",
    "        U2 = U2_full[:, :k]\n",
    "\n",
    "        angles_01 = subspace_angles(U0, U1)\n",
    "        angles_12 = subspace_angles(U1, U2)\n",
    "        angles_02 = subspace_angles(U0, U2)\n",
    "\n",
    "        for n in n_range:\n",
    "            if n >= k:\n",
    "                continue  # can't take more angles than vectors\n",
    "\n",
    "            score = (\n",
    "                np.mean(np.cos(np.partition(angles_01, n)[:n])) +\n",
    "                np.mean(np.cos(np.partition(angles_12, n)[:n])) +\n",
    "                np.mean(np.cos(np.partition(angles_02, n)[:n]))\n",
    "            ) / 3  # average across pairs\n",
    "\n",
    "            # maximize score; break ties by larger k, then larger n\n",
    "            if score > layer_best_score or (score == layer_best_score and (k, n) > layer_best_params):\n",
    "                layer_best_score = score\n",
    "                layer_best_params = (k, n)\n",
    "\n",
    "    best_score = max(best_score, layer_best_score)\n",
    "    print(f\"Best score for layer {i}: {layer_best_score:.4f} at k={layer_best_params[0]}, n={layer_best_params[1]}\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae8eac1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import numpy as np\n",
    "# from scipy.linalg import subspace_angles\n",
    "\n",
    "# As: list of 3 lists of PyTorch tensors, one per trajectory/set\n",
    "# num_layers: number of layers\n",
    "# Set k = max number of singular vectors to consider (top modes)\n",
    "k_max = 200\n",
    "\n",
    "for i in range(num_layers):\n",
    "    print(f\"Layer {i}\")\n",
    "    \n",
    "    # Convert tensors to numpy\n",
    "    A_np = As[0][i].detach().cpu().numpy()\n",
    "    B_np = As[1][i].detach().cpu().numpy()\n",
    "    C_np = As[2][i].detach().cpu().numpy()\n",
    "\n",
    "    A_m = (A_np + B_np + C_np) / 3\n",
    "\n",
    "    # Full SVD\n",
    "    UA, sA, _ = np.linalg.svd(A_np, full_matrices=False)\n",
    "    UB, sB, _ = np.linalg.svd(B_np, full_matrices=False)\n",
    "    UC, sC, _ = np.linalg.svd(C_np, full_matrices=False)\n",
    "    UAm, sAm, _ = np.linalg.svd(A_m, full_matrices=False)\n",
    "\n",
    "    # Keep top k modes\n",
    "    k = min(k_max, UA.shape[1], UB.shape[1], UC.shape[1], UAm.shape[1])\n",
    "    UA = UA[:, :k]\n",
    "    UB = UB[:, :k]\n",
    "    UC = UC[:, :k]\n",
    "    UAm = UAm[:, :k]\n",
    "\n",
    "    sA = sA[:k]\n",
    "    sB = sB[:k]\n",
    "    sC = sC[:k]\n",
    "    sAm = sAm[:k]\n",
    "\n",
    "    # Compute subspace angles\n",
    "    angles_01 = subspace_angles(UA, UB)\n",
    "    angles_12 = subspace_angles(UB, UC)\n",
    "    angles_02 = subspace_angles(UA, UC)\n",
    "    \n",
    "    angles_m0 = subspace_angles(UAm, UA)\n",
    "    angles_m1 = subspace_angles(UAm, UB)\n",
    "    angles_m2 = subspace_angles(UAm, UC)\n",
    "\n",
    "    # Energy-weighted cosine similarity\n",
    "    # Normalize singular values to sum=1 to get relative energy\n",
    "    wA = sA / sA.sum()\n",
    "    wB = sB / sB.sum()\n",
    "    wC = sC / sC.sum()\n",
    "    wAm = sAm / sAm.sum()\n",
    "\n",
    "    def weighted_cos(angles, wX, wY):\n",
    "        # Take average of cosines weighted by geometric mean of mode energies\n",
    "        # Use min length in case kX != kY\n",
    "        k = min(len(wX), len(wY), len(angles))\n",
    "        weights = np.sqrt(wX[:k] * wY[:k])\n",
    "        return np.sum(weights * np.cos(angles[:k])) / np.sum(weights)\n",
    "\n",
    "    score_01 = weighted_cos(angles_01, wA, wB)\n",
    "    score_12 = weighted_cos(angles_12, wB, wC)\n",
    "    score_02 = weighted_cos(angles_02, wA, wC)\n",
    "    \n",
    "    score_m0 = weighted_cos(angles_m0, wAm, wA)\n",
    "    score_m1 = weighted_cos(angles_m1, wAm, wB)\n",
    "    score_m2 = weighted_cos(angles_m2, wAm, wC)\n",
    "\n",
    "    print(f\"Energy-weighted similarity (0-1): {score_01:.4f}\")\n",
    "    print(f\"Energy-weighted similarity (1-2): {score_12:.4f}\")\n",
    "    print(f\"Energy-weighted similarity (0-2): {score_02:.4f}\\n\")\n",
    "\n",
    "    print(f\"Energy-weighted similarity (m-0): {score_m0:.4f}\")\n",
    "    print(f\"Energy-weighted similarity (m-1): {score_m1:.4f}\")\n",
    "    print(f\"Energy-weighted similarity (m-2): {score_m2:.4f}\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9743db2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "k_max = 100\n",
    "print(k_max)\n",
    "\n",
    "num_matrices = len(As)  # number of trajectory linearizations\n",
    "num_layers = len(As[0])  # assuming all matrices have the same number of layers\n",
    "\n",
    "for i in range(num_layers):\n",
    "    print(f\"Layer {i}\")\n",
    "    \n",
    "    # Convert all tensors to numpy\n",
    "    mats = [A[i].detach().cpu().numpy() for A in As]\n",
    "    \n",
    "    # Compute mean Jacobian\n",
    "    A_m = sum(mats) / num_matrices\n",
    "\n",
    "    # Full SVDs\n",
    "    svd_results = [np.linalg.svd(mat, full_matrices=False) for mat in mats]\n",
    "    U_list = [svd[0] for svd in svd_results]\n",
    "    s_list = [svd[1] for svd in svd_results]\n",
    "\n",
    "\n",
    "    U_m, s_m, _ = np.linalg.svd(A_m, full_matrices=False)\n",
    "\n",
    "    # Determine k (top modes)\n",
    "    k = min([U.shape[1] for U in U_list] + [U_m.shape[1], k_max])\n",
    "    U_list = [U[:, :k] for U in U_list]\n",
    "    s_list = [s[:k] for s in s_list]\n",
    "    U_m = U_m[:, :k]\n",
    "    s_m = s_m[:k]\n",
    "\n",
    "    print(U_list[0].shape)\n",
    "\n",
    "    # Normalize singular values to get relative energy\n",
    "    w_list = [s / s.sum() for s in s_list]\n",
    "    w_m = s_m / s_m.sum()\n",
    "\n",
    "    # Compute subspace angles and energy-weighted cosine similarity\n",
    "    def weighted_cos(angles, wX, wY):\n",
    "        k = min(len(angles), len(wX), len(wY))\n",
    "        weights = np.sqrt(wX[:k] * wY[:k])\n",
    "        return np.sum(weights * np.cos(angles[:k])) / np.sum(weights)\n",
    "\n",
    "    # Trajectory-to-trajectory similarities\n",
    "    # for idx1 in range(num_matrices):\n",
    "    #     for idx2 in range(idx1 + 1, num_matrices):\n",
    "    #         angles = subspace_angles(U_list[idx1], U_list[idx2])\n",
    "    #         score = weighted_cos(angles, w_list[idx1], w_list[idx2])\n",
    "    #         print(f\"Energy-weighted similarity ({idx1}-{idx2}): {score:.4f}\")\n",
    "\n",
    "    # Mean vs each trajectory\n",
    "    for idx, (U, w) in enumerate(zip(U_list, w_list)):\n",
    "        angles_m = subspace_angles(U_m, U)\n",
    "        score_m = weighted_cos(angles_m, w_m, w)\n",
    "        print(f\"Energy-weighted similarity (m-{idx}): {score_m:.4f}\")\n",
    "\n",
    "    print(\"\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bb05bc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "# Path to your text file\n",
    "filename = 'gemma-2-2b-100'\n",
    "file_path = 'jacobian_alignment/' + filename + '.txt'\n",
    "\n",
    "# List to hold similarities per layer\n",
    "similarities = []\n",
    "\n",
    "with open(file_path, 'r') as f:\n",
    "    layer_sims = []\n",
    "    for line in f:\n",
    "        line = line.strip()\n",
    "        if line.startswith(\"Layer\"):\n",
    "            # If this is not the first layer, append the previous layer's data\n",
    "            if layer_sims:\n",
    "                similarities.append(layer_sims)\n",
    "                layer_sims = []\n",
    "        elif \"Energy-weighted similarity\" in line:\n",
    "            # Extract the numeric value using regex\n",
    "            match = re.search(r\": ([0-9.]+)\", line)\n",
    "            if match:\n",
    "                value = float(match.group(1))\n",
    "                layer_sims.append(value)\n",
    "    # Append last layer\n",
    "    if layer_sims:\n",
    "        similarities.append(layer_sims)\n",
    "\n",
    "# Convert to numpy array if you want\n",
    "import numpy as np\n",
    "sim_array = np.array(similarities)  # shape: (num_layers, num_jacobians)\n",
    "\n",
    "print(sim_array.shape)\n",
    "print(sim_array[:2])  # print first 2 layers as check\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e962a64d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_orthonormal(n, k, rng=None):\n",
    "    \"\"\"\n",
    "    Generate an n x k matrix with orthonormal columns,\n",
    "    uniformly random (Haar measure on the Stiefel manifold).\n",
    "    \"\"\"\n",
    "    if rng is None:\n",
    "        rng = np.random.default_rng()\n",
    "\n",
    "    # X = rng.standard_normal((n, k))\n",
    "    # Q, R = np.linalg.qr(X)\n",
    "\n",
    "    # Fix sign ambiguity for reproducibility / true Haar\n",
    "    # signs = np.sign(np.diag(R))\n",
    "    # Q *= signs\n",
    "    X = np.random.randn(n, k) \n",
    "    X = X / np.linalg.norm(X, axis=0, keepdims=True)\n",
    "    return X\n",
    "    # return Q\n",
    "\n",
    "\n",
    "# for i in range(100):\n",
    "#     n = As.shape[-1]\n",
    "#     M1 = random_orthonormal(n, k_max)\n",
    "#     M2 = random_orthonormal(n, k_max)\n",
    "def weighted_cos(angles, wX, wY):\n",
    "    k = min(len(angles), len(wX), len(wY))\n",
    "    weights = np.sqrt(wX[:k] * wY[:k])\n",
    "    # print(weights.shape)\n",
    "    return np.sum(weights * np.cos(angles[:k])) / np.sum(weights)\n",
    "\n",
    "_, s_ref, _ = np.linalg.svd(A_m, full_matrices=False)\n",
    "s_ref = s_ref[:k]\n",
    "w_energy = s_ref / s_ref.sum()\n",
    "# w_energy = np.ones(A_m.shape[-1])/ A_m.shape[-1]\n",
    "\n",
    "S_rand_energy = []\n",
    "\n",
    "num_trials = 100\n",
    "n = As.shape[-1]\n",
    "print(w_energy)\n",
    "k=1\n",
    "from sklearn.metrics.pairwise import cosine_similarity \n",
    "for _ in range(num_trials):\n",
    "    M1 = random_orthonormal(n, k)\n",
    "    M2 = random_orthonormal(n, k)\n",
    "\n",
    "\n",
    "    # print(np.trace(M1.T @ M1))\n",
    "\n",
    "    angles = subspace_angles(M1, M2)\n",
    "    print(f\"iter{_} \\nsubspace\")\n",
    "    print(np.cos(angles.flatten().mean()))\n",
    "    print(\"\\ncosine sim\")\n",
    "    print(cosine_similarity(M1.T, M2.T).mean())\n",
    "    # score = weighted_cos(angles, w_energy, w_energy)\n",
    "    # S_rand_energy.append(score)\n",
    "\n",
    "S_rand_energy = np.array(S_rand_energy)\n",
    "\n",
    "print(f\"Energy-matched ⟨S_rand⟩ = {S_rand_energy.mean():.4f} ± {S_rand_energy.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "479308df",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# sim_array: shape (num_layers, num_jacobians)\n",
    "num_layers = sim_array.shape[0]\n",
    "\n",
    "# Compute statistics per layer\n",
    "# sim_array = (sim_array - S_rand_energy.mean()) / (1-S_rand_energy.mean())\n",
    "layer_mean = np.mean(sim_array, axis=1)\n",
    "layer_min = np.min(sim_array, axis=1)\n",
    "layer_max = np.max(sim_array, axis=1)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(12,6))\n",
    "plt.plot(range(num_layers), layer_mean, label='Mean similarity', color='blue', linewidth=2)\n",
    "plt.fill_between(range(num_layers), layer_min, layer_max, color='blue', alpha=0.2, label='Min-Max range')\n",
    "plt.axhline(y=0.1551, color=\"r\", linestyle=\"--\", label=\"Random baseline\")\n",
    "plt.xlim(0,25)\n",
    "\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Energy-weighted similarity with mean Jacobian\")\n",
    "plt.title(\"Layer-wise Similarity Across 10 Jacobians\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"jacobian_alignment/\" + filename + \".png\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6308ad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# sim_array: shape (num_layers, num_jacobians)\n",
    "num_layers = sim_array.shape[0]\n",
    "\n",
    "# Compute statistics per layer\n",
    "sim_array_normed = (sim_array - S_rand_energy.mean()) / (1-S_rand_energy.mean())\n",
    "layer_mean = np.mean(sim_array_normed, axis=1)\n",
    "print(layer_mean)\n",
    "layer_min = np.min(sim_array_normed, axis=1)\n",
    "layer_max = np.max(sim_array_normed, axis=1)\n",
    "\n",
    "# Plot\n",
    "plt.figure(figsize=(12,6))\n",
    "plt.plot(range(num_layers), layer_mean, label='Mean similarity', color='blue', linewidth=2)\n",
    "plt.fill_between(range(num_layers), layer_min, layer_max, color='blue', alpha=0.2, label='Min-Max range')\n",
    "\n",
    "plt.axhline(y=0.08923, color=\"r\", linestyle=\"--\", label=\"Lorenz baseline\")\n",
    "\n",
    "plt.xlim(0,25)\n",
    "plt.ylim(0,1)\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Energy-weighted similarity with mean Jacobian\")\n",
    "plt.title(\"Layer-wise Similarity Across 10 Jacobians\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"jacobian_alignment/\" + filename + \".png\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a37cc7fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------\n",
    "# Lorenz-96 system\n",
    "# ----------------------------\n",
    "from scipy.integrate import solve_ivp\n",
    "def lorenz96(t, x, F=8.0):\n",
    "    N = len(x)\n",
    "    dx = np.zeros_like(x)\n",
    "    for i in range(N):\n",
    "        dx[i] = (x[(i+1) % N] - x[i-2]) * x[i-1] - x[i] + F\n",
    "    return dx\n",
    "\n",
    "\n",
    "def lorenz96_jacobian(x):\n",
    "    \"\"\"\n",
    "    Analytic Jacobian of Lorenz-96 at state x\n",
    "    \"\"\"\n",
    "    N = len(x)\n",
    "    J = np.zeros((N, N))\n",
    "\n",
    "    for i in range(N):\n",
    "        J[i, i] = -1\n",
    "        J[i, (i-1) % N] = x[(i+1) % N] - x[i-2]\n",
    "        J[i, (i+1) % N] = x[i-1]\n",
    "        J[i, i-2] = -x[i-1]\n",
    "\n",
    "    return J\n",
    "\n",
    "N = 500          # dimension (scale this up)\n",
    "F = 8.0          # chaotic regime\n",
    "T = 40.0\n",
    "dt = 0.05\n",
    "\n",
    "rng = np.random.default_rng(0)\n",
    "x0 = rng.standard_normal(N)\n",
    "\n",
    "t_eval = np.arange(0, T, dt)\n",
    "\n",
    "sol = solve_ivp(\n",
    "    lorenz96,\n",
    "    (0, T),\n",
    "    x0,\n",
    "    t_eval=t_eval,\n",
    "    args=(F,),\n",
    "    rtol=1e-8,\n",
    "    atol=1e-8\n",
    ")\n",
    "\n",
    "trajectory = sol.y.T\n",
    "\n",
    "Js = np.array([lorenz96_jacobian(x) for x in trajectory])\n",
    "\n",
    "# Mean Jacobian\n",
    "J_mean = Js.mean(axis=0)\n",
    "\n",
    "k = 50  # number of modes (try 20, 50, 100)\n",
    "\n",
    "def svd_modes(J, k):\n",
    "    U, s, _ = np.linalg.svd(J, full_matrices=False)\n",
    "    return U[:, :k], s[:k]\n",
    "\n",
    "def weighted_cos(angles, wX, wY):\n",
    "    k = min(len(angles), len(wX), len(wY))\n",
    "    w = np.sqrt(wX[:k] * wY[:k])\n",
    "    return np.sum(w * np.cos(angles[:k])) / np.sum(w)\n",
    "\n",
    "U_m, s_m = svd_modes(J_mean, k)\n",
    "w_energy = s_m / s_m.sum()\n",
    "\n",
    "scores = []\n",
    "\n",
    "for J in Js[::10]:  # subsample to reduce correlation\n",
    "    U, s = svd_modes(J, k)\n",
    "    angles = subspace_angles(U_m, U)\n",
    "    score = weighted_cos(angles, w_energy, w_energy)\n",
    "    scores.append(score)\n",
    "\n",
    "scores = np.array(scores)\n",
    "\n",
    "print(f\"Lorenz-96 chaotic baseline:\")\n",
    "print(f\"mean similarity = {scores.mean():.3f}\")\n",
    "print(f\"std            = {scores.std():.3f}\")\n",
    "print(f\"min / max      = {scores.min():.3f} / {scores.max():.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23c6aceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "n=N\n",
    "k=k\n",
    "S_rand_lor = []\n",
    "for _ in range(num_trials):\n",
    "    M1 = random_orthonormal(n, k)\n",
    "    M2 = random_orthonormal(n, k)\n",
    "\n",
    "    angles = subspace_angles(M1, M2)\n",
    "    score = weighted_cos(angles, w_energy, w_energy)\n",
    "    S_rand_lor.append(score)\n",
    "\n",
    "S_rand_lor = np.array(S_rand_lor)\n",
    "\n",
    "print(f\"Energy-matched ⟨S_rand⟩ = {S_rand_lor.mean():.4f} ± {S_rand_lor.std():.4f}\")\n",
    "\n",
    "print(f\"normalized for random chance: {(scores.mean() - S_rand_lor.mean()) / (1-S_rand_lor.mean())}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
