{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "379514c5-a323-46c1-8528-321ec83807b4",
   "metadata": {},
   "source": [
    "### Analyse Output Neuron Distribution\n",
    "Designed to run in the root directory, with utils_output_masking.py as well in the root directory.\n",
    "This script was designed to be run in the root folder of the project.\n",
    "It requires two different already trained models and their configuration files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c10e52cf-5fb1-495c-a853-0d2ce2cca2dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import Dict and Any\n",
    "import os\n",
    "import sys\n",
    "import os\n",
    "\n",
    "# Add parent directory to sys.path\n",
    "sys.path.append(os.getcwd())  # or the absolute path to the parent of \"library\"\n",
    "from typing import Dict, Any\n",
    "\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from library import baseline_configs\n",
    "from library import configs\n",
    "from library import load_datasets\n",
    "from library import metrics\n",
    "from library import misc\n",
    "from library import model_io\n",
    "from library import models\n",
    "from library import results_json\n",
    "from library import train\n",
    "\n",
    "from difflogic import GroupSum, LogicLayer\n",
    "\n",
    "\n",
    "importlib.reload(load_datasets)\n",
    "importlib.reload(baseline_configs)\n",
    "importlib.reload(configs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7d0b844-b26f-42de-b5ac-d2379f7de716",
   "metadata": {},
   "outputs": [],
   "source": [
    "# E.g. two different tau values\n",
    "network, config = model_io.load_model(model_path=\"./models/\", model_name=\"mnist_different_tau_final_100\")\n",
    "network, config = model_io.load_model(model_path=\"./models/\", model_name=\"mnist_different_tau_final_1\")\n",
    "\n",
    "train_loader, validation_loader, test_loader, bin_loader, test_bin_loader = load_datasets.load_dataset(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c00c22d-a005-4d76-8c2c-9c26c8e538d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze_neuron_importance(\n",
    "    model: torch.nn.Module, \n",
    "    test_loader: torch.utils.data.DataLoader,\n",
    "    config: configs.DifflogicConfig\n",
    ") -> Dict[str, Any]:\n",
    "    \"\"\"\n",
    "    Analyzes the importance of individual neurons in the final layer of a DiffLogic model.\n",
    "    \n",
    "    Args:\n",
    "        model: Trained DiffLogic model\n",
    "        test_loader: DataLoader with test data\n",
    "        config: DiffLogic configuration\n",
    "        \n",
    "    Returns:\n",
    "        Dictionary containing analysis results\n",
    "    \"\"\"\n",
    "    device = config.model_config.device\n",
    "    model.eval()\n",
    "    \n",
    "    # Extract the final logic layer before GroupSum\n",
    "    final_layer = None\n",
    "    for i, layer in enumerate(model):\n",
    "        if isinstance(layer, GroupSum):\n",
    "            final_layer = model[i-1]\n",
    "            break\n",
    "    \n",
    "    assert final_layer is not None, \"Could not find final logic layer\"\n",
    "    \n",
    "    # Get number of classes and neurons per class\n",
    "    num_classes = model[-1].k\n",
    "    neurons_per_class = final_layer.out_dim // num_classes\n",
    "    \n",
    "    # Collect activations and contribution scores\n",
    "    neuron_activations = torch.zeros(final_layer.out_dim, device=device)\n",
    "    neuron_contributions = torch.zeros(final_layer.out_dim, device=device)\n",
    "    neuron_accuracy = torch.zeros(final_layer.out_dim, device=device)\n",
    "    neuron_false_positives = torch.zeros(final_layer.out_dim, device=device)\n",
    "    neuron_false_negatives = torch.zeros(final_layer.out_dim, device=device)\n",
    "    neuron_true_positives = torch.zeros(final_layer.out_dim, device=device)\n",
    "    neuron_true_negatives = torch.zeros(final_layer.out_dim, device=device)\n",
    "\n",
    "\n",
    "    correct_counts = torch.zeros(final_layer.out_dim, device=device)\n",
    "    samples_per_class = torch.zeros(num_classes, device=device)\n",
    "    samples_total = len(test_loader.dataset)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "\n",
    "            # Threshold the input to the model\n",
    "            # x = (x > 0.5).float()\n",
    "            \n",
    "            # Get intermediate activations before GroupSum\n",
    "            intermediate = model[:-1](x) > 0.5\n",
    "            \n",
    "            # Count activations by neuron\n",
    "            neuron_activations += intermediate.sum(dim=0)\n",
    "            \n",
    "            # For each correctly classified sample, credit the neurons that contributed\n",
    "            outputs = model(x)\n",
    "            predictions = outputs.argmax(dim=1)\n",
    "            correct_mask = (predictions == y)\n",
    "\n",
    "            for i, (is_correct, true_class, activation) in enumerate(zip(correct_mask, y, intermediate)):\n",
    "                # Credit neurons in the class group\n",
    "                class_start = true_class * neurons_per_class\n",
    "                class_end = (true_class + 1) * neurons_per_class\n",
    "                if is_correct:\n",
    "                    \n",
    "                    # Credit neurons that were active for correct predictions\n",
    "                    neuron_contributions[class_start:class_end] += activation[class_start:class_end]\n",
    "                    \n",
    "                    # Count correct activations per neuron\n",
    "                    correct_counts[class_start:class_end] += activation[class_start:class_end]\n",
    "\n",
    "                samples_per_class[true_class] += 1\n",
    "\n",
    "                # Calculate metrics\n",
    "                ground_truth_mask = torch.zeros(final_layer.out_dim, dtype=torch.bool, device=device)\n",
    "                ground_truth_mask[class_start:class_end] = True\n",
    "\n",
    "                neuron_accuracy += (activation == ground_truth_mask).float()\n",
    "                neuron_false_positives += (activation & ~ground_truth_mask).float()\n",
    "                neuron_false_negatives += (~activation & ground_truth_mask).float()\n",
    "                neuron_true_positives += (activation & ground_truth_mask).float()\n",
    "                neuron_true_negatives += (~activation & ~ground_truth_mask).float()\n",
    "\n",
    "    # Normalize neuron contributions\n",
    "    neuron_accuracy /= samples_total\n",
    "    neuron_false_positives /= samples_total\n",
    "    neuron_false_negatives /= samples_total\n",
    "    neuron_true_positives /= samples_total\n",
    "    neuron_true_negatives /= samples_total\n",
    "\n",
    "    \n",
    "    # Calculate per-neuron metrics\n",
    "    activation_rate = neuron_activations / len(test_loader.dataset)\n",
    "    correct_rate = correct_counts / samples_per_class.repeat_interleave(neurons_per_class)\n",
    "    \n",
    "    # Analyze neuron importance distribution\n",
    "    results = {\n",
    "        'activation_rate': activation_rate.cpu().numpy(),\n",
    "        'correct_rate': correct_rate.cpu().numpy(),\n",
    "        'contribution_scores': neuron_contributions.cpu().numpy(),\n",
    "        'neurons_per_class': neurons_per_class,\n",
    "        'num_classes': num_classes,\n",
    "        'accuracy': neuron_accuracy.cpu().numpy(),\n",
    "        'false_positives': neuron_false_positives.cpu().numpy(),\n",
    "        'false_negatives': neuron_false_negatives.cpu().numpy(),\n",
    "        'true_positives': neuron_true_positives.cpu().numpy(),\n",
    "        'true_negatives': neuron_true_negatives.cpu().numpy(),\n",
    "        'samples_per_class': samples_per_class.cpu().numpy(),\n",
    "        'samples_total': samples_total,\n",
    "    }\n",
    "    \n",
    "    # Additional analysis\n",
    "    for class_idx in range(num_classes):\n",
    "        class_start = class_idx * neurons_per_class\n",
    "        class_end = (class_idx + 1) * neurons_per_class\n",
    "        class_scores = neuron_contributions[class_start:class_end].cpu().numpy()\n",
    "        \n",
    "        # Sort neurons by contribution score\n",
    "        sorted_indices = np.argsort(-class_scores)\n",
    "        cumulative_contribution = np.cumsum(class_scores[sorted_indices]) / np.sum(class_scores)\n",
    "        \n",
    "        # Find what percentage of neurons account for 80% of the contribution\n",
    "        threshold_80pct = np.searchsorted(cumulative_contribution, 0.8) + 1\n",
    "        results[f'neurons_for_80pct_class_{class_idx}'] = threshold_80pct\n",
    "        results[f'pct_neurons_for_80pct_class_{class_idx}'] = threshold_80pct / neurons_per_class * 100\n",
    "    \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba2ec4bf-d6f3-4ac6-885e-76222bbaa64a",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = analyze_neuron_importance(network, test_loader, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ba52340-45f5-4bbd-b669-0226b190e1a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_neuron_importance(results: Dict[str, Any]) -> None:\n",
    "    \"\"\"\n",
    "    Plots the analysis of neuron importance.\n",
    "    \n",
    "    Args:\n",
    "        results: Dictionary containing analysis results\n",
    "    \"\"\"\n",
    "    \n",
    "    num_classes = results['num_classes']\n",
    "    neurons_per_class = results['neurons_per_class']\n",
    "    \n",
    "    # Create a figure with multiple subplots\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(8, 8))\n",
    "    \n",
    "    # Plot 1: Distribution of neuron activation rates\n",
    "    sns.histplot(results['activation_rate'], ax=axes[0, 0], kde=True)#, bins=10)\n",
    "    axes[0, 0].set_title('Distribution of Neuron Activation Rates')\n",
    "    axes[0, 0].set_xlabel('Activation Rate')\n",
    "    axes[0, 0].set_ylabel('Count')\n",
    "    \n",
    "    # Plot 2: Distribution of neuron correct rates\n",
    "    sns.histplot(results['correct_rate'], ax=axes[0, 1], kde=True)\n",
    "    axes[0, 1].set_title('Distribution of Neuron Correct Rates')\n",
    "    axes[0, 1].set_xlabel('Correct Rate')\n",
    "    axes[0, 1].set_ylabel('Count')\n",
    "    \n",
    "    # Plot 3: Cumulative contribution for each class\n",
    "    for class_idx in range(num_classes):\n",
    "        class_start = class_idx * neurons_per_class\n",
    "        class_end = (class_idx + 1) * neurons_per_class\n",
    "        class_scores = results['contribution_scores'][class_start:class_end]\n",
    "        \n",
    "        # Sort neurons by contribution score\n",
    "        sorted_indices = np.argsort(-class_scores)\n",
    "        cumulative_contribution = np.cumsum(class_scores[sorted_indices]) / np.sum(class_scores)\n",
    "        \n",
    "        # Plot cumulative contribution curve\n",
    "        x_axis = np.arange(1, neurons_per_class + 1) / neurons_per_class\n",
    "        axes[1, 0].plot(x_axis, cumulative_contribution, label=f'Class {class_idx}')\n",
    "    \n",
    "    axes[1, 0].axhline(y=0.8, color='r', linestyle='--', label='80% Contribution')\n",
    "    axes[1, 0].set_title('Cumulative Contribution by Neuron')\n",
    "    axes[1, 0].set_xlabel('Fraction of Neurons (Sorted by Contribution)')\n",
    "    axes[1, 0].set_ylabel('Cumulative Contribution')\n",
    "    axes[1, 0].legend()\n",
    "    \n",
    "    # Plot 4: Bar chart showing percentage of neurons needed for 80% contribution\n",
    "    class_indices = np.arange(num_classes)\n",
    "    pct_values = [results[f'pct_neurons_for_80pct_class_{i}'] for i in class_indices]\n",
    "    \n",
    "    axes[1, 1].bar(class_indices, pct_values)\n",
    "    axes[1, 1].set_title('% of Neurons Needed for 80% Contribution')\n",
    "    axes[1, 1].set_xlabel('Class')\n",
    "    axes[1, 1].set_ylabel('% of Neurons')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('neuron_importance_analysis.png')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a5b0da6-bf95-47ae-8cc1-c1aec01769d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from typing import Dict, Any\n",
    "\n",
    "def plot_neuron_activation_rate(results: Dict[str, Any]) -> None:\n",
    "    \"\"\"\n",
    "    Plots the distribution of neuron activation rates as a percentage of total neurons,\n",
    "    with a fixed y-axis from 0 to 15%.\n",
    "    \n",
    "    Args:\n",
    "        results: Dictionary containing analysis results, including 'activation_rate'.\n",
    "    \"\"\"\n",
    "    data = np.array(results['activation_rate'])\n",
    "    total_neurons = len(data)\n",
    "\n",
    "    # Compute histogram\n",
    "    counts, bins = np.histogram(data, bins=50, range=(0, 1))\n",
    "    percentages = 100 * counts / total_neurons  # Convert to percent\n",
    "\n",
    "    # Plot\n",
    "    plt.figure(figsize=(5, 4))\n",
    "    bin_centers = 0.5 * (bins[:-1] + bins[1:])\n",
    "    plt.bar(bin_centers, percentages, width=(bins[1] - bins[0]), alpha=0.6, color='C0', edgecolor='black')\n",
    "\n",
    "    # Optional: KDE for smooth overlay\n",
    "    # sns.kdeplot(data, bw_adjust=1, color='blue', linewidth=2)\n",
    "\n",
    "    plt.xlabel('Activation Rate', fontsize=22)\n",
    "    plt.ylabel('% of Neurons', fontsize=22)\n",
    "    plt.xticks(np.arange(0, 1.01, 0.5))\n",
    "    plt.yticks(np.arange(0, 16, 5))\n",
    "    plt.ylim(0, 12)  # Fixed y-axis limit\n",
    "    plt.tick_params(axis='both', which='major', labelsize=16)\n",
    "    plt.grid(True, linestyle='--', alpha=0.5)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"mnist_activation_rate_50.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_neuron_activation_rate(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6a61feb-fe77-444c-968d-4be3c8606b5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from typing import Dict, Any\n",
    "\n",
    "def plot_two_activation_distributions(results1: Dict[str, Any],\n",
    "                                      results2: Dict[str, Any],\n",
    "                                      label1: str = \"Model 1\",\n",
    "                                      label2: str = \"Model 2\") -> None:\n",
    "    \"\"\"\n",
    "    Plots the distribution of neuron activation rates for two sets of results.\n",
    "    \n",
    "    Args:\n",
    "        results1: First dictionary containing 'activation_rate'.\n",
    "        results2: Second dictionary containing 'activation_rate'.\n",
    "        label1: Label for the first distribution.\n",
    "        label2: Label for the second distribution.\n",
    "    \"\"\"\n",
    "    data1 = np.array(results1['activation_rate'])\n",
    "    data2 = np.array(results2['activation_rate'])\n",
    "\n",
    "    total_neurons1 = len(data1)\n",
    "    total_neurons2 = len(data2)\n",
    "\n",
    "    bins = np.linspace(0, 1, 51)  # 50 bins\n",
    "    bin_centers = 0.5 * (bins[:-1] + bins[1:])\n",
    "    width = bins[1] - bins[0]\n",
    "\n",
    "    # Compute histograms\n",
    "    counts1, _ = np.histogram(data1, bins=bins)\n",
    "    counts2, _ = np.histogram(data2, bins=bins)\n",
    "    percentages1 = 100 * counts1 / total_neurons1\n",
    "    percentages2 = 100 * counts2 / total_neurons2\n",
    "    ###############\n",
    "    df = pd.DataFrame({\n",
    "        \"bin_center\": bin_centers,\n",
    "        \"Model1\": percentages1,\n",
    "        \"Model2\": percentages2\n",
    "    })\n",
    "\n",
    "    df.to_csv(\"activation_distributions.csv\", index=False)\n",
    "    ###############\n",
    "\n",
    "    # Plot\n",
    "    plt.figure(figsize=(6, 4))\n",
    "    plt.bar(bin_centers - width/4, percentages1, width=width/2, alpha=0.6, color='C0', edgecolor='black', label=label1)\n",
    "    plt.bar(bin_centers + width/4, percentages2, width=width/2, alpha=0.6, color='C1', edgecolor='black', label=label2)\n",
    "\n",
    "    # Optional KDE overlay\n",
    "    # sns.kdeplot(data1, bw_adjust=1, color='C0', linewidth=2)\n",
    "    # sns.kdeplot(data2, bw_adjust=1, color='C1', linewidth=2)\n",
    "\n",
    "    plt.xlabel('Activation Rate', fontsize=22)\n",
    "    plt.ylabel('% of Neurons', fontsize=22)\n",
    "    plt.xticks(np.arange(0, 1.01, 0.5))\n",
    "    plt.yticks(np.arange(0, 10.1, 5))\n",
    "    plt.ylim(0, 10)\n",
    "    plt.tick_params(axis='both', which='major', labelsize=16)\n",
    "    plt.grid(True, linestyle='--', alpha=0.5)\n",
    "    plt.legend(fontsize=16)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"comparison_activation_rate.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "# results1 = {'activation_rate': np.random.beta(2, 5, size=1000)}\n",
    "# results2 = {'activation_rate': np.random.beta(5, 2, size=1000)}\n",
    "# Requires two different results to run\n",
    "plot_two_activation_distributions(results, results2, label1=r\"$\\tau=100$\", label2=r\"$\\tau=1$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7775c236-da18-4e3a-9dd1-5f629e58b340",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_neuron_importance(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d65bc1-6b39-46df-85e1-b21bdeded596",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_classification_metrics(results: Dict[str, Any]) -> None:\n",
    "    # Plot histograms of classification metrics (accuracy, false positives, etc.)\n",
    "    metrics_to_plot = ['accuracy', 'false_positives', 'false_negatives', 'true_positives', 'true_negatives']\n",
    "    num_metrics = len(metrics_to_plot)\n",
    "    fig, axes = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))\n",
    "    for i, metric in enumerate(metrics_to_plot):\n",
    "        sns.histplot(results[metric], ax=axes[i], kde=True ,bins=100)\n",
    "        axes[i].set_title(f'Distribution of {metric.replace(\"_\", \" \").title()}')\n",
    "        axes[i].set_xlabel(metric.replace(\"_\", \" \").title())\n",
    "        axes[i].set_ylabel('Count')\n",
    "    fig.tight_layout()\n",
    "    plt.show()\n",
    "print(results['accuracy'])\n",
    "plot_classification_metrics(results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
