{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import glob\n",
    "import torch\n",
    "import pickle\n",
    "import re\n",
    "import plotly.graph_objects as go\n",
    "import einops\n",
    "import pandas as pd\n",
    "from functools import partial\n",
    "from torch import Tensor\n",
    "from torchtyping import TensorType as TT\n",
    "from utils.visualization import imshow_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_TO_VIEW = \"pythia-160m\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = f\"backup_research/FINAL_figure_files/sr_over_time/data/{MODEL_TO_VIEW}/sample_results_dict.pt\"\n",
    "data = torch.load(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# sort the dictionary by the key\n",
    "data = dict(sorted(data.items(), key=lambda item: int(item[0])))\n",
    "data.keys()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single Checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CHECKPOINT = 143000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "repair_score = 1 - (data[CHECKPOINT]['thresholded_cil'][0] / -data[CHECKPOINT]['thresholded_de'][0])\n",
    "#weighted_repair_score = (data[CHECKPOINT]['thresholded_de'][0] / -data[CHECKPOINT]['thresholded_cil'][0]) * data[CHECKPOINT]['thresholded_de'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer = 11\n",
    "for head in range(12):\n",
    "    print((data[143000]['thresholded_de'][0][layer, head], data[143000]['thresholded_cil'][0][layer, head], repair_score[layer, head]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union\n",
    "from jaxtyping import Float\n",
    "import numpy as np\n",
    "import itertools\n",
    "\n",
    "def create_layered_scatter(\n",
    "    heads_x: Float[Tensor, \"layer head\"],\n",
    "    heads_y: Float[Tensor, \"layer head\"], \n",
    "    x_title: str, \n",
    "    y_title: str, \n",
    "    plot_title: str,\n",
    "    mlp_x: Union[Float[Tensor, \"layer\"], None] = None,\n",
    "    mlp_y: Union[Float[Tensor, \"layer\"], None] = None,\n",
    "    x_range: Union[list, None] = None,  # New parameter for x-range\n",
    "    y_range: Union[list, None] = None   # New parameter for y-range\n",
    "):\n",
    "    \"\"\"\n",
    "    This function now also accepts x_data and y_data for MLP layers and manual x- and y-ranges. \n",
    "    It plots properties of transformer heads and MLP layers with layered coloring and annotations.\n",
    "    Additionally, it plots a dotted line where the negative value of the y-axis equals the positive value of the x-axis.\n",
    "    \"\"\"\n",
    "    num_layers = 12\n",
    "    num_heads = 12\n",
    "    layer_colors = np.linspace(0, num_layers, num_layers, endpoint=False)\n",
    "    \n",
    "    # Annotations and colors for transformer heads\n",
    "    head_annotations = [f\"Layer {layer}, Head {head}\" for layer, head in itertools.product(range(num_layers), range(num_heads))]\n",
    "    head_marker_colors = [layer_colors[layer] for layer in range(num_layers) for _ in range(num_heads)]\n",
    "\n",
    "    # Prepare MLP data if provided\n",
    "    mlp_annotations = []\n",
    "    mlp_marker_colors = []\n",
    "    if mlp_x is not None and mlp_y is not None:\n",
    "        mlp_annotations = [f\"MLP Layer {layer}\" for layer in range(num_layers)]\n",
    "        mlp_marker_colors = [layer_colors[layer] for layer in range(num_layers)]\n",
    "\n",
    "    # Flatten data\n",
    "    heads_x = heads_x.flatten().cpu().numpy() if heads_x.ndim > 1 else heads_x.cpu().numpy()\n",
    "    heads_y = heads_y.flatten().cpu().numpy() if heads_y.ndim > 1 else heads_y.cpu().numpy()\n",
    "    if mlp_x is not None and mlp_y is not None:\n",
    "        mlp_x = mlp_x.flatten().cpu().numpy() if mlp_x.ndim > 1 else mlp_x.cpu().numpy()\n",
    "        mlp_y = mlp_y.flatten().cpu().numpy() if mlp_y.ndim > 1 else mlp_y.cpu().numpy()\n",
    "\n",
    "    # Create scatter plots\n",
    "    scatter_heads = go.Scatter(\n",
    "        x=heads_x,\n",
    "        y=heads_y,\n",
    "        text=head_annotations,\n",
    "        mode='markers',\n",
    "        marker=dict(\n",
    "            size=8,\n",
    "            opacity=0.8,\n",
    "            color=head_marker_colors,\n",
    "            colorscale='Viridis',\n",
    "            colorbar=dict(\n",
    "                title='Layer',\n",
    "                #tickvals=[0, num_layers - 1],\n",
    "                #ticktext=[0, 1,2,1,1,1,1,1,1,1,1,1,1,11,1,1,3,4,5,5,num_layers - 1],\n",
    "                orientation=\"h\"\n",
    "            ),\n",
    "            line=dict(width=0.5, color='DarkSlateGrey')\n",
    "        ),\n",
    "        name=\"Attention Heads\"\n",
    "    )\n",
    "\n",
    "    scatter_mlp = go.Scatter(\n",
    "        x=mlp_x,\n",
    "        y=mlp_y,\n",
    "        text=mlp_annotations,\n",
    "        mode='markers',\n",
    "        name='MLP Layers',\n",
    "        marker=dict(\n",
    "            size=10,\n",
    "            opacity=0.6,\n",
    "            color=mlp_marker_colors,\n",
    "            colorscale='Viridis',\n",
    "            symbol='diamond',\n",
    "            line=dict(width=1, color='Black')\n",
    "        )\n",
    "    ) if mlp_x is not None and mlp_y is not None else None\n",
    "\n",
    "    # Create the figure and add the traces\n",
    "    fig = go.Figure()\n",
    "    fig.add_trace(scatter_heads)\n",
    "    if scatter_mlp:\n",
    "        fig.add_trace(scatter_mlp)\n",
    "\n",
    "    # Add a dotted line where the negative y-value equals the positive x-value\n",
    "    if x_range and y_range:\n",
    "        # Ensuring the line covers the entire visible range by finding the min and max\n",
    "        line_range = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])]\n",
    "        fig.add_trace(go.Scatter(x=line_range, y=[-x for x in line_range], mode='lines', line=dict(color='grey', dash='dot'), name='Y=-X'))\n",
    "\n",
    "    # Update the layout with the manual x- and y-range\n",
    "    fig.update_layout(\n",
    "        title=f\"{plot_title}\",\n",
    "        title_x=0.5,\n",
    "        xaxis_title=x_title,\n",
    "        yaxis_title=y_title,\n",
    "        legend_title=\"Component\",\n",
    "        # do not show legend\n",
    "        showlegend=False,\n",
    "        width=500,\n",
    "        height=500,\n",
    "        #xaxis_range=x_range,\n",
    "        #yaxis_range=y_range\n",
    "    )\n",
    "\n",
    "    return fig\n",
    "fig = create_layered_scatter(\n",
    "    data[CHECKPOINT]['thresholded_de'][0], \n",
    "    data[CHECKPOINT]['thresholded_cil'][0], \n",
    "    \"Direct Effect of Component\", \n",
    "    \"Change in Logits Upon Ablation\", \n",
    "    f\"Self-Repair in {MODEL_TO_VIEW} at Checkpoint {CHECKPOINT}\",\n",
    "    x_range=[-0.1, 0.2],\n",
    "    y_range=[-0.4, 0.1]    \n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imshow_p(\n",
    "    repair_score * 100,\n",
    "    title=\"Self-Repair Score\",\n",
    "    labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Self-Repair Score\"},\n",
    "    coloraxis=dict(colorbar_ticksuffix = \"%\", cmin=-500, cmax=500),\n",
    "    border=True,\n",
    "    width=600,\n",
    "    margin={\"r\": 100, \"l\": 100}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[CHECKPOINT]['thresholded_cil'][0].sum(), data[CHECKPOINT]['thresholded_de'][0].sum(), repair_score.sum()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### All Checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "subset_checkpoints = [512, 3000, 60000, 143000]\n",
    "subset_data = {checkpoint: data[checkpoint] for checkpoint in subset_checkpoints}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from utils.visualization import convert_title_to_filename\n",
    "\n",
    "\n",
    "# Specify your desired x and y axis range\n",
    "x_range = [-0.1, 0.4]\n",
    "y_range = [-0.4, 0.1]\n",
    "\n",
    "# Determine the grid size (rows x columns)\n",
    "total_checkpoints = len(subset_data.keys())\n",
    "num_columns = 4  # Example: 3 columns in your grid\n",
    "num_rows = (total_checkpoints + num_columns - 1) // num_columns  # Calculate rows needed\n",
    "\n",
    "# Initialize the subplot figure with specified rows and columns\n",
    "fig = make_subplots(\n",
    "    rows=num_rows, cols=num_columns, \n",
    "    subplot_titles=[f\"Checkpoint {checkpoint}\" for checkpoint in data.keys()],\n",
    "    horizontal_spacing=0.05,  # Reduce horizontal spacing\n",
    "    vertical_spacing=0.05     # Reduce vertical spacing\n",
    ")\n",
    "\n",
    "subplot_index = 1  # Initialize subplot index\n",
    "\n",
    "for checkpoint in subset_data.keys():\n",
    "    # Calculate the current row and column position\n",
    "    row = (subplot_index - 1) // num_columns + 1\n",
    "    col = (subplot_index - 1) % num_columns + 1\n",
    "\n",
    "    # Generate the plot for the current checkpoint\n",
    "    current_fig = create_layered_scatter(\n",
    "        subset_data[checkpoint]['thresholded_de'][0], \n",
    "        subset_data[checkpoint]['thresholded_cil'][0], \n",
    "        \"Direct Effect of Component\", \n",
    "        \"Change in Logits Upon Ablation\", \n",
    "        f\"Self-Repair in {MODEL_TO_VIEW} at Checkpoint {checkpoint}\",\n",
    "    )\n",
    "\n",
    "    # Add each trace from the current figure to the subplot\n",
    "    for trace in current_fig.data:\n",
    "        fig.add_trace(trace, row=row, col=col)\n",
    "\n",
    "    # Explicitly set x and y ranges for this subplot\n",
    "    fig.update_xaxes(range=x_range, row=row, col=col)\n",
    "    fig.update_yaxes(range=y_range, row=row, col=col)\n",
    "    \n",
    "    # Add the dotted line where the negative value of the y-axis equals the positive value of the x-axis\n",
    "    line_range = [min(x_range[0], y_range[0]), max(x_range[1], y_range[1])]\n",
    "    fig.add_trace(\n",
    "        go.Scatter(x=line_range, y=[-x for x in line_range], mode='lines',\n",
    "                   line=dict(color='grey', dash='dot'), name='Y=-X'),\n",
    "        row=row, col=col\n",
    "    )\n",
    "\n",
    "    subplot_index += 1\n",
    "\n",
    "title = f\"Self-Repair for {MODEL_TO_VIEW}\"\n",
    "\n",
    "# Update layout with centered axis titles, increased legend-subplot spacing, and specified label size\n",
    "fig.update_layout(\n",
    "    height=400*num_rows+100, \n",
    "    width=350*num_columns, \n",
    "    title_text=title, \n",
    "    showlegend=False,\n",
    "    xaxis_title=\"Original Direct Effect\",  # Centered x-axis title\n",
    "    yaxis_title=\"Change in Logits\",  # Centered y-axis title\n",
    "    margin=dict(t=75),  # Increase top margin to provide space for the legend\n",
    "    xaxis_title_font=dict(size=16),  # Set x-axis label size\n",
    "    yaxis_title_font=dict(size=16)   # Set y-axis label size\n",
    ")\n",
    "\n",
    "filename = \"results/plots/\" + convert_title_to_filename(title) + \".pdf\"\n",
    "fig.write_image(filename, format='pdf', width=350*num_columns, height=400*num_rows+100, engine=\"kaleido\")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_title_to_filename(title: str):\n",
    "    # replace spaces with dashes, remove parentheses, and make lowercase\n",
    "    return title.replace(' ', '-').replace('(', '').replace(')', '').lower()\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "from typing import Dict\n",
    "\n",
    "def plot_all_heads(\n",
    "        model_name: str,\n",
    "        checkpoint_dict: Dict[int, Dict[str, np.ndarray]], \n",
    "        plot_everything: bool = False, \n",
    "        top_k_per_checkpoint: int = 5, \n",
    "        top_k: int = 5\n",
    "    ) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Plot the head attributions across checkpoints, with the option to plot all heads or only the top ones.\n",
    "\n",
    "    Args:\n",
    "        model_name (str): Name of the model for title display.\n",
    "        checkpoint_dict (Dict[int, Dict[str, np.ndarray]]): A dictionary mapping checkpoints to a dictionary\n",
    "            that includes a key \"self_repair_score\" pointing to numpy arrays of head attributions.\n",
    "        plot_everything (bool, optional): If True, plots all heads without applying top_k filters. Defaults to False.\n",
    "        top_k_per_checkpoint (int, optional): The number of top heads to consider per checkpoint. Effective only if plot_everything is False.\n",
    "        top_k (int, optional): The number of overall top heads to plot. Effective only if plot_everything is False.\n",
    "\n",
    "    Returns:\n",
    "        pd.DataFrame: A DataFrame containing the plot data.\n",
    "    \"\"\"\n",
    "    plot_data = []\n",
    "\n",
    "    for checkpoint, data in checkpoint_dict.items():\n",
    "        array = data['self_repair_score'][0].numpy()\n",
    "        print(array.shape)\n",
    "        \n",
    "        if plot_everything:\n",
    "            indices = np.indices(array.shape)\n",
    "            selected_heads = [(layer, head) for layer, head in zip(indices[0].flatten(), indices[1].flatten())]\n",
    "        else:\n",
    "            # Use argpartition to get the indices of the top heads in the entire array\n",
    "            flat_indices = np.argpartition(array.flatten(), -top_k_per_checkpoint)[-top_k_per_checkpoint:]\n",
    "            # Convert flat indices to 2D indices\n",
    "            indices = np.unravel_index(flat_indices, array.shape)\n",
    "            selected_heads = [(layer, head) for layer, head in zip(indices[0], indices[1])]\n",
    "\n",
    "        for layer, head in selected_heads:\n",
    "            plot_data.append(\n",
    "                {\n",
    "                    'Checkpoint': checkpoint,\n",
    "                    'Layer-Head': f'Layer {layer}-Head {head}',\n",
    "                    'Layer': layer,\n",
    "                    'Head': head,\n",
    "                    'Value': float(array[layer, head])  # Ensure conversion to float\n",
    "                }\n",
    "            )\n",
    "\n",
    "    # Convert to DataFrame\n",
    "    df = pd.DataFrame(plot_data)\n",
    "\n",
    "    if not plot_everything:\n",
    "        # Ensure 'Value' is numeric for aggregation functions\n",
    "        df['Value'] = pd.to_numeric(df['Value'], errors='coerce')  # Converts non-numeric to NaN, can handle errors\n",
    "\n",
    "        # Calculate sum of values over all checkpoints for each head\n",
    "        summary_df = df.groupby(['Layer-Head', 'Layer', 'Head']).sum().reset_index()\n",
    "\n",
    "        # Label the top_k items in summary_df based on their sum\n",
    "        summary_df['Top K'] = summary_df['Layer-Head'].isin(df.groupby('Layer-Head').mean().nlargest(top_k, 'Value').index)\n",
    "\n",
    "        # Filter the DataFrame to include only the top_k heads across all checkpoints\n",
    "        df = df.merge(summary_df, on=['Layer-Head', 'Layer', 'Head'], how='inner').query('`Top K`')\n",
    "\n",
    "    # Step 3: Plot the data\n",
    "    fig = px.line(\n",
    "        df, \n",
    "        x='Checkpoint',  # Corrected column name for Checkpoint\n",
    "        y='Value',       # Assuming Value_x is the correct column for self-repair scores\n",
    "        color='Layer-Head', \n",
    "        # specify y_range\n",
    "        range_y=[-300, 300],\n",
    "        title=f'Self Repair Across Checkpoints (DE+CIL/DE) ({model_name})', \n",
    "        height=500,\n",
    "        labels={'Checkpoint': 'Checkpoint', 'Value': 'Self Repair Score'}  # Correct labels for axes\n",
    "    )\n",
    "    fig.show()\n",
    "\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for checkpoint in data.keys():\n",
    "    data[checkpoint]['self_repair_score'] = (data[checkpoint]['thresholded_de'] + data[checkpoint]['thresholded_cil']) / data[checkpoint]['thresholded_de']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = plot_all_heads(MODEL_TO_VIEW, data, plot_everything=True, top_k_per_checkpoint=5, top_k=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[143000]['self_repair_score'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head(50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
