{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad08abc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "from regression.feature_loader import test_feature_loader\n",
    "from regression.regression_utils import load_pkl\n",
    "from regression.losses import correlation_loss\n",
    "from regression.losses import brain_score_1 as brain_score\n",
    "from regression.load_meg_targets import load_meg_targets\n",
    "from regression.session_story_configs import subject_test_configs\n",
    "from regression.lm_embeddings.embeddings_store import MEGFeatureMapStore\n",
    "from regression.helmet_plot import HelmetPlot\n",
    "from regression.helpers import load_sensor_locations\n",
    "import torch\n",
    "from regression.regression_closed_form import block_gpu_multiply\n",
    "import mne\n",
    "\n",
    "subject = \"D\"\n",
    "embeddings_loc = \"./embeddings\"\n",
    "dataset_loc = \"./data\"\n",
    "rank = 10\n",
    "embeddings_transform_cache_loc = \"./embeddings_transform_cache\"\n",
    "llm_features = { \"name\": \"llama2\",\"layer\": 3,\"context\": 20,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "context_llm = { \"name\": \"llama2\",\"layer\": 3,\"context\": 5,\"pca\": 0.95,\"load\": True,\"delays\": 40 }\n",
    "\n",
    "model_loc = f\"./runs/subject_{subject}_rank_sweep_single/rank_{rank}\"\n",
    "full_model_loc = f\"./runs/\"\n",
    "llm_store_loc = f\"{embeddings_loc}/embeddings_sweep/{llm_features['name']}/layer_{llm_features['layer']}_context_{llm_features['context_len']}\"\n",
    "meg_store_loc = llm_store_loc + \"/meg_store\"\n",
    "embeddings_store_loc = llm_store_loc + f\"/{llm_features['name']}/layer_{llm_features['layer']}\"\n",
    "helmet_positions_loc = dataset_loc + \"/locations.txt\"\n",
    "layer = llm_features[\"layer\"]\n",
    "context_len = llm_features[\"context_len\"]\n",
    "\n",
    "mpl.rcParams[\"font.size\"]        = 20   # global default for text\n",
    "mpl.rcParams[\"axes.titlesize\"]   = 20   # specifically for axes titles\n",
    "mpl.rcParams[\"axes.labelsize\"]   = 20   # for x/y axis labels\n",
    "mpl.rcParams[\"xtick.labelsize\"]  = 12\n",
    "mpl.rcParams[\"ytick.labelsize\"]  = 12\n",
    "mpl.rcParams[\"legend.fontsize\"]  = 12\n",
    "mpl.rcParams[\"figure.titlesize\"] = 20   # for `plt.suptitle`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4c2bab",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "test_configs = subject_test_configs(subject, dataset_loc)\n",
    "meg = np.stack(load_meg_targets(test_configs),  axis=0)\n",
    "meg_test_target = np.mean(meg, axis=0)\n",
    "score = brain_score(meg, 0.03, ceiling_cutoff = None)\n",
    "\n",
    "_, test_features = test_feature_loader(llm_features, lm_feature_map_loc=llm_store_loc, \n",
    "                    subject = subject, controls = [], delays = [], \n",
    "                    force_load=True, load_as_control=False)\n",
    "single_test_feature = test_features[0]\n",
    "meg_store = MEGFeatureMapStore(meg_store_loc)\n",
    "helmet_plotter = HelmetPlot(helmet_positions_loc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b69719f5",
   "metadata": {},
   "source": [
    "full model predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb6bb11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_subject_weights(subject, full_models_loc = \"./runs\"):\n",
    "    model_loc = full_models_loc + f\"/subject_{subject}/full_layer_3\"\n",
    "    full_model = load_pkl(model_loc + \"/final_model.pkl\")\n",
    "    W = full_model.weights.T\n",
    "    b = full_model.bias\n",
    "    W_total = np.concat((W, b[:,None]), axis=1).T\n",
    "    return W_total\n",
    "\n",
    "def subject_rank_predictions(subject, r, test_features, full_models_loc = \"./runs\"):\n",
    "    model_loc = full_models_loc + f\"/subject_{subject}_rank_sweep_single/rank_{r}\"\n",
    "    low_rank_model = torch.load(model_loc + \"/final_model.pt\", weights_only=False)\n",
    "    return low_rank_model.numpy_forward(test_features)\n",
    "\n",
    "W_total = load_subject_weights(subject)#np.concat((W, b[:,None]), axis=1).T\n",
    "meg_full_predicted = block_gpu_multiply(single_test_feature, W_total, 1000, 1000)\n",
    "brain_score_full = score(meg_full_predicted)\n",
    "full_correlation = correlation_loss(meg_full_predicted, meg_test_target)\n",
    "mean_correlation_full = np.mean(full_correlation)\n",
    "plt.scatter(np.arange(0,306), full_correlation)\n",
    "helmet_plotter.plot(full_correlation, vlim = (-0.2, 0.2), title = \"Full Regression Correlation\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a9aab45",
   "metadata": {},
   "outputs": [],
   "source": [
    "ranks = list(range(1,21))\n",
    "rank_scores = []\n",
    "correlations = []\n",
    "r = 10\n",
    "low_rank_model = torch.load(model_loc, weights_only=False)\n",
    "meg_predicted = low_rank_model.numpy_forward(single_test_feature)\n",
    "rank_scores.append(score(meg_predicted))\n",
    "rank_correlation = correlation_loss(meg_predicted, meg_test_target)\n",
    "correlations.append(np.mean(rank_correlation))\n",
    "plt.figure()\n",
    "plt.scatter(rank_correlation, full_correlation, label = \"Channel\")\n",
    "xmin, xmax = plt.xlim()\n",
    "ymin, ymax = plt.ylim()\n",
    "\n",
    "# 2) determine the span for the unity line\n",
    "lo = min(xmin, ymin)\n",
    "hi = max(xmax, ymax)\n",
    "plt.xlim(lo, hi)\n",
    "plt.ylim(lo, hi)\n",
    "\n",
    "# 3) plot y = x over that span\n",
    "plt.plot([lo, hi], [lo, hi], 'k--', linewidth=1, label = \"Equal Model\")\n",
    "plt.title(f\"Rank {r} Vs Full Regression Correlation \")\n",
    "plt.xlabel(f\"Rank {r}\")\n",
    "plt.ylabel(f\"Full Regression\")\n",
    "plt.legend()\n",
    "helmet_plotter.plot(rank_correlation, vlim = (-0.2, 0.2), title = f\"Rank {r} Correlation\")\n",
    "\n",
    "helmet_plotter.plot(rank_correlation - full_correlation, vlim = (-0.2, 0.2), title = f\"Rank {r} Correlation Change\")\n",
    "\n",
    "def plot_three_topomaps(\n",
    "    baseline_corr: np.ndarray,\n",
    "    lowrank_corr: np.ndarray,\n",
    "    positions: np.ndarray,\n",
    "    sphere: float,\n",
    "    cmap: str = \"RdBu_r\",\n",
    "    figsize: tuple = (15, 5),\n",
    "    rank: int = 10\n",
    "):\n",
    "    \"\"\"\n",
    "    Plots three topomaps—baseline, low‐rank (shared scale + colorbar),\n",
    "    and their difference (own scale + colorbar)—side by side.\n",
    "    \"\"\"\n",
    "    # difference map\n",
    "    diff_corr = baseline_corr - lowrank_corr\n",
    "\n",
    "    # two groups: first two share scale\n",
    "    shared_maps = [baseline_corr, lowrank_corr]\n",
    "    shared_titles = [f\"Rank {rank} Correlation\", \"Full-Rank Correlation\"]\n",
    "    diff_title = \"Difference\"\n",
    "\n",
    "    # compute shared vmin/vmax\n",
    "    all_shared = np.concatenate(shared_maps)\n",
    "    vmin_s, vmax_s = all_shared.min(), all_shared.max()\n",
    "\n",
    "    # compute diff vmin/vmax\n",
    "    vmin_d, vmax_d = diff_corr.min(), diff_corr.max()\n",
    "\n",
    "    fig, axes = plt.subplots(\n",
    "        nrows=1, ncols=3,\n",
    "        figsize=figsize,\n",
    "        constrained_layout=False\n",
    "    )\n",
    "    fig.subplots_adjust(left=0.12, right=0.88, wspace=0.3)\n",
    "\n",
    "    # --- plot baseline & low-rank with shared scale ---\n",
    "    ims = []\n",
    "    for ax, data, title in zip(axes[:2], shared_maps, shared_titles):\n",
    "        im = mne.viz.plot_topomap(\n",
    "            data, positions,\n",
    "            axes=ax, show=False,\n",
    "            outlines=\"head\", sphere=sphere,\n",
    "            cmap=cmap\n",
    "        )[0]\n",
    "        ax.set_title(title, fontsize=16)\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "        ims.append(im)\n",
    "\n",
    "    # add shared colorbar on the left\n",
    "    cbar_shared = fig.colorbar(\n",
    "        ims[0],\n",
    "        ax=axes[:2].tolist(),\n",
    "        orientation=\"vertical\",\n",
    "        fraction=0.08,\n",
    "        pad=0.02,\n",
    "        location=\"left\"\n",
    "    )\n",
    "    cbar_shared.set_label(\"Correlation\", fontsize=14)\n",
    "    cbar_shared.ax.tick_params(labelsize=12)\n",
    "\n",
    "    # --- plot difference with its own scale ---\n",
    "    im_diff = mne.viz.plot_topomap(\n",
    "        diff_corr, positions,\n",
    "        axes=axes[2], show=False,\n",
    "        outlines=\"head\", sphere=sphere,\n",
    "        cmap=\"PiYG_r\"\n",
    "    )[0]\n",
    "    axes[2].set_title(diff_title, fontsize=16)\n",
    "    axes[2].set_xticks([]); axes[2].set_yticks([])\n",
    "\n",
    "    # add independent colorbar on the right\n",
    "    cbar_diff = fig.colorbar(\n",
    "        im_diff,\n",
    "        ax=axes[2],\n",
    "        orientation=\"vertical\",\n",
    "        fraction=0.08,\n",
    "        pad=0.02,\n",
    "        location=\"right\"\n",
    "    )\n",
    "    cbar_diff.set_label(\"Δ Correlation\", fontsize=14)\n",
    "    cbar_diff.ax.tick_params(labelsize=12)\n",
    "\n",
    "    return fig\n",
    "\n",
    "positions = load_sensor_locations(\"./preprocessing_data/locations.txt\", partial_sensors=False)\n",
    "\n",
    "# Example usage:\n",
    "fig = plot_three_topomaps(\n",
    "     rank_correlation,\n",
    "     full_correlation,\n",
    "     positions,\n",
    "     sphere=47,\n",
    "     rank=10\n",
    " )\n",
    "plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbe85103",
   "metadata": {},
   "outputs": [],
   "source": [
    "low_rank_model = torch.load(model_loc, weights_only=False)\n",
    "low_rank_model = torch.load(model_loc, weights_only=False)\n",
    "ranks = list(range(1,21))\n",
    "rank_scores = []\n",
    "correlations = []\n",
    "for r in ranks:\n",
    "    model_loc = f\"./runs/subject_{subject}_rank_sweep_single/rank_{r}/last_model.pt\"\n",
    "    low_rank_model = torch.load(model_loc, weights_only=False)\n",
    "    meg_predicted = low_rank_model.numpy_forward(single_test_feature)\n",
    "    rank_scores.append(score(meg_predicted))\n",
    "    rank_correlation = correlation_loss(meg_predicted, meg_test_target)\n",
    "    correlations.append(np.mean(rank_correlation))\n",
    "plt.title(\"$CC_{norm}$ Over Ranks\")\n",
    "plt.plot(np.arange(1, max(ranks)+1), rank_scores, marker = \"o\", label=\"Low-Rank Models\")\n",
    "plt.xticks(np.arange(1, max(ranks)+5, 3))\n",
    "plt.hlines([brain_score_full], [min(ranks)], [max(ranks)], colors=[\"red\"], linestyle=\"--\", label=\"Full Ridge Model\")\n",
    "plt.ylim((0.0, max(rank_scores)*1.1))\n",
    "plt.xlabel(\"Rank\")\n",
    "plt.ylabel(\"$CC_{norm}$\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a0e9868",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def score_func(subject):\n",
    "    test_configs = subject_test_configs(subject, llm_features[\"name\"])\n",
    "    meg = np.stack(load_meg_targets(test_configs),  axis=0)\n",
    "    score = brain_score(meg, 0.03, ceiling_cutoff = None)\n",
    "    return score\n",
    "\n",
    "def subject_full_score(subject):\n",
    "    score = score_func(subject)\n",
    "    W_total = load_subject_weights(subject)\n",
    "    meg_full_predicted = block_gpu_multiply(single_test_feature, W_total, 1000, 1000)\n",
    "    brain_score_full = score(meg_full_predicted)\n",
    "    return brain_score_full   \n",
    "\n",
    "def subject_rank_score(subject):\n",
    "    score = score_func(subject)\n",
    "    ranks = list(range(1,21))\n",
    "    rank_scores = []\n",
    "    for r in ranks:\n",
    "        model_loc = f\"./runs/subject_{subject}_rank_sweep_single/rank_{r}/last_model.pt\"\n",
    "        low_rank_model = torch.load(model_loc, weights_only=False)\n",
    "        meg_predicted = low_rank_model.numpy_forward(single_test_feature)\n",
    "        rank_scores.append(score(meg_predicted))\n",
    "    return rank_scores\n",
    "    \n",
    "ranks = list(range(1,21))\n",
    "\n",
    "\n",
    "A_full_score = subject_full_score(\"A\")\n",
    "C_full_score = subject_full_score(\"C\")\n",
    "D_full_score = subject_full_score(\"D\")\n",
    "\n",
    "A_rank_scores = subject_rank_score(\"A\")\n",
    "print(len(A_rank_scores))\n",
    "C_rank_scores = subject_rank_score(\"C\")\n",
    "D_rank_scores = subject_rank_score(\"D\")\n",
    "\n",
    "plt.title(\"$CC_{norm}$ Over Ranks\")\n",
    "plt.plot(np.arange(1, max(ranks)+1), A_rank_scores, color=\"red\", label = \"Subject A\", marker = \"o\")\n",
    "plt.plot(np.arange(1, max(ranks)+1), D_rank_scores, color=\"blue\", label = \"Subject D\", marker = \"o\")\n",
    "plt.plot(np.arange(1, max(ranks)+1), C_rank_scores, color=\"green\", label = \"Subject C\", marker = \"o\")\n",
    "\n",
    "plt.xticks(np.arange(1, max(ranks)+5, 3))\n",
    "plt.hlines([A_full_score, D_full_score, C_full_score], [min(ranks)]*3, [max(ranks)]*3, colors=[\"red\", \"blue\", \"green\"], linestyle=\"--\")\n",
    "plt.ylim((0.0, max(A_rank_scores + C_rank_scores + D_rank_scores)*1.1))\n",
    "plt.xlabel(\"Rank\")\n",
    "plt.ylabel(\"$CC_{norm}$\")\n",
    "plt.legend()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
