{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ce02a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from regression.feature_loader import test_feature_loader\n",
    "from regression.regression_utils import load_pkl\n",
    "from regression.losses import correlation_loss\n",
    "from regression.losses import brain_score_1 as brain_score\n",
    "from regression.load_meg_targets import load_meg_targets\n",
    "from regression.session_story_configs import subject_test_configs\n",
    "from regression.lm_embeddings.embeddings_store import MEGFeatureMapStore\n",
    "from regression.helmet_plot import HelmetPlot\n",
    "from regression.session_story_configs import subject_unique_configs\n",
    "import torch\n",
    "from regression.feature_loader import load_embedding_transform, test_feature_loader, control_subtracted_test_meg\n",
    "from regression.lm_embeddings.embeddings_store import SessionStoryEmbeddingsFeatureLoader\n",
    "from regression.regression_utils import load_pkl\n",
    "from regression.regression_closed_form import block_gpu_multiply\n",
    "import mne\n",
    "from src.helpers import load_sensor_locations\n",
    "import matplotlib.gridspec as gridspec\n",
    "import textwrap\n",
    "from tqdm import tqdm\n",
    "from neurips_plotter_test import plot_with_inset_colorbar\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import textwrap\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "import mne\n",
    "import matplotlib as mpl\n",
    "\n",
    "import re\n",
    "\n",
    "subject = \"D\"\n",
    "embeddings_loc = \"./embeddings\"\n",
    "dataset_loc = \"./data\"\n",
    "rank = 10\n",
    "embeddings_transform_cache_loc = \"./embeddings_transform_cache\"\n",
    "llm_features = { \"name\": \"llama2\",\"layer\": 3,\"context\": 20,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "context_llm = { \"name\": \"llama2\",\"layer\": 3,\"context\": 5,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "controls_folder = \"./runs/controls\"\n",
    "controls_subtracted_folder = \"./runs/controls_subtracted\"\n",
    "controls_index = 3\n",
    "control_sets = [[\"spectrogram\"], [\"word_onset\"], [\"sentence_start\"], [\"spectrogram\",\"word_onset\",\"sentence_start\"]]\n",
    "save_locs = [\"spectrogram_subtracted\", \"word_onset_subtracted\", \"sentence_start_subtracted\", \"all_controls_subtracted\"]\n",
    "delay_sets = [[15], [40], [40], [15, 40, 40]]\n",
    "\n",
    "original_model_loc = f\"./runs/subject_{subject}_rank_sweep_single/rank_{rank}\"\n",
    "llm_store_loc = f\"{embeddings_loc}/embeddings_sweep/{llm_features['name']}/layer_{llm_features['layer']}_context_{llm_features['context_len']}\"\n",
    "meg_store_loc = llm_store_loc + \"/meg_store\"\n",
    "embeddings_store_loc = llm_store_loc + f\"/{llm_features['name']}/layer_{llm_features['layer']}\"\n",
    "helmet_positions_loc = dataset_loc + \"/locations.txt\"\n",
    "layer = llm_features[\"layer\"]\n",
    "context_len = llm_features[\"context_len\"]\n",
    "llm_name = llm_features[\"name\"]\n",
    "control_set = control_sets[controls_index]\n",
    "save_loc = save_locs[controls_index]\n",
    "delay_set = delay_sets[controls_index]\n",
    "control_model_loc = controls_folder + \"/\" + save_loc + \".pkl\"\n",
    "control_subtracted_loc = f\"./runs/subject_{subject}_rank_sweep_single/rank_{rank}/control_subtracted_models/{save_loc}\"\n",
    "control_loc = f\"{controls_subtracted_folder}/subject_{subject}_rank_{rank}_{save_loc}_subtracted\"\n",
    "helmet_positions = load_sensor_locations(helmet_positions_loc)\n",
    "test_stories = subject_test_configs(subject, dataset_loc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d962578",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_empirical_embeddings(subject, layer, context_size, llm_name):\n",
    "    unique_configs = subject_unique_configs(subject, dataset_loc=dataset_loc)\n",
    "    loader = SessionStoryEmbeddingsFeatureLoader(embeddings_loc + \"/meg_store\", f\"{embeddings_loc}/embeddings_sweep/{llm_name}/layer_{layer}_context_{context_size}\",\n",
    "                                                delays=40, mmap=True, use_cuda=True)\n",
    "    unique_configs = subject_unique_configs(subject, llm_name)\n",
    "    all_embeddings_lst = []\n",
    "    all_contexts_lst = []\n",
    "    for config in unique_configs:\n",
    "        story_embeddings = loader.embeddings_store.load_embeddings(config.story, mmap=False, as_feature_map=False)\n",
    "        story_contexts = loader.embeddings_store.load_contexts(config.story)\n",
    "        all_embeddings_lst.append(story_embeddings)\n",
    "        all_contexts_lst.extend(story_contexts)\n",
    "    all_embeddings = np.concat(all_embeddings_lst, axis = 0)\n",
    "    all_contexts = np.array(all_contexts_lst)\n",
    "    return all_embeddings, all_contexts\n",
    "\n",
    "def empirical_context_activations(subject, llm_features, normed_embedding_factors):\n",
    "    embeddings, contexts = get_empirical_embeddings(subject, llm_features[\"layer\"], llm_features[\"context\"], llm_features[\"name\"])\n",
    "    transform = load_embedding_transform(llm_features, embeddings_transform_cache_loc, add_test=False)\n",
    "    activations = transform(embeddings) @ normed_embedding_factors.T\n",
    "    rank_sorts = np.argsort(activations, axis=0)\n",
    "    sorted_contexts = np.array([contexts[rank_sorts[:,r]] for r in range(rank_sorts.shape[1])])\n",
    "    sorted_activations = np.array([activations[rank_sorts[:,r],r] for r in range(rank_sorts.shape[1])])\n",
    "    return sorted_activations, sorted_contexts\n",
    "\n",
    "def factor_weights(time_factors, space_factors, embedding_factors):\n",
    "    rank_weights = np.einsum(\"rt,re,rs->rtes\", time_factors, embedding_factors, space_factors)\n",
    "    out_weights = np.sum(rank_weights, axis=0).transpose((1,0,2))\n",
    "    dim_collapsed_weights = out_weights.reshape(-1, space_factors.shape[1]).T\n",
    "    return dim_collapsed_weights\n",
    "\n",
    "def predict_meg_from_factors(features, time_factors, space_factors, embedding_factors):\n",
    "    W = factor_weights(time_factors, space_factors, embedding_factors)\n",
    "    return np.matmul(features, W.T)\n",
    "\n",
    "def leave_one_factor_out_sort(features, loss_fn, time_factors, space_factors, embedding_factors, baseline = True):\n",
    "    if baseline:\n",
    "        factor_indices_to_use = list(range(time_factors.shape[0]))\n",
    "        W = factor_weights(time_factors[factor_indices_to_use], space_factors[factor_indices_to_use], embedding_factors[factor_indices_to_use])\n",
    "        predictions = np.matmul(features, W.T)\n",
    "        baseline = loss_fn(predictions)\n",
    "    else:\n",
    "        baseline = 0.0\n",
    "    rank = time_factors.shape[0]\n",
    "    influences = []\n",
    "    for r in range(rank):\n",
    "        factor_indices_to_use = list(range(0, r)) + list(range(r+1, rank))\n",
    "        W = factor_weights(time_factors[factor_indices_to_use], space_factors[factor_indices_to_use], embedding_factors[factor_indices_to_use])\n",
    "        predictions = np.matmul(features, W.T)\n",
    "        s = loss_fn(predictions)\n",
    "        influences.append(baseline - s)\n",
    "    rank_ordering = np.argsort(influences)[::-1]\n",
    "    return np.array(influences), rank_ordering\n",
    "\n",
    "def norm_sort(factor_norms):\n",
    "    rank_ordering = np.argsort(factor_norms)[::-1]\n",
    "    return rank_ordering\n",
    "\n",
    "def normalize_factors(time_factors, space_factors, embedding_factors, sort_order = None):\n",
    "    factor_norms = np.linalg.norm(time_factors, axis=1)*np.linalg.norm(space_factors, axis=1)*np.linalg.norm(embedding_factors, axis=1)\n",
    "    normed_time_factors = time_factors/np.linalg.norm(time_factors, axis=1)[:, None]\n",
    "    normed_space_factors = space_factors/np.linalg.norm(space_factors, axis=1)[:, None]\n",
    "    normed_embedding_factors = embedding_factors/np.linalg.norm(embedding_factors, axis=1)[:,None]\n",
    "    if not sort_order is None:\n",
    "       factor_norms = factor_weights[sort_order]\n",
    "       normed_time_factors = normed_time_factors[sort_order]\n",
    "       normed_space_factors = normed_space_factors[sort_order]\n",
    "       normed_embedding_factors = normed_embedding_factors[sort_order]\n",
    "    return factor_norms, normed_time_factors, normed_space_factors, normed_embedding_factors\n",
    "\n",
    "def context_activations(subject, base_llm_features, context_llm_features, normed_embedding_factors):\n",
    "    embeddings, contexts = get_empirical_embeddings(subject, context_llm_features[\"layer\"], context_llm_features[\"context\"], context_llm_features[\"name\"])\n",
    "    transform = load_embedding_transform(base_llm_features, embeddings_transform_cache_loc, add_test=False)\n",
    "    activations = transform(embeddings) @ normed_embedding_factors.T\n",
    "    #for rank in range(activations.shape[1]):\n",
    "    rank_sorts = np.argsort(activations, axis=0)\n",
    "    #sorted_activations = activations[rank_sorts]\n",
    "    #sorted_contexts = contexts[rank_sorts]\n",
    "    sorted_contexts = np.array([contexts[rank_sorts[:,r]] for r in range(rank_sorts.shape[1])])\n",
    "    sorted_activations = np.array([activations[rank_sorts[:,r],r] for r in range(rank_sorts.shape[1])])\n",
    "    return sorted_activations, sorted_contexts\n",
    "\n",
    "def load_mean_std(llm_features, llm_load_save_loc):\n",
    "    add_test = False\n",
    "    mean_std_save_loc = llm_load_save_loc + \"/\" +f\"MEANSTD_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl\"\n",
    "    return load_pkl(mean_std_save_loc)\n",
    "\n",
    "def load_pca(llm_features, llm_load_save_loc):\n",
    "    add_test = False\n",
    "    pca_save_loc = llm_load_save_loc + \"/\" +f\"PCA_{llm_features['name']}_{llm_features['layer']}_{llm_features['context']}_pca{llm_features['pca']}_with_test_{add_test}_delay_{llm_features['delays']}.pkl\"\n",
    "    return load_pkl(pca_save_loc)\n",
    "\n",
    "def load_test_context_features(llm_features, mean, std, pca_weights):\n",
    "    loader = SessionStoryEmbeddingsFeatureLoader(meg_store_loc,\n",
    "                                                 f\"{embeddings_loc}/embeddings_sweep/{llm_features['name']}/layer_{llm_features['layer']}_context_{llm_features['context']}\",\n",
    "                                                delays=40, mmap=True, use_cuda=True)\n",
    "    single_test_feature_lazy = loader.load_configs([subject_test_configs(subject, dataset_loc)[0]], -mean, 1/std, pca_weights)[0]\n",
    "    single_test_feature = single_test_feature_lazy[np.arange(0, len(single_test_feature_lazy))]\n",
    "    return single_test_feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db00b4e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "unnormalized_meg = control_subtracted_test_meg(control_weight_loc=controls_load_loc, control_features=control_set, \n",
    "                                       subject=subject, delays=delay_set,\n",
    "                                       llm_features=llm_features)\n",
    "#renormalize the meg after control subtraction\n",
    "meg = []\n",
    "for story_meg in unnormalized_meg:\n",
    "    story_normalized_meg = (story_meg - np.mean(story_meg, axis=0)[None,:])/(np.std(story_meg, axis=0)[None,:])\n",
    "    meg.append(story_normalized_meg)\n",
    "feature_size, regression_features = test_feature_loader(llm_features, lm_feature_map_loc=embeddings_loc, \n",
    "                                                            subject = subject, controls = [], delays = [], \n",
    "                                                            force_load=True, load_as_control=False)\n",
    "helmet_plotter = HelmetPlot(helmet_positions_loc)\n",
    "embeddings_transform = load_embedding_transform(llm_features, embeddings_transform_cache_loc, add_test=False, use_torch = False)\n",
    "meg_targets = np.stack(meg)\n",
    "single_test_feature = regression_features[0]\n",
    "score = brain_score(meg_targets, 0.03)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32129e67",
   "metadata": {},
   "source": [
    "control subtracted model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdd0f4fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "rank = 10\n",
    "model = torch.load(control_subtracted_loc + \"/last_model.pt\", weights_only=False)\n",
    "\n",
    "unsorted_raw_time_factors = model.time_factors.cpu().detach().clone().numpy()\n",
    "unsorted_raw_space_factors = model.space_factors.cpu().detach().clone().numpy()\n",
    "unsorted_raw_embedding_factors = model.embedding_factors.cpu().detach().clone().numpy()\n",
    "unsorted_influence, sorting_order = leave_one_factor_out_sort(single_test_feature, score,\n",
    "                                                     unsorted_raw_time_factors, unsorted_raw_space_factors, unsorted_raw_embedding_factors, \n",
    "                                                     baseline = True)\n",
    "influence = unsorted_influence[sorting_order]\n",
    "factor_norms, time_factors, space_factors, embedding_factors = normalize_factors(unsorted_raw_time_factors, unsorted_raw_space_factors, unsorted_raw_embedding_factors)\n",
    "rank = model.time_factors.shape[0]\n",
    "\n",
    "embedding_dim = model.embedding_dim\n",
    "space_dim = model.channel_dim\n",
    "time_dim = model.delay_dim\n",
    "\n",
    "#delay_bias = model.delay_bias.clone().detach().cpu().numpy()\n",
    "\n",
    "print(f\"Factor Importance: {factor_norms}\")\n",
    "print(f\"Leave One Out Influence: {influence}\")\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "for i in range(rank):\n",
    "    ax.plot(np.arange(0,40)*20, np.flip(time_factors[i,:]), label = f\"factor {i+1}\")\n",
    "ax.legend(loc=\"right\", bbox_to_anchor=(1.3, 0.5))\n",
    "ax.set_xlabel(\"Time (ms)\")\n",
    "ax.set_ylabel(\"Normed Units\")\n",
    "ax.set_title(f\"Rank {rank} Timecourses\")\n",
    "print(f\"Brain Score {score(model.numpy_forward(single_test_feature))}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ee9b3cb",
   "metadata": {},
   "source": [
    "uncontrolled model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a912ccb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "rank = 10\n",
    "pre_model = torch.load(original_model_loc + \"/last_model.pt\", weights_only=False)\n",
    "\n",
    "pre_unsorted_raw_time_factors = pre_model.time_factors.cpu().detach().clone().numpy()\n",
    "pre_unsorted_raw_space_factors = pre_model.space_factors.cpu().detach().clone().numpy()\n",
    "pre_unsorted_raw_embedding_factors = pre_model.embedding_factors.cpu().detach().clone().numpy()\n",
    "\n",
    "pre_score = brain_score(np.stack(load_meg_targets(test_stories), axis=0), 0.03)\n",
    "\n",
    "pre_unsorted_influence, presorting_order = leave_one_factor_out_sort(single_test_feature, pre_score,\n",
    "                                                     pre_unsorted_raw_time_factors, pre_unsorted_raw_space_factors, pre_unsorted_raw_embedding_factors, \n",
    "                                                     baseline = True)\n",
    "pre_influence = pre_unsorted_influence[sorting_order]\n",
    "pre_factor_norms, pre_time_factors, pre_space_factors, pre_embedding_factors = normalize_factors(pre_unsorted_raw_time_factors, pre_unsorted_raw_space_factors, pre_unsorted_raw_embedding_factors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3de669b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.title(\"Power in Time Factors\", fontsize=20)\n",
    "controlled_time_power = np.sum(np.square(np.flip(time_factors)*np.square(factor_norms[:,None])), axis=0)\n",
    "uncontrolled_time_power = np.sum(np.square(np.flip(pre_time_factors))*np.square(pre_factor_norms[:,None]), axis=0)\n",
    "\n",
    "plt.plot(20*np.arange(0, 40), controlled_time_power/np.linalg.norm(controlled_time_power), label=\"Controlled\")\n",
    "plt.plot(20*np.arange(0, 40), uncontrolled_time_power/np.linalg.norm(uncontrolled_time_power), label = \"Uncontrolled\")\n",
    "plt.xlabel(\"Time (ms)\", fontsize=20)\n",
    "plt.ylabel(\"Power\", fontsize=20)\n",
    "plt.legend()\n",
    "\n",
    "#plt.title(\"Rank 10 Space Factor Power\")\n",
    "#helmet_plotter.plot(np.mean(np.square(space_factors), axis=0), vlim=None, cmap=None, title=\"Control Subtracted\")\n",
    "#helmet_plotter.plot(np.mean(np.square(pre_space_factors), axis=0), vlim=None, cmap=None, title=\"Pre-Controlled\")\n",
    "#helmet_plotter.plot(np.mean(np.square(space_factors)) - np.mean(np.square(pre_space_factors), axis=0), vlim=None, cmap=None, title=\"Difference\")\n",
    "\n",
    "def plot_two_topomaps(\n",
    "    topo1: np.ndarray,\n",
    "    topo2: np.ndarray,\n",
    "    positions: np.ndarray,\n",
    "    sphere: float,\n",
    "    cmap: str = \"RdBu_r\",\n",
    "    figsize: tuple = (8, 4),\n",
    "    title1: str = \"Map 1\",\n",
    "    title2: str = \"Map 2\",\n",
    "    cbar_label: str = \"\"\n",
    "):\n",
    "    \"\"\"\n",
    "    Plots two topomaps side by side, sharing a single colorbar.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    topo1, topo2 : (n_sensors,) arrays\n",
    "        Data for each topomap.\n",
    "    positions : (n_sensors, 2) array\n",
    "        2D sensor coordinates.\n",
    "    sphere : float\n",
    "        Sphere size for the head outline.\n",
    "    cmap : str\n",
    "        Matplotlib colormap.\n",
    "    figsize : tuple\n",
    "        Figure size.\n",
    "    title1, title2 : str\n",
    "        Subplot titles.\n",
    "    cbar_label : str\n",
    "        Label for the shared colorbar.\n",
    "    \"\"\"\n",
    "    # determine common color scale\n",
    "    all_vals = np.concatenate([topo1, topo2])\n",
    "    vmin, vmax = all_vals.min(), all_vals.max()\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        nrows=1, ncols=2,\n",
    "        figsize=figsize,\n",
    "        constrained_layout=True\n",
    "    )\n",
    "        # make room at top for suptitle and at left for cba\n",
    "\n",
    "    # first topomap\n",
    "    im, _ = mne.viz.plot_topomap(\n",
    "        topo1, positions,\n",
    "        axes=axes[0],\n",
    "        show=False,\n",
    "        cmap=cmap, vlim=(vmin, vmax),\n",
    "        outlines=\"head\", sphere=sphere\n",
    "    )\n",
    "    axes[0].set_title(title1, fontsize=20)\n",
    "    axes[0].axis(\"off\")\n",
    "\n",
    "    # second topomap\n",
    "    m2, _ = mne.viz.plot_topomap(\n",
    "        topo2, positions,\n",
    "        axes=axes[1],\n",
    "        show=False,\n",
    "        cmap=cmap, vlim=(vmin, vmax),\n",
    "        outlines=\"head\", sphere=sphere\n",
    "    )\n",
    "    axes[1].set_title(title2, fontsize=20)\n",
    "    axes[1].axis(\"off\")\n",
    "\n",
    "    # shared colorbar on the right\n",
    "    cbar = fig.colorbar(\n",
    "        im,\n",
    "        ax=axes.tolist(),\n",
    "        orientation=\"vertical\",\n",
    "        fraction=0.046,\n",
    "        pad=0.04,\n",
    "        location=\"left\",\n",
    "    )\n",
    "    # move ticks & label to left side\n",
    "    cbar.ax.yaxis.set_label_position('left')\n",
    "    cbar.ax.yaxis.set_ticks_position('left')\n",
    "    if cbar_label:\n",
    "        cbar.set_label(cbar_label, fontsize=20)\n",
    "\n",
    "    plt.show()\n",
    "    return fig\n",
    "\n",
    "mpl.rcParams['axes.titlesize'] = 20\n",
    "\n",
    "controlled_channel_power = np.sum(np.square(space_factors*np.square(factor_norms[:,None])), axis=0)\n",
    "uncontrolled_channel_power = np.sum(np.square(pre_space_factors)*np.square(pre_factor_norms[:,None]), axis=0)\n",
    "\n",
    "plot_two_topomaps(uncontrolled_channel_power/np.linalg.norm(uncontrolled_channel_power) , controlled_channel_power/np.linalg.norm(controlled_channel_power) ,\n",
    "                  positions=helmet_positions, sphere=47,cmap=\"Reds\", title2=\"Controlled\", title1=\"Uncontrolled\",\n",
    "                  cbar_label=\"Power\")\n",
    "\n",
    "helmet_plotter.plot(uncontrolled_channel_power/np.linalg.norm(uncontrolled_channel_power),cmap=\"Reds\", title=f\"Power in Channel Factors\", vlim=None)\n",
    "#plt.plot(20*np.arange(0, 40), np.mean(np.square(space_factors), axis=0), label=\"Control Subtracted\")\n",
    "#plt.plot(20*np.arange(0, 40),np.mean(np.square(np.flip(pre_time_factors)), axis=0), label = \"Pre-Controlled\")\n",
    "#plt.xlabel(\"Time (ms)\")\n",
    "#plt.ylabel(\"Power\")\n",
    "#plt.legend()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0f7a6ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_llm = { \"name\": \"llama2\",\"layer\": 3,\"context\": 20,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "context_len = 5\n",
    "mean,std = load_mean_std(base_llm, embeddings_transform_cache_loc)\n",
    "pca_weights = load_pca(base_llm, embeddings_transform_cache_loc)\n",
    "context_test_features = load_test_context_features(context_llm, mean, std, pca_weights)\n",
    "predicted_meg = model.numpy_forward(context_test_features)\n",
    "s = score(predicted_meg)\n",
    "activation, context = context_activations(subject, base_llm, context_llm, embedding_factors)\n",
    "print(context.shape)\n",
    "least_activating_contexts = context[:,:5]\n",
    "most_activating_contexts = context[:,-5:]\n",
    "\n",
    "def remove_non_chars(str_arr):\n",
    "    pat = re.compile(r\"[^A-Za-z]+\")\n",
    "    vsub = np.vectorize(lambda s: pat.sub(\"\", s))\n",
    "    return vsub(str_arr)\n",
    "\n",
    "def plot_with_inset_colorbar(\n",
    "    time_series,\n",
    "    spatial_maps,\n",
    "    embeddings,\n",
    "    positions_2d,\n",
    "    vlim=(-0.1, 0.1),\n",
    "    sentences_pos_matrix=None,    # list of lists or strings, shape (k × n_pos)\n",
    "    sentences_neg_matrix=None,    # list of lists or strings, shape (k × n_neg)\n",
    "    activating_sentence_shift=0.4, # how far into left margin the row‐titles go,\n",
    "    factor_indices = None,\n",
    "    fontsize = 14,\n",
    "    axis_fontsize = 14\n",
    "):\n",
    "    \"\"\"\n",
    "    Rows:\n",
    "      0: Time series per factor\n",
    "      1: MEG topomap per factor (with inset colorbar)\n",
    "      2: Embedding bar per factor\n",
    "      3: Positive activating samples (optional)\n",
    "      4: Negative activating samples (optional)\n",
    "\n",
    "    Sentences are right-justified, no wrapping.\n",
    "    Row titles flush-left and perfectly centered vertically.\n",
    "    \"\"\"\n",
    "    k = len(time_series)\n",
    "    T = time_series[0].shape[0]\n",
    "    D = embeddings[0].shape[0]\n",
    "\n",
    "    if factor_indices is None:\n",
    "        col_labels = [f\"Component {i+1}\" for i in range(k)]\n",
    "    else:\n",
    "        col_labels = [f\"Component {i}\" for i in factor_indices]\n",
    "    add_pos = sentences_pos_matrix is not None\n",
    "    add_neg = sentences_neg_matrix is not None\n",
    "    n_extra = int(add_pos) + int(add_neg)\n",
    "    nrows = 3 + n_extra\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        nrows=nrows,\n",
    "        ncols=k,\n",
    "        figsize=(4*k, 1.8*nrows),\n",
    "        constrained_layout=True,\n",
    "        sharey=\"row\",\n",
    "        gridspec_kw={\"hspace\": 0.2}\n",
    "    )\n",
    "\n",
    "    # ── Rows 0–2 ──\n",
    "    axes_ts, axes_sp, axes_em = axes[0], axes[1], axes[2]\n",
    "    im = None\n",
    "\n",
    "    for i in range(k):\n",
    "        # Row 0: time series\n",
    "        ax = axes_ts[i]\n",
    "        ax.plot(20*np.arange(T), np.flip(time_series[i]), 'k')\n",
    "        ax.annotate(\n",
    "            col_labels[i],\n",
    "            xy=(0.5, 1.5), xycoords='axes fraction',\n",
    "            ha='center', fontsize=fontsize, fontweight='bold'\n",
    "        )\n",
    "        if i == 0:\n",
    "            ax.set_ylabel(\"Time\\nFactor\", fontsize=axis_fontsize, fontweight='bold')\n",
    "        ax.set_xlabel(\"Time (ms)\", fontsize=axis_fontsize)\n",
    "\n",
    "        # Row 1: topomap\n",
    "        ax = axes_sp[i]\n",
    "        im, _ = mne.viz.plot_topomap(\n",
    "            spatial_maps[i], positions_2d, axes=ax,\n",
    "            show=False, cmap=\"RdBu_r\",\n",
    "            outlines=\"head\", sphere=0.1, vlim=vlim\n",
    "        )\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "\n",
    "        # Row 2: embedding bar\n",
    "        ax = axes_em[i]\n",
    "        ax.bar(np.arange(D), embeddings[i], width=1, color='gray')\n",
    "        if i == 0:\n",
    "            ax.set_ylabel(\"Embedding\\nFactor\", fontsize=axis_fontsize, fontweight='bold')\n",
    "        ax.set_xlabel(\"Dim\", fontsize=fontsize)\n",
    "\n",
    "    # ── Inset colorbar on first topomap ──\n",
    "    first_topo_ax = axes_sp[0]\n",
    "    cax = inset_axes(\n",
    "        first_topo_ax,\n",
    "        width=\"4%\", height=\"80%\",\n",
    "        loc=\"center left\",\n",
    "        bbox_to_anchor=(-0.95, 0, 1, 1),\n",
    "        bbox_transform=first_topo_ax.transAxes,\n",
    "        borderpad=0\n",
    "    )\n",
    "    cbar = fig.colorbar(im, cax=cax, orientation='vertical')\n",
    "    cbar.ax.yaxis.set_label_position('left')\n",
    "    cbar.ax.yaxis.set_ticks_position('left')\n",
    "    cbar.set_label(\n",
    "        \"Channel\\nFactor\",\n",
    "        rotation=90, va='bottom', labelpad=5, fontsize=axis_fontsize,\n",
    "        fontweight='bold'\n",
    "    )\n",
    "    cbar.ax.tick_params(labelsize=8)\n",
    "\n",
    "    # ── Rows 3 & 4: Positive/Negative activating samples ──\n",
    "    current_row = 3\n",
    "    axes_pos = axes[current_row] if add_pos else None\n",
    "    axes_neg = (\n",
    "        axes[current_row+1] if (add_pos and add_neg)\n",
    "        else axes[current_row] if (not add_pos and add_neg)\n",
    "        else None\n",
    "    )\n",
    "\n",
    "    if add_pos:\n",
    "        for ax, sents in zip(axes_pos, sentences_pos_matrix):\n",
    "            ax.axis('off')\n",
    "            # space them evenly\n",
    "            y_coords = np.linspace(0.95, 0.05, len(sents))\n",
    "            for y, sent in zip(y_coords, sents):\n",
    "                # if passed a list/tuple, join into one string\n",
    "                txt = \" \".join(sent) if isinstance(sent, (list, tuple)) else str(sent)\n",
    "                txt = re.sub(r\"\\s*-\\s*\", \"-\", txt)\n",
    "                ax.text(\n",
    "                    0.98, y, txt,\n",
    "                    ha='right', va='center',\n",
    "                    color='red', fontsize=fontsize,\n",
    "                    transform=ax.transAxes\n",
    "                )\n",
    "        current_row += 1\n",
    "\n",
    "    if add_neg:\n",
    "        for ax, sents in zip(axes_neg, sentences_neg_matrix):\n",
    "            ax.axis('off')\n",
    "            y_coords = np.linspace(0.95, 0.05, len(sents))\n",
    "            for y, sent in zip(y_coords, sents):\n",
    "                txt = \" \".join(sent) if isinstance(sent, (list, tuple)) else str(sent)\n",
    "                txt = re.sub(r\"\\s*-\\s*\", \"-\", txt)\n",
    "                ax.text(\n",
    "                    0.98, y, txt,\n",
    "                    ha='right', va='center',\n",
    "                    color='blue', fontsize=fontsize,\n",
    "                    transform=ax.transAxes\n",
    "                )\n",
    "\n",
    "    # ── Row-titles flush-left, centered vertically ──\n",
    "    left_edge = axes_ts[0].get_position().x0\n",
    "    label_x = left_edge - activating_sentence_shift\n",
    "\n",
    "    if axes_pos is not None:\n",
    "        mids = [ax.get_position().y0 + ax.get_position().height/2 for ax in axes_pos]\n",
    "        fig.text(\n",
    "            label_x, np.mean(mids)-0.1,\n",
    "            \"Positive\\nActivating\\nContexts\",\n",
    "            va='center', ha='center',\n",
    "            rotation=90,\n",
    "            fontsize=axis_fontsize, fontweight='bold', color='red'\n",
    "        )\n",
    "\n",
    "    if axes_neg is not None:\n",
    "        mids = [ax.get_position().y0 + ax.get_position().height/2 for ax in axes_neg]\n",
    "        fig.text(\n",
    "            label_x, np.mean(mids) - 0.1,\n",
    "            \"Negative\\nActivating\\nContexts\",\n",
    "            va='center', ha='center',\n",
    "            rotation=90,\n",
    "            fontsize=axis_fontsize, fontweight='bold', color='blue'\n",
    "        )\n",
    "\n",
    "    return fig\n",
    "def plot_text_only(\n",
    "    sentences_pos_matrix=None,    # list of lists or strings, shape (k × n_pos)\n",
    "    sentences_neg_matrix=None,    # list of lists or strings, shape (k × n_neg)\n",
    "    factor_indices=None,\n",
    "    activating_sentence_shift=0.4,\n",
    "    fontsize=14,\n",
    "    axis_fontsize=14\n",
    "):\n",
    "    \"\"\"\n",
    "    Only the text rows (positive/negative contexts) for each factor.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    sentences_pos_matrix : list[list[str]] or None\n",
    "    sentences_neg_matrix : list[list[str]] or None\n",
    "    factor_indices       : list[int] or None → used to label columns\n",
    "    activating_sentence_shift : float\n",
    "      how far left the row-titles should go\n",
    "    fontsize, axis_fontsize : float\n",
    "      text sizes\n",
    "    \"\"\"\n",
    "    # how many factors (columns)?\n",
    "    if sentences_pos_matrix is not None:\n",
    "        k = len(sentences_pos_matrix)\n",
    "    elif sentences_neg_matrix is not None:\n",
    "        k = len(sentences_neg_matrix)\n",
    "    else:\n",
    "        raise ValueError(\"Must pass at least one of sentences_pos_matrix or sentences_neg_matrix\")\n",
    "\n",
    "    # column labels\n",
    "    if factor_indices is None:\n",
    "        col_labels = [f\"Component {i+1}\" for i in range(k)]\n",
    "    else:\n",
    "        col_labels = [f\"Component {i}\" for i in factor_indices]\n",
    "\n",
    "    add_pos = sentences_pos_matrix is not None\n",
    "    add_neg = sentences_neg_matrix is not None\n",
    "    n_extra = int(add_pos) + int(add_neg)\n",
    "    nrows = n_extra\n",
    "\n",
    "    # build figure\n",
    "    fig, axes = plt.subplots(\n",
    "        nrows=nrows, ncols=k,\n",
    "        figsize=(4*k, 1.8*nrows),\n",
    "        constrained_layout=True,\n",
    "        sharey=\"row\",\n",
    "        gridspec_kw={\"hspace\": 0.2}\n",
    "    )\n",
    "    # if only one row, make axes a list of list\n",
    "    if nrows == 1:\n",
    "        axes = np.expand_dims(axes, 0)\n",
    "\n",
    "    current_row = 0\n",
    "\n",
    "    # Positive row\n",
    "    if add_pos:\n",
    "        for i, ax in enumerate(axes[current_row]):\n",
    "            ax.axis('off')\n",
    "            # add column title above the text\n",
    "            ax.set_title(col_labels[i], fontsize=axis_fontsize, pad=20, fontweight='bold')\n",
    "            sents = sentences_pos_matrix[i]\n",
    "            y_coords = np.linspace(0.95, 0.05, len(sents))\n",
    "            for y, sent in zip(y_coords, sents):\n",
    "                txt = \" \".join(sent) if isinstance(sent, (list, tuple)) else str(sent)\n",
    "                txt = re.sub(r\"\\s*-\\s*\", \"-\", txt)\n",
    "                ax.text(\n",
    "                    0.98, y, txt,\n",
    "                    ha='right', va='center',\n",
    "                    color='red', fontsize=fontsize,\n",
    "                    transform=ax.transAxes\n",
    "                )\n",
    "        current_row += 1\n",
    "\n",
    "    # Negative row\n",
    "    if add_neg:\n",
    "        for i, ax in enumerate(axes[current_row]):\n",
    "            ax.axis('off')\n",
    "            # no title on negative row\n",
    "            sents = sentences_neg_matrix[i]\n",
    "            y_coords = np.linspace(0.95, 0.05, len(sents))\n",
    "            for y, sent in zip(y_coords, sents):\n",
    "                txt = \" \".join(sent) if isinstance(sent, (list, tuple)) else str(sent)\n",
    "                txt = re.sub(r\"\\s*-\\s*\", \"-\", txt)\n",
    "                ax.text(\n",
    "                    0.98, y, txt,\n",
    "                    ha='right', va='center',\n",
    "                    color='blue', fontsize=fontsize,\n",
    "                    transform=ax.transAxes\n",
    "                )\n",
    "\n",
    "    # flush-left row labels\n",
    "    left_edge = axes[0,0].get_position().x0\n",
    "    label_x = left_edge - activating_sentence_shift\n",
    "\n",
    "    if add_pos:\n",
    "        mids = [ax.get_position().y0 + ax.get_position().height/2 for ax in axes[0]]\n",
    "        fig.text(\n",
    "            label_x, np.mean(mids),\n",
    "            \"Positive\\nActivating\\nContexts\",\n",
    "            va='center', ha='center',\n",
    "            rotation=90,\n",
    "            fontsize=axis_fontsize, fontweight='bold', color='red'\n",
    "        )\n",
    "    if add_neg:\n",
    "        row_idx = 1 if add_pos else 0\n",
    "        mids = [ax.get_position().y0 + ax.get_position().height/2 for ax in axes[row_idx]]\n",
    "        fig.text(\n",
    "            label_x, np.mean(mids),\n",
    "            \"Negative\\nActivating\\nContexts\",\n",
    "            va='center', ha='center',\n",
    "            rotation=90,\n",
    "            fontsize=axis_fontsize, fontweight='bold', color='blue'\n",
    "        )\n",
    "\n",
    "    return fig\n",
    "\n",
    "least_activating_contexts = least_activating_contexts\n",
    "most_activating_contexts = most_activating_contexts\n",
    "print(most_activating_contexts.shape)\n",
    "print(least_activating_contexts.shape)\n",
    "#print(\"found contexts\")\n",
    "#print(least_activating_contexts)\n",
    "#print(most_activating_contexts)\n",
    "#most_activating_contexts = np.array([[\"and I went down. And\", \"howdy..\" ,\"test\",\"how about a long sentence\",\"cat\" ]]*8)\n",
    "#print(most_activating_contexts.shape)\n",
    "#least_activating_contexts = np.array([[\"bye bye bye\"]*5]*8)\n",
    "#print(most_activating_contexts.shape)\n",
    "#print(least_activating_contexts.shape)\n",
    "\n",
    "selection = list(range(1, 11))\n",
    "indexing = [x-1 for x in selection]\n",
    "plot_with_inset_colorbar(\n",
    "    time_factors[indexing],\n",
    "    space_factors[indexing],\n",
    "    embedding_factors[indexing],\n",
    "    helmet_positions,\n",
    "    vlim=(-0.1, 0.1),\n",
    "    sentences_pos_matrix=most_activating_contexts[indexing],\n",
    "    sentences_neg_matrix=least_activating_contexts[indexing], \n",
    "    activating_sentence_shift=0.11,\n",
    "    factor_indices = selection,\n",
    "    fontsize=14.5,\n",
    "    axis_fontsize=15\n",
    "    )\n",
    "plot_text_only(sentences_pos_matrix=most_activating_contexts[indexing],\n",
    "    sentences_neg_matrix=least_activating_contexts[indexing], \n",
    "    activating_sentence_shift=0.19,\n",
    "    factor_indices = selection,\n",
    "    fontsize=14.5,\n",
    "    axis_fontsize=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "234d30f9",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
