{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "# set cuda visible devices\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"   # see issue #152\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import os\n",
    "\n",
    "# Check for GPU availability\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# 1. Data augmentations and normalization\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.RandomCrop(32, padding=4),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "\n",
    "transform_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])\n",
    "rootdir = os.path.expanduser('~/datasets')\n",
    "aug_trainset = torchvision.datasets.CIFAR10(root=rootdir, train=True, download=True, transform=transform_train)\n",
    "aug_trainloader = torch.utils.data.DataLoader(aug_trainset, batch_size=500, shuffle=True, num_workers=0)\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root=rootdir, train=True, download=True, transform=transform_test)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=500, shuffle=False, num_workers=0)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root=rootdir, train=False, download=True, transform=transform_test)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=500, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlp import MLP, KronMLP, accuracy_func, loss_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numparams 18470602\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8add31cc40ea4ac6a40caf6f4eee1923",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, train loss 1.6043516397476196, test acc 0.424\n",
      "Epoch 0, train acc 0.426\n",
      "Epoch 1, train loss 1.4843229055404663, test acc 0.461\n",
      "Epoch 2, train loss 1.4813092947006226, test acc 0.472\n",
      "Epoch 3, train loss 1.4532357454299927, test acc 0.493\n",
      "Epoch 4, train loss 1.4214383363723755, test acc 0.492\n",
      "Epoch 5, train loss 1.341141939163208, test acc 0.495\n",
      "Epoch 5, train acc 0.513\n",
      "Epoch 6, train loss 1.216399908065796, test acc 0.518\n",
      "Epoch 7, train loss 1.2673360109329224, test acc 0.516\n",
      "Epoch 8, train loss 1.1792267560958862, test acc 0.521\n",
      "Epoch 9, train loss 1.249699592590332, test acc 0.532\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.random as jr\n",
    "import optax\n",
    "import jax.numpy as jnp\n",
    "import serket as sk\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "#sizes = (3*1024,1024,1024,10)\n",
    "sizes = (3*1024,3024,3024,10)\n",
    "nn = MLP(sizes, jr.PRNGKey(42))\n",
    "#nn = KronMLP((3*1024,100000,100000,4*1024,10), jr.PRNGKey(42))\n",
    "nn = sk.tree_mask(nn)\n",
    "print(\"numparams\", jnp.concatenate([p.reshape(-1) for p in jax.tree_util.tree_flatten(nn)[0]]).shape[0])\n",
    "# 2) initialize the optimizer state\n",
    "optim = optax.adam(3e-4)\n",
    "optim_state = optim.init(nn)\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def train_step(nn, optim_state, x, y):\n",
    "    # the loss function will have an output of (loss, logits)\n",
    "    # as an auxillary output, and will have a gradient `grads` of same\n",
    "    # structure as the input\n",
    "    grads, (loss, logits) = loss_fn(nn, x, y)\n",
    "    updates, optim_state = optim.update(grads, optim_state)\n",
    "    nn = optax.apply_updates(nn, updates)\n",
    "    return nn, optim_state, (loss, logits)\n",
    "\n",
    "@jax.jit\n",
    "def eval_acc_mb(nn, x,y):\n",
    "    logits = jax.vmap(sk.tree_unmask(nn))(x)\n",
    "    acc = accuracy_func(logits, y)\n",
    "    return acc\n",
    "\n",
    "def eval_acc(nn, dataloader):\n",
    "    accs = []\n",
    "    for x,y in dataloader:\n",
    "        x = x.cpu().data.numpy().reshape(x.shape[0],-1)\n",
    "        y = y.cpu().data.numpy()\n",
    "        accs.append(eval_acc_mb(nn,x,y))\n",
    "    return jnp.mean(jnp.array(accs))\n",
    "\n",
    "num_epochs=20\n",
    "for e in tqdm(range(num_epochs+1)):\n",
    "    for x,y in aug_trainloader:#, desc='inner'):\n",
    "        x = x.cpu().data.numpy().reshape(x.shape[0],-1)\n",
    "        y = y.cpu().data.numpy()\n",
    "        nn, optim_state, (loss, logits) = train_step(nn, optim_state, x,y)\n",
    "    if e % 1 == 0:\n",
    "        print(f\"Epoch {e}, train loss {loss}, test acc {eval_acc(nn, testloader):.3f}\")\n",
    "    if e%5==0:\n",
    "        print(f\"Epoch {e}, train acc {eval_acc(nn, trainloader):.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "17930"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(32, 56)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.prjax.tree_util.tree_flatten(nn)[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kron 41 vs mlp 49 for 10 epochs with the same size and hypers (3*1024,1024,1024,10)\n",
    "kron 46 with (3*1024,40000,40000,1024,10)\n",
    "kron 52 with (3*1024,40000,40000,1024,10) and k=5\n",
    "kron 55 with (3*1024,100000,100000,4*1024,10) and k=5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.nn.initializers as ji\n",
    "b = ji.he_uniform()(jr.PRNGKey(42), (100,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Array(0.23839533, dtype=float32), Array(-0.23060192, dtype=float32))"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jnp.max(b),jnp.min(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mlp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
