{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "1581a7c9-7507-4100-acac-ebcc040753ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "import pandas as pd\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import math\n",
    "\n",
    "# Define the MLP model with 3 layers\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, width, gamma_0):\n",
    "        super(MLP, self).__init__()\n",
    "        self.D = 28*28\n",
    "        self.width = width\n",
    "        self.gamma = gamma_0 * math.sqrt(width)\n",
    "        self.fc1 = nn.Linear(self.D, self.width, bias=False)\n",
    "        self.fc2 = nn.Linear(self.width, self.width, bias=False)\n",
    "        self.fc3 = nn.Linear(self.width, 10, bias=False)\n",
    "\n",
    "        # change inits\n",
    "        nn.init.normal_(self.fc1.weight, mean=0, std=1)\n",
    "        nn.init.normal_(self.fc2.weight, mean=0, std=1)\n",
    "        nn.init.normal_(self.fc3.weight, mean=0, std=0)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 28 * 28)\n",
    "        x = self.fc1(x) / math.sqrt(self.D)\n",
    "        x = self.fc2(x) / math.sqrt(self.width)\n",
    "        x = self.fc3(x) / math.sqrt(self.width)\n",
    "\n",
    "        x = x / self.gamma\n",
    "        return x\n",
    "\n",
    "\n",
    "# Training loop\n",
    "def train(model, train_loader, criterion, optimizer, epoch):\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bfb14a9-871e-4973-a963-266acd14c5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "width = 128\n",
    "eta_0 = 0.0001\n",
    "gamma_0 = 0.0001\n",
    "batch_size = 64\n",
    "eta = eta_0 * width\n",
    "epochs = 5\n",
    "\n",
    "# Define transformations for the dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307,), (0.3081,))\n",
    "])\n",
    "\n",
    "# Download and load the MNIST dataset\n",
    "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n",
    "\n",
    "# Instantiate the model, loss function, and optimizer\n",
    "model = MLP(width, gamma_0).cuda()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=eta)\n",
    "\n",
    "results = []\n",
    "\n",
    "# Run the training and testing\n",
    "for epoch in range(1, epochs + 1):\n",
    "    model.train()\n",
    "    losses = 0.0\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data = data.cuda()\n",
    "        target = target.cuda()\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses += loss.item()\n",
    "    results.append({\n",
    "        'width': width,\n",
    "        'gamma_0': gamma_0,\n",
    "        'eta_0': eta_0,\n",
    "        'epoch': epoch,\n",
    "        'train_loss': losses / len(train_loader)\n",
    "    })\n",
    "\n",
    "results = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59769503-bc4a-4568-95dd-0bbfcbed5063",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
