{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4477dd5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from regression.feature_loader import test_feature_loader\n",
    "from regression.regression_utils import load_pkl\n",
    "from regression.losses import correlation_loss\n",
    "from regression.losses import brain_score_1 as brain_score\n",
    "from regression.load_meg_targets import load_meg_targets\n",
    "from regression.session_story_configs import subject_test_configs\n",
    "from regression.lm_embeddings.embeddings_store import MEGFeatureMapStore, EmbeddingsStore\n",
    "from regression.helmet_plot import HelmetPlot\n",
    "from regression.session_story_configs import subject_unique_configs\n",
    "import torch\n",
    "from regression.feature_loader import load_embedding_transform, test_feature_loader\n",
    "from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader\n",
    "from regression.regression_utils import load_pkl\n",
    "from regression.helpers import load_sensor_locations\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import textwrap\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import mne\n",
    "\n",
    "subject = \"D\"\n",
    "embeddings_loc = \"./embeddings\"\n",
    "dataset_loc = \"./data\"\n",
    "rank = 10\n",
    "embeddings_transform_cache_loc = \"./embeddings_transform_cache\"\n",
    "llm_features = { \"name\": \"llama2\",\"layer\": 3,\"context\": 20,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "context_llm = { \"name\": \"llama2\",\"layer\": 3,\"context\": 5,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "\n",
    "model_loc = f\"./runs/subject_{subject}_rank_sweep_single/rank_{rank}\"\n",
    "llm_store_loc = f\"{embeddings_loc}/embeddings_sweep/{llm_features['name']}/layer_{llm_features['layer']}_context_{llm_features['context_len']}\"\n",
    "meg_store_loc = llm_store_loc + \"/meg_store\"\n",
    "embeddings_store_loc = llm_store_loc + f\"/{llm_features['name']}/layer_{llm_features['layer']}\"\n",
    "helmet_positions_loc = dataset_loc + \"/locations.txt\"\n",
    "layer = llm_features[\"layer\"]\n",
    "context_len = llm_features[\"context_len\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e96dac3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_empirical_embeddings(subject, layer, context_size, llm_name):\n",
    "    unique_configs = subject_unique_configs(subject, dataset_loc=dataset_loc)\n",
    "    loader = SessionStoryEmbeddingsFeatureLoader(embeddings_loc + \"/meg_store\", f\"{embeddings_loc}/embeddings_sweep/{llm_name}/layer_{layer}_context_{context_size}\",\n",
    "                                                delays=40, mmap=True, use_cuda=True)\n",
    "    unique_configs = subject_unique_configs(subject, llm_name)\n",
    "    all_embeddings_lst = []\n",
    "    all_contexts_lst = []\n",
    "    for config in unique_configs:\n",
    "        story_embeddings = loader.embeddings_store.load_embeddings(config.story, mmap=False, as_feature_map=False)\n",
    "        story_contexts = loader.embeddings_store.load_contexts(config.story)\n",
    "        all_embeddings_lst.append(story_embeddings)\n",
    "        all_contexts_lst.extend(story_contexts)\n",
    "    all_embeddings = np.concat(all_embeddings_lst, axis = 0)\n",
    "    all_contexts = np.array(all_contexts_lst)\n",
    "    return all_embeddings, all_contexts\n",
    "\n",
    "def empirical_context_activations(subject, llm_features, normed_embedding_factors):\n",
    "    embeddings, contexts = get_empirical_embeddings(subject, llm_features[\"layer\"], llm_features[\"context\"], llm_features[\"name\"])\n",
    "    transform = load_embedding_transform(llm_features, embeddings_transform_cache_loc, add_test=False)\n",
    "    activations = transform(embeddings) @ normed_embedding_factors.T\n",
    "    rank_sorts = np.argsort(activations, axis=0)\n",
    "    sorted_contexts = np.array([contexts[rank_sorts[:,r]] for r in range(rank_sorts.shape[1])])\n",
    "    sorted_activations = np.array([activations[rank_sorts[:,r],r] for r in range(rank_sorts.shape[1])])\n",
    "    return sorted_activations, sorted_contexts\n",
    "\n",
    "def factor_weights(time_factors, space_factors, embedding_factors):\n",
    "    rank_weights = np.einsum(\"rt,re,rs->rtes\", time_factors, embedding_factors, space_factors)\n",
    "    out_weights = np.sum(rank_weights, axis=0).transpose((1,0,2))\n",
    "    dim_collapsed_weights = out_weights.reshape(-1, space_factors.shape[1]).T\n",
    "    return dim_collapsed_weights\n",
    "\n",
    "def predict_meg_from_factors(features, time_factors, space_factors, embedding_factors):\n",
    "    W = factor_weights(time_factors, space_factors, embedding_factors)\n",
    "    return np.matmul(features, W.T)\n",
    "\n",
    "def leave_one_factor_out_sort(features, loss_fn, time_factors, space_factors, embedding_factors, baseline = True):\n",
    "    if baseline:\n",
    "        factor_indices_to_use = list(range(time_factors.shape[0]))\n",
    "        W = factor_weights(time_factors[factor_indices_to_use], space_factors[factor_indices_to_use], embedding_factors[factor_indices_to_use])\n",
    "        predictions = np.matmul(features, W.T)\n",
    "        baseline = loss_fn(predictions)\n",
    "    else:\n",
    "        baseline = 0.0\n",
    "    rank = time_factors.shape[0]\n",
    "    influences = []\n",
    "    for r in range(rank):\n",
    "        factor_indices_to_use = list(range(0, r)) + list(range(r+1, rank))\n",
    "        W = factor_weights(time_factors[factor_indices_to_use], space_factors[factor_indices_to_use], embedding_factors[factor_indices_to_use])\n",
    "        predictions = np.matmul(features, W.T)\n",
    "        s = loss_fn(predictions)\n",
    "        influences.append(baseline - s)\n",
    "    rank_ordering = np.argsort(influences)[::-1]\n",
    "    return np.array(influences), rank_ordering\n",
    "\n",
    "def norm_sort(factor_norms):\n",
    "    rank_ordering = np.argsort(factor_norms)[::-1]\n",
    "    return rank_ordering\n",
    "\n",
    "def normalize_factors(time_factors, space_factors, embedding_factors, sort_order = None):\n",
    "    factor_norms = np.linalg.norm(time_factors, axis=1)*np.linalg.norm(space_factors, axis=1)*np.linalg.norm(embedding_factors, axis=1)\n",
    "    normed_time_factors = time_factors/np.linalg.norm(time_factors, axis=1)[:, None]\n",
    "    normed_space_factors = space_factors/np.linalg.norm(space_factors, axis=1)[:, None]\n",
    "    normed_embedding_factors = embedding_factors/np.linalg.norm(embedding_factors, axis=1)[:,None]\n",
    "    if not sort_order is None:\n",
    "       factor_norms = factor_weights[sort_order]\n",
    "       normed_time_factors = normed_time_factors[sort_order]\n",
    "       normed_space_factors = normed_space_factors[sort_order]\n",
    "       normed_embedding_factors = normed_embedding_factors[sort_order]\n",
    "    return factor_norms, normed_time_factors, normed_space_factors, normed_embedding_factors\n",
    "\n",
    "def context_activations(subject, base_llm_features, context_llm_features, normed_embedding_factors):\n",
    "    embeddings, contexts = get_empirical_embeddings(subject, context_llm_features[\"layer\"], context_llm_features[\"context\"], context_llm_features[\"name\"])\n",
    "    transform = load_embedding_transform(base_llm_features, embeddings_transform_cache_loc, add_test=False)\n",
    "    activations = transform(embeddings) @ normed_embedding_factors.T\n",
    "    #for rank in range(activations.shape[1]):\n",
    "    rank_sorts = np.argsort(activations, axis=0)\n",
    "    #sorted_activations = activations[rank_sorts]\n",
    "    #sorted_contexts = contexts[rank_sorts]\n",
    "    sorted_contexts = np.array([contexts[rank_sorts[:,r]] for r in range(rank_sorts.shape[1])])\n",
    "    sorted_activations = np.array([activations[rank_sorts[:,r],r] for r in range(rank_sorts.shape[1])])\n",
    "    return sorted_activations, sorted_contexts\n",
    "\n",
    "def load_mean_std(llm_features, llm_load_save_loc):\n",
    "    add_test = False\n",
    "    mean_std_save_loc = llm_load_save_loc + \"/\" +f\"MEANSTD_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl\"\n",
    "    return load_pkl(mean_std_save_loc)\n",
    "\n",
    "def load_pca(llm_features, llm_load_save_loc):\n",
    "    add_test = False\n",
    "    pca_save_loc = llm_load_save_loc + \"/\" +f\"PCA_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_pca{llm_features['pca']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl\"\n",
    "    return load_pkl(pca_save_loc)\n",
    "\n",
    "def load_test_context_features(llm_features, mean, std, pca_weights):\n",
    "    loader = SessionStoryEmbeddingsFeatureLoader(meg_store_loc,\n",
    "                                                 f\"{embeddings_loc}/embeddings_sweep/{llm_features['name']}/layer_{llm_features['layer']}_context_{llm_features['context']}\",\n",
    "                                                delays=40, mmap=True, use_cuda=True)\n",
    "    single_test_feature_lazy = loader.load_configs([subject_test_configs(subject, dataset_loc)[0]], -mean, 1/std, pca_weights)[0]\n",
    "    single_test_feature = single_test_feature_lazy[np.arange(0, len(single_test_feature_lazy))]\n",
    "    return single_test_feature"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4fc1055",
   "metadata": {},
   "source": [
    "makes the llm features for layer use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c424bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_configs = subject_test_configs(subject, dataset_loc=dataset_loc)\n",
    "meg = np.stack(load_meg_targets(test_configs),  axis=0)\n",
    "meg_test_target = np.mean(meg, axis=0)\n",
    "score = brain_score(meg, 0.03)\n",
    "\n",
    "_, test_features = test_feature_loader(llm_features, lm_feature_map_loc=llm_store_loc, \n",
    "                    subject = subject, controls = [], delays = [], \n",
    "                    force_load=True, load_as_control=False)\n",
    "single_test_feature = test_features[0]\n",
    "meg_store = MEGFeatureMapStore(meg_store_loc)\n",
    "helmet_plotter = HelmetPlot(helmet_positions_loc)\n",
    "embeddings_transform = load_embedding_transform(llm_features, embeddings_transform_cache_loc, add_test=False, use_torch = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a93420cf",
   "metadata": {},
   "source": [
    "Loads the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78d63591",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(model_loc + \"/last_model.pt\", weights_only=False)\n",
    "\n",
    "unsorted_raw_time_factors = model.time_factors.cpu().detach().clone().numpy()\n",
    "unsorted_raw_space_factors = model.space_factors.cpu().detach().clone().numpy()\n",
    "unsorted_raw_embedding_factors = model.embedding_factors.cpu().detach().clone().numpy()\n",
    "\n",
    "unsorted_influence, sorting_order = leave_one_factor_out_sort(single_test_feature, score,\n",
    "                                                     unsorted_raw_time_factors, unsorted_raw_space_factors, unsorted_raw_embedding_factors, \n",
    "                                                     baseline = True)\n",
    "print(sorting_order)\n",
    "influence = unsorted_influence[sorting_order]\n",
    "factor_norms, time_factors, space_factors, embedding_factors = normalize_factors(unsorted_raw_time_factors, unsorted_raw_space_factors, unsorted_raw_embedding_factors)\n",
    "\n",
    "rank = model.time_factors.shape[0]\n",
    "\n",
    "embedding_dim = model.embedding_dim\n",
    "print(embedding_dim)\n",
    "space_dim = model.channel_dim\n",
    "time_dim = model.delay_dim\n",
    "\n",
    "#delay_bias = model.delay_bias.clone().detach().cpu().numpy()\n",
    "\n",
    "print(f\"Factor Importance: {factor_norms}\")\n",
    "print(f\"Leave One Out Influence: {influence}\")\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "for i in range(rank):\n",
    "    ax.plot(np.arange(0,40)*20, np.flip(time_factors[i,:]), label = f\"factor {i+1}\")\n",
    "ax.legend(loc=\"right\", bbox_to_anchor=(1.3, 0.5))\n",
    "ax.set_xlabel(\"Time (ms)\")\n",
    "ax.set_ylabel(\"Normed Units\")\n",
    "ax.set_title(f\"Rank {rank} Timecourses\")\n",
    "print(f\"Brain Score {score(model.numpy_forward(single_test_feature))}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc866594",
   "metadata": {},
   "outputs": [],
   "source": [
    "def single_factor_meg_projection_predict(features, time_factor, embedding_factor):\n",
    "    weight = time_factor[None,:]*embedding_factor[:,None]\n",
    "    collapsed_weight = weight.reshape(-1)\n",
    "    meg_spatial_collapsed = np.matmul(features, collapsed_weight)\n",
    "    return meg_spatial_collapsed\n",
    "\n",
    "#meg_collapsed = single_factor_meg_projection_predict(single_test_feature, time_factors[0], embedding_factors[0])\n",
    "#plt.plot(meg_collapsed[:300])\n",
    "bins_to_use = 300\n",
    "contexts = EmbeddingsStore(embeddings_store_loc).load_contexts(test_configs[0].story)\n",
    "word_times = MEGFeatureMapStore(meg_store_loc).load_meg_map(test_configs[0])\n",
    "word_timing_indices = [x for x in np.where(word_times != -1)[0] if x < bins_to_use]\n",
    "words = [x.split(\" \")[-1] for x in contexts]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fff291a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def plot_factors_and_raster(\n",
    "    features: np.ndarray,\n",
    "    time_factors: list[np.ndarray],\n",
    "    embedding_factors: list[np.ndarray],\n",
    "    contexts: list[str],\n",
    "    word_times: np.ndarray,\n",
    "    bins_to_use: int = 300,\n",
    "    bin_width_sec: float = 0.020,   # 20 ms per bin\n",
    "    raster_ratio: float = 0.5,      # raster height = 0.5 × time‐series height\n",
    "    font_scale: float = 1.0,        # global font size multiplier\n",
    "    factor_norms=None,\n",
    "    factors_to_plot = None \n",
    "):\n",
    "    \"\"\"\n",
    "    Top:    MEG projections for each factor (legend outside, 2 cols)\n",
    "    Bottom: raster of word‐events with time (s) on top and words on bottom.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    raster_ratio : float\n",
    "        The relative height of the raster row versus the time‐series row.\n",
    "        (e.g. 0.5 makes the raster half as tall as the time‐series plot).\n",
    "    font_scale : float\n",
    "        Multiplier for all font sizes in the figure.\n",
    "    \"\"\"\n",
    "    factors_to_plot_indices = [i - 1 for i in factors_to_plot]\n",
    "    time_factors = time_factors[factors_to_plot_indices]\n",
    "    embedding_factors = embedding_factors[factors_to_plot_indices]\n",
    "    labels = [f\"Factor {i}\" for i in factors_to_plot]\n",
    "    k = len(time_factors)\n",
    "    words = [ctx.split()[-1] for ctx in contexts]\n",
    "\n",
    "    # collapse MEG features per factor\n",
    "    collapsed = []\n",
    "    for tf, ef, norm in zip(time_factors, embedding_factors, factor_norms):\n",
    "        w   = tf[None, :] * ef[:, None]   # (1,L)*(D,1)->(D,L)\n",
    "        cw  = w.ravel()                   # -> (D*L,)\n",
    "        ts  = features.dot(cw)[:bins_to_use]\n",
    "        ts = norm*ts\n",
    "        collapsed.append(ts)\n",
    "    collapsed = np.stack(collapsed, 0)    # (k, bins_to_use)\n",
    "\n",
    "    # which bins have events\n",
    "    bins       = np.arange(bins_to_use)\n",
    "    event_bins = bins[word_times[:bins_to_use] != -1]\n",
    "\n",
    "    # compute x‐axis in seconds\n",
    "    secs = bins * bin_width_sec\n",
    "    # ticks every 1 s\n",
    "    tick_secs = np.arange(0, secs[-1]+1, 1.0)\n",
    "\n",
    "    # figure layout\n",
    "    fig, (ax_ts, ax_raster) = plt.subplots(\n",
    "        2, 1,\n",
    "        figsize=(12, 6),\n",
    "        sharex=True,\n",
    "        gridspec_kw={\n",
    "            'height_ratios': [1, raster_ratio],\n",
    "            'hspace': 0.3\n",
    "        }\n",
    "    )\n",
    "    plt.subplots_adjust(bottom=0.2, right=0.75, top=0.9)\n",
    "\n",
    "    # ── Top: factor time series ──\n",
    "    for i, ts in enumerate(collapsed):\n",
    "        ax_ts.plot(secs, ts, label=labels[i], linewidth=1)\n",
    "    ax_ts.set_ylabel(\"MEG Factor Weight\", fontsize=12*font_scale)\n",
    "    ax_ts.set_title(\"Rank 10 Factor Examples\", fontsize=14*font_scale)\n",
    "    ax_ts.grid(True, alpha=0.3)\n",
    "\n",
    "    # legend outside to the right, 2 columns\n",
    "    ax_ts.legend(\n",
    "        loc='upper left',\n",
    "        bbox_to_anchor=(1.02, 1),\n",
    "        ncol=1,\n",
    "        frameon=False,\n",
    "        fontsize=10*font_scale\n",
    "    )\n",
    "\n",
    "    # ── Bottom: raster ──\n",
    "    ax_raster.eventplot(\n",
    "        [event_bins * bin_width_sec],\n",
    "        colors='black',\n",
    "        lineoffsets=1,\n",
    "        linelengths=0.8\n",
    "    )\n",
    "    ax_raster.set_ylim(0.5, 1.5)\n",
    "    ax_raster.set_yticks([])\n",
    "\n",
    "    # time axis on top\n",
    "    ax_raster.xaxis.set_label_position('top')\n",
    "    ax_raster.xaxis.tick_top()\n",
    "    ax_raster.set_xlabel(\"Time (s)\", fontsize=12*font_scale)\n",
    "    ax_raster.set_xticks(tick_secs)\n",
    "    ax_raster.set_xticklabels([f\"{t:.1f}\" for t in tick_secs], fontsize=10*font_scale)\n",
    "\n",
    "    # annotate words on bottom\n",
    "    for t in event_bins:\n",
    "        w = words[word_times[t]]\n",
    "        ax_raster.text(\n",
    "            t*bin_width_sec, -0.1, w,\n",
    "            rotation=90, va='top', ha='center',\n",
    "            fontsize=12*font_scale,\n",
    "            transform=ax_raster.get_xaxis_transform(),\n",
    "            clip_on=False\n",
    "        )\n",
    "\n",
    "    plt.show()\n",
    "    return fig\n",
    "\n",
    "fig = plot_factors_and_raster(\n",
    "    single_test_feature,\n",
    "    time_factors,\n",
    "    embedding_factors,\n",
    "    contexts,\n",
    "    word_times,\n",
    "    bins_to_use=200,\n",
    "    raster_ratio=0.2,\n",
    "    font_scale=1.3,\n",
    "    factor_norms=factor_norms,\n",
    "    factors_to_plot=[1,2,4,7]\n",
    ")\n",
    "fig = plot_factors_and_raster(\n",
    "    single_test_feature,\n",
    "    time_factors,\n",
    "    embedding_factors,\n",
    "    contexts,\n",
    "    word_times,\n",
    "    bins_to_use=200,\n",
    "    raster_ratio=0.2,\n",
    "    font_scale=1.3,\n",
    "    factor_norms=factor_norms,\n",
    "    factors_to_plot=list(range(1, 11))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9d6dc22",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "positions = load_sensor_locations(helmet_positions_loc, partial_sensors=False)\n",
    "base_llm = llm_features\n",
    "mean,std = load_mean_std(base_llm, embeddings_transform_cache_loc)\n",
    "pca_weights = load_pca(base_llm, embeddings_transform_cache_loc)\n",
    "context_test_features = load_test_context_features(context_llm, mean, std, pca_weights)\n",
    "predicted_meg = model.numpy_forward(context_test_features)\n",
    "s = score(predicted_meg)\n",
    "activation, context = context_activations(subject, base_llm, context_llm, embedding_factors)\n",
    "print(context.shape)\n",
    "least_activating_contexts = context[:,:5]\n",
    "most_activating_contexts = context[:,-5:]\n",
    "\n",
    "def remove_non_chars(str_arr):\n",
    "    pat = re.compile(r\"[^A-Za-z]+\")\n",
    "    vsub = np.vectorize(lambda s: pat.sub(\"\", s))\n",
    "    return vsub(str_arr)\n",
    "\n",
    "def plot_with_inset_colorbar(\n",
    "    time_series,\n",
    "    spatial_maps,\n",
    "    embeddings,\n",
    "    positions_2d,\n",
    "    vlim=(-0.1, 0.1),\n",
    "    sentences_pos_matrix=None,    # list of lists or strings, shape (k × n_pos)\n",
    "    sentences_neg_matrix=None,    # list of lists or strings, shape (k × n_neg)\n",
    "    activating_sentence_shift=0.4, # how far into left margin the row‐titles go,\n",
    "    factor_indices = None,\n",
    "    fontsize = 14,\n",
    "    axis_fontsize = 14\n",
    "):\n",
    "    \"\"\"\n",
    "    Rows:\n",
    "      0: Time series per factor\n",
    "      1: MEG topomap per factor (with inset colorbar)\n",
    "      2: Embedding bar per factor\n",
    "      3: Positive activating samples (optional)\n",
    "      4: Negative activating samples (optional)\n",
    "\n",
    "    Sentences are right-justified, no wrapping.\n",
    "    Row titles flush-left and perfectly centered vertically.\n",
    "    \"\"\"\n",
    "    k = len(time_series)\n",
    "    T = time_series[0].shape[0]\n",
    "    D = embeddings[0].shape[0]\n",
    "\n",
    "    if factor_indices is None:\n",
    "        col_labels = [f\"Component {i+1}\" for i in range(k)]\n",
    "    else:\n",
    "        col_labels = [f\"Component {i}\" for i in factor_indices]\n",
    "    add_pos = sentences_pos_matrix is not None\n",
    "    add_neg = sentences_neg_matrix is not None\n",
    "    n_extra = int(add_pos) + int(add_neg)\n",
    "    nrows = 3 + n_extra\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        nrows=nrows,\n",
    "        ncols=k,\n",
    "        figsize=(4*k, 1.8*nrows),\n",
    "        constrained_layout=True,\n",
    "        sharey=\"row\",\n",
    "        gridspec_kw={\"hspace\": 0.2}\n",
    "    )\n",
    "\n",
    "    # ── Rows 0–2 ──\n",
    "    axes_ts, axes_sp, axes_em = axes[0], axes[1], axes[2]\n",
    "    im = None\n",
    "\n",
    "    for i in range(k):\n",
    "        # Row 0: time series\n",
    "        ax = axes_ts[i]\n",
    "        ax.plot(20*np.arange(T), np.flip(time_series[i]), 'k')\n",
    "        ax.annotate(\n",
    "            col_labels[i],\n",
    "            xy=(0.5, 1.5), xycoords='axes fraction',\n",
    "            ha='center', fontsize=fontsize, fontweight='bold'\n",
    "        )\n",
    "        if i == 0:\n",
    "            ax.set_ylabel(\"Time\\nFactor\", fontsize=fontsize, fontweight='bold')\n",
    "        ax.set_xlabel(\"Time (ms)\", fontsize=axis_fontsize)\n",
    "\n",
    "        # Row 1: topomap\n",
    "        ax = axes_sp[i]\n",
    "        im, _ = mne.viz.plot_topomap(\n",
    "            spatial_maps[i], positions_2d, axes=ax,\n",
    "            show=False, cmap=\"RdBu_r\",\n",
    "            outlines=\"head\", sphere=0.1, vlim=vlim\n",
    "        )\n",
    "        #ax.set_ylabel(\"Helmet\\nFactor\", fontsize=fontsize)\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "\n",
    "        # Row 2: embedding bar\n",
    "        ax = axes_em[i]\n",
    "        ax.bar(np.arange(D), embeddings[i], width=1, color='gray')\n",
    "        if i == 0:\n",
    "            ax.set_ylabel(\"Embedding\\nFactor\", fontsize=fontsize, fontweight='bold')\n",
    "        ax.set_xlabel(\"Dim\", fontsize=fontsize)\n",
    "\n",
    "    # ── Inset colorbar on first topomap ──\n",
    "    first_topo_ax = axes_sp[0]\n",
    "    cax = inset_axes(\n",
    "        first_topo_ax,\n",
    "        width=\"4%\", height=\"80%\",\n",
    "        loc=\"center left\",\n",
    "        bbox_to_anchor=(-0.95, 0, 1, 1),\n",
    "        bbox_transform=first_topo_ax.transAxes,\n",
    "        borderpad=0\n",
    "    )\n",
    "    cbar = fig.colorbar(im, cax=cax, orientation='vertical')\n",
    "    cbar.ax.yaxis.set_label_position('left')\n",
    "    cbar.ax.yaxis.set_ticks_position('left')\n",
    "    cbar.set_label(\n",
    "        \"Channel\\nFactor\",\n",
    "        rotation=90, va='bottom', labelpad=5, fontsize=fontsize,\n",
    "        fontweight='bold'\n",
    "    )\n",
    "    cbar.ax.tick_params(labelsize=8)\n",
    "\n",
    "    # ── Rows 3 & 4: Positive/Negative activating samples ──\n",
    "    current_row = 3\n",
    "    axes_pos = axes[current_row] if add_pos else None\n",
    "    axes_neg = (\n",
    "        axes[current_row+1] if (add_pos and add_neg)\n",
    "        else axes[current_row] if (not add_pos and add_neg)\n",
    "        else None\n",
    "    )\n",
    "\n",
    "    if add_pos:\n",
    "        for ax, sents in zip(axes_pos, sentences_pos_matrix):\n",
    "            ax.axis('off')\n",
    "            # space them evenly\n",
    "            y_coords = np.linspace(0.95, 0.05, len(sents))\n",
    "            for y, sent in zip(y_coords, sents):\n",
    "                # if passed a list/tuple, join into one string\n",
    "                txt = \" \".join(sent) if isinstance(sent, (list, tuple)) else str(sent)\n",
    "                txt = re.sub(r\"\\s*-\\s*\", \"-\", txt)\n",
    "                ax.text(\n",
    "                    0.98, y, txt,\n",
    "                    ha='right', va='center',\n",
    "                    color='red', fontsize=fontsize,\n",
    "                    transform=ax.transAxes\n",
    "                )\n",
    "        current_row += 1\n",
    "\n",
    "    if add_neg:\n",
    "        for ax, sents in zip(axes_neg, sentences_neg_matrix):\n",
    "            ax.axis('off')\n",
    "            y_coords = np.linspace(0.95, 0.05, len(sents))\n",
    "            for y, sent in zip(y_coords, sents):\n",
    "                txt = \" \".join(sent) if isinstance(sent, (list, tuple)) else str(sent)\n",
    "                txt = re.sub(r\"\\s*-\\s*\", \"-\", txt)\n",
    "                ax.text(\n",
    "                    0.98, y, txt,\n",
    "                    ha='right', va='center',\n",
    "                    color='blue', fontsize=fontsize,\n",
    "                    transform=ax.transAxes\n",
    "                )\n",
    "\n",
    "    # ── Row-titles flush-left, centered vertically ──\n",
    "    left_edge = axes_ts[0].get_position().x0\n",
    "    label_x = left_edge - activating_sentence_shift\n",
    "\n",
    "    if axes_pos is not None:\n",
    "        mids = [ax.get_position().y0 + ax.get_position().height/2 for ax in axes_pos]\n",
    "        fig.text(\n",
    "            label_x, np.mean(mids)-0.1,\n",
    "            \"Positive\\nActivating\\nContexts\",\n",
    "            va='center', ha='center',\n",
    "            rotation=90,\n",
    "            fontsize=axis_fontsize, fontweight='bold', color='red'\n",
    "        )\n",
    "\n",
    "    if axes_neg is not None:\n",
    "        mids = [ax.get_position().y0 + ax.get_position().height/2 for ax in axes_neg]\n",
    "        fig.text(\n",
    "            label_x, np.mean(mids) - 0.1,\n",
    "            \"Negative\\nActivating\\nContexts\",\n",
    "            va='center', ha='center',\n",
    "            rotation=90,\n",
    "            fontsize=axis_fontsize, fontweight='bold', color='blue'\n",
    "        )\n",
    "\n",
    "    return fig\n",
    "\n",
    "\n",
    "least_activating_contexts = least_activating_contexts\n",
    "most_activating_contexts = most_activating_contexts\n",
    "print(most_activating_contexts.shape)\n",
    "print(least_activating_contexts.shape)\n",
    "#print(\"found contexts\")\n",
    "#print(least_activating_contexts)\n",
    "#print(most_activating_contexts)\n",
    "#most_activating_contexts = np.array([[\"and I went down. And\", \"howdy..\" ,\"test\",\"how about a long sentence\",\"cat\" ]]*8)\n",
    "#print(most_activating_contexts.shape)\n",
    "#least_activating_contexts = np.array([[\"bye bye bye\"]*5]*8)\n",
    "#print(most_activating_contexts.shape)\n",
    "#print(least_activating_contexts.shape)\n",
    "\n",
    "selection = [1,2,4,7]\n",
    "indexing = [x-1 for x in selection]\n",
    "plot_with_inset_colorbar(\n",
    "    time_factors[indexing],\n",
    "    space_factors[indexing],\n",
    "    embedding_factors[indexing],\n",
    "    positions,\n",
    "    vlim=(-0.1, 0.1),\n",
    "    sentences_pos_matrix=most_activating_contexts[indexing],\n",
    "    sentences_neg_matrix=least_activating_contexts[indexing], \n",
    "    activating_sentence_shift=0.11,\n",
    "    factor_indices = selection,\n",
    "    fontsize=15,\n",
    "    axis_fontsize=15\n",
    "    )\n",
    "\n",
    "least_activating_contexts = least_activating_contexts\n",
    "most_activating_contexts = most_activating_contexts\n",
    "print(most_activating_contexts.shape)\n",
    "print(least_activating_contexts.shape)\n",
    "print(\"found contexts\")\n",
    "#print(least_activating_contexts)\n",
    "#print(most_activating_contexts)\n",
    "#most_activating_contexts = np.array([[\"and I went down. And\", \"howdy..\" ,\"test\",\"how about a long sentence\",\"cat\" ]]*8)\n",
    "#print(most_activating_contexts.shape)\n",
    "#least_activating_contexts = np.array([[\"bye bye bye\"]*5]*8)\n",
    "#print(most_activating_contexts.shape)\n",
    "#print(least_activating_contexts.shape)\n",
    "selection = list(range(1, 11))\n",
    "indexing = [x-1 for x in selection]\n",
    "plot_with_inset_colorbar(\n",
    "    time_factors[indexing],\n",
    "    space_factors[indexing],\n",
    "    embedding_factors[indexing],\n",
    "    positions,\n",
    "    vlim=(-0.1, 0.1),\n",
    "    sentences_pos_matrix=most_activating_contexts[indexing],\n",
    "    sentences_neg_matrix=least_activating_contexts[indexing], \n",
    "    activating_sentence_shift=0.11,\n",
    "    factor_indices = selection)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59d7e4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "indices = np.arange(1, len(influence)+1)      # 1…k\n",
    "ax.bar(indices, influence, color='C2', edgecolor='k')\n",
    "ax.set_xlabel(\"Factor\", fontsize=16)\n",
    "ax.set_ylabel(\"$CC_{norm}$ Drop\", fontsize=16)\n",
    "ax.set_title(\"Factors Leave One Out $CC_{norm}$ Drop\", fontsize=16)\n",
    "ax.set_xticks(indices)\n",
    "ax.set_xticklabels([f\"{i}\" for i in indices])\n",
    "ax.set_ylim(influence.min()*1.1, influence.max() * 1.1)  # add 10% headroom\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b92701e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "indices = np.arange(1, len(factor_norms)+1)      # 1…k\n",
    "ax.bar(indices, factor_norms, color='blue', edgecolor='k')\n",
    "ax.set_xlabel(\"Factor\", fontsize=12)\n",
    "ax.set_ylabel(\"Scale\", fontsize=12)\n",
    "ax.set_title(\"Factors Component Scaling\", fontsize=14)\n",
    "ax.set_xticks(indices)\n",
    "ax.set_xticklabels([f\"{i}\" for i in indices])\n",
    "ax.set_ylim(0.0, factor_norms.max() * 1.1)  # add 10% headroom\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a028379e",
   "metadata": {},
   "source": [
    "get empirical embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "216b9930",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_embeddings, context_contexts = get_empirical_embeddings(subject, layer, context_len, llm_features['name'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5a1906b",
   "metadata": {},
   "source": [
    "Gets the score of a model under a shorter context length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cb8c3fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "context_lens = list(range(1,21))\n",
    "context_len_test_features = []\n",
    "scores = []\n",
    "context_llms = []\n",
    "activations = []\n",
    "contexts = []\n",
    "for context_len in context_lens:\n",
    "    mean,std = load_mean_std(base_llm, embeddings_transform_cache_loc)\n",
    "    pca_weights = load_pca(base_llm, embeddings_transform_cache_loc)\n",
    "    context_llm = { \"name\": \"llama2\",\"layer\": 3,\"context\": context_len,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "    context_test_features = load_test_context_features(context_llm, mean, std, pca_weights)\n",
    "    predicted_meg = model.numpy_forward(context_test_features)\n",
    "    s = score(predicted_meg)\n",
    "    activation, context = context_activations(subject, base_llm, context_llm, embedding_factors)\n",
    "    scores.append(s)\n",
    "    context_len_test_features.append(context_test_features)\n",
    "    context_llms.append(context_llm)\n",
    "    activations.append(activation)\n",
    "    contexts.append(context)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d743d1ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(6, 4))\n",
    "indices = np.arange(1, len(context_lens)+1)      # 1…k\n",
    "ax.bar(indices, scores, color='blue', edgecolor='k')\n",
    "ax.set_xlabel(\"Context Length (words)\", fontsize=12)\n",
    "ax.set_ylabel(\"$CC_{norm}$\", fontsize=12)\n",
    "ax.set_title(\"$CC_{norm}$ Over Context Length\", fontsize=14)\n",
    "ax.set_xticks(indices)\n",
    "ax.set_xticklabels([f\"{i}\" for i in indices])\n",
    "ax.set_ylim(0.0, max(scores) * 1.1)  # add 10% headroom\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56efe9db",
   "metadata": {},
   "source": [
    "How do I get the predictions for different embeddings sizes if I have a pca map and mean/std subtraction already made?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77d4b87a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def single_factor_meg_predict(single_test_feature, time_factor, space_factor, embedding_factor, meg_test_target):\n",
    "    weight = time_factor[:,None]*embedding_factor[None,:]\n",
    "    dim_collapsed_weight = weight.T.reshape(-1)\n",
    "    meg_predict_on_space_factor = np.matmul(single_test_feature, dim_collapsed_weight)\n",
    "    meg_true_on_space_factor = np.dot(meg_test_target, space_factor)\n",
    "    corr = correlation_loss(meg_predict_on_space_factor[:,None], meg_true_on_space_factor[:,None])\n",
    "    return corr, meg_predict_on_space_factor, meg_true_on_space_factor\n",
    "\n",
    "all_corrs = []\n",
    "all_most_activating = []\n",
    "all_least_activating = []\n",
    "all_most_activating_contexts = []\n",
    "all_least_activating_contexts = []\n",
    "\n",
    "for r in tqdm(range(rank)):\n",
    "    corrs = []\n",
    "    most_activating = []\n",
    "    least_activating = []\n",
    "    most_activating_contexts = []\n",
    "    least_activating_contexts = []\n",
    "    for context_len_test_feature, context_llm, activation, context in zip(context_len_test_features, context_llms, activations, contexts):\n",
    "        corr, predicted, true = single_factor_meg_predict(context_len_test_feature, time_factors[r], space_factors[r], embedding_factors[r], meg_test_target)\n",
    "        corrs.append(corr)\n",
    "        least_activating.append(activation[r, 0])\n",
    "        most_activating.append(activation[r, -1])\n",
    "        least_activating_contexts.append(context[r, :5])\n",
    "        most_activating_contexts.append(context[r, -5:])\n",
    "\n",
    "    all_least_activating.append(least_activating)\n",
    "    all_most_activating.append(most_activating)\n",
    "    all_corrs.append(corrs)\n",
    "    all_most_activating_contexts.append(most_activating_contexts)\n",
    "    all_least_activating_contexts.append(least_activating_contexts)\n",
    "\n",
    "all_corrs = np.array(all_corrs).squeeze(-1)\n",
    "all_most_activating = np.array(all_most_activating)\n",
    "all_least_activating = np.array(all_least_activating)\n",
    "all_most_activating_contexts = np.array(all_most_activating_contexts)\n",
    "all_least_activating_contexts = np.array(all_least_activating_contexts)\n",
    "    #print(f\"Factor {r+1}\")\n",
    "    #print(\"____Least Activating___\")\n",
    "    #print(context[r,:5])\n",
    "    #print(\"____Most Activating___\")\n",
    "    #print(context[r,-5:])\n",
    "    \n",
    "    #context_xs = list(range(1,21))\n",
    "    #plt.figure()\n",
    "    #plt.title(f\"Factor {r+1} MEG Projection \\n Correlation Over Context Length\")\n",
    "    #plt.plot(context_xs, corrs, linestyle=\"None\", marker=\"o\")\n",
    "    #plt.xlabel(\"Context Length (words)\")\n",
    "    #plt.ylabel(\"MEG Projection Correlation\")\n",
    "    #plt.xticks(context_xs)\n",
    "\n",
    "    #plt.figure()\n",
    "    #plt.title(f\"Factor {r+1} Extreme Activations Over Context Length\")\n",
    "    #plt.plot(context_xs, most_activating, linestyle=\"None\", marker=\"o\", label = \"Highest Activation\")\n",
    "    #plt.plot(context_xs, least_activating, linestyle=\"None\", marker=\"o\", label = \"Lowest Activation\")\n",
    "    #plt.xlabel(\"Context Length (words)\")\n",
    "    #plt.ylabel(\"Embedding Activation\")\n",
    "    #plt.xticks(context_xs)\n",
    "    #plt.legend()\n",
    "    #print(activation[rank,:10])\n",
    "    #print(context[rank, :10])\n",
    "#plt.figure()\n",
    "#plt.plot(predicted[:200], label = \"prediction\")\n",
    "#plt.plot(true[:200]/100, label = \"true\")\n",
    "#plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "785eeefd",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_context_indices = np.argmax(all_corrs, axis=1)\n",
    "most_activating_contexts = all_most_activating_contexts[np.arange(0,rank), best_context_indices]\n",
    "least_activating_contexts = all_least_activating_contexts[np.arange(0,rank),best_context_indices]\n",
    "print(most_activating_contexts.shape)\n",
    "print(least_activating_contexts.shape)\n",
    "print(all_corrs.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51764d80",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_sentence_factors(\n",
    "    sentences_pos_list,\n",
    "    sentences_neg_list,\n",
    "    factors=None,\n",
    "    correlation_matrix=None,        # shape (k, L)\n",
    "    positive_series_matrix=None,    # shape (k, L)\n",
    "    negative_series_matrix=None,    # shape (k, L)\n",
    "    context_lengths=None,           # length L\n",
    "    max_chars=40,\n",
    "    figsize_per_factor=(4, 10),     # a bit taller now we have 4 rows\n",
    "    left_margin=0.12,\n",
    "    right_margin=0.98,\n",
    "    top_margin=0.95,\n",
    "    bottom_margin=0.05,\n",
    "    hspace=0.3,\n",
    "    max_xticks=6,\n",
    "    pad_between=0.02\n",
    "):\n",
    "    \"\"\"\n",
    "    4×k grid if both series+corr given, else fewer rows:\n",
    "      Row0: positive_series (red Δ) & negative_series (blue ■)\n",
    "      Row1: correlation (black ◯)\n",
    "      Row2: positive sentences (red text)\n",
    "      Row3: negative sentences (blue text)\n",
    "    \"\"\"\n",
    "    k = len(sentences_pos_list)\n",
    "    assert k == len(sentences_neg_list), \"must match # of pos/neg sentence lists\"\n",
    "    if factors is None:\n",
    "        factors = [f\"Factor {i+1}\" for i in range(k)]\n",
    "    assert len(factors) == k\n",
    "\n",
    "    has_series = (positive_series_matrix is not None) and (negative_series_matrix is not None) and (context_lengths is not None)\n",
    "    has_corr   = (correlation_matrix      is not None) and (context_lengths is not None)\n",
    "\n",
    "    # validate shapes\n",
    "    if has_series:\n",
    "        assert positive_series_matrix.shape == negative_series_matrix.shape == correlation_matrix.shape if has_corr else positive_series_matrix.shape == negative_series_matrix.shape\n",
    "    if has_corr:\n",
    "        assert correlation_matrix.shape[1] == len(context_lengths)\n",
    "\n",
    "    # total rows = sentences(2) + corr? + series?\n",
    "    nrows = 2 + has_corr + has_series\n",
    "\n",
    "    fig = plt.figure(\n",
    "        figsize=(figsize_per_factor[0]*k, figsize_per_factor[1]),\n",
    "        constrained_layout=False\n",
    "    )\n",
    "    gs = fig.add_gridspec(\n",
    "        nrows=nrows, ncols=k,\n",
    "        left=left_margin, right=right_margin,\n",
    "        top=top_margin, bottom=bottom_margin,\n",
    "        hspace=hspace\n",
    "    )\n",
    "\n",
    "    row = 0\n",
    "    # ── Row 0: series (red Δ, blue ■) ──\n",
    "    if has_series:\n",
    "        # shared y‐limits across both series\n",
    "        all_vals = positive_series_matrix.flatten().tolist() + negative_series_matrix.flatten().tolist()\n",
    "        vmin, vmax = min(all_vals), max(all_vals)\n",
    "\n",
    "        # thin x‐ticks\n",
    "        L = len(context_lengths)\n",
    "        if L <= max_xticks:\n",
    "            idxs = np.arange(L)\n",
    "        else:\n",
    "            idxs = np.linspace(0, L-1, max_xticks, dtype=int)\n",
    "        xticks = [context_lengths[i] for i in idxs]\n",
    "\n",
    "        axes_series = []\n",
    "        for i in range(k):\n",
    "            ax = fig.add_subplot(gs[row, i]) if i==0 else fig.add_subplot(gs[row, i], sharey=axes_series[0])\n",
    "            axes_series.append(ax)\n",
    "            # plot positive_series\n",
    "            ax.plot(context_lengths,\n",
    "                    positive_series_matrix[i],\n",
    "                    marker='o', color='red', lw=1.5,\n",
    "                    label='Most positive' if i==0 else None)\n",
    "            # plot negative_series\n",
    "            ax.plot(context_lengths,\n",
    "                    negative_series_matrix[i],\n",
    "                    marker='o', color='blue', lw=1.5,\n",
    "                    label='Most negative' if i==0 else None)\n",
    "            ax.set_xticks(xticks); ax.set_xticklabels([])\n",
    "            if i==0:\n",
    "                ax.set_ylabel(\"Embedding Activation Maxima\")\n",
    "                ax.legend(loc='upper left', fontsize=8)\n",
    "            else:\n",
    "                plt.setp(ax.get_yticklabels(), visible=False)\n",
    "            ax.set_ylim(vmin*1.1, vmax*1.1)\n",
    "            ax.grid(True, linestyle='--', alpha=0.3)\n",
    "            ax.set_title(factors[i])\n",
    "        row += 1\n",
    "\n",
    "    # ── Next row: correlation (black ◯) ──\n",
    "    if has_corr:\n",
    "        all_corr = correlation_matrix.flatten()\n",
    "        vmin_c, vmax_c = all_corr.min(), all_corr.max()\n",
    "        # thin x‐ticks reuse above\n",
    "        axes_corr = []\n",
    "        for i in range(k):\n",
    "            ax = fig.add_subplot(gs[row, i]) if i==0 else fig.add_subplot(gs[row, i], sharey=axes_corr[0])\n",
    "            axes_corr.append(ax)\n",
    "            ax.plot(context_lengths,\n",
    "                    correlation_matrix[i],\n",
    "                    marker='o', color='black', lw=1.5,\n",
    "                    label='Correlation' if i==0 else None)\n",
    "            ax.set_xticks(xticks); ax.set_xticklabels(xticks)\n",
    "            if i==0:\n",
    "                ax.set_ylabel(\"MEG Factor Projection Correlation \")\n",
    "                #ax.legend(loc='upper left', fontsize=8)\n",
    "            else:\n",
    "                plt.setp(ax.get_yticklabels(), visible=False)\n",
    "            ax.set_ylim(vmin_c*1.1, vmax_c*1.1)\n",
    "            ax.set_xlabel(\"Context Length (words)\")\n",
    "            ax.grid(True, linestyle='--', alpha=0.3)\n",
    "        # shared x‐label below correlation row\n",
    "        mid_x = (left_margin + right_margin)/2\n",
    "        fig.text(mid_x, bottom_margin/2,\n",
    "                 \"Context length (words)\",\n",
    "                 ha='center', va='center', fontsize=12)\n",
    "        row += 1\n",
    "\n",
    "    # ── Row row: positive sentences ──\n",
    "    axes_pos = [fig.add_subplot(gs[row, i]) for i in range(k)]\n",
    "    for i, ax in enumerate(axes_pos):\n",
    "        ax.axis(\"off\")\n",
    "        y_blocks = len(sentences_pos_list[i])\n",
    "        total_gap = pad_between*(y_blocks-1)\n",
    "        block_h = (0.90 - total_gap)/y_blocks\n",
    "        for s_idx, sent in enumerate(sentences_pos_list[i]):\n",
    "            if isinstance(sent, (list, tuple)):\n",
    "                sent = \" \".join(sent)\n",
    "            lines = textwrap.wrap(sent, width=max_chars)\n",
    "            top = 0.95 - s_idx*(block_h+pad_between)\n",
    "            for j, line in enumerate(lines):\n",
    "                y = top - (j+1)*(block_h/(len(lines)+1))\n",
    "                ax.text(0.98, y, line,\n",
    "                        ha=\"right\", va=\"center\",\n",
    "                        color=\"red\", fontsize=12,\n",
    "                        transform=ax.transAxes)\n",
    "    row += 1\n",
    "\n",
    "    # ── Row row: negative sentences ──\n",
    "    axes_neg = [fig.add_subplot(gs[row, i]) for i in range(k)]\n",
    "    for i, ax in enumerate(axes_neg):\n",
    "        ax.axis(\"off\")\n",
    "        y_blocks = len(sentences_neg_list[i])\n",
    "        total_gap = pad_between*(y_blocks-1)\n",
    "        block_h = (0.90 - total_gap)/y_blocks\n",
    "        for s_idx, sent in enumerate(sentences_neg_list[i]):\n",
    "            if isinstance(sent, (list, tuple)):\n",
    "                sent = \" \".join(sent)\n",
    "            lines = textwrap.wrap(sent, width=max_chars)\n",
    "            top = 0.95 - s_idx*(block_h+pad_between)\n",
    "            for j, line in enumerate(lines):\n",
    "                y = top - (j+1)*(block_h/(len(lines)+1))\n",
    "                ax.text(0.98, y, line,\n",
    "                        ha=\"right\", va=\"center\",\n",
    "                        color=\"blue\", fontsize=12,\n",
    "                        transform=ax.transAxes)\n",
    "\n",
    "    # ── Row‐labels for text rows ──\n",
    "    pos_bbox = axes_pos[0].get_position()\n",
    "    neg_bbox = axes_neg[0].get_position()\n",
    "    y_center_pos = pos_bbox.y0 + pos_bbox.height/2\n",
    "    y_center_neg = neg_bbox.y0 + neg_bbox.height/2\n",
    "\n",
    "    fig.text(left_margin*0.8, y_center_pos, \"Positive\",\n",
    "             va=\"center\", ha=\"left\",\n",
    "             fontsize=16, color=\"red\", weight=\"bold\", rotation=90)\n",
    "    fig.text(left_margin*0.8, y_center_neg, \"Negative\",\n",
    "             va=\"center\", ha=\"left\",\n",
    "             fontsize=16, color=\"blue\", weight=\"bold\", rotation=90)\n",
    "\n",
    "    plt.show()\n",
    "    return fig\n",
    "plot_sentence_factors(most_activating_contexts, least_activating_contexts, factors = [f\"Factor {i+1}\" for i in range(6)],\n",
    "                      correlation_matrix = all_corrs, context_lengths=context_lens, max_chars=38,figsize_per_factor=(5,13), pad_between=0.0,hspace=0.1, \n",
    "                      positive_series_matrix=all_most_activating, negative_series_matrix=all_least_activating)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
