{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader, random_split, TensorDataset\n",
    "from torchvision import datasets, transforms\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "SEED = 0\n",
    "torch.manual_seed(SEED)\n",
    "os.environ['PYTHONHASHSEED'] = str(SEED)\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "g = torch.Generator()\n",
    "g.manual_seed(0)\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(SEED)\n",
    "    torch.cuda.manual_seed_all(SEED)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.use_deterministic_algorithms(True)\n",
    "    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # or ':16:8'\n",
    "\n",
    "\n",
    "################################\n",
    "#     RESTART     RUNTIME      #\n",
    "################################\n",
    "from ssonn.model.utils import *\n",
    "from ssonn.metrics.nonlinearity_metrics import *\n",
    "from ssonn.metrics.edge_finder import *\n",
    "from ssonn.metrics.train_metrics import *\n",
    "from ssonn.train.train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_teacher_data(teacher, num_samples=1000):\n",
    "    X = torch.randn(num_samples, teacher.fc1.in_features)\n",
    "    with torch.no_grad():\n",
    "        y = teacher(X)\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Teacher(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        self.act = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(hidden_size, output_size)\n",
    "        self.init_weights()\n",
    "        \n",
    "    def init_weights(self):\n",
    "        for param in self.parameters():\n",
    "            param.data.uniform_(-1, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.fc2(self.act(self.fc1(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 100\n",
    "hidden_teacher = 50\n",
    "output_size = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher = Teacher(input_size, hidden_teacher, output_size)\n",
    "X, y = generate_teacher_data(teacher)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "dataset = TensorDataset(X, y)\n",
    "\n",
    "train_loader = DataLoader(dataset, batch_size=128, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_size=28 * 28, hidden_size=16, output_size=10):\n",
    "        super(MLP, self).__init__()\n",
    "        self.fc0 = nn.Linear(input_size, hidden_size)\n",
    "        self.fc1 = nn.Linear(hidden_size, output_size)\n",
    "        self.act = nn.ReLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.fc1(self.act(self.fc0(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MLP(input_size=input_size, hidden_size=hidden_teacher, output_size=output_size)\n",
    "sparse_model = convert_dense_to_sparse_network(model, layers=[model.fc0, model.fc1], device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    \"num_epochs\": 512,\n",
    "    \"edge_importance_metric\": AbsGradientEdgeMetric(nn.CrossEntropyLoss()),\n",
    "    \"edge_score_aggregation\": \"mean\",\n",
    "    \"expansion_thresholds\": {\"fc0\": 0.55},\n",
    "    \"pruning_thresholds\": {\"fc0\": 0.2},\n",
    "    \"plateau_threshold\": 1,\n",
    "    \"min_epochs_between_expansions\": 64,\n",
    "    \"plateau_window_size\": 5,\n",
    "    \"learning_rate\": 0.01,\n",
    "    \"prune_after_epochs\": 2,\n",
    "    \"task_type\": \"regression\",\n",
    "    \"start_fully_connected\": False,\n",
    "    \"max_new_edges_per_expansion\": 900,\n",
    "    \"weight_decay\": 0\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "wandb.login()\n",
    "wandb.finish()\n",
    "run = wandb.init(\n",
    "    project=\"self-expanding-nets-teacher-student\",\n",
    "    name=f\"name\",\n",
    "    config=hyperparams\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "start_time = time.time()\n",
    "\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(sparse_model.parameters(), lr=hyperparams['learning_rate'], weight_decay=hyperparams['weight_decay'])\n",
    "train_sparse_recursive(sparse_model, train_loader, val_loader, val_loader, criterion, optimizer, hyperparams, device)\n",
    "\n",
    "print(\"--- %s seconds ---\" % (time.time() - start_time))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssonn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
