{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the following we call log-odds $M(z)$ as margin."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "af81847f"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import math\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from typing import Optional, Dict, Any, Tuple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mxHtvUumJvTH"
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class modelConfig:\n",
    "  d_embd: int = 128\n",
    "\n",
    "@dataclass\n",
    "class dataConfig:\n",
    "  V: int = 4\n",
    "  m: int = 128\n",
    "  groups: int = 8\n",
    "  min_support: int = 3\n",
    "  max_support: int = 6\n",
    "\n",
    "\n",
    "DATAconfig = dataConfig()\n",
    "MODELconfig = modelConfig()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "aTaeY5N1LwfN"
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class SUCConfig:\n",
    "    V: int                 # vocabulary size\n",
    "    m: int                 # number of distinct contexts\n",
    "    G: int                 # number of concept groups (must divide V)\n",
    "    epsilon: float = 0.05  # smoothing mass on the complement\n",
    "    pi_uniform_over_groups: bool = False  # else uniform over contexts\n",
    "    device: str = \"cpu\"\n",
    "    dtype: torch.dtype = torch.float32\n",
    "\n",
    "def build_smoothed_uniform_concept_dataset(\n",
    "    cfg: SUCConfig,\n",
    "    *,\n",
    "    return_groups: bool = True,\n",
    ") -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:\n",
    "    \"\"\"\n",
    "    Build the smoothed-uniform-concept dataset in PyTorch:\n",
    "\n",
    "      - Tokens partitioned into G disjoint groups of size s = V/G\n",
    "        group g: {g*s, ..., g*s + s - 1}\n",
    "      - Contexts partitioned as evenly as possible across the same G concepts\n",
    "      - For context j with concept g(j), set:\n",
    "            p_j(z) = (1 - eps)/s          if token z in group g(j)\n",
    "                     eps / ((G - 1) * s)  otherwise\n",
    "        (If G == 1 -> uniform over all tokens.)\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    P  : (V, m) float tensor (columns sum to 1)\n",
    "    pi : (m,)    float tensor (context frequencies)\n",
    "    meta : dict  with 'token_groups' (V,), 'context_groups' (m,), 's' (int)\n",
    "    \"\"\"\n",
    "    V, m, G = int(cfg.V), int(cfg.m), int(cfg.G)\n",
    "    eps = float(cfg.epsilon)\n",
    "    device, dtype = cfg.device, cfg.dtype\n",
    "\n",
    "    assert V >= 1 and m >= 1 and G >= 1\n",
    "    assert V % G == 0, \"G must divide V so that s = V/G is an integer.\"\n",
    "    s = V // G\n",
    "    assert 0.0 <= eps < 1.0\n",
    "\n",
    "    token_groups = torch.arange(V, device=device) // s  # (V,)\n",
    "\n",
    "    base = m // G\n",
    "    rem = m % G\n",
    "    context_groups = []\n",
    "    for g in range(G):\n",
    "        context_groups += [g] * (base + (1 if g < rem else 0))\n",
    "    context_groups = torch.tensor(context_groups, device=device)  # (m,)\n",
    "\n",
    "    P = torch.empty(V, m, device=device, dtype=dtype)\n",
    "\n",
    "    if G == 1:\n",
    "        P.fill_(1.0 / V)\n",
    "    else:\n",
    "        P.fill_(eps / ((G - 1) * s))\n",
    "        for j in range(m):\n",
    "            g = context_groups[j].item()\n",
    "            in_group = (token_groups == g)\n",
    "            P[in_group, j] = (1.0 - eps) / s\n",
    "\n",
    "    P /= P.sum(dim=0, keepdim=True)\n",
    "\n",
    "    if cfg.pi_uniform_over_groups:\n",
    "        pi = torch.zeros(m, device=device, dtype=dtype)\n",
    "        idx = 0\n",
    "        for g in range(G):\n",
    "            cnt = base + (1 if g < rem else 0)\n",
    "            if cnt > 0:\n",
    "                pi[idx:idx+cnt] = 1.0 / G / cnt\n",
    "            idx += cnt\n",
    "    else:\n",
    "        pi = torch.full((m,), 1.0 / m, device=device, dtype=dtype)\n",
    "\n",
    "    meta_infos = {\n",
    "        \"token_groups\": token_groups,   # (V,)\n",
    "        \"context_groups\": context_groups,  # (m,)\n",
    "        \"s\": s,\n",
    "        \"G\": G,\n",
    "    } if return_groups else {}\n",
    "\n",
    "    return P, pi, meta_infos\n",
    "\n",
    "\n",
    "\n",
    "cfg_data = SUCConfig(V=9, m=10, G=3)\n",
    "P, pi, meta = build_smoothed_uniform_concept_dataset(cfg_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-5Tl985vJ6yQ"
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class modelConfig:\n",
    "  d_embd: int\n",
    "\n",
    "\n",
    "class logBilinearModel(nn.Module):\n",
    "  \"\"\" Defined as a container for the log-bilinear model parameters \"\"\"\n",
    "  def __init__(self, config_model, config_data):\n",
    "    super(logBilinearModel, self).__init__()\n",
    "\n",
    "    self.config_model = config_model\n",
    "    self.config_data = config_data\n",
    "\n",
    "    self.H = nn.Parameter(nn.Parameter(torch.randn(config_model.d_embd, config_data.m)/math.sqrt(config_model.d_embd)))\n",
    "    self.W = nn.Parameter(nn.Parameter(torch.randn(config_data.V, config_model.d_embd)/math.sqrt(config_data.V)))\n",
    "\n",
    "  def forward(self, context_idx):\n",
    "    \"\"\"Only used for inference and sampling, not for training\"\"\"\n",
    "    inputs = torch.zeros((context_idx.shape[0], self.config_data.m))\n",
    "    inputs[context_idx] = 1\n",
    "    return self.W @ self.H @ inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Kz3OVVdBPZ6D"
   },
   "outputs": [],
   "source": [
    "def cross_entropy_softlabels(Q, P, pi):\n",
    "    return -(pi * torch.sum(P * torch.log(Q), dim=0)).sum()\n",
    "\n",
    "def empirical_entropy(P, pi):\n",
    "    return -(pi * torch.sum(P * torch.log(P), dim=0)).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DAnbsWuOOttM",
    "outputId": "2bcb5442-7b84-43fa-972f-7c781a8d02aa"
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class hyperparametersConfig:\n",
    "    max_iters: int = 5000\n",
    "    learning_rate: float = 0.1\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "cfg_model = modelConfig(d_embd = cfg_data.V)\n",
    "model = logBilinearModel(cfg_model, cfg_data).to(device)\n",
    "\n",
    "config_hparams = hyperparametersConfig()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=config_hparams.learning_rate)\n",
    "\n",
    "P, pi = P.to(device), pi.to(device)\n",
    "entropy = empirical_entropy(P, pi)\n",
    "print(f\"entropy: {entropy}\")\n",
    "\n",
    "pbar = tqdm(range(config_hparams.max_iters), desc=\"training\", leave=True)\n",
    "losses = list()\n",
    "for iter in pbar:\n",
    "\n",
    "    L = model.W @ model.H\n",
    "    Q = torch.softmax(L, dim=0)\n",
    "    loss = cross_entropy_softlabels(Q, P, pi)\n",
    "\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "    pbar.set_description(f\"training, loss: {loss:.4f}\")\n",
    "    losses.append(loss.item())\n",
    "    if torch.abs(loss - entropy) <= 1e-4:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5RLzOICKVtpQ",
    "outputId": "a8da880e-bc92-4fd6-f90b-4e71675c790b"
   },
   "outputs": [],
   "source": [
    "L = model.W @ model.H\n",
    "Q = torch.softmax(L, dim=0)\n",
    "error = 0.005\n",
    "print((torch.abs(Q - P) <= error).all())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zUXLgl5CWVAS"
   },
   "outputs": [],
   "source": [
    "def get_concepts_and_randoms_and_steering(meta, cfg, concept_idx = 0, n=1):\n",
    "  \"\"\" steering using unique contexts \"\"\"\n",
    "  J = torch.where(meta[\"context_groups\"] == concept_idx)[0].tolist()\n",
    "  K = torch.where(meta[\"context_groups\"] != concept_idx)[0].tolist()\n",
    "  concept_sampled_list = random.sample(J, n)\n",
    "  random_sampled_list = random.sample(K, n)\n",
    "\n",
    "  concept_one_hot = torch.nn.functional.one_hot(torch.tensor(concept_sampled_list), num_classes=cfg.m)\n",
    "  random_one_hot = torch.nn.functional.one_hot(torch.tensor(random_sampled_list), num_classes=cfg.m)\n",
    "  steering_vector = 1/n * (torch.sum(concept_one_hot, dim = 0) - torch.sum(random_one_hot, dim = 0))\n",
    "\n",
    "  return concept_sampled_list, random_sampled_list, steering_vector\n",
    "\n",
    "concept_idx = 0\n",
    "concept_sampled_list, random_sampled_list, steering_vector = get_concepts_and_randoms_and_steering(meta, cfg_data, concept_idx=concept_idx, n=2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "nsrS2GWfH8dh",
    "outputId": "47e6683d-f0d9-452a-e3f3-634fadad0532"
   },
   "outputs": [],
   "source": [
    "steering_vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ubVpvuL1zu38"
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def get_margins(q, C, R):\n",
    "  n = len(C)\n",
    "  margins = torch.prod(q[:, C], dim=1) / torch.prod(q[:, R], dim=1)\n",
    "  margins = 1/n * torch.log(margins)\n",
    "  return margins\n",
    "\n",
    "@torch.no_grad()\n",
    "def max_min_margins_diff(M):\n",
    "  D = M.unsqueeze(1) - M.unsqueeze(0)\n",
    "  D.fill_diagonal_(float('-inf'))\n",
    "  maxMargin, _ = D.max(dim=1)\n",
    "  D.fill_diagonal_(float('inf'))\n",
    "  minMargin, _ = D.min(dim=1)\n",
    "\n",
    "  return minMargin, maxMargin\n",
    "\n",
    "@torch.no_grad()\n",
    "def compute_mean_margin_from_q(M: torch.Tensor, q: torch.Tensor) -> torch.Tensor:\n",
    "    V, m = q.shape\n",
    "    M_col = M.view(V, 1)                           # (V,1)\n",
    "    sumM_per_j = (M_col * q).sum(dim=0, keepdims=True)  # (1,m) = ∑_u M(u) q_{u,j}\n",
    "    Mq = M_col * q                                  # (V,m) = M(z) q_{z,j}\n",
    "    denom = (1.0 - q).clamp_min(1e-45)              # (V,m) = 1 - q_{z,j}\n",
    "    E_M_comp = (sumM_per_j - Mq) / denom            # = E_{u≠z}[M(u)] under p_{j,z}\n",
    "    mu = M_col - E_M_comp                           # μ_{z,j} = M(z) - E[M(u)]\n",
    "    return mu\n",
    "\n",
    "alpha =1\n",
    "margins = get_margins(Q, concept_sampled_list, random_sampled_list)\n",
    "minMargin, maxMargin = max_min_margins_diff(margins)\n",
    "meanMargin = compute_mean_margin_from_q(margins, Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9rBU5RcYKyq4"
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def actual_gap(model, steering_vector: torch.Tensor, alpha: float) -> torch.Tensor:\n",
    "    L = model.W @ model.H\n",
    "\n",
    "    V, m = L.shape\n",
    "    Q_0 = torch.softmax(L, dim=0)\n",
    "    D = torch.eye(m)\n",
    "    Q_1 = torch.softmax(L @ (D + alpha * steering_vector.unsqueeze(-1)), dim=0)\n",
    "    return Q_1 - Q_0  # (V,m)\n",
    "\n",
    "gaps = actual_gap(model, steering_vector, alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "oQ_dbMo8LQao",
    "outputId": "49eacde9-c5dd-4d30-f253-b215fcb4eb0d"
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def mean_variance_margins_after_steering(model, steering_vector: torch.Tensor, margins: torch.Tensor, alpha: float) -> torch.Tensor:\n",
    "    L = model.W @ model.H\n",
    "\n",
    "    V, m = L.shape\n",
    "    D = torch.eye(m)\n",
    "    Q_steered = torch.softmax(L @ (D + alpha * steering_vector.unsqueeze(-1)), dim=0)\n",
    "\n",
    "    mean = torch.sum(Q_steered * margins.unsqueeze(-1), dim = 0)\n",
    "    var = torch.sum(Q_steered * (margins.unsqueeze(-1)- mean)**2, dim = 0)\n",
    "    return mean, var\n",
    "\n",
    "mean, var = mean_variance_margins_after_steering(model, steering_vector, margins, alpha)\n",
    "\n",
    "\n",
    "\n",
    "alphas = np.linspace(0, 10, 1000)\n",
    "means = list()\n",
    "vars = list()\n",
    "for alpha in alphas:\n",
    "  mean, var = mean_variance_margins_after_steering(model, steering_vector, margins, alpha)\n",
    "  means.append(mean)\n",
    "  vars.append(var)\n",
    "means = torch.stack(means)\n",
    "vars = torch.stack(vars)\n",
    "for context_idx in range(cfg_data.m):\n",
    "\n",
    "  plt.plot(alphas, means[:, context_idx], color = \"blue\", alpha = 0.5)\n",
    "  plt.plot(alphas, vars[:, context_idx], color = \"red\", alpha = 0.5)\n",
    "\n",
    "  plt.xlabel(\"alpha\")\n",
    "  plt.ylabel(\"gap\")\n",
    "\n",
    "  plt.title(f\"Actual steering gaps for context {context_idx+1}, contains concept: {meta['context_groups'][context_idx] == concept_idx}\")\n",
    "  plt.grid()\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Do_NaiW-OtRe"
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def actual_gap_crossentropy(model, steering_vector: torch.Tensor, alpha: float) -> torch.Tensor:\n",
    "    L = model.W @ model.H\n",
    "\n",
    "    V, m = L.shape\n",
    "    Q_0 = torch.softmax(L, dim=0)\n",
    "    D = torch.eye(m)\n",
    "    Q_1 = torch.softmax(L @ (D + alpha * steering_vector.unsqueeze(-1)), dim=0)\n",
    "    loss_0 = cross_entropy_softlabels(Q_0, P, pi)\n",
    "    loss_1 = cross_entropy_softlabels(Q_1, P, pi)\n",
    "    return loss_1 - loss_0  # (V,m)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LVpKD7RsXnaJ"
   },
   "outputs": [],
   "source": [
    "def actual_gap_converged_model(P, steering_vector: torch.Tensor, alpha: float) -> torch.Tensor:\n",
    "    L = torch.log(P)\n",
    "\n",
    "    V, m = L.shape\n",
    "    Q_0 = torch.softmax(L, dim=0)\n",
    "    D = torch.eye(m)\n",
    "    Q_1 = torch.softmax(L @ (D + alpha * steering_vector.unsqueeze(-1)), dim=0)\n",
    "    return Q_1 - Q_0  # (V,m)\n",
    "\n",
    "gaps = actual_gap_converged_model(P, steering_vector, alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ucJHZhxqu0rj"
   },
   "outputs": [],
   "source": [
    "var_margins = torch.sum(P*(margins.unsqueeze(-1) - (margins.unsqueeze(-1) * P).sum(dim=0))**2, dim = 0)\n",
    "coef = torch.sum(var_margins * pi)/2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "ydF-mK7BKNxg",
    "outputId": "36e7dd98-4c32-4594-8af1-08ca948d3451"
   },
   "outputs": [],
   "source": [
    "alphas = np.linspace(-5, 5, 1000)\n",
    "for context_idx in range(cfg_data.m):\n",
    "  gaps_1 = list()\n",
    "  for alpha in alphas:\n",
    "    gaps_1.append(actual_gap(model, steering_vector, alpha)[:, context_idx])\n",
    "  gaps_1 = torch.stack(gaps_1)\n",
    "\n",
    "  means_concept_delta = torch.mean(gaps_1[:, :3], dim = 1)\n",
    "  means_nonconcept_delta = torch.mean(gaps_1[:, 3:], dim = 1)\n",
    "\n",
    "  plt.plot(alphas, means_concept_delta, color = \"blue\", alpha = 0.8, label = r\"concept $\\mathcal{C}$\", linewidth = 2.5)\n",
    "\n",
    "  plt.plot(alphas, means_nonconcept_delta, color = \"red\", alpha = 0.8, label = r\"off-concept $\\mathcal{C}^\\complement$\", linewidth = 2.5)\n",
    "\n",
    "\n",
    "  plt.xlabel(r\"$\\alpha$\", fontsize=15)\n",
    "  plt.ylabel(r\"$\\Delta p(\\cdot \\mid j, \\alpha)$\", fontsize=15)\n",
    "  min_y = gaps.min()\n",
    "  max_y = gaps.max()\n",
    "  eps = 0.1\n",
    "  plt.legend()\n",
    "  plt.grid()\n",
    "  plt.savefig(f\"mean_behavior_{context_idx}.pdf\")\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 392
    },
    "id": "kCsXEZH3Omhc",
    "outputId": "1fb88b2d-44bf-471c-9742-cc41009286b4"
   },
   "outputs": [],
   "source": [
    "alpha_range = 5\n",
    "alphas = np.linspace(-alpha_range, alpha_range, num = 1000)\n",
    "\n",
    "gaps = list()\n",
    "for alpha in alphas:\n",
    "  gaps.append(actual_gap_crossentropy(model, steering_vector, alpha))\n",
    "gaps = torch.stack(gaps)\n",
    "\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1,1, figsize=(6.5, 3.5))\n",
    "\n",
    "ax.plot(alphas, gaps, color = \"blue\", linewidth=4)\n",
    "ax.plot(alphas, coef*alphas**2, linestyle = \"dotted\", color = 'black', linewidth=4)\n",
    "plt.xlim(-1,1)\n",
    "plt.ylim(0,1)\n",
    "\n",
    "plt.xlabel(r\"$\\alpha$\", fontsize=20)\n",
    "plt.ylabel(r\"$\\Delta \\mathrm{CE}(\\alpha)$\", fontsize=20)\n",
    "plt.grid()\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"cross_entropy_convex.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "Mj0cNPsc4exx",
    "outputId": "80b07332-7261-4df0-ab93-b57bbf470391"
   },
   "outputs": [],
   "source": [
    "alphas = np.linspace(-5, 5, 1000)\n",
    "for context_idx in range(cfg_data.m):\n",
    "  gaps_2 = list()\n",
    "  for alpha in alphas:\n",
    "    gaps_2.append(actual_gap(model, steering_vector, alpha)[:, context_idx])\n",
    "  gaps_2 = torch.stack(gaps_2)\n",
    "\n",
    "  means_concept_delta = torch.mean(gaps_2[:, -3:], dim = 1)\n",
    "\n",
    "  plt.plot(alphas, means_concept_delta, color = \"blue\", alpha = 0.8, label = r\"concept $\\mathcal{C}$\", linewidth = 2.5)\n",
    "\n",
    "\n",
    "\n",
    "  plt.xlabel(r\"$\\alpha$\", fontsize=15)\n",
    "  plt.ylabel(r\"$\\Delta p(\\cdot \\mid j, \\alpha)$\", fontsize=15)\n",
    "  min_y = gaps.min()\n",
    "  max_y = gaps.max()\n",
    "  eps = 0.1\n",
    "  plt.legend()\n",
    "  plt.grid()\n",
    "  plt.savefig(f\"mean_behavior_{context_idx}.pdf\")\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "LbEHpvRmFRaF",
    "outputId": "efa6aa81-6fa4-47ad-d2e3-d0b3e16b9dd7"
   },
   "outputs": [],
   "source": [
    "alphas = np.linspace(-5, 5, 1000)\n",
    "for context_idx in range(cfg_data.m):\n",
    "  gaps_3 = list()\n",
    "  for alpha in alphas:\n",
    "    gaps_3.append(actual_gap(model, steering_vector, alpha)[:, context_idx])\n",
    "  gaps_3 = torch.stack(gaps_3)\n",
    "\n",
    "  means_concept_delta = torch.mean(gaps_3[:, 3:6], dim = 1)\n",
    "\n",
    "  plt.plot(alphas, means_concept_delta, color = \"blue\", alpha = 0.8, label = r\"concept $\\mathcal{C}$\", linewidth = 2.5)\n",
    "\n",
    "\n",
    "\n",
    "  plt.xlabel(r\"$\\alpha$\", fontsize=15)\n",
    "  plt.ylabel(r\"$\\Delta p(\\cdot \\mid j, \\alpha)$\", fontsize=15)\n",
    "  min_y = gaps.min()\n",
    "  max_y = gaps.max()\n",
    "  eps = 0.1\n",
    "  plt.legend()\n",
    "  plt.grid()\n",
    "  plt.savefig(f\"mean_behavior_{context_idx}.pdf\")\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 463
    },
    "id": "GEl0qJk6FfyA",
    "outputId": "dfdf7770-fab7-4c76-faa2-aa6c099b761d"
   },
   "outputs": [],
   "source": [
    "import matplotlib.patheffects as PathEffects\n",
    "\n",
    "alphas = np.linspace(-5, 5, 1000)\n",
    "\n",
    "means_1 = torch.mean(gaps_1[:, :3], dim = 1)\n",
    "means_2 = torch.mean(gaps_2[:, -3:], dim = 1)\n",
    "means_3 = torch.mean(gaps_3[:, 3:6], dim = 1)\n",
    "\n",
    "plt.plot(alphas, means_1, color = \"blue\", alpha = 0.8, linewidth = 3.5)\n",
    "plt.plot(alphas, means_2, color = \"orange\", alpha = 0.8, linewidth = 3.5, linestyle = 'dotted')\n",
    "plt.plot(alphas, means_3, color = \"violet\", alpha = 0.8, linewidth = 3.5, linestyle = '--')\n",
    "plt.axhline(y=0.0, color='black', linestyle='-', alpha = 0.9, linewidth=1)\n",
    "\n",
    "shift = 2\n",
    "txt1 = plt.text(shift, means_1[-1].item()-0.05, r\"$\\Delta p(\\mathcal{T} \\mid \\alpha)$\", color=\"blue\", fontsize=14, ha='left', va='center')\n",
    "txt1.set_path_effects([PathEffects.withStroke(linewidth=4, foreground='white')])\n",
    "\n",
    "txt2 = plt.text(shift, means_2[-1].item()+0.05, r\"$\\Delta p(\\mathcal{C}' \\mid \\alpha)$\", color=\"darkgoldenrod\", fontsize=14, ha='left', va='center')\n",
    "txt2.set_path_effects([PathEffects.withStroke(linewidth=4, foreground='white')])\n",
    "\n",
    "txt3 = plt.text(shift, means_3[-1].item()+0.05, r\"$\\Delta p(\\mathcal{C} \\mid \\alpha)$\", color=\"purple\", fontsize=14, ha='left', va='center')\n",
    "txt3.set_path_effects([PathEffects.withStroke(linewidth=4, foreground='white')])\n",
    "\n",
    "plt.xlabel(r\"$\\alpha$\", fontsize=20)\n",
    "min_y = gaps.min()\n",
    "max_y = gaps.max()\n",
    "eps = 0.1\n",
    "plt.grid()\n",
    "plt.savefig(f\"mean_behavior_all.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 211
    },
    "id": "N03rVxqoIZAd",
    "outputId": "63056f4f-fc66-4a88-f448-ba12cb0216d7"
   },
   "outputs": [],
   "source": [
    "import matplotlib.patheffects as PathEffects\n",
    "\n",
    "alphas = np.linspace(0, 60, 1000)\n",
    "\n",
    "for context_idx in range(cfg_data.m):\n",
    "  gaps = list()\n",
    "  bounds = list()\n",
    "\n",
    "  for alpha in alphas:\n",
    "    gaps.append(actual_gap(model, steering_vector, alpha)[:, context_idx])\n",
    "\n",
    "  gaps = torch.stack(gaps)\n",
    "\n",
    "  concept_labeled = False\n",
    "  non_concept_labeled = False\n",
    "  for i in range(cfg_data.V):\n",
    "\n",
    "    if concept_idx == meta[\"token_groups\"][i].item():\n",
    "\n",
    "      label_text = r\"target\" if not concept_labeled else \"_nolegend_\"\n",
    "      plt.plot(alphas, gaps[:, i], color = \"blue\", alpha = 0.8, label=label_text, linewidth = 2.5)\n",
    "      plt.ylabel(\"$\\\\Delta_{\" + f\"{context_idx+1}\" + \",z}(\\\\alpha)$\")\n",
    "      concept_labeled = True\n",
    "    else:\n",
    "\n",
    "      label_text = r\"off-target\" if not non_concept_labeled else \"_nolegend_\"\n",
    "      plt.plot(alphas, gaps[:, i], color = \"orange\", alpha = 0.8, label=label_text, linewidth = 2.5)\n",
    "      plt.ylabel(\"$\\\\Delta p(\\\\alpha)$\", fontsize=20)\n",
    "      non_concept_labeled = True\n",
    "\n",
    "  alpha_vline_pos_list = [2.4]\n",
    "\n",
    "\n",
    "  ymin, ymax = plt.gca().get_ylim()\n",
    "\n",
    "  y_mid = ymin + 0.1\n",
    "\n",
    "  for i, pos in enumerate(alpha_vline_pos_list):\n",
    "      # Plot the vertical line\n",
    "      plt.axvline(x=pos, color='black', linestyle='--', linewidth=2.)\n",
    "\n",
    "      label = \"$\\\\alpha_{(1,\" + f\"{i+1}\" + \")}$ \"\n",
    "\n",
    "      text_obj = plt.text(pos, y_mid, label, color='black',\n",
    "                          ha='center', va='center', fontsize=15)\n",
    "\n",
    "      text_obj.set_path_effects([PathEffects.withStroke(linewidth=4, foreground='white')])\n",
    "\n",
    "\n",
    "  plt.xlabel(\"$\\\\alpha$\", fontsize=20)\n",
    "\n",
    "  plt.legend(fontsize=15)\n",
    "  plt.axhline(y=0.0, color='black', linestyle='-')\n",
    "  plt.grid()\n",
    "  plt.tight_layout()\n",
    "  plt.savefig(f\"not_pure_context_{context_idx+1}_contains_concept_{meta['context_groups'][context_idx] == concept_idx}_pos.pdf\")\n",
    "  plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
