{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b8Fqts3mkuZX"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OrDC_2t_kuZZ"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "from torch.utils.data import DataLoader, random_split, TensorDataset\n",
        "from torchvision import datasets, transforms\n",
        "import os\n",
        "import random\n",
        "import numpy as np\n",
        "\n",
        "SEED = 0\n",
        "torch.manual_seed(SEED)\n",
        "os.environ['PYTHONHASHSEED'] = str(SEED)\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "g = torch.Generator()\n",
        "g.manual_seed(0)\n",
        "\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed(SEED)\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "    torch.use_deterministic_algorithms(True)\n",
        "    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'\n",
        "\n",
        "\n",
        "################################\n",
        "#     RESTART     RUNTIME      #\n",
        "################################\n",
        "from ssonn.model.utils import *\n",
        "from ssonn.metrics.nonlinearity_metrics import *\n",
        "from ssonn.metrics.edge_finder import *\n",
        "from ssonn.metrics.train_metrics import *\n",
        "from ssonn.train.train import *"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o4573XdtkuZa"
      },
      "outputs": [],
      "source": [
        "SEED = 8642\n",
        "torch.manual_seed(SEED)\n",
        "torch.cuda.manual_seed(SEED)\n",
        "import random\n",
        "random.seed(SEED)\n",
        "import numpy as np\n",
        "np.random.seed(SEED)\n",
        "\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.backends.cudnn.benchmark = False\n",
        "\n",
        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
        "device"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zay3-k1OkuZa"
      },
      "source": [
        "## Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vpe4zqz7kuZa"
      },
      "outputs": [],
      "source": [
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    transforms.Lambda(lambda x: x.view(-1))\n",
        "])\n",
        "\n",
        "dataset = datasets.FashionMNIST(root='./data', train=True,\n",
        "                                download=True, transform=transform)\n",
        "\n",
        "test_dataset = datasets.FashionMNIST(root='./data', train=False,\n",
        "                                     download=True, transform=transform)\n",
        "\n",
        "train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])\n",
        "\n",
        "batch_size = 64\n",
        "\n",
        "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
        "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
        "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T4lmQ-WYkuZb"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KR5c6TwwkuZb"
      },
      "outputs": [],
      "source": [
        "class SimpleFCN(nn.Module):\n",
        "    def __init__(self, input_size=28 * 28, hidden_size=16, output_size=10):\n",
        "        super(SimpleFCN, self).__init__()\n",
        "        self.fc0 = nn.Linear(input_size, output_size)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.fc0(x)\n",
        "        return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7eKijkvYkuZb"
      },
      "outputs": [],
      "source": [
        "model = SimpleFCN()\n",
        "sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0], device=device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jQKmEa4KkuZb"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fx_IuRlDWf89"
      },
      "outputs": [],
      "source": [
        "hyperparams = {\n",
        "    \"num_epochs\": 40,\n",
        "    \"batch_size\": 256,\n",
        "    \"edge_importance_metric\": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),\n",
        "    \"edge_score_aggregation\": \"mean\",\n",
        "    \"expansion_thresholds\": {\"fc0\": 0.7},\n",
        "    \"pruning_thresholds\": {\"fc0\": 0.1},\n",
        "    \"plateau_threshold\": 0.05,\n",
        "    \"min_epochs_between_expansions\": 8,\n",
        "    \"plateau_window_size\": 5,\n",
        "    \"learning_rate\": 2e-4,\n",
        "    \"prune_after_epochs\": 2,\n",
        "    \"task_type\": \"classification\",\n",
        "    \"start_fully_connected\": False,\n",
        "    \"max_new_edges_per_expansion\": None,\n",
        "    \"weight_decay\": 0,\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q0I5aUlikuZb"
      },
      "outputs": [],
      "source": [
        "name = \", \".join(\n",
        "    f\"{key}: {value.__class__.__name__ if key == 'metric' else value}\"\n",
        "    for key, value in hyperparams.items()\n",
        ")\n",
        "\n",
        "name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7AtKfxKHkuZc"
      },
      "outputs": [],
      "source": [
        "import wandb\n",
        "\n",
        "wandb.login()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W1cS0CTWkuZc"
      },
      "outputs": [],
      "source": [
        "wandb.finish()\n",
        "run = wandb.init(\n",
        "    project=\"self-expanding-nets\",\n",
        "    name=f\"name\",\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SdIxhi842Rlo"
      },
      "outputs": [],
      "source": [
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.AdamW(sparse_model.parameters(), lr=hyperparams['learning_rate'], weight_decay=hyperparams['weight_decay'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-twcrw7iG9aT"
      },
      "outputs": [],
      "source": [
        "train_sparse_recursive(sparse_model, train_loader, train_loader, val_loader, criterion, optimizer, hyperparams, device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "torch.save(sparse_model, 'fashionmnist.pt')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ArfS3uplHTPk"
      },
      "outputs": [],
      "source": [
        "_, accuracy = eval_one_epoch(sparse_model, criterion, test_loader, hyperparams['task_type'], device)\n",
        "params = get_params_amount(sparse_model)\n",
        "accuracy, params"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": ".venv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
