{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5a20134-2ee3-4f08-b1ef-92da02af28a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import pickle\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1ae471c-9131-4af7-9483-1b4c1d3f4de4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-1.5B\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2421993-371e-466f-9446-68302413b8f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def post_act_to_triplet(post_act_idx: int, m: int, n: int):\n",
    "    elements_per_head = m * n\n",
    "    head_idx = post_act_idx // elements_per_head\n",
    "    position_in_head = post_act_idx % elements_per_head\n",
    "    row_idx = position_in_head // n\n",
    "    col_idx = position_in_head % n\n",
    "    return head_idx, row_idx, col_idx\n",
    "\n",
    "def triplet_to_pre_act(head_idx: int, \n",
    "                      row_idx: int,\n",
    "                      col_idx: int,\n",
    "                      m: int, n: int):\n",
    "    features_per_head = m + n\n",
    "    base_idx = head_idx * features_per_head\n",
    "    row_pre_act = base_idx + row_idx\n",
    "    column_pre_act = base_idx + m + col_idx\n",
    "    return row_pre_act, column_pre_act\n",
    "\n",
    "def columns_for_row(head: int, row: int, m: int, n: int):\n",
    "    assert row < m, f\"Row index should be less than m = {m}\"\n",
    "    base = head * (m + n) + m\n",
    "    return [(col, base + col) for col in range(n)]\n",
    "\n",
    "def rows_for_column(head: int, column: int, m: int, n: int):\n",
    "    assert column < n, f\"Column index should be less than n = {n}\"\n",
    "    base = head * (m + n)\n",
    "    return [(row, base + row) for row in range(m)]\n",
    "\n",
    "def pre_to_post(head: int, row: int, column: int, m: int, n: int):\n",
    "    return head * (m * n) + row * n + column\n",
    "\n",
    "def nanmean(values):\n",
    "    return np.nanmean(np.array(list(values), dtype=float))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "498a7424",
   "metadata": {},
   "outputs": [],
   "source": [
    "pre = pickle.load(open(f'pre.pkl', 'rb')) \n",
    "post = pickle.load(open(f'post.pkl', 'rb')) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e62888ad-db97-4afd-81a2-a334b1efd3c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "def get_mean_score(scores):\n",
    "    det, fuz = scores.get('detection'), scores.get('fuzzing')\n",
    "    if det is None and fuz is None: return None\n",
    "    return fuz if det is None else det if fuz is None else np.round((fuz + det) / 2, 2)\n",
    "\n",
    "def process_feature(feature):\n",
    "    if feature.statistics.token_entropy is None:\n",
    "        feature.compute_token_statistics()\n",
    "    \n",
    "    interpretation = feature.interpretations[-1]\n",
    "    stat = feature.statistics\n",
    "    \n",
    "    return {\n",
    "        'Feature': feature.index,\n",
    "        'Examples': len(feature.positive_examples),\n",
    "        'Mean activation': stat.mean,\n",
    "        'Token entropy': stat.token_entropy,\n",
    "        'Frequency': stat.frequency,\n",
    "        'Multitoken ratio': stat.multitoken_ratio,\n",
    "        'Interpretation': interpretation.value,\n",
    "        'Detection score': interpretation.score.get('detection'),\n",
    "        'Fuzzing score': interpretation.score.get('fuzzing'),\n",
    "        'Mean score': get_mean_score(interpretation.score),\n",
    "    }\n",
    "\n",
    "def get_parent_properties(parent, prefix):\n",
    "    if parent is None:\n",
    "        return {f'{prefix} {k}': None for k in [\n",
    "            'Feature', 'Examples', 'Mean activation', 'Token entropy',\n",
    "            'Frequency', 'Multitoken ratio', 'Interpretation',\n",
    "            'Detection score', 'Fuzzing score', 'Mean score'\n",
    "        ]}\n",
    "    \n",
    "    interpretation = parent.interpretations[-1]\n",
    "    stat = parent.statistics\n",
    "    return {\n",
    "        f'{prefix} Feature': parent.index,\n",
    "        f'{prefix} Examples': len(parent.positive_examples),\n",
    "        f'{prefix} Mean activation': stat.mean,\n",
    "        f'{prefix} Token entropy': stat.token_entropy,\n",
    "        f'{prefix} Frequency': stat.frequency,\n",
    "        f'{prefix} Multitoken ratio': stat.multitoken_ratio,\n",
    "        f'{prefix} Interpretation': interpretation.value,\n",
    "        f'{prefix} Detection score': interpretation.score.get('detection'),\n",
    "        f'{prefix} Fuzzing score': interpretation.score.get('fuzzing'),\n",
    "        f'{prefix} Mean score': get_mean_score(interpretation.score),\n",
    "    }\n",
    "\n",
    "m, n = 4, 4\n",
    "results = {'topk': [], 'post': [], 'pre': []}\n",
    "\n",
    "for name, features in zip(['topk', 'post', 'pre'], [topk, post, pre]):\n",
    "    for feature in features.values():\n",
    "        feature_data = process_feature(feature)\n",
    "        \n",
    "        if name == 'pre':\n",
    "            feature_data['H'] = feature.index // (m + n)\n",
    "        elif name == 'post':\n",
    "            h, r, c = post_act_to_triplet(feature.index, m, n)\n",
    "            row_pre_idx, col_pre_idx = triplet_to_pre_act(h, r, c, m, n)\n",
    "            \n",
    "            row_parent = pre.get(row_pre_idx, None)\n",
    "            col_parent = pre.get(col_pre_idx, None)\n",
    "            \n",
    "            feature_data.update({\n",
    "                'H': h, 'R': r, 'C': c,\n",
    "                **get_parent_properties(row_parent, 'Row'),\n",
    "                **get_parent_properties(col_parent, 'Col')\n",
    "            })\n",
    "        \n",
    "        results[name].append(feature_data)\n",
    "\n",
    "# Create dataframes\n",
    "topk_df = pd.DataFrame(results['topk'])\n",
    "post_df = pd.DataFrame(results['post'])\n",
    "pre_df = pd.DataFrame(results['pre'])\n",
    "\n",
    "# Save to CSV\n",
    "topk_df.to_csv('topk_features.csv', index=False)\n",
    "post_df.to_csv('post_features.csv', index=False)\n",
    "pre_df.to_csv('pre_features.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef78ef8f-c9ab-4eb2-9ac9-c44b095a7a6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Select columns for correlation analysis\n",
    "corr_cols = ['Mean score', 'Detection score', 'Fuzzing score', \n",
    "             'Mean activation', 'Token entropy', 'Multitoken ratio', 'Frequency']\n",
    "\n",
    "# Create figure for heatmaps\n",
    "fig, axes = plt.subplots(2, 3, figsize=(12, 8), sharex=True, sharey=True, dpi=250)\n",
    "\n",
    "# Calculate and plot correlations for each SAE type\n",
    "for i, (sae_type, data) in enumerate(zip(['TopK', 'KronSAE Post', 'KronSAE Pre'], [topk_df, post_df, pre_df])):\n",
    "    sae_data = data[corr_cols]\n",
    "    \n",
    "    # Pearson correlation\n",
    "    pearson_corr = sae_data.corr(method='pearson')\n",
    "    # pearson_corr[pearson_corr.abs() < 0.25] = 0\n",
    "    sns.heatmap(pearson_corr, annot=True, fmt='.2f', cmap='coolwarm', \n",
    "                vmin=-1, vmax=1, ax=axes[0, i], cbar=False)\n",
    "    axes[0, i].set_title(f'{sae_type} - Pearson')\n",
    "    \n",
    "    # Spearman correlation\n",
    "    spearman_corr = sae_data.corr(method='spearman')\n",
    "    # spearman_corr[spearman_corr.abs() < 0.25] = 0\n",
    "    sns.heatmap(spearman_corr, annot=True, fmt='.2f', cmap='coolwarm',\n",
    "                vmin=-1, vmax=1, ax=axes[1, i], cbar=False)\n",
    "    axes[1, i].set_title(f'{sae_type} - Spearman')\n",
    "\n",
    "# Adjust layout and show\n",
    "plt.tight_layout()\n",
    "plt.savefig('results/correlation.png')\n",
    "plt.savefig('results/correlation.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8914eac6-cef5-455a-940a-2216e28f79a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "def plot_feature_parent_correlations(post_df, properties=None):\n",
    "    \"\"\"\n",
    "    Plot bar charts comparing correlations with different parent relationships for each property.\n",
    "    \n",
    "    Args:\n",
    "        post_df: DataFrame containing post-activation features with parent properties\n",
    "        properties: Ordered list of properties to analyze (plotted in this exact order)\n",
    "    \n",
    "    Returns:\n",
    "        matplotlib.figure.Figure: The generated figure\n",
    "    \"\"\"\n",
    "    if properties is None:\n",
    "        properties = [\n",
    "            'Mean score',\n",
    "            'Token entropy',\n",
    "            'Multitoken ratio',\n",
    "            'Mean activation', \n",
    "            'Frequency',\n",
    "        ]\n",
    "    \n",
    "    # Compute correlations for all relationships\n",
    "    corr_data = []\n",
    "    for prop in properties:  # Maintain original order\n",
    "        # Get relevant data\n",
    "        feat = post_df[prop].dropna()\n",
    "        row = post_df[f'Row {prop}'].dropna()\n",
    "        col = post_df[f'Col {prop}'].dropna()\n",
    "        \n",
    "        # Find common valid indices\n",
    "        common_idx = feat.index.intersection(row.index).intersection(col.index)\n",
    "        if len(common_idx) < 2:\n",
    "            continue\n",
    "        \n",
    "        # Calculate derived values\n",
    "        feat_vals = feat.loc[common_idx]\n",
    "        row_vals = row.loc[common_idx]\n",
    "        col_vals = col.loc[common_idx]\n",
    "        mean_vals = (row_vals + col_vals) / 2\n",
    "        prod_vals = row_vals * col_vals\n",
    "        prod_sqrt = np.sqrt(row_vals * col_vals)\n",
    "        \n",
    "        # Compute correlations\n",
    "        correlations = {\n",
    "            'Property': prop,\n",
    "            'Row': pearsonr(feat_vals, row_vals)[0],\n",
    "            'Column': pearsonr(feat_vals, col_vals)[0],\n",
    "            'Mean': pearsonr(feat_vals, mean_vals)[0],\n",
    "            'Product': pearsonr(feat_vals, prod_vals)[0],\n",
    "            'mAND': pearsonr(feat_vals, prod_sqrt)[0],\n",
    "        }\n",
    "        corr_data.append(correlations)\n",
    "    \n",
    "    # Create DataFrame and maintain property order\n",
    "    corr_df = pd.DataFrame(corr_data)\n",
    "    corr_df['Property'] = pd.Categorical(corr_df['Property'], categories=properties, ordered=True)\n",
    "    corr_df = corr_df.sort_values('Property')\n",
    "    corr_df = corr_df.melt(id_vars='Property', var_name='Type', value_name='Correlation')\n",
    "\n",
    "    # Create plot\n",
    "    n_props = len(corr_df['Property'].unique())\n",
    "    fig, axs = plt.subplots(1, n_props, figsize=(12, 2.5), sharey=True, sharex=True, dpi=250)    \n",
    "    \n",
    "    # Plot each property in specified order\n",
    "    for i, prop in enumerate(properties):\n",
    "        if i >= n_props:  # In case some properties were skipped\n",
    "            continue\n",
    "            \n",
    "        ax = axs[i]\n",
    "        group = corr_df[corr_df['Property'] == prop]\n",
    "        \n",
    "        sns.barplot(data=group, x='Type', y='Correlation', hue='Type', legend=i==0,\n",
    "                    palette='Set2', ax=ax, order=['Row', 'Column', 'Mean', 'Product', 'mAND'])\n",
    "\n",
    "        # Add plot decorations\n",
    "        ax.set_title(prop, fontsize=12)\n",
    "        ax.set_xlabel('')\n",
    "        ax.set_ylabel('Pearson r' if i == 0 else '')\n",
    "        ax.set_xticklabels('')\n",
    "        ax.axhline(0, color='black', linewidth=0.8)\n",
    "        ax.set_ylim(0.2, 1)\n",
    "        ax.grid(axis='y', alpha=0.35)\n",
    "        \n",
    "        # Add correlation values as text labels\n",
    "        for p in ax.patches:\n",
    "            ax.annotate(f\"{p.get_height():.2f}\", \n",
    "                       (p.get_x() + p.get_width() / 2., p.get_height()), \n",
    "                       ha='center', va='center', \n",
    "                       xytext=(0, 10), \n",
    "                       textcoords='offset points',\n",
    "                       fontsize=10)\n",
    "\n",
    "    axs[0].legend(loc='upper left', ncol=1)\n",
    "    plt.tight_layout()\n",
    "    return fig\n",
    "\n",
    "fig = plot_feature_parent_correlations(post_df)\n",
    "plt.savefig('results/correlation with parents.png')\n",
    "plt.savefig('results/correlation with parents.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "454e8530-68b9-47f3-a4fb-276883d1d411",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up the figure for comparing distributions of feature properties\n",
    "properties = [\n",
    "    'Detection score', 'Fuzzing score', 'Mean activation', \n",
    "     'Token entropy', 'Multitoken ratio','Frequency' \n",
    "]\n",
    "\n",
    "# Define the grid layout\n",
    "num_rows = 2\n",
    "num_cols = 3\n",
    "\n",
    "fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(12, 4), sharex=False, sharey=True, dpi=250)\n",
    "axes = axes.flatten()\n",
    "\n",
    "# Prepare data for plotting - ensure clean column names and no duplicates\n",
    "dfs = []\n",
    "for df, name in zip([topk_df, pre_df, post_df], ['TopK', 'KronSAE Pre', 'KronSAE Post']):\n",
    "    # Select only the columns we need\n",
    "    df_clean = df[list(set(properties) & set(df.columns))].copy()\n",
    "    df_clean['SAE'] = name\n",
    "    dfs.append(df_clean)\n",
    "\n",
    "# Combine data ensuring no duplicate indices\n",
    "combined_data = pd.concat(dfs, ignore_index=True)\n",
    "\n",
    "# Plot each property's distribution\n",
    "for i, prop in enumerate(properties):\n",
    "    if prop not in combined_data.columns:\n",
    "        continue  # Skip if property doesn't exist in any dataframe\n",
    "        \n",
    "    ax = axes[i]\n",
    "    sns.boxplot(\n",
    "        data=combined_data, \n",
    "        y=\"SAE\",\n",
    "        x=prop,\n",
    "        # y=prop, \n",
    "        ax=ax, \n",
    "        hue=\"SAE\", \n",
    "        palette='Set2', \n",
    "        showfliers=False,\n",
    "        # alpha=0.5,\n",
    "        orient='horizontal',\n",
    "        order=['TopK', 'KronSAE Pre', 'KronSAE Post']\n",
    "    )\n",
    "    \n",
    "    ax.grid(alpha=0.5, axis='x')\n",
    "    ax.set_title(prop)\n",
    "    ax.set_ylabel('')\n",
    "    ax.set_xlabel('')\n",
    "\n",
    "    # if 'score' in prop.lower():\n",
    "        # ax.set_ylim(0.35, 1.025)\n",
    "    if prop in ['Frequency', 'Mean activation']:\n",
    "        ax.set_xscale('log')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('results/distribution.png')\n",
    "plt.savefig('results/distribution.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3039428d-4cf0-457c-94d2-da65997c1ae3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "# Define properties to plot\n",
    "properties = [\n",
    "    'Detection score', 'Fuzzing score', 'Mean activation',\n",
    "    'Token entropy', 'Multitoken ratio', 'Frequency',\n",
    "]\n",
    "\n",
    "# Create figure with constrained layout\n",
    "fig, axs = plt.subplots(2, 3, figsize=(15, 5), dpi=250, layout='constrained')\n",
    "axs = axs.flatten()\n",
    "\n",
    "# Define models and color palette\n",
    "models = [\n",
    "    ('TopK', topk_df.reset_index(drop=True)),\n",
    "    ('KronSAE Pre', pre_df.reset_index(drop=True)), \n",
    "    ('KronSAE Post', post_df.reset_index(drop=True))\n",
    "]\n",
    "palette = {'TopK': '#66c2a5', 'KronSAE Pre': '#fc8d62', 'KronSAE Post': '#8da0cb'}\n",
    "\n",
    "# Create one unified legend at the end\n",
    "legend_handles = []\n",
    "\n",
    "for i, prop in enumerate(properties):\n",
    "    # Create combined dataframe\n",
    "    prop_data = []\n",
    "    for model_name, df in models:\n",
    "        if prop in df.columns:\n",
    "            temp_df = df[[prop]].copy()\n",
    "            temp_df['Model'] = model_name\n",
    "            temp_df = temp_df.rename(columns={prop: 'Value'})\n",
    "            prop_data.append(temp_df)\n",
    "    \n",
    "    combined_prop_df = pd.concat(prop_data, ignore_index=True)\n",
    "    \n",
    "    # Plot with consistent parameters\n",
    "    plot = sns.histplot(\n",
    "        data=combined_prop_df,\n",
    "        x='Value',\n",
    "        hue='Model',\n",
    "        ax=axs[i],\n",
    "        palette=palette,\n",
    "        fill=True,\n",
    "        log_scale=(True if prop in ['Frequency', 'Mean activation'] else False),\n",
    "        bins=30,\n",
    "        linewidth=1.5,\n",
    "        element='step',\n",
    "        common_norm=False,\n",
    "        stat='density',\n",
    "        alpha=0.075,\n",
    "        hue_order=['TopK', 'KronSAE Pre', 'KronSAE Post'],\n",
    "        legend=i == 0\n",
    "    )\n",
    "    \n",
    "    axs[i].set(xlabel='', ylabel='', title=prop)\n",
    "    axs[i].grid(True, alpha=0.3)\n",
    "    \n",
    "    # Store handles from first plot\n",
    "    if i == 0:\n",
    "        legend_handles = [h for h in plot.legend_.legend_handles]\n",
    "\n",
    "    # Extract the actual bins used by seaborn from the plot\n",
    "    bins = []\n",
    "    for collection in axs[i].collections:\n",
    "        if hasattr(collection, 'get_paths'):\n",
    "            paths = collection.get_paths()\n",
    "            if paths:\n",
    "                vertices = paths[0].vertices\n",
    "                bins = np.unique(vertices[:, 0])\n",
    "                break\n",
    "    \n",
    "    # If we couldn't extract bins (shouldn't happen), fall back to default calculation\n",
    "    log_scale_current = prop in ['Frequency', 'Mean activation']\n",
    "    if len(bins) == 0:\n",
    "        all_values = combined_prop_df['Value'].dropna()\n",
    "        if log_scale_current:\n",
    "            positive_vals = all_values#[all_values > 0]\n",
    "            if len(positive_vals) > 0:\n",
    "                bins = np.logspace(np.log10(positive_vals.min()), \n",
    "                                 np.log10(positive_vals.max()), 25)\n",
    "        else:\n",
    "            bins = np.linspace(all_values.min(), all_values.max(), 25)\n",
    "\n",
    "    # Add vertical lines for medians\n",
    "    for model_name, _ in models:\n",
    "        model_data = combined_prop_df.loc[\n",
    "            combined_prop_df['Model'] == model_name, 'Value'\n",
    "        ].dropna()\n",
    "        \n",
    "        if len(model_data) == 0:\n",
    "            continue\n",
    "\n",
    "        median_x = model_data.median()\n",
    "        \n",
    "        # Calculate histogram using the exact bins\n",
    "        counts, _ = np.histogram(model_data, bins=bins)\n",
    "        bin_widths = np.diff(bins)\n",
    "        total_count = len(model_data)\n",
    "        density = counts / (total_count * bin_widths)\n",
    "\n",
    "        # Find bin containing median\n",
    "        bin_index = np.searchsorted(bins, median_x, side='right') - 1\n",
    "        bin_index = np.clip(bin_index, 0, len(density)-1)\n",
    "        y_value = density[bin_index]\n",
    "\n",
    "        # Get current axis limits\n",
    "        y_max_axis = axs[i].get_ylim()[1]\n",
    "        \n",
    "        # Calculate line height\n",
    "        if y_max_axis > 0:\n",
    "            ymax_fraction = min(y_value / y_max_axis, 1.0)\n",
    "        else:\n",
    "            ymax_fraction = 0\n",
    "\n",
    "        mapping = {\n",
    "            'TopK': {'Frequency': 0.48, 'Mean activation': 2.15}, \n",
    "            'KronSAE Pre': {'Frequency': 0.61, 'Mean activation': 1.4}, \n",
    "            'KronSAE Post': {'Frequency': 0.47, 'Mean activation': 1.47}\n",
    "        }\n",
    "\n",
    "        # Draw vertical line\n",
    "        axs[i].axvline(\n",
    "            x=median_x,\n",
    "            ymax=mapping[model_name][prop] / y_max_axis if log_scale_current else ymax_fraction,\n",
    "            color=palette[model_name],\n",
    "            linestyle='--',\n",
    "            linewidth=2,\n",
    "            alpha=1,\n",
    "            zorder=3  # Ensure lines appear above histogram\n",
    "        )\n",
    "\n",
    "median_line_handle = Line2D([0], [0], color='black', linestyle='--', linewidth=1.5, label='Median')\n",
    "all_handles = legend_handles + [median_line_handle]\n",
    "axs[0].legend( \n",
    "    handles=all_handles,\n",
    "    labels=['TopK', 'КronSAE Pre', 'KronSAE Post', 'Median'],\n",
    "    loc='upper left',\n",
    "    ncol=1,\n",
    "    frameon=True\n",
    ")\n",
    "\n",
    "plt.savefig('results/distribution hist.png', bbox_inches='tight', dpi=250)\n",
    "plt.savefig('results/distribution hist.pdf', bbox_inches='tight', dpi=250)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6422c89-9cd0-4fb5-9e29-cdefa58361f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def nanmean(values):\n",
    "    return np.nanmean(np.array(list(values), dtype=float))\n",
    "\n",
    "result = {}\n",
    "\n",
    "for index in list(post.keys()):\n",
    "    triplet = post_act_to_triplet(index, 4, 4)\n",
    "    h, r, c = triplet\n",
    "    row, col = triplet_to_pre_act(h, r, c, 4, 4)\n",
    "\n",
    "    row_interp = pre[row].interpretations[-1]\n",
    "    col_interp = pre[col].interpretations[-1]\n",
    "    post_interp = post[index].interpretations[-1]\n",
    "\n",
    "    row_det = row_interp.score['detection']\n",
    "    row_fuz = row_interp.score['fuzzing']\n",
    "    col_det = col_interp.score['detection']\n",
    "    col_fuz = col_interp.score['fuzzing']\n",
    "    post_det = post_interp.score['detection']\n",
    "    post_fuz = post_interp.score['fuzzing']\n",
    "\n",
    "    mean_det = nanmean([row_det, col_det, post_det])\n",
    "    mean_fuz = nanmean([row_fuz, col_fuz, post_fuz])\n",
    "    \n",
    "    result[index] = {\n",
    "        'h': h, 'r': r, 'c': c, \n",
    "        'interpretation': post_interp.value,\n",
    "        'det': post_det, 'fuz': post_fuz, \n",
    "        'mean score': nanmean([post_det, post_fuz]),\n",
    "        'row': {\n",
    "            'interpretation': row_interp.value,\n",
    "            'det': row_det, 'fuz': row_fuz,\n",
    "            'mean score': nanmean([row_det, row_fuz])\n",
    "        },\n",
    "        'col': {\n",
    "            'interpretation': col_interp.value,\n",
    "            'det': col_det, 'fuz': col_fuz, \n",
    "            'mean score': nanmean([col_det, col_fuz])\n",
    "        },\n",
    "        'mean total': nanmean([mean_det, mean_fuz]),\n",
    "        'mean det': mean_det, 'mean fuz': mean_fuz,\n",
    "    }\n",
    "\n",
    "with open('results/interpretations.json', 'w') as f:\n",
    "    json.dump(result, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef097530-3049-475b-aed2-90dcc886d348",
   "metadata": {},
   "outputs": [],
   "source": [
    "head = np.random.choice(192, 1)[0]\n",
    "# head = 68\n",
    "base = (m + n) * head\n",
    "for row, i in enumerate(range(base, base + m)):\n",
    "    interp = pre[i].interpretations[-1]\n",
    "    print(f\"(h{head}, r{row}) - ({nanmean(interp.score.values()):.2f}): {interp.value}\")\n",
    "for col, i in enumerate(range(base + m, base + m + n)):\n",
    "    interp = pre[i].interpretations[-1]\n",
    "    print(f\"(h{head}, c{col}) - ({nanmean(interp.score.values()):.2f}): {interp.value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6adebd29-44db-447f-93b2-4343c2176ec6",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4\n",
    "# head, row = 18, 3\n",
    "head, row = 23, 3\n",
    "row_idx = head * (m + n) + row\n",
    "interp = pre[row_idx].interpretations[-1]\n",
    "print(f\"(h{head}, r{row}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")\n",
    "for col, col_idx in columns_for_row(head, row, m, n):\n",
    "    interp = pre[col_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\")\n",
    "    postact_idx = pre_to_post(head, row, col, m, n)\n",
    "    interp = post[postact_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, r{row}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8de41477",
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_to_post(23, 3, 2, 4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89b0b2f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "post[np.random.choice(range(len(post)), 1)[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35f7d115",
   "metadata": {},
   "outputs": [],
   "source": [
    "triplet_to_pre_act(23, 3, 0, 4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc8a5a48",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(row.show(\n",
    "    tokenizer=tokenizer,\n",
    "    context_window=15,\n",
    "    max_example_length=100,\n",
    "    examples_per_quantile=35\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5f81df1",
   "metadata": {},
   "outputs": [],
   "source": [
    "triplet_to_pre_act(23, 3, 1, 4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d5beaa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "r, c = triplet_to_pre_act(23, 3, 1, 4, 4)\n",
    "row, col = pre[r], pre[c]\n",
    "row_acts, col_acts = [], []\n",
    "\n",
    "for ex in row.positive_examples:\n",
    "    row_acts.extend(ex.activation_values)\n",
    "for ex in col.positive_examples:\n",
    "    col_acts.extend(ex.activation_values)\n",
    "\n",
    "fig = plt.figure(figsize=(6, 4), dpi=250)\n",
    "sns.histplot(row_acts, alpha=0.0, bins=50, stat='density', element='step', fill=True, lw=1.5, log_scale=True, label=r)\n",
    "sns.histplot(col_acts, alpha=0.0, bins=50, stat='density', element='step', fill=True, lw=1.5, log_scale=True, label=c)\n",
    "\n",
    "third = []\n",
    "for ex in pre[191].positive_examples:\n",
    "    third.extend(ex.activation_values)\n",
    "sns.histplot(third, alpha=0.0, bins=50, stat='density', element='step', fill=True, lw=1.5, log_scale=True, label=\"191\")\n",
    "\n",
    "plt.grid(axis='y', alpha=0.35)\n",
    "plt.legend(loc='upper left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f6d0e40",
   "metadata": {},
   "outputs": [],
   "source": [
    "def return_rows_and_cols(head: int, m=4, n=4):\n",
    "    base = head * 16\n",
    "\n",
    "    col_interp, row_interp = [], []\n",
    "    col_mean, row_mean = [], []\n",
    "    for idx in range(base, base + (m * n)):\n",
    "        target_ftr = result[str(idx)]\n",
    "        row_parent, col_parent = target_ftr['row'], target_ftr['col']\n",
    "\n",
    "        if row_parent['interpretation'] not in row_interp:\n",
    "            row_interp.append(row_parent['interpretation'])\n",
    "            row_mean.append(row_parent['mean score'])\n",
    "        if col_parent['interpretation'] not in col_interp:\n",
    "            col_interp.append(col_parent['interpretation'])\n",
    "            col_mean.append(col_parent['mean score'])\n",
    "\n",
    "    return row_interp, col_interp, row_mean, col_mean\n",
    "\n",
    "\n",
    "def return_row_interactions(head: int, row: int):\n",
    "\n",
    "    interpretations = []\n",
    "    scores = []\n",
    "    for col_idx, _ in columns_for_row(head, row, 4, 4):\n",
    "        target_idx = pre_to_post(head, row, col_idx, 4, 4)\n",
    "        target_ftr = result[str(target_idx)]\n",
    "        interpretations.append(target_ftr['interpretation'])\n",
    "        scores.append(target_ftr['mean score'])\n",
    "\n",
    "    return interpretations, scores\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.set_option('display.max_colwidth', 500)\n",
    "\n",
    "def return_df(head):\n",
    "    \n",
    "    row_interp, col_interp, row_mean, col_mean = return_rows_and_cols(head)\n",
    "\n",
    "    df = pd.DataFrame(data=[[' ', *col_interp]], columns=[-1, 0, 1, 2, 3], index=[-1])\n",
    "    scores_df = pd.DataFrame(data=[[0, *col_mean]], columns=[-1, 0, 1, 2, 3], index=[-1])\n",
    "\n",
    "    for row_idx in range(4):\n",
    "\n",
    "        interpretations, scores = return_row_interactions(head, row_idx)\n",
    "        df = pd.concat((df, pd.DataFrame(data=[[row_interp[row_idx], *interpretations]], index=[row_idx], columns=[-1, 0, 1, 2, 3])))\n",
    "        scores_df = pd.concat((scores_df, pd.DataFrame(data=[[row_mean[row_idx], *scores]], index=[row_idx], columns=[-1, 0, 1, 2, 3])))\n",
    "\n",
    "    return df, scores_df\n",
    "\n",
    "def color_cells(s, cmap):\n",
    "    \"\"\"\n",
    "    Map a scalar value `s` (already in [0, 1]) to a text color from the given colormap.\n",
    "    \"\"\"\n",
    "    # Get the color from the colormap\n",
    "    color = cmap(s)\n",
    "    # Convert the RGBA color to a hex string\n",
    "    hex_color = \"#{:02x}{:02x}{:02x}\".format(\n",
    "        int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)\n",
    "    )\n",
    "    return f\"color: {hex_color}; font-weight: bold;\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab234c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "df, scores_df = return_df(177)\n",
    "cmap = sns.color_palette(\"viridis\", as_cmap=True)\n",
    "df.style.apply(\n",
    "    lambda x: scores_df.applymap(lambda s: color_cells(s, cmap)),\n",
    "    axis=None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8faa7085",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-1.5B')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4acf4c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "features = pickle.load(open('results/post.pkl', 'rb'))\n",
    "features = list(sorted(features.values(), key=lambda x: x.statistics.token_entropy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f121361f",
   "metadata": {},
   "outputs": [],
   "source": [
    "features_with_many_examples = [ftr for ftr in features if len(ftr.positive_examples) > 35]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70b40d81",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 0\n",
    "feature = features_with_many_examples[index]\n",
    "\n",
    "interpretation = feature.interpretations[-1]\n",
    "print(interpretation.score)\n",
    "\n",
    "print(feature.show(\n",
    "    tokenizer=tokenizer,\n",
    "    context_window=15,\n",
    "    max_example_length=100,\n",
    "    examples_per_quantile=7\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "719605e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(post[2178].positive_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4a22b8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('results/interpretations.json', 'r') as f:\n",
    "    result = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27c4a3fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4\n",
    "# head, row = 18, 3\n",
    "head, row = 177, 2\n",
    "row_idx = head * (m + n) + row\n",
    "interp = pre[row_idx].interpretations[-1]\n",
    "print(f\"(h{head}, r{row}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")\n",
    "for col, col_idx in columns_for_row(head, row, m, n):\n",
    "    interp = pre[col_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\")\n",
    "    postact_idx = pre_to_post(head, row, col, m, n)\n",
    "    interp = post[postact_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, r{row}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ec76243",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4\n",
    "# head, row = 18, 3\n",
    "head, row = 23, 2\n",
    "row_idx = head * (m + n) + row\n",
    "interp = pre[row_idx].interpretations[-1]\n",
    "print(f\"(h{head}, r{row}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")\n",
    "for col, col_idx in columns_for_row(head, row, m, n):\n",
    "    interp = pre[col_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\")\n",
    "    postact_idx = pre_to_post(head, row, col, m, n)\n",
    "    interp = post[postact_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, r{row}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b818f9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "\n",
    "with open('interpretations.json', 'r') as f:\n",
    "    result = json.load(f)\n",
    "\n",
    "def print_row_interactions(head, row):\n",
    "    row_parent = result[str(pre_to_post(head, row, 0, 4, 4))]['row']\n",
    "    print(f\"(h{head}, r{row}) - ({row_parent['mean score']:.2f}) - {row_parent['interpretation']}\\n\")\n",
    "    for col_idx, _ in columns_for_row(head, row, 4, 4):\n",
    "        target_idx = pre_to_post(head, row, col_idx, 4, 4)\n",
    "        target_ftr = result[str(target_idx)]\n",
    "        col_parent = target_ftr['col']\n",
    "\n",
    "        print(f\"(h{head}, c{col_idx}) - ({col_parent['mean score']:.2f}) - {col_parent['interpretation']}\")\n",
    "        print(f\"(h{head}, r{row}, c{col_idx}) - ({target_ftr['mean score']:.2f}) - {target_ftr['interpretation']}\\n\")\n",
    "\n",
    "def print_col_interactions(head, col):\n",
    "    col_parent = result[str(pre_to_post(head, 0, col, 4, 4))]['col']\n",
    "    print(f\"(h{head}, c{col}) - ({col_parent['mean score']:.2f}) - {col_parent['interpretation']}\\n\")\n",
    "\n",
    "    elements = rows_for_column(head, col, 4, 4)\n",
    "    for row_idx, _ in elements:\n",
    "        target_idx = pre_to_post(head, row_idx, col, 4, 4)\n",
    "        target_ftr = result[str(target_idx)]\n",
    "        row_parent = target_ftr['row']\n",
    "\n",
    "        print(f\"(h{head}, r{row_idx}) - ({row_parent['mean score']:.2f}) - {row_parent['interpretation']}\")\n",
    "        print(f\"(h{head}, r{row_idx}, c{col}) - ({target_ftr['mean score']:.2f}) - {target_ftr['interpretation']}\\n\")\n",
    "\n",
    "def print_rows_and_cols(head):\n",
    "    base = head * 16\n",
    "\n",
    "    rows, cols = [], []\n",
    "    for idx in range(base, base + (m * n)):\n",
    "        target_ftr = result[str(idx)]\n",
    "        row_parent, col_parent = target_ftr['row'], target_ftr['col']\n",
    "\n",
    "        row_str = f\"(h{target_ftr['h']}, r{target_ftr['r']}) - ({row_parent['mean score']:.2f}): {row_parent['interpretation']}\"\n",
    "        col_str = f\"(h{target_ftr['h']}, c{target_ftr['c']}) - ({col_parent['mean score']:.2f}): {col_parent['interpretation']}\"\n",
    "\n",
    "        if row_str not in rows:\n",
    "            rows.append(row_str)\n",
    "        if col_str not in cols:\n",
    "            cols.append(col_str)\n",
    "\n",
    "    for row in rows:\n",
    "        print(row)\n",
    "    print()\n",
    "    for col in cols:\n",
    "        print(col)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f140974f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print_row_interactions(23, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "934ca688",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4\n",
    "head, row = np.random.choice(192, 1)[0], np.random.choice(4, 1)[0]\n",
    "head, row = 23, 3\n",
    "row_idx = head * (m + n) + row\n",
    "interp = pre[row_idx].interpretations[-1]\n",
    "print(f\"(h{head}, r{row}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")\n",
    "for col, col_idx in columns_for_row(head, row, m, n):\n",
    "    interp = pre[col_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\")\n",
    "    postact_idx = pre_to_post(head, row, col, m, n)\n",
    "    interp = post[postact_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, r{row}, c{col}) - ({nanmean(interp.score.values()):.2f}) - {interp.value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0014e1-a3c6-4c5e-b6cd-a40cdb9480f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4\n",
    "head, col = 4, 3\n",
    "col_idx = head * (m + n) + m + col\n",
    "interp = pre[col_idx].interpretations[-1]\n",
    "print(f\"(h{head}, c{col}) - ({np.mean(list(interp.score.values())):.2f}): {interp.value}\\n\")\n",
    "for row, row_idx in rows_for_column(head, col, m, n):\n",
    "    interp = pre[row_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, r{row}) ({np.mean(list(interp.score.values())):.2f}): {interp.value}\")\n",
    "    postact_idx = pre_to_post(head, row, col, m, n)\n",
    "    interp = post[postact_idx].interpretations[-1]\n",
    "    print(f\"(h{head}, r{row}, c{col}) - ({np.mean(list(interp.score.values())):.2f}) - {postact_idx}: {interp.value}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eafda5d-1c3f-48cc-9219-1b3e62e56086",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(6, 1, figsize=(20, 10), dpi=250, sharex=True)\n",
    "\n",
    "properties = [\"Detection score\", \"Fuzzing score\", \"Mean activation\", \n",
    "              \"Token entropy\", \"Multitoken ratio\", \"Frequency\"]\n",
    "for i, prop in enumerate(properties):\n",
    "    sns.boxplot(post_df, x=\"H\", y=prop, hue=\"H\", legend=False, palette='Set2', showfliers=False, ax=axs[i])\n",
    "    axs[i].grid(axis='y', alpha=0.35)\n",
    "    axs[i].set_xticklabels('')\n",
    "    if prop in ['Frequency', 'Mean activation']:\n",
    "        axs[i].set_yscale('log')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('results/heads comparison.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a55c6ade-407d-4aaa-95bf-a57be7d4b78d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28413a8c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a30b0c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fab9874",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from src.utils import load_trained_sae\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "def load_embs(name):\n",
    "    # sae = load_trained_sae(f'trained/{name}')[1]\n",
    "    weights = torch.load(f'trained/{name}/sae.pt', map_location='cpu')\n",
    "    return weights['W_enc'].detach(), weights['W_dec'].detach()\n",
    "    # return sae.W_enc.detach().T, sae.W_dec.detach()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7a41b84",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f786025a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def renyi_entropy(embs, alpha = 0.99):\n",
    "    K = (embs @ embs.T) # [F, F]\n",
    "    trace = np.trace(K)\n",
    "    eigvals = np.linalg.eigvals(K)\n",
    "    summa = (np.float_power(eigvals / trace, alpha)).sum()\n",
    "    entropy = np.log(summa) / (1 - alpha)\n",
    "    return entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4672e0e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def renyi_entropy_svd(embs, alpha=0.99):\n",
    "    # Perform SVD on embs\n",
    "    U, s, Vt = np.linalg.svd(embs, full_matrices=False)\n",
    "    \n",
    "    eigvals = s**2\n",
    "    trace = np.sum(eigvals)\n",
    "    \n",
    "    normalized_eigvals = eigvals / trace\n",
    "    summa = np.sum(normalized_eigvals**alpha)\n",
    "    \n",
    "    entropy = np.log(summa) / (1 - alpha)\n",
    "    return entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ceccebb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_within_group(embs, m, n):\n",
    "    df = []\n",
    "    for i in range(len(embs)):\n",
    "        h, r, c = post_act_to_triplet(i, m, n)\n",
    "        row_idx, col_idx = triplet_to_pre_act(h, r, c, m, n)\n",
    "        df.append({\n",
    "            \"Index\": i,\n",
    "            \"Head\": h,\n",
    "            \"Row\": row_idx,\n",
    "            \"Column\": col_idx\n",
    "        })\n",
    "    df = pd.DataFrame(df)\n",
    "\n",
    "    result = {\n",
    "        'Head': [], 'Row': [], 'Column': []\n",
    "    }\n",
    "\n",
    "    for group_name in result.keys():\n",
    "        for group in df[group_name].unique():\n",
    "            indices = df[df[group_name] == group]['Index'].to_numpy()\n",
    "            result[group_name].append(renyi_entropy_svd(embs[indices]))\n",
    "    \n",
    "    return result        \n",
    "\n",
    "from prettytable import PrettyTable\n",
    "table = PrettyTable(['SAE', \"Encoder\", \"Decoder\", \"Heads\", \"Rows\", \"Columns\"])\n",
    "\n",
    "data = []\n",
    "\n",
    "for name in ['topk', 'm4 n4', 'm4 n8', 'm8 n8', 'm8 n16', 'm16 n16']:\n",
    "    enc, dec = load_embs(name)\n",
    "    enc_entropy = np.round(renyi_entropy_svd(enc.T), 3)\n",
    "    dec_entropy = np.round(renyi_entropy_svd(dec), 3)\n",
    "    \n",
    "    if name == \"topk\":\n",
    "        entropy_within_group_dec = None\n",
    "        data.append({\n",
    "            'SAE': name,\n",
    "            'Encoder entropy': enc_entropy,\n",
    "            'Decoder entropy': dec_entropy,\n",
    "            'Within head mean': None,\n",
    "            'Within head std': None,\n",
    "            'Within row mean': None,\n",
    "            'Within row std': None,\n",
    "            'Within col mean': None,\n",
    "            'Within col std': None\n",
    "        })\n",
    "        table.add_row([\n",
    "            name, enc_entropy, dec_entropy, '', '', ''\n",
    "        ])\n",
    "    else:\n",
    "        m, n = [int(el[1:]) for el in name.split(' ')]\n",
    "\n",
    "        entropy_within_group_dec = compute_within_group(dec, m, n)\n",
    "        head, row, column = entropy_within_group_dec['Head'], entropy_within_group_dec['Row'], entropy_within_group_dec['Column']\n",
    "\n",
    "        mean_head, std_head = np.mean(head), np.std(head)\n",
    "        mean_row, std_row = np.mean(row), np.std(row)\n",
    "        mean_col, std_col = np.mean(column), np.std(column)\n",
    "\n",
    "        data.append({\n",
    "            'SAE': name,\n",
    "            'Encoder entropy': enc_entropy,\n",
    "            'Decoder entropy': dec_entropy,\n",
    "            'Within head mean': mean_head,\n",
    "            'Within head std': std_head,\n",
    "            'Within row mean': mean_row,\n",
    "            'Within row std': std_row,\n",
    "            'Within col mean': mean_col,\n",
    "            'Within col std': std_col\n",
    "        })\n",
    "\n",
    "        table.add_row([\n",
    "            name, enc_entropy, dec_entropy,\n",
    "            f\"{mean_head:.3f} +- {std_head:.3f}\",\n",
    "            f\"{mean_row:.3f} +- {std_row:.3f}\",\n",
    "            f\"{mean_col:.3f} +- {std_col:.3f}\"\n",
    "        ])\n",
    "\n",
    "data = pd.DataFrame(data)\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e9db00d",
   "metadata": {},
   "outputs": [],
   "source": [
    "enc, dec = load_embs(\"m4 n4\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ca0f138",
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = UMAP(n_components=2, n_neighbors=15, min_dist=0.05, metric='cosine', random_state=42)\n",
    "vectors_2d = reducer.fit_transform(dec[:3072])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14f868a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "m, n = 4, 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb19b662",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = []\n",
    "for j in range(3072):\n",
    "    h, r, c = post_act_to_triplet(j, m, n)\n",
    "    row_idx, col_idx = triplet_to_pre_act(h, r, c, m, n)\n",
    "    if h == 136:\n",
    "        df.append({\n",
    "            \"Index\": j,\n",
    "            \"x\": vectors_2d[j, 0],\n",
    "            \"y\": vectors_2d[j, 1],\n",
    "            \"Head\": h,\n",
    "            \"Row\": row_idx,\n",
    "            \"Column\": col_idx\n",
    "        })\n",
    "df = pd.DataFrame(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "325f89a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32a42e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2, figsize=(4, 2), dpi=250)\n",
    "\n",
    "sns.scatterplot(\n",
    "    df, x='x', y='y', hue='Row', palette=palette,\n",
    "    s=10,           # Marker size\n",
    "    alpha=0.7,      # Transparency\n",
    "    edgecolor='none',\n",
    "    # lw=0.5,\n",
    "    legend=False,\n",
    "    ax=axs[0]\n",
    ")\n",
    "\n",
    "sns.scatterplot(\n",
    "    df, x='x', y='y', hue='Column', palette=palette,\n",
    "    s=10,           # Marker size\n",
    "    alpha=0.7,      # Transparency\n",
    "    edgecolor='none',\n",
    "    # lw=0.5,\n",
    "    legend=False,\n",
    "    ax=axs[1]\n",
    ")\n",
    "\n",
    "for ax in axs:\n",
    "    ax.set_xlabel('')\n",
    "    ax.set_ylabel('')\n",
    "    ax.set_xticklabels('')\n",
    "    ax.set_yticklabels('')\n",
    "\n",
    "axs[0].set_title('By row')\n",
    "axs[1].set_title('By column')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c9ee27d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "from umap import UMAP\n",
    "from prettytable import PrettyTable\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.metrics import silhouette_score, adjusted_rand_score, calinski_harabasz_score\n",
    "\n",
    "table = PrettyTable(['Name', 'P1 SS', 'P1 KH', 'P2 SS', 'P3 SS', 'P2 CH', 'P3 CH'])\n",
    "\n",
    "# --- 2. Apply t-SNE for Dimensionality Reduction ---\n",
    "n_vectors = 512\n",
    "# reducer = TSNE(n_components=2, perplexity=50, random_state=42)\n",
    "# reducer = PCA(n_components=2)\n",
    "reducer = UMAP(n_components=2, n_neighbors=15, min_dist=0.05, metric='cosine', random_state=42)\n",
    "\n",
    "fig, axs = plt.subplots(3, 5, figsize=(12, 7), dpi=250)\n",
    "\n",
    "palette='tab10'\n",
    "\n",
    "s = {\n",
    "    0: 45,\n",
    "    1: 35,\n",
    "    2: 25,\n",
    "    3: 17,\n",
    "    4: 10\n",
    "}\n",
    "\n",
    "n_groups = 4\n",
    "for i, name in enumerate(['m4 n4', 'm4 n8', 'm8 n8', 'm8 n16', 'm16 n16']):\n",
    "\n",
    "    m, n = [int(el[1:]) for el in name.split(' ')]\n",
    "\n",
    "    enc, dec = load_embs(name)\n",
    "    n_vectors = (m * n) * n_groups\n",
    "    # n_vectors = 1024\n",
    "    vectors_2d = reducer.fit_transform(dec[:n_vectors])\n",
    "\n",
    "    df = []\n",
    "    for j in range(n_vectors):\n",
    "        h, r, c = post_act_to_triplet(j, m, n)\n",
    "        row_idx, col_idx = triplet_to_pre_act(h, r, c, m, n)\n",
    "        df.append({\n",
    "            \"Index\": j,\n",
    "            \"x\": vectors_2d[j, 0],\n",
    "            \"y\": vectors_2d[j, 1],\n",
    "            \"Head\": h,\n",
    "            \"Row\": row_idx,\n",
    "            \"Column\": col_idx\n",
    "        })\n",
    "    df = pd.DataFrame(df)\n",
    "\n",
    "    sns.scatterplot(\n",
    "        df, x='x', y='y', hue='Head', palette=palette,\n",
    "        s=s[i],           # Marker size\n",
    "        alpha=0.7,      # Transparency\n",
    "        edgecolor='none',\n",
    "        # lw=0.5,\n",
    "        legend=False,\n",
    "        ax=axs[0, i]\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        df, x='x', y='y', hue='Row', palette=palette,\n",
    "        s=s[i],           # Marker size\n",
    "        alpha=0.7,      # Transparency\n",
    "        edgecolor='none',\n",
    "        # lw=0.5,\n",
    "        legend=False,\n",
    "        ax=axs[1, i]\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        df, x='x', y='y', hue='Column', palette=palette,\n",
    "        s=s[i],           # Marker size\n",
    "        alpha=0.7,      # Transparency\n",
    "        edgecolor='none',\n",
    "        # lw=0.5,\n",
    "        legend=False,\n",
    "        ax=axs[2, i]\n",
    "    )\n",
    "\n",
    "    results = []\n",
    "    for column_name in ['Head', 'Row', 'Column']:\n",
    "        silhouette = silhouette_score(vectors_2d[:n_vectors], df[column_name].to_numpy())\n",
    "        ch_score = calinski_harabasz_score(vectors_2d[:n_vectors], df[column_name].to_numpy())\n",
    "        results.append(f'{silhouette:.3f}')\n",
    "        results.append(f'{ch_score:.3f}')\n",
    "\n",
    "    table.add_row([name, results[0], results[1], results[2], results[4], results[3], results[5]])\n",
    "\n",
    "print(table)\n",
    "\n",
    "for ax in axs.flatten():\n",
    "\n",
    "    ax.grid(alpha=0.2)\n",
    "\n",
    "    ax.set_yticklabels('')\n",
    "    ax.set_xticklabels('')\n",
    "\n",
    "    ax.set_ylabel('')\n",
    "    ax.set_xlabel('')\n",
    "\n",
    "axs[0, 0].set_title(r'$m=4$, $n=4$, $\\text{EV} = 0.845$', fontsize=12)\n",
    "axs[0, 1].set_title(r'$m=4$, $n=8$, $\\text{EV} = 0.824$', fontsize=12)\n",
    "axs[0, 2].set_title(r'$m=8$, $n=8$, $\\text{EV} = 0.815$', fontsize=12)\n",
    "axs[0, 3].set_title(r'$m=8$, $n=16$, $\\text{EV} = 0.812$', fontsize=12)\n",
    "axs[0, 4].set_title(r'$m=16$, $n=16$, $\\text{EV} = 0.799$', fontsize=12)\n",
    "\n",
    "axs[0, 0].set_ylabel('By Head', fontsize=14)\n",
    "axs[1, 0].set_ylabel('By Row', fontsize=14)\n",
    "axs[2, 0].set_ylabel('By Column', fontsize=14)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('results/embeddings 3 x 5 same groups.png')\n",
    "plt.savefig('results/embeddings 3 x 5 same groups.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce70a918",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "from umap import UMAP\n",
    "\n",
    "# --- 2. Apply t-SNE for Dimensionality Reduction ---\n",
    "n_vectors = 512\n",
    "# reducer = TSNE(n_components=2, perplexity=50, random_state=42)\n",
    "# reducer = PCA(n_components=2)\n",
    "reducer = UMAP(n_components=2, n_neighbors=15, min_dist=0.05, metric='cosine', random_state=42)\n",
    "\n",
    "fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=250)\n",
    "axs = axs.flatten()\n",
    "\n",
    "palette='tab10'\n",
    "\n",
    "s = 15\n",
    "\n",
    "n_groups = 4\n",
    "for i, name in enumerate(['topk', 'm4 n4', 'm4 n8', 'm8 n8', 'm8 n16', 'm16 n16']):\n",
    "\n",
    "    if name == 'topk':\n",
    "        m, n = 1, 1\n",
    "    else:\n",
    "        m, n = [int(el[1:]) for el in name.split(' ')]\n",
    "\n",
    "    enc, dec = load_embs(name)\n",
    "    n_vectors = (m * n) * n_groups\n",
    "    n_vectors = 1024\n",
    "    vectors_2d = reducer.fit_transform(dec[:n_vectors])\n",
    "\n",
    "    df = []\n",
    "    for j in range(n_vectors):\n",
    "        df.append({\n",
    "            \"Index\": j,\n",
    "            \"x\": vectors_2d[j, 0],\n",
    "            \"y\": vectors_2d[j, 1],\n",
    "            \"Head\": j // (m * n)\n",
    "        })\n",
    "    df = pd.DataFrame(df)\n",
    "\n",
    "    sns.scatterplot(\n",
    "        df, x='x', y='y', hue='Head', palette=palette,\n",
    "        s=s,           # Marker size\n",
    "        alpha=0.7,      # Transparency\n",
    "        edgecolor='none',\n",
    "        # lw=0.5,\n",
    "        legend=False,\n",
    "        ax=axs[i]\n",
    "    )\n",
    "\n",
    "for ax in axs.flatten():\n",
    "\n",
    "    ax.grid(alpha=0.2)\n",
    "\n",
    "    ax.set_yticklabels('')\n",
    "    ax.set_xticklabels('')\n",
    "\n",
    "    ax.set_ylabel('')\n",
    "    ax.set_xlabel('')\n",
    "\n",
    "axs[0].set_title('TopK')\n",
    "axs[1].set_title('M=4, N=4')\n",
    "axs[2].set_title('M=4, N=8')\n",
    "axs[3].set_title('M=8, N=8')\n",
    "axs[4].set_title('M=8, N=16')\n",
    "axs[5].set_title('M=16, N=16')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('results/embeddings 1 x 5 with topk.png')\n",
    "plt.savefig('results/embeddings 1 x 5 with topk.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10503807",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "from umap import UMAP\n",
    "\n",
    "# --- 2. Apply t-SNE for Dimensionality Reduction ---\n",
    "n_vectors = 512\n",
    "# reducer = TSNE(n_components=2, perplexity=50, random_state=42)\n",
    "# reducer = PCA(n_components=2)\n",
    "reducer = UMAP(n_components=2, n_neighbors=15, min_dist=0.05, metric='cosine', random_state=42)\n",
    "\n",
    "fig, axs = plt.subplots(2, 3, figsize=(9, 6), dpi=250)\n",
    "\n",
    "palette = 'tab10'\n",
    "n_groups = 4\n",
    "for i, name in enumerate(['m8 n8', 'm16 n16']):\n",
    "\n",
    "    m, n = [int(el[1:]) for el in name.split(' ')]\n",
    "\n",
    "    enc, dec = load_embs(name)\n",
    "    n_vectors = (m * n) * n_groups\n",
    "    # n_vectors = 1024\n",
    "    vectors_2d = reducer.fit_transform(dec[:n_vectors])\n",
    "\n",
    "    df = []\n",
    "    for j in range(n_vectors):\n",
    "        h, r, c = post_act_to_triplet(j, m, n)\n",
    "        row_idx, col_idx = triplet_to_pre_act(h, r, c, m, n)\n",
    "        df.append({\n",
    "            \"Index\": j,\n",
    "            \"x\": vectors_2d[j, 0],\n",
    "            \"y\": vectors_2d[j, 1],\n",
    "            \"Head\": h,\n",
    "            \"Row\": row_idx,\n",
    "            \"Column\": col_idx\n",
    "        })\n",
    "    df = pd.DataFrame(df)\n",
    "\n",
    "    s = 10 if i == 1 else 15\n",
    "\n",
    "    sns.scatterplot(\n",
    "        df, x='x', y='y', hue='Head', palette=palette,\n",
    "        s=s,           # Marker size\n",
    "        alpha=0.7,      # Transparency\n",
    "        edgecolor='none',\n",
    "        # lw=0.5,\n",
    "        legend=False,\n",
    "        ax=axs[i, 0]\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        df, x='x', y='y', hue='Row', palette=palette,\n",
    "        s=s,           # Marker size\n",
    "        alpha=0.7,      # Transparency\n",
    "        edgecolor='none',\n",
    "        # lw=0.5,\n",
    "        legend=False,\n",
    "        ax=axs[i, 1]\n",
    "    )\n",
    "\n",
    "    sns.scatterplot(\n",
    "        df, x='x', y='y', hue='Column', palette=palette,\n",
    "        s=s,           # Marker size\n",
    "        alpha=0.7,      # Transparency\n",
    "        edgecolor='none',\n",
    "        # lw=0.5,\n",
    "        legend=False,\n",
    "        ax=axs[i, 2]\n",
    "    )\n",
    "\n",
    "for ax in axs.flatten():\n",
    "\n",
    "    ax.grid(alpha=0.2)\n",
    "\n",
    "    ax.set_yticklabels('')\n",
    "    ax.set_xticklabels('')\n",
    "\n",
    "    ax.set_ylabel('')\n",
    "    ax.set_xlabel('')\n",
    "\n",
    "axs[0, 0].set_ylabel('M=8, N=8')\n",
    "axs[1, 0].set_ylabel('M=16, N=16')\n",
    "\n",
    "axs[0, 0].set_title('By Head')\n",
    "axs[0, 1].set_title('By Row')\n",
    "axs[0, 2].set_title('By Column')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('results/embeddings 2 x 3.png')\n",
    "plt.savefig('results/embeddings 2 x 3.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b0a921a-91c8-48d5-9fe3-48f9adee9e09",
   "metadata": {},
   "outputs": [],
   "source": [
    "def post_act_to_triplet(post_act_idx: int, m: int, n: int):\n",
    "    elements_per_head = m * n\n",
    "    head_idx = post_act_idx // elements_per_head\n",
    "    position_in_head = post_act_idx % elements_per_head\n",
    "    row_idx = position_in_head // n\n",
    "    col_idx = position_in_head % n\n",
    "    return head_idx, row_idx, col_idx\n",
    "\n",
    "def triplet_to_pre_act(head_idx: int, \n",
    "                      row_idx: int,\n",
    "                      col_idx: int,\n",
    "                      m: int, n: int):\n",
    "    features_per_head = m + n\n",
    "    base_idx = head_idx * features_per_head\n",
    "    row_pre_act = base_idx + row_idx\n",
    "    column_pre_act = base_idx + m + col_idx\n",
    "    return row_pre_act, column_pre_act\n",
    "\n",
    "def columns_for_row(head: int, row: int, m: int, n: int):\n",
    "    assert row < m, f\"Row index should be less than m = {m}\"\n",
    "    base = head * (m + n) + m\n",
    "    return [(col, base + col) for col in range(n)]\n",
    "\n",
    "def rows_for_column(head: int, column: int, m: int, n: int):\n",
    "    assert column < n, f\"Column index should be less than n = {n}\"\n",
    "    base = head * (m + n)\n",
    "    return [(row, base + row) for row in range(m)]\n",
    "\n",
    "def pre_to_post(head: int, row: int, column: int, m: int, n: int):\n",
    "    return head * (m * n) + row * n + column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee42cd9-3028-449d-a7fd-ec1238470ad1",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('interpretations.json', 'r') as f:\n",
    "    result = json.load(f)\n",
    "\n",
    "def print_row_interactions(head, row):\n",
    "    row_parent = result[str(pre_to_post(head, row, 0, 4, 4))]['row']\n",
    "    print(f\"(h{head}, r{row}) - ({row_parent['mean score']:.2f}) - {row_parent['interpretation']}\\n\")\n",
    "    for col_idx, _ in columns_for_row(head, row, 4, 4):\n",
    "        target_idx = pre_to_post(head, row, col_idx, 4, 4)\n",
    "        target_ftr = result[str(target_idx)]\n",
    "        col_parent = target_ftr['col']\n",
    "\n",
    "        print(f\"(h{head}, c{col_idx}) - ({col_parent['mean score']:.2f}) - {col_parent['interpretation']}\")\n",
    "        print(f\"(h{head}, r{row}, c{col_idx}) - ({target_ftr['mean score']:.2f}) - {target_ftr['interpretation']}\\n\")\n",
    "\n",
    "def print_col_interactions(head, col):\n",
    "    col_parent = result[str(pre_to_post(head, 0, col, 4, 4))]['col']\n",
    "    print(f\"(h{head}, c{col}) - ({col_parent['mean score']:.2f}) - {col_parent['interpretation']}\\n\")\n",
    "\n",
    "    elements = rows_for_column(head, col, 4, 4)\n",
    "    for row_idx, _ in elements:\n",
    "        target_idx = pre_to_post(head, row_idx, col, 4, 4)\n",
    "        target_ftr = result[str(target_idx)]\n",
    "        row_parent = target_ftr['row']\n",
    "\n",
    "        print(f\"(h{head}, r{row_idx}) - ({row_parent['mean score']:.2f}) - {row_parent['interpretation']}\")\n",
    "        print(f\"(h{head}, r{row_idx}, c{col}) - ({target_ftr['mean score']:.2f}) - {target_ftr['interpretation']}\\n\")\n",
    "\n",
    "def print_rows_and_cols(head):\n",
    "    base = head * 16\n",
    "\n",
    "    rows, cols = [], []\n",
    "    for idx in range(base, base + (m * n)):\n",
    "        target_ftr = result[str(idx)]\n",
    "        row_parent, col_parent = target_ftr['row'], target_ftr['col']\n",
    "\n",
    "        row_str = f\"(h{target_ftr['h']}, r{target_ftr['r']}) - ({row_parent['mean score']:.2f}): {row_parent['interpretation']}\"\n",
    "        col_str = f\"(h{target_ftr['h']}, c{target_ftr['c']}) - ({col_parent['mean score']:.2f}): {col_parent['interpretation']}\"\n",
    "\n",
    "        if row_str not in rows:\n",
    "            rows.append(row_str)\n",
    "        if col_str not in cols:\n",
    "            cols.append(col_str)\n",
    "\n",
    "    for row in rows:\n",
    "        print(row)\n",
    "    print()\n",
    "    for col in cols:\n",
    "        print(col)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
