{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0df83144",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib.ticker import MaxNLocator\n",
    "\n",
    "# --- Input/Output Paths ---\n",
    "# Modify these paths to match your project structure.\n",
    "INPUT_DIR = \"full_compute\"  # The directory containing your input JSON files.\n",
    "OUTPUT_DIR_BASE = \"graphs\"  # The base directory where output graphs will be saved.\n",
    "\n",
    "# --- Plotting Customization ---\n",
    "# Easily change these parameters to customize all generated graphs.\n",
    "TOP_K_LAYERS = 3\n",
    "BOTTOM_K_LAYERS = 3\n",
    "DELTA_BIAS_CHERRY_PICK_K = 3\n",
    "BAR_COLORS = ('#0077bb', \"#ee5233\")  # Dark Blue and Orange\n",
    "\n",
    "# CHANGE: Added more granular font size and tick control\n",
    "BASE_FONT_SIZE = 20\n",
    "X_AXIS_TICK_FONT_SIZE = 20  # Font size for labels on the x-axis\n",
    "Y_AXIS_TICK_FONT_SIZE = 16  # Font size for labels on the y-axis\n",
    "Y_AXIS_MAX_TICKS = 5    # Maximum number of ticks on the y-axis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "20dd9af2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scanning for JSON files in: full_compute\n",
      "\n",
      "Processing file: llama_1b.json\n",
      "--> Successfully saved plot to graphs/llama_1b/graphs/mlp_comparison.pdf\n",
      "--> Successfully saved plot to graphs/llama_1b/graphs/attn_comparison.pdf\n",
      "\n",
      "Processing file: gemma_4b.json\n",
      "--> Successfully saved plot to graphs/gemma_4b/graphs/mlp_comparison.pdf\n",
      "--> Successfully saved plot to graphs/gemma_4b/graphs/attn_comparison.pdf\n",
      "\n",
      "Processing file: llama_3b.json\n",
      "--> Successfully saved plot to graphs/llama_3b/graphs/mlp_comparison.pdf\n",
      "--> Successfully saved plot to graphs/llama_3b/graphs/attn_comparison.pdf\n",
      "\n",
      "Processing file: gemma_1b.json\n",
      "--> Successfully saved plot to graphs/gemma_1b/graphs/mlp_comparison.pdf\n",
      "--> Successfully saved plot to graphs/gemma_1b/graphs/attn_comparison.pdf\n",
      "\n",
      "Script finished.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# %%\n",
    "# =============================================================================\n",
    "# CELL 2: PLOTTING FUNCTION\n",
    "# =============================================================================\n",
    "# def plot_block_type_results_customizable(experiment_data, block_type, side_names, main_title, output_dir):\n",
    "#     \"\"\"\n",
    "#     Generates and saves grouped bar charts for a specific block type (attn or mlp).\n",
    "\n",
    "#     Args:\n",
    "#         experiment_data (dict): The dictionary for one experiment.\n",
    "#         block_type (str): The type of block to plot ('attn' or 'mlp').\n",
    "#         side_names (tuple): Tuple of the two side names (e.g., ('cpp_top', 'python_top')).\n",
    "#         main_title (str): The overall title for the figure.\n",
    "#         output_dir (str): The directory where the plot image will be saved.\n",
    "#     \"\"\"\n",
    "#     side_a_name, side_b_name = side_names\n",
    "\n",
    "#     filtered_data = {\n",
    "#         key: value for key, value in experiment_data.items() if block_type in key\n",
    "#     }\n",
    "\n",
    "#     if not filtered_data:\n",
    "#         print(f\"--> Warning: No data found for block type '{block_type}'. Skipping plot.\")\n",
    "#         return\n",
    "\n",
    "#     fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))\n",
    "#     fig.suptitle(main_title, fontsize=BASE_FONT_SIZE + 6, weight='bold')\n",
    "\n",
    "#     axes = {'KL': ax1, 'shift_other': ax2}\n",
    "    \n",
    "#     y_limits = {'KL': [float('inf'), float('-inf')], 'delta_bias': [float('inf'), float('-inf')]}\n",
    "\n",
    "#     all_labels = []\n",
    "#     num_top_layers = 0\n",
    "\n",
    "#     # First pass: gather all data and determine y-axis limits\n",
    "#     for category_name, layers_data in sorted(filtered_data.items(), reverse=True):\n",
    "#         k = TOP_K_LAYERS if 'top' in category_name else BOTTOM_K_LAYERS\n",
    "        \n",
    "#         if 'top' in category_name:\n",
    "#             layers_to_process = dict(list(layers_data.items())[:k])\n",
    "#         else:  # 'bottom' category\n",
    "#             layers_to_process = dict(list(layers_data.items())[-k:])  # Slice from the end for bottom layers\n",
    "        \n",
    "#         if 'top' in category_name:\n",
    "#             num_top_layers = len(layers_to_process)\n",
    "        \n",
    "#         all_labels.extend(layers_to_process.keys())\n",
    "\n",
    "#         for metric in [\"KL\", \"delta_bias\"]:\n",
    "#             means_a = [layers_to_process[layer][side_a_name][metric]['mean'] for layer in layers_to_process]\n",
    "#             means_b = [layers_to_process[layer][side_b_name][metric]['mean'] for layer in layers_to_process]\n",
    "            \n",
    "#             min_val = min(min(means_a), min(means_b))\n",
    "#             max_val = max(max(means_a), max(means_b))\n",
    "\n",
    "#             if min_val < y_limits[metric][0]: y_limits[metric][0] = min_val\n",
    "#             if max_val > y_limits[metric][1]: y_limits[metric][1] = max_val\n",
    "\n",
    "#     # Second pass: plot the data\n",
    "#     for metric, ax in axes.items():\n",
    "#         if metric in [\"KL\", \"delta_bias\"]:\n",
    "#             all_means_a, all_stds_a = [], []\n",
    "#             all_means_b, all_stds_b = [], []\n",
    "\n",
    "#             for category_name, layers_data in sorted(filtered_data.items(), reverse=True):\n",
    "#                 k = TOP_K_LAYERS if 'top' in category_name else BOTTOM_K_LAYERS\n",
    "                \n",
    "#                 if 'top' in category_name:\n",
    "#                     layers_to_process = dict(list(layers_data.items())[:k])\n",
    "#                 else:  # 'bottom' category\n",
    "#                     layers_to_process = dict(list(layers_data.items())[-k:])  # Slice from the end for bottom layers\n",
    "                \n",
    "#                 layer_names = list(layers_to_process.keys())\n",
    "                \n",
    "#                 all_means_a.extend([layers_to_process[layer][side_a_name][metric]['mean'] for layer in layer_names])\n",
    "#                 all_stds_a.extend([layers_to_process[layer][side_a_name][metric]['std'] for layer in layer_names])\n",
    "#                 all_means_b.extend([layers_to_process[layer][side_b_name][metric]['mean'] for layer in layer_names])\n",
    "#                 all_stds_b.extend([layers_to_process[layer][side_b_name][metric]['std'] for layer in layer_names])\n",
    "\n",
    "#             x = np.arange(len(all_labels))\n",
    "#             width = 0.35\n",
    "\n",
    "#             ax.bar(x - width/2, all_means_a, width, label=side_a_name.upper().replace('_TOP', ''), yerr=all_stds_a, capsize=5, color=BAR_COLORS[0], alpha=0.9)\n",
    "#             ax.bar(x + width/2, all_means_b, width, label=side_b_name.upper().replace('_TOP', ''), yerr=all_stds_b, capsize=5, color=BAR_COLORS[1], alpha=0.9)\n",
    "\n",
    "#             ax.set_ylabel(metric, fontsize=BASE_FONT_SIZE + 2)\n",
    "#             if metric == \"shift_other\": metric = \"Shift\"\n",
    "#             ax.set_title(f'Comparison of {metric} Metric', fontsize=BASE_FONT_SIZE + 4)\n",
    "#             ax.set_xticks(x)\n",
    "#             ax.set_xticklabels(all_labels, rotation=45, ha=\"right\")\n",
    "#             ax.legend(fontsize=BASE_FONT_SIZE)\n",
    "#             ax.grid(axis='y', linestyle='--', alpha=0.7)\n",
    "            \n",
    "#             # CHANGE: Use tunable font sizes for x and y tick labels\n",
    "#             ax.tick_params(axis='x', which='major', labelsize=X_AXIS_TICK_FONT_SIZE)\n",
    "#             ax.tick_params(axis='y', which='major', labelsize=Y_AXIS_TICK_FONT_SIZE)\n",
    "            \n",
    "#             # CHANGE: Set a maximum number of ticks on the y-axis to prevent clutter\n",
    "#             ax.yaxis.set_major_locator(MaxNLocator(integer=False, nbins=Y_AXIS_MAX_TICKS))\n",
    "            \n",
    "#             if len(filtered_data) > 1 and num_top_layers > 0:\n",
    "#                 separator_pos = num_top_layers - 0.5\n",
    "#                 ax.axvline(separator_pos, color='grey', linestyle='--')\n",
    "\n",
    "#     try:\n",
    "#         os.makedirs(output_dir, exist_ok=True)\n",
    "#         filename = f\"{block_type}_{side_names[0]}_{side_names[1]}.pdf\"\n",
    "#         full_path = os.path.join(output_dir, filename)\n",
    "#         plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "#         plt.savefig(full_path, bbox_inches=\"tight\")\n",
    "#         print(f\"--> Successfully saved plot to {full_path}\")\n",
    "#     except Exception as e:\n",
    "#         print(f\"--> Error saving plot: {e}\")\n",
    "#     finally:\n",
    "#         plt.close(fig)\n",
    "\n",
    "# %%\n",
    "# =============================================================================\n",
    "# CELL 2: PLOTTING FUNCTION\n",
    "# =============================================================================\n",
    "def plot_block_type_results_customizable(experiment_data, block_type, side_names, main_title, output_dir):\n",
    "    \"\"\"\n",
    "    Generates and saves bar charts. KL plot respects JSON order.\n",
    "    Delta_bias plot cherry-picks layers with the highest C++/Python difference from each category.\n",
    "    \"\"\"\n",
    "    side_a_name, side_b_name = side_names\n",
    "    \n",
    "    filtered_data = {k: v for k, v in experiment_data.items() if block_type in k}\n",
    "    if not filtered_data:\n",
    "        print(f\"--> Warning: No data found for '{block_type}'. Skipping.\")\n",
    "        return\n",
    "\n",
    "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))\n",
    "    fig.suptitle(main_title, fontsize=BASE_FONT_SIZE + 6, weight='bold')\n",
    "    \n",
    "    # --- Plot KL Divergence (Respecting JSON Order) ---\n",
    "    ax_kl = ax1\n",
    "    kl_labels, kl_means_a, kl_stds_a, kl_means_b, kl_stds_b = [], [], [], [], []\n",
    "    num_top_layers_kl = 0\n",
    "\n",
    "    for cat_name, layers in sorted(filtered_data.items(), key=lambda i: 'bottom' in i[0]):\n",
    "        k = TOP_K_LAYERS if 'top' in cat_name else BOTTOM_K_LAYERS\n",
    "        layers_to_process = dict(list(layers.items())[:k])\n",
    "        if 'top' in cat_name:\n",
    "            num_top_layers_kl = len(layers_to_process)\n",
    "        kl_labels.extend(layers_to_process.keys())\n",
    "        for layer_data in layers_to_process.values():\n",
    "            kl_means_a.append(layer_data[side_a_name]['KL']['mean'])\n",
    "            kl_stds_a.append(layer_data[side_a_name]['KL']['std'])\n",
    "            kl_means_b.append(layer_data[side_b_name]['KL']['mean'])\n",
    "            kl_stds_b.append(layer_data[side_b_name]['KL']['std'])\n",
    "    \n",
    "    x_kl = np.arange(len(kl_labels))\n",
    "    width = 0.35\n",
    "    ax_kl.bar(x_kl - width/2, kl_means_a, width, label=side_a_name.upper().replace('_TOP', ''), yerr=kl_stds_a, capsize=5, color=BAR_COLORS[0])\n",
    "    ax_kl.bar(x_kl + width/2, kl_means_b, width, label=side_b_name.upper().replace('_TOP', ''), yerr=kl_stds_b, capsize=5, color=BAR_COLORS[1])\n",
    "    ax_kl.set_title('Comparison of KL Divergence', fontsize=BASE_FONT_SIZE + 4)\n",
    "    ax_kl.set_ylabel('KL Divergence', fontsize=BASE_FONT_SIZE + 2)\n",
    "    ax_kl.set_xticks(x_kl)\n",
    "    ax_kl.set_xticklabels(kl_labels, rotation=45, ha=\"right\")\n",
    "    if len(filtered_data) > 1 and num_top_layers_kl > 0:\n",
    "        ax_kl.axvline(num_top_layers_kl - 0.5, color='grey', linestyle='--')\n",
    "\n",
    "    # --- Plot Delta Bias (Cherry-Picking by Difference within each category) ---\n",
    "    ax_bias = ax2\n",
    "    bias_labels, bias_means_a, bias_stds_a, bias_means_b, bias_stds_b = [], [], [], [], []\n",
    "    num_top_layers_bias = 0\n",
    "\n",
    "    for cat_name, layers in sorted(filtered_data.items(), key=lambda i: 'bottom' in i[0]):\n",
    "        layers_with_diff = []\n",
    "        for layer_name, layer_data in layers.items():\n",
    "            mean_a = layer_data[side_a_name]['delta_bias']['mean']\n",
    "            mean_b = layer_data[side_b_name]['delta_bias']['mean']\n",
    "            diff = abs(mean_a - mean_b)\n",
    "            layers_with_diff.append((diff, layer_name, layer_data))\n",
    "        \n",
    "        layers_with_diff.sort(key=lambda item: item[0], reverse=True)\n",
    "        \n",
    "        picked_layers = layers_with_diff[:DELTA_BIAS_CHERRY_PICK_K]\n",
    "        \n",
    "        if 'top' in cat_name:\n",
    "            num_top_layers_bias = len(picked_layers)\n",
    "            \n",
    "        for _, layer_name, layer_data in picked_layers:\n",
    "            bias_labels.append(layer_name)\n",
    "            bias_means_a.append(layer_data[side_a_name]['delta_bias']['mean'])\n",
    "            bias_stds_a.append(layer_data[side_a_name]['delta_bias']['std'])\n",
    "            bias_means_b.append(layer_data[side_b_name]['delta_bias']['mean'])\n",
    "            bias_stds_b.append(layer_data[side_b_name]['delta_bias']['std'])\n",
    "\n",
    "    x_bias = np.arange(len(bias_labels))\n",
    "    ax_bias.bar(x_bias - width/2, bias_means_a, width, label=side_a_name.upper().replace('_TOP', ''), yerr=bias_stds_a, capsize=5, color=BAR_COLORS[0])\n",
    "    ax_bias.bar(x_bias + width/2, bias_means_b, width, label=side_b_name.upper().replace('_TOP', ''), yerr=bias_stds_b, capsize=5, color=BAR_COLORS[1])\n",
    "    ax_bias.set_title(f'Top {DELTA_BIAS_CHERRY_PICK_K} Layers by Delta Bias Difference', fontsize=BASE_FONT_SIZE + 4)\n",
    "    ax_bias.set_ylabel('Delta Bias', fontsize=BASE_FONT_SIZE + 2)\n",
    "    ax_bias.set_xticks(x_bias)\n",
    "    ax_bias.set_xticklabels(bias_labels, rotation=45, ha=\"right\")\n",
    "    if len(filtered_data) > 1 and num_top_layers_bias > 0:\n",
    "        ax_bias.axvline(num_top_layers_bias - 0.5, color='grey', linestyle='--')\n",
    "\n",
    "    # --- Final Touches for Both Plots ---\n",
    "    for ax in [ax_kl, ax_bias]:\n",
    "        ax.legend(fontsize=BASE_FONT_SIZE)\n",
    "        ax.grid(axis='y', linestyle='--', alpha=0.7)\n",
    "        ax.tick_params(axis='x', labelsize=X_AXIS_TICK_FONT_SIZE)\n",
    "        ax.tick_params(axis='y', labelsize=Y_AXIS_TICK_FONT_SIZE)\n",
    "        ax.yaxis.set_major_locator(MaxNLocator(nbins=Y_AXIS_MAX_TICKS))\n",
    "\n",
    "    # --- Save the Figure ---\n",
    "    try:\n",
    "        os.makedirs(output_dir, exist_ok=True)\n",
    "        filename = f\"{block_type}_comparison.pdf\"\n",
    "        full_path = os.path.join(output_dir, filename)\n",
    "        plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "        plt.savefig(full_path, bbox_inches=\"tight\")\n",
    "        print(f\"--> Successfully saved plot to {full_path}\")\n",
    "    except Exception as e:\n",
    "        print(f\"--> Error saving plot: {e}\")\n",
    "    finally:\n",
    "        plt.close(fig)\n",
    "\n",
    "# %%\n",
    "# =============================================================================\n",
    "# CELL 3: MAIN EXECUTION LOGIC\n",
    "# =============================================================================\n",
    "def main(input_dir, output_dir_base):\n",
    "    \"\"\"\n",
    "    Main function to find JSON files and generate plots.\n",
    "    \"\"\"\n",
    "    print(f\"Scanning for JSON files in: {input_dir}\")\n",
    "    if not os.path.isdir(input_dir):\n",
    "        print(f\"Error: Input directory '{input_dir}' not found.\")\n",
    "        return\n",
    "\n",
    "    for filename in os.listdir(input_dir):\n",
    "        if filename.endswith('.json'):\n",
    "            json_path = os.path.join(input_dir, filename)\n",
    "            print(f\"\\nProcessing file: {filename}\")\n",
    "            \n",
    "            model_name = os.path.splitext(filename)[0]\n",
    "            output_dir_final = os.path.join(output_dir_base, model_name, 'graphs')\n",
    "\n",
    "            try:\n",
    "                with open(json_path, 'r') as f:\n",
    "                    data = json.load(f)\n",
    "            except json.JSONDecodeError:\n",
    "                print(f\"--> Error: Could not decode JSON from {filename}. Skipping.\")\n",
    "                continue\n",
    "            except Exception as e:\n",
    "                print(f\"--> Error reading file {filename}: {e}. Skipping.\")\n",
    "                continue\n",
    "            \n",
    "            experiment_key = next(iter(data), None)\n",
    "            if not experiment_key or not isinstance(data[experiment_key], dict):\n",
    "                 print(f\"--> Error: JSON file {filename} has an unexpected format. Skipping.\")\n",
    "                 continue\n",
    "\n",
    "            experiment_data = data[experiment_key]\n",
    "\n",
    "            for block in ['mlp', 'attn']:\n",
    "                plot_block_type_results_customizable(\n",
    "                    experiment_data=experiment_data,\n",
    "                    block_type=block,\n",
    "                    side_names=('cpp_top', 'python_top'),\n",
    "                    main_title=f'{model_name.upper()} - {block.upper()} Layer Analysis',\n",
    "                    output_dir=output_dir_final\n",
    "                )\n",
    "\n",
    "    print(\"\\nScript finished.\")\n",
    "\n",
    "# --- Run the main function ---\n",
    "main(INPUT_DIR, OUTPUT_DIR_BASE)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfdd4d5e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "causal",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
