{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4ba61398",
   "metadata": {},
   "source": [
    "## In this notebook we compare encoding performances across different feature spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a0251de",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
    "if parent_dir not in sys.path:\n",
    "    sys.path.append(parent_dir)\n",
    "\n",
    "import os \n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "import torch.nn.functional as F\n",
    "from transformers import BertTokenizer, BertModel\n",
    "from os.path import join as opj\n",
    "from himalaya.ridge import RidgeCV\n",
    "from himalaya.backend import set_backend\n",
    "from config import DATASET_FULL_TRIALS_ZSCORE\n",
    "from dataset import getDatasetLoaders_V3\n",
    "from encoding_utils import plot_channels_grid_fdr\n",
    "from transformers import WhisperForConditionalGeneration, WhisperTokenizer, AutoProcessor\n",
    "import torchaudio\n",
    "set_backend(\"torch_cuda\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4eb4b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_output_dirs = [\"encoding_mfcc\", \"encoding_synthetic_mfcc\", \"encoding_semantic\",\"encoding_speech\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b749d4c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoding_data = {}\n",
    "\n",
    "\n",
    "for base_output_dir in tqdm.tqdm(base_output_dirs):\n",
    "    # models = pickle.load(open(opj(base_output_dir, \"time_windows_models.pkl\"), \"rb\"))\n",
    "    corrs = pickle.load(open(opj(base_output_dir, \"time_windows_SM_corrs.pkl\"), \"rb\"))\n",
    "    null_dist = pickle.load(open(opj(base_output_dir, \"time_windows_SM_null_dist.pkl\"), \"rb\"))\n",
    "    null_array = np.array(null_dist)\n",
    "\n",
    "    encoding_data[base_output_dir] = {\n",
    "        \"corrs\": corrs,\n",
    "        \"null_array\": null_array,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "680d9174",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "from statsmodels.stats.multitest import multipletests\n",
    "import seaborn as sns\n",
    "\n",
    "def plot_channels_grid_fdr_compare_feature_spaces(channels_range, encoding_data, feature_labels, alpha_level=0.05):\n",
    "    \"\"\"\n",
    "    Compare encoding performance for multiple feature spaces across channels.\n",
    "    Each subplot shows the real encoding curve + null for one channel.\n",
    "    \"\"\"\n",
    "    n_plots = len(channels_range)\n",
    "    grid_size = int(np.ceil(np.sqrt(n_plots)))\n",
    "    fig, axes = plt.subplots(4, 4, figsize=(20, 16), dpi=150)\n",
    "    fig.suptitle(\"Encoding Performance Across Feature Spaces\", fontsize=16, fontweight=\"bold\")\n",
    "\n",
    "    n_time_points = next(iter(encoding_data.values()))[\"corrs\"].shape[0]\n",
    "    n_permutations = next(iter(encoding_data.values()))[\"null_array\"].shape[0]\n",
    "\n",
    "    colors = [\"tab:orange\", \"tab:green\", \"tab:blue\", \"tab:purple\"]  # up to 4 spaces\n",
    "\n",
    "    for i, ch_idx in enumerate(channels_range):\n",
    "        row = i // 4\n",
    "        col = i % 4\n",
    "        ax = axes[row, col]\n",
    "\n",
    "        for j, (feature_name, data) in enumerate(encoding_data.items()):\n",
    "            real_corrs = data[\"corrs\"][:, ch_idx]\n",
    "            smoothed = gaussian_filter1d(real_corrs, sigma=2)\n",
    "\n",
    "            null = data[\"null_array\"][:, :, ch_idx]\n",
    "            null_mean = null.mean(axis=1)\n",
    "            null_std = null.std(axis=1)\n",
    "\n",
    "            color = colors[j % len(colors)]\n",
    "\n",
    "            # Plot real\n",
    "            ax.plot(smoothed, label=f\"{feature_labels[feature_name]}\", color=color, linewidth=2)\n",
    "            ax.plot(real_corrs, alpha=0.3, color=color)\n",
    "\n",
    "            # Plot null band\n",
    "            ax.plot(null_mean, color=color, linestyle=\"--\", alpha=0.6)\n",
    "            ax.fill_between(np.arange(n_time_points),\n",
    "                            null_mean - null_std,\n",
    "                            null_mean + null_std,\n",
    "                            color=color, alpha=0.1)\n",
    "\n",
    "        ax.set_title(f\"Ch {ch_idx}\", fontsize=9)\n",
    "        ax.axvline(x=50, color=\"red\", linestyle=\"--\", alpha=0.5)\n",
    "        ax.set_ylim(-0.3, 0.5)\n",
    "        ax.set_xticks([0, 50, 100, 150, n_time_points - 1])\n",
    "        ax.set_yticks([-0.3, 0, 0.3])\n",
    "        sns.despine(ax=ax)\n",
    "\n",
    "    # Turn off unused subplots\n",
    "    for j in range(n_plots, 16):\n",
    "        row = j // 4\n",
    "        col = j % 4\n",
    "        axes[row, col].axis('off')\n",
    "\n",
    "    handles, labels = axes[0, 0].get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc=\"upper right\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dfb18fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_channels_grid_fdr_compare_feature_spaces(\n",
    "    channels_range=range(192,208),\n",
    "    encoding_data=encoding_data,\n",
    "    feature_labels={\n",
    "        \"encoding_mfcc\": \"MFCC\",\n",
    "        \"encoding_synthetic_mfcc\": \"Synthetic MFCC\",\n",
    "        \"encoding_semantic\": \"Semantic\",\n",
    "        \"encoding_speech\": \"Speech\",\n",
    "    }\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1f5e29f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_compare_embeddings_with_stats(\n",
    "    encoding_data,\n",
    "    feature_labels,\n",
    "    channels_range,\n",
    "    alpha_level=0.05,\n",
    "    sig_fraction_threshold=0.2,\n",
    "    sigma=2,\n",
    "    ncols=4\n",
    "):\n",
    "    \"\"\"\n",
    "    Compare encoding performances with statistical testing.\n",
    "    For each timepoint, mark significant encoding with a dot in the color of the best-correlated embedding.\n",
    "    \"\"\"\n",
    "    import numpy as np\n",
    "    import matplotlib.pyplot as plt\n",
    "    from scipy.ndimage import gaussian_filter1d\n",
    "    from statsmodels.stats.multitest import multipletests\n",
    "    import seaborn as sns\n",
    "\n",
    "    n_time_points = next(iter(encoding_data.values()))[\"corrs\"].shape[0]\n",
    "    n_permutations = next(iter(encoding_data.values()))[\"null_array\"].shape[0]\n",
    "    feature_names = list(encoding_data.keys())\n",
    "    n_embeddings = len(feature_names)\n",
    "\n",
    "    n_channels = len(channels_range)\n",
    "    nrows = int(np.ceil(n_channels / ncols))\n",
    "    fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4 * nrows), dpi=150)\n",
    "    fig.suptitle(\"Encoding Performance Across Feature Spaces\", fontsize=18, fontweight=\"bold\")\n",
    "\n",
    "    if isinstance(axes, np.ndarray):\n",
    "        axes = axes.flatten()\n",
    "    else:\n",
    "        axes = [axes]\n",
    "\n",
    "    color_palette = sns.color_palette(\"tab10\", n_embeddings)\n",
    "    color_map = {feature: color_palette[i] for i, feature in enumerate(feature_names)}\n",
    "\n",
    "    sig_channels = {feature: [] for feature in feature_names}\n",
    "\n",
    "    for i, ch_idx in enumerate(channels_range):\n",
    "        ax = axes[i]\n",
    "        pvals_matrix = np.zeros((n_embeddings, n_time_points))\n",
    "        corr_matrix = np.zeros((n_embeddings, n_time_points))\n",
    "\n",
    "        for j, feature in enumerate(feature_names):\n",
    "            data = encoding_data[feature]\n",
    "            corrs = data[\"corrs\"][:, ch_idx]\n",
    "            null_vals = data[\"null_array\"][:, :, ch_idx]\n",
    "\n",
    "            corr_matrix[j] = corrs\n",
    "\n",
    "            # Permutation p-values (one-tailed)\n",
    "            greater_counts = (null_vals >= corrs[:, None]).sum(axis=1)\n",
    "            pvals = greater_counts / n_permutations\n",
    "            pvals_matrix[j] = pvals\n",
    "\n",
    "        # FDR correction across all embedding × timepoint\n",
    "        pvals_flat = pvals_matrix.flatten()\n",
    "        reject_flags, pvals_corrected_flat, _, _ = multipletests(pvals_flat, alpha=alpha_level, method='fdr_bh')\n",
    "        pvals_corrected = pvals_corrected_flat.reshape(n_embeddings, n_time_points)\n",
    "        sig_mask = reject_flags.reshape(n_embeddings, n_time_points)\n",
    "\n",
    "        # Track significant channels\n",
    "        for j, feature in enumerate(feature_names):\n",
    "            if sig_mask[j].mean() >= sig_fraction_threshold:\n",
    "                sig_channels[feature].append(ch_idx)\n",
    "\n",
    "        # Plot all curves\n",
    "        for j, feature in enumerate(feature_names):\n",
    "            color = color_map[feature]\n",
    "            corrs = corr_matrix[j]\n",
    "            smoothed = gaussian_filter1d(corrs, sigma=sigma)\n",
    "            ax.plot(corrs, alpha=0.3, color=color)\n",
    "            ax.plot(smoothed, label=feature_labels[feature], color=color, linewidth=2)\n",
    "\n",
    "            # Plot null mean ± std\n",
    "            null_vals = encoding_data[feature][\"null_array\"][:, :, ch_idx]\n",
    "            null_mean = null_vals.mean(axis=1)\n",
    "            null_std = null_vals.std(axis=1)\n",
    "            ax.fill_between(np.arange(n_time_points), null_mean - null_std, null_mean + null_std,\n",
    "                            color=color, alpha=0.1)\n",
    "\n",
    "        # Find the best (most correlated) significant embedding per timepoint\n",
    "        best_colors = []\n",
    "        for t in range(n_time_points):\n",
    "            sig_at_t = sig_mask[:, t]\n",
    "            if np.any(sig_at_t):\n",
    "                best_idx = np.argmax(corr_matrix[:, t] * sig_at_t)  # argmax only over significant ones\n",
    "                best_colors.append(color_palette[best_idx])\n",
    "            else:\n",
    "                best_colors.append(None)\n",
    "\n",
    "        # Plot small dots for significant timepoints\n",
    "        for t, color in enumerate(best_colors):\n",
    "            if color is not None:\n",
    "                ax.plot(t, 0.42, marker='o', color=color, markersize=3, alpha=0.8)\n",
    "\n",
    "        ax.axvline(x=50, color=\"red\", linestyle=\"--\", alpha=0.5)\n",
    "        ax.set_ylim(-0.3, 0.45)\n",
    "        ax.set_title(f\"Ch {ch_idx}\", fontsize=10)\n",
    "        ax.set_xticks([0, 50, 100, 150, n_time_points - 1])\n",
    "        ax.set_yticks([-0.3, 0, 0.3])\n",
    "        sns.despine(ax=ax)\n",
    "\n",
    "    # Turn off unused axes\n",
    "    for j in range(n_channels, len(axes)):\n",
    "        axes[j].axis('off')\n",
    "\n",
    "    handles, labels = axes[0].get_legend_handles_labels()\n",
    "    fig.legend(handles, labels, loc=\"upper right\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    return sig_channels\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a26d3eb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_sig_chs = {feature: [] for feature in encoding_data.keys()}\n",
    "\n",
    "for start in range(0, 256, 16):\n",
    "    end = min(start + 16, 256)\n",
    "    print(f\"Processing channels {start} to {end-1}\")\n",
    "    \n",
    "    sig_chs = plot_compare_embeddings_with_stats(\n",
    "        encoding_data=encoding_data,\n",
    "        feature_labels={\n",
    "            \"encoding_mfcc\": \"MFCC\",\n",
    "            \"encoding_synthetic_mfcc\": \"Synthetic MFCC\",\n",
    "            \"encoding_semantic\": \"Semantic\",\n",
    "            \"encoding_speech\": \"Speech\"\n",
    "        },\n",
    "        channels_range=range(start, end),\n",
    "        alpha_level=0.05,\n",
    "        sig_fraction_threshold=0.5,\n",
    "        sigma=2\n",
    "    )\n",
    "\n",
    "    # Merge results\n",
    "    for feature, channels in sig_chs.items():\n",
    "        all_sig_chs[feature].extend(channels)\n",
    "\n",
    "# Optionally: deduplicate\n",
    "for feature in all_sig_chs:\n",
    "    all_sig_chs[feature] = sorted(list(set(all_sig_chs[feature])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51611943",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save significant channels\n",
    "with open(\"significant_channels.pkl\", \"wb\") as f:\n",
    "    pickle.dump(all_sig_chs, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ae622ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Corrected area 6v data (removed leading zeros)\n",
    "area6v_superior = [\n",
    "    [62, 51, 43, 35, 94, 87, 79, 78],\n",
    "    [60, 53, 41, 33, 95, 86, 77, 76],\n",
    "    [63, 54, 47, 44, 93, 84, 75, 74],\n",
    "    [58, 55, 48, 40, 92, 85, 73, 72],\n",
    "    [59, 45, 46, 38, 91, 82, 71, 70],\n",
    "    [61, 49, 42, 36, 90, 83, 69, 68],\n",
    "    [56, 52, 39, 34, 89, 81, 67, 66],\n",
    "    [57, 50, 37, 32, 88, 80, 65, 64],\n",
    "]\n",
    "\n",
    "area6v_inferior = [\n",
    "    [125, 126, 112, 103, 31, 28, 11, 8],\n",
    "    [123, 124, 110, 102, 29, 26, 9, 5],\n",
    "    [121, 122, 109, 101, 27, 19, 18, 4],\n",
    "    [119, 120, 108, 100, 25, 15, 12, 3],\n",
    "    [117, 118, 107, 99, 23, 13, 10, 2],\n",
    "    [115, 116, 106, 97, 21, 20, 7, 0],\n",
    "    [113, 114, 105, 98, 17, 24, 14, 6],\n",
    "    [127, 111, 104, 96, 30, 22, 16, 1],\n",
    "]\n",
    "\n",
    "# Combine superior and inferior parts\n",
    "area6v_full = np.array(area6v_superior + area6v_inferior)\n",
    "\n",
    "# Create topological plot\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "\n",
    "for row_idx, row in enumerate(area6v_full):\n",
    "    for col_idx, ch in enumerate(row):\n",
    "        ax.scatter(col_idx, -row_idx, s=100, c=\"lightgray\", edgecolor=\"k\")\n",
    "        ax.text(col_idx, -row_idx, str(ch), ha=\"center\", va=\"center\", fontsize=8)\n",
    "\n",
    "ax.set_aspect(\"equal\")\n",
    "ax.set_title(\"Area 6v Channel Topology (0–127, Spike Counts)\", fontsize=14)\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "plt.grid(False)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ce8f2e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "def plot_winner_map(area6v_layout, winner_dict, feature_labels, cmap=None, title=\"Winner Embedding per Channel\"):\n",
    "    \"\"\"\n",
    "    Plot the layout of Area 6v with channels colored based on the winner embedding.\n",
    "    \n",
    "    Args:\n",
    "        area6v_layout (np.ndarray): 2D array of channel numbers.\n",
    "        winner_dict (dict): channel_id -> embedding_key (e.g., \"encoding_mfcc\").\n",
    "        feature_labels (dict): mapping from embedding_key to label for legend.\n",
    "        cmap (dict): embedding_key -> color.\n",
    "    \"\"\"\n",
    "    if cmap is None:\n",
    "        color_palette = sns.color_palette(\"tab10\", len(feature_labels))\n",
    "        cmap = {key: color_palette[i] for i, key in enumerate(feature_labels.keys())}\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(10, 8))\n",
    "\n",
    "    for row_idx, row in enumerate(area6v_layout):\n",
    "        for col_idx, ch in enumerate(row):\n",
    "            embedding = winner_dict.get(ch, None)\n",
    "            if embedding is None:\n",
    "                color = \"lightgray\"\n",
    "            else:\n",
    "                color = cmap[embedding]\n",
    "            ax.scatter(col_idx, -row_idx, s=180, c=color, edgecolor=\"black\", linewidth=0.6)\n",
    "            ax.text(col_idx, -row_idx, str(ch), ha=\"center\", va=\"center\", fontsize=8, color=\"white\" if embedding else \"black\")\n",
    "\n",
    "    # Legend\n",
    "    for i, (key, label) in enumerate(feature_labels.items()):\n",
    "        ax.scatter([], [], c=cmap[key], s=100, label=label, edgecolor=\"black\")\n",
    "    ax.scatter([], [], c=\"lightgray\", s=100, label=\"Not Significant\", edgecolor=\"black\")\n",
    "\n",
    "    ax.set_title(title, fontsize=14)\n",
    "    ax.set_aspect(\"equal\")\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    ax.legend(loc=\"upper right\", bbox_to_anchor=(1.25, 1.0))\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b2d2e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_winner_dict(all_sig_chs, encoding_data):\n",
    "    \"\"\"\n",
    "    Determine the best embedding per channel based on max mean correlation.\n",
    "    \n",
    "    Args:\n",
    "        all_sig_chs (dict): embedding_key -> list of significantly encoded channels.\n",
    "        encoding_data (dict): embedding_key -> dict with \"corrs\" array.\n",
    "    \n",
    "    Returns:\n",
    "        dict: channel -> embedding_key with best correlation\n",
    "    \"\"\"\n",
    "    candidate_channels = set(ch for ch_list in all_sig_chs.values() for ch in ch_list)\n",
    "    winner_dict = {}\n",
    "\n",
    "    for ch in candidate_channels:\n",
    "        best_embedding = None\n",
    "        best_score = -np.inf\n",
    "        for emb_key in all_sig_chs:\n",
    "            if ch in all_sig_chs[emb_key]:\n",
    "                mean_corr = np.mean(encoding_data[emb_key][\"corrs\"][:, ch])\n",
    "                if mean_corr > best_score:\n",
    "                    best_score = mean_corr\n",
    "                    best_embedding = emb_key\n",
    "        if best_embedding is not None:\n",
    "            winner_dict[ch] = best_embedding\n",
    "\n",
    "    return winner_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe0e1148",
   "metadata": {},
   "outputs": [],
   "source": [
    "winner_dict = compute_winner_dict(all_sig_chs, encoding_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4064451a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# winner_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85eb96f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is your 6v layout for spike counts (0–127)\n",
    "area6v_full = np.array(area6v_superior + area6v_inferior)\n",
    "\n",
    "# Suppose you built this dict from the analysis (one winner per channel)\n",
    "# winner_dict = {62: \"encoding_mfcc\", 63: \"encoding_speech\", ...}\n",
    "\n",
    "plot_winner_map(\n",
    "    area6v_layout=area6v_full,\n",
    "    winner_dict=winner_dict,  # your own {channel: embedding_key}\n",
    "    feature_labels={\n",
    "        \"encoding_mfcc\": \"MFCC\",\n",
    "        \"encoding_synthetic_mfcc\": \"Synthetic MFCC\",\n",
    "        \"encoding_semantic\": \"Semantic\",\n",
    "        \"encoding_speech\": \"Speech\"\n",
    "    }\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6b4dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is your 6v layout for spike counts (0–127)\n",
    "area6v_full = np.array(area6v_superior + area6v_inferior)\n",
    "\n",
    "# Suppose you built this dict from the analysis (one winner per channel)\n",
    "# winner_dict = {62: \"encoding_mfcc\", 63: \"encoding_speech\", ...}\n",
    "\n",
    "plot_winner_map(\n",
    "    area6v_layout=area6v_full+128,\n",
    "    winner_dict=winner_dict,  # your own {channel: embedding_key}\n",
    "    feature_labels={\n",
    "        \"encoding_mfcc\": \"MFCC\",\n",
    "        \"encoding_synthetic_mfcc\": \"Synthetic MFCC\",\n",
    "        \"encoding_semantic\": \"Semantic\",\n",
    "        \"encoding_speech\": \"Speech\"\n",
    "    }\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21fa601e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib as mpl\n",
    "\n",
    "def plot_winner_map_high_quality(area6v_layout, winner_dict, feature_labels, title=\"Winner Feature per Channel\"):\n",
    "    \"\"\"\n",
    "    High-quality topological plot of winner embeddings per channel.\n",
    "    \"\"\"\n",
    "    # Define Nature-style font and colors\n",
    "    mpl.rcParams.update({\n",
    "        \"font.family\": \"Arial\",\n",
    "        \"font.size\": 10,\n",
    "        \"axes.titlesize\": 12,\n",
    "        \"axes.labelsize\": 10,\n",
    "        \"legend.fontsize\": 10\n",
    "    })\n",
    "\n",
    "    # Use CUD-friendly color palette\n",
    "    color_palette = sns.color_palette(\"colorblind\", len(feature_labels))\n",
    "    cmap = {key: color_palette[i] for i, key in enumerate(feature_labels)}\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(6, 5))\n",
    "\n",
    "    for row_idx, row in enumerate(area6v_layout):\n",
    "        for col_idx, ch in enumerate(row):\n",
    "            embedding = winner_dict.get(ch, None)\n",
    "            if embedding is None:\n",
    "                color = \"#e0e0e0\"  # light gray\n",
    "                edge = \"#b0b0b0\"\n",
    "                lw = 0.5\n",
    "            else:\n",
    "                color = cmap[embedding]\n",
    "                edge = \"black\"\n",
    "                lw = 0.6\n",
    "            ax.scatter(col_idx, -row_idx, s=120, color=color, edgecolor=edge, linewidth=lw, zorder=2)\n",
    "\n",
    "    # Annotate channel numbers only optionally\n",
    "    # for row_idx, row in enumerate(area6v_layout):\n",
    "    #     for col_idx, ch in enumerate(row):\n",
    "    #         ax.text(col_idx, -row_idx, str(ch), ha=\"center\", va=\"center\", fontsize=7, zorder=3)\n",
    "\n",
    "    # Custom legend (scalable and clean)\n",
    "    handles = []\n",
    "    for key, label in feature_labels.items():\n",
    "        patch = plt.Line2D([0], [0], marker='o', color='w', label=label,\n",
    "                           markerfacecolor=cmap[key], markersize=8, markeredgecolor='black')\n",
    "        handles.append(patch)\n",
    "    handles.append(plt.Line2D([0], [0], marker='o', color='w', label=\"Not Significant\",\n",
    "                              markerfacecolor=\"#e0e0e0\", markersize=8, markeredgecolor=\"#b0b0b0\"))\n",
    "\n",
    "    ax.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., frameon=False)\n",
    "\n",
    "    ax.set_title(title, fontsize=12, pad=10)\n",
    "    ax.set_aspect(\"equal\")\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    sns.despine(ax=ax, left=True, bottom=True)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60a06a2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_winner_map_high_quality(\n",
    "    area6v_layout=area6v_full+128,  # or area6v_spikepower\n",
    "    winner_dict=winner_dict,\n",
    "    feature_labels={\n",
    "        \"encoding_mfcc\": \"MFCC\",\n",
    "        \"encoding_synthetic_mfcc\": \"Synthetic MFCC\",\n",
    "        \"encoding_semantic\": \"Semantic\",\n",
    "        \"encoding_speech\": \"Speech\"\n",
    "    },\n",
    "    title=\"Winner Embedding per Channel (Spike Counts)\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79ea5533",
   "metadata": {},
   "outputs": [],
   "source": [
    "sig_chs = plot_compare_embeddings_with_stats(\n",
    "        encoding_data=encoding_data,\n",
    "        feature_labels={\n",
    "            \"encoding_mfcc\": \"MFCC\",\n",
    "            \"encoding_synthetic_mfcc\": \"Synthetic MFCC\",\n",
    "            \"encoding_semantic\": \"Semantic\",\n",
    "            \"encoding_speech\": \"Speech\"\n",
    "        },\n",
    "        channels_range=[i  for i in all_sig_chs[\"encoding_mfcc\"] if i >128 ][:16],\n",
    "        alpha_level=0.05,\n",
    "        sig_fraction_threshold=0.5,\n",
    "        sigma=2\n",
    "    )\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
