{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ig13UZIqZ_ij",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ig13UZIqZ_ij",
    "outputId": "d45cbe64-33e8-466c-a9d0-966e4516ef5a"
   },
   "outputs": [],
   "source": [
    "!wget https://data.ncl.ac.uk/ndownloader/articles/24118743/versions/2 -O ../data/Chesseract.zip\n",
    "!mkdir ../data/Chesseract\n",
    "!unzip ../data/Chesseract.zip -d ../data/Chesseract/\n",
    "!rm ../data/Chesseract.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc3d556b",
   "metadata": {
    "id": "fc3d556b"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader, random_split, TensorDataset\n",
    "from torchvision import datasets, transforms\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "SEED = 0\n",
    "torch.manual_seed(SEED)\n",
    "os.environ['PYTHONHASHSEED'] = str(SEED)\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "g = torch.Generator()\n",
    "g.manual_seed(0)\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(SEED)\n",
    "    torch.cuda.manual_seed_all(SEED)\n",
    "    torch.backends.cudnn.enabled = False\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'\n",
    "\n",
    "\n",
    "################################\n",
    "#     RESTART     RUNTIME      #\n",
    "################################\n",
    "from ssonn.model.utils import *\n",
    "from ssonn.metrics.nonlinearity_metrics import *\n",
    "from ssonn.metrics.edge_finder import *\n",
    "from ssonn.metrics.train_metrics import *\n",
    "from ssonn.train.train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "591d520f",
   "metadata": {
    "id": "591d520f"
   },
   "outputs": [],
   "source": [
    "class SimpleFCN(nn.Module):\n",
    "    def __init__(self, input_size=12 * 8 * 8, output_size=10):\n",
    "        super(SimpleFCN, self).__init__()\n",
    "        self.fc0 = nn.Linear(input_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.fc0(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75c8a0e8",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 87
    },
    "id": "75c8a0e8",
    "outputId": "fb458d18-a0fb-4997-dc53-ca2082925e6a"
   },
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 32,\n",
    "    \"batch_size\": 256,\n",
    "    \"edge_importance_metric\": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),\n",
    "    \"edge_score_aggregation\": \"mean\",\n",
    "    \"expansion_thresholds\": {\"fc0\": 0.05},\n",
    "    \"pruning_thresholds\": {\"fc0\": 0.1},\n",
    "    \"plateau_threshold\": 0.5,\n",
    "    \"min_epochs_between_expansions\": 24,\n",
    "    \"plateau_window_size\": 5,\n",
    "    \"learning_rate\": 2e-4,\n",
    "    \"prune_after_epochs\": 4,\n",
    "    \"task_type\": \"classification\",\n",
    "    \"start_fully_connected\": False,\n",
    "    \"max_new_edges_per_expansion\": 3000,\n",
    "    \"weight_decay\": 1e-3,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "92087925",
   "metadata": {
    "id": "92087925"
   },
   "outputs": [],
   "source": [
    "train_x = np.load(\"../data/Chesseract/train_x.npy\")\n",
    "train_y = np.load(\"../data/Chesseract/train_y.npy\")\n",
    "valid_x = np.load(\"../data/Chesseract/valid_x.npy\")\n",
    "valid_y = np.load(\"../data/Chesseract/valid_y.npy\")\n",
    "test_x = np.load('../data/Chesseract/test_x.npy')\n",
    "test_y = np.load('../data/Chesseract/test_y.npy')\n",
    "\n",
    "\n",
    "train_tensor_x = torch.tensor(train_x).float().view(-1, 12 * 8 * 8)\n",
    "train_tensor_y = torch.tensor(train_y).long()\n",
    "\n",
    "valid_tensor_x = torch.tensor(valid_x).float().view(-1, 12 * 8 * 8)\n",
    "valid_tensor_y = torch.tensor(valid_y).long()\n",
    "\n",
    "test_tensor_x = torch.tensor(test_x).float().view(-1, 12 * 8 * 8)\n",
    "test_tensor_y = torch.tensor(test_y).long()\n",
    "\n",
    "train_dataset = TensorDataset(train_tensor_x, train_tensor_y)\n",
    "valid_dataset = TensorDataset(valid_tensor_x, valid_tensor_y)\n",
    "test_dataset = TensorDataset(test_tensor_x, test_tensor_y)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=hyperparams['batch_size'], shuffle=True, worker_init_fn=np.random.seed(0),num_workers=0, generator=g)\n",
    "val_loader = DataLoader(valid_dataset, batch_size=hyperparams['batch_size'], shuffle=False, worker_init_fn=np.random.seed(0),num_workers=0, generator=g)\n",
    "test_loader = DataLoader(test_dataset, batch_size=hyperparams['batch_size'], shuffle=False, worker_init_fn=np.random.seed(0),num_workers=0, generator=g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5bbd7a97",
   "metadata": {
    "id": "5bbd7a97"
   },
   "outputs": [],
   "source": [
    "model = SimpleFCN()\n",
    "sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0265bbb",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "f0265bbb",
    "outputId": "f3d39aa7-e3e2-4b1a-d756-2efd18511df7"
   },
   "outputs": [],
   "source": [
    "import wandb\n",
    "\n",
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1f2251",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 173
    },
    "id": "fb1f2251",
    "outputId": "ef9c1237-1ff8-46c0-f62f-6b5e06b895aa"
   },
   "outputs": [],
   "source": [
    "wandb.finish()\n",
    "run = wandb.init(\n",
    "    project=\"self-expanding-nets-chesseract\",\n",
    "    name=f\"name\",\n",
    "    config=hyperparams\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "842732f3",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "842732f3",
    "outputId": "982f667e-0fe8-4326-ff94-0a2c8d2db0fe"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "start_time = time.time()\n",
    "\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(sparse_model.parameters(), lr=hyperparams['learning_rate'], weight_decay=hyperparams['weight_decay'])\n",
    "train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, optimizer, hyperparams, device)\n",
    "\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e62163bf-edb9-4614-b946-1be52b5756f3",
   "metadata": {
    "id": "oxc6xNbrRho8"
   },
   "outputs": [],
   "source": [
    "_, accuracy = eval_one_epoch(sparse_model, criterion, test_loader, hyperparams['task_type'], device)\n",
    "params = get_params_amount(sparse_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb119093",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy, params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4296418",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(sparse_model, 'chesseract.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0660e24f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = {\n",
    "    'Model': ['Ours', 'ResNet-18', 'AlexNet', 'VGG16', 'ConvNext', 'MNASNet', 'DenseNet', 'ResNeXt', 'PC-DARTS', 'DrNAS', 'Bonsai-Net', 'DARTS', ' Bonsai', 'Random'],\n",
    "    'Accuracy (%)': [accuracy * 100, 57.83, 57.45, 55.69, 52.74, 56.26, 59.60, 55.15, 57.20, 58.24, 60.76, 59.16, 68.83, 10],\n",
    "    'Parameters': [params, 11_689_512, 61_100_840, 138_357_544, 88_591_464, 4_383_312, 28_681_000, 25_028_904, None, None, None, None, None, None]\n",
    "}\n",
    "\n",
    "table = pd.DataFrame(data)\n",
    "\n",
    "def format_with_commas(x):\n",
    "    return \"{:,}\".format(x)\n",
    "\n",
    "styled_table = (table.style\n",
    "               .format({'Accuracy (%)': '{:.2f}',\n",
    "                       'Parameters': format_with_commas})\n",
    "               .set_properties(**{'text-align': 'center'})\n",
    "               .set_table_styles([\n",
    "                   {'selector': 'th', 'props': [('text-align', 'center')]},\n",
    "                   {'selector': 'caption', 'props': [('font-size', '1.1em')]}\n",
    "               ])\n",
    "               .hide(axis='index'))\n",
    "styled_table"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
