{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8e8a0d0b-a55f-496e-9a28-8b674725e1e2",
   "metadata": {},
   "source": [
    "### Post Trained FFN\n",
    "Designed to be run in the root folder of the repository.\n",
    "This script trains an additional feed forward layer at the output of the last LogicLayer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "134d21d4-1d6d-453b-b098-425df0e1ba81",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from library import load_datasets\n",
    "from library import metrics\n",
    "from library import misc\n",
    "from library import model_io\n",
    "from library import models\n",
    "from library import results_json\n",
    "from library import train\n",
    "from library import baseline_configs\n",
    "from library import configs\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef1ebd0e-8a55-41c6-b78a-2437221beee9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model_io.load_model(model_path=\"./models/\", model_name=\"mnist_baseline_final_0\")\n",
    "network = model[0]\n",
    "config = model[1]\n",
    "\n",
    "train_loader, validation_loader, test_loader, bin_loader, test_bin_loader = load_datasets.load_dataset(config=config)\n",
    "\n",
    "# Test that the model is loaded correctly\n",
    "print(metrics.get_accuracy(network, validation_loader, config=config, train_mode=True))\n",
    "print(metrics.get_accuracy(network, validation_loader, config=config, train_mode=False))\n",
    "print(metrics.get_accuracy(network, test_loader, config=config, train_mode=True))\n",
    "print(metrics.get_accuracy(network, test_loader, config=config, train_mode=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "434f289b-7701-4b69-813b-6c3eff009d31",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class SecondStageModel(nn.Module):\n",
    "    def __init__(self, input_size=64000, hidden_sizes=[1024], output_size=10):\n",
    "        super(SecondStageModel, self).__init__()\n",
    "        layers = []\n",
    "        sizes = [input_size] + hidden_sizes\n",
    "        for i in range(len(sizes) - 1):\n",
    "            layers.append(nn.Linear(sizes[i], sizes[i + 1]))\n",
    "            layers.append(nn.ReLU())\n",
    "        layers.append(nn.Linear(sizes[-1], output_size))\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a141f192-0bd1-41bf-b575-c2917939c59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "def accuracy(y_pred: torch.Tensor, y: torch.Tensor) -> float:\n",
    "    return (y_pred.argmax(-1) == y).to(torch.float32).mean().item()\n",
    "\n",
    "pretrained_model = network\n",
    "pretrained_model.to(device)\n",
    "pretrained_model.eval()  # Freeze it\n",
    "for param in pretrained_model.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "second_stage_model = SecondStageModel(input_size=64000, output_size=47).to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(second_stage_model.parameters(), lr=1e-4)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "train_epoch_samples = len(train_loader.dataset)\n",
    "val_epoch_samples = len(validation_loader.dataset)\n",
    "\n",
    "for epoch in range(100):\n",
    "    # === Training ===\n",
    "    second_stage_model.train()\n",
    "    total_loss = 0\n",
    "    epoch_acc = 0\n",
    "    epoch_loss = 0\n",
    "\n",
    "    for batch in train_loader:\n",
    "        inputs, labels = batch\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            features = pretrained_model[:-1](inputs)\n",
    "\n",
    "        outputs = second_stage_model(features)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        epoch_loss += total_loss\n",
    "\n",
    "        batch_acc = accuracy(outputs, labels)\n",
    "        epoch_acc += batch_acc * (len(outputs) / train_epoch_samples)\n",
    "\n",
    "    # === Validation ===\n",
    "    second_stage_model.eval()\n",
    "    val_acc = 0.0\n",
    "    val_loss = 0.0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for val_batch in validation_loader:\n",
    "            val_inputs, val_labels = val_batch\n",
    "            val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)\n",
    "\n",
    "            features = pretrained_model[:-1](val_inputs)\n",
    "            val_outputs = second_stage_model(features)\n",
    "            loss = criterion(val_outputs, val_labels)\n",
    "\n",
    "            val_loss += loss.item()\n",
    "            val_acc += accuracy(val_outputs, val_labels) * (len(val_outputs) / val_epoch_samples)\n",
    "\n",
    "    print(f'Epoch {epoch} - Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.1%} | '\n",
    "          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.1%}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
