{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "75675eda-8c49-41dd-9cd0-8aed95f68e6c",
   "metadata": {},
   "source": [
    "### Output Masking\n",
    "Computes and saves performance of a network with a certain amount of masked out neurons depending on different measurement metrics. This may be used for pruning the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7b1bb0f-feb9-4a1a-98f0-219d3ff39353",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import Dict and Any\n",
    "from typing import Dict, Any\n",
    "\n",
    "import importlib\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import utils_output_masking as uom\n",
    "\n",
    "from library import baseline_configs\n",
    "from library import configs\n",
    "from library import load_datasets\n",
    "from library import metrics\n",
    "from library import misc\n",
    "from library import model_io\n",
    "from library import models\n",
    "from library import results_json\n",
    "from library import train\n",
    "\n",
    "from difflogic import GroupSum, LogicLayer\n",
    "\n",
    "\n",
    "importlib.reload(load_datasets)\n",
    "importlib.reload(baseline_configs)\n",
    "importlib.reload(configs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "593baf70-c461-4a0f-af35-fe766dc876e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model_io.load_model(model_path=\"./models/\", model_name=\"fmnist_baseline_final_0\")\n",
    "network = model[0]\n",
    "config = model[1]\n",
    "train_loader, validation_loader, test_loader, bin_loader, test_bin_loader = load_datasets.load_dataset(config=config)\n",
    "results = uom.analyze_neuron_importance(network, test_loader, config)\n",
    "results = uom.calculate_metrics(results)\n",
    "\n",
    "num_classes = results['num_classes']\n",
    "\n",
    "labels = ['f1_score', 'random_score', 'precision', 'recall', 'sensitivity', 'specificity', 'accuracy', 'mcc']\n",
    "\n",
    "for label in labels:\n",
    "    test_accuracies_train, test_accuracies_eval, val_accuracies_eval, val_accuracies_train, n_values = uom.run(results, label, network, num_classes, test_loader, validation_loader, config)\n",
    "    data = {\n",
    "        'model_name': f\"_{label}\",\n",
    "        'test_accuracies_train': test_accuracies_train,\n",
    "        'test_accuracies_eval': test_accuracies_eval,\n",
    "        'val_accuracies_eval': val_accuracies_eval,\n",
    "        'val_accuracies_train': val_accuracies_train,\n",
    "        'n_values': n_values\n",
    "    }\n",
    "    print(data)\n",
    "    \n",
    "    with open(f\"output_masking/{}_{label}.json\", \"w\") as f:\n",
    "        json.dump(my_data, f, indent=4)  # 'indent=4' makes it pretty\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
