{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## 2-Layer MLP on MNIST \n",
    "### Settings: network width $C=512$, batch_size=128, Adam optimizer, learning_rate=1e-3, weight_decay=1e-2, 10 epochs per run, 10 runs on a TPUv4-8.\n",
    "|Activation    | Test Accuracy    |\n",
    "|--------------|----------------|\n",
    "|ReLU          | 0.9399 ± 0.0019|\n",
    "|CoLU        | **0.9489** ± 0.0025|\n",
    "\n",
    "## ResNet-56 on CIFAR10\n",
    "### Settings: batch_size=128, Adam optimizer, learning_rate=1e-3, weight_decay=1e-2, 180 epochs per run, 10 runs on a TPUv4-8.\n",
    "|Activation    | Test Accuracy    |Train Loss    |\n",
    "|--------------|----------------|----------------|\n",
    "|ReLU          | 0.9065 ± 0.0100|0.005132 ± 0.001461|\n",
    "|CoLU  | **0.9101** ± 0.0039| 0.003244 ± 0.000185|\n",
    "\n",
    "## GPT2 \n",
    "### Settings: Shakespeare's Plays dataset. Transformer parameters: block_size = 64, embed_size = 256, num_heads = 8, head_size = 32, num_layers = 6. Batch_size=512, Adam optimizer, learning_rate=1e-4, weight_decay=1e-2, 20k steps on a TPUv4-8. Jax random seed=42.\n",
    "|Activation    | Test Loss      | Train Loss     |\n",
    "|--------------|----------------|----------------|\n",
    "|ReLU          |     1.482      | 1.256          |\n",
    "|CoLU          | **1.481**      | 1.263          |\n",
    "\n",
    "Generated samples are attached in the PDF.\n",
    "\n",
    "## Diffusion Model\n",
    "### Settings: Unconditional generation on CIFAR10 without using VAE latent. The UNet structure follows the Latent Diffusion Model with cross-attention replaced by self-attention for unconditional generation, with channel sizes of the blocks lowered to (64,128,256,512), and 1 ResNet layer per downscaling block. Batch_size=128, Adam optimizer, learning_rate=1e-4, weight_decay=1e-2, 50k steps on a TPUv4-8. Jax random seed=42.\n",
    "|Activation    | Train Loss     |\n",
    "|--------------|----------------|\n",
    "|ReLU              | 0.1606          |\n",
    "|CoLU    | **0.1593**     |\n",
    "\n",
    "CoLU's cone dimension is $S=4$, with neither soft projection nor axis sharing for fair comparisons with standard ReLU.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Activation Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import functools\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import flax.linen as nn\n",
    "from typing import Optional\n",
    "\n",
    "@functools.partial(jax.jit, static_argnames=['channel_axis','variant','eps','num_groups','dim','share_axis'])\n",
    "def colu(input: jnp.ndarray,\n",
    "         channel_axis: int = -1,\n",
    "         variant: str = \"hard\",\n",
    "         eps: float = 1e-7,\n",
    "         num_groups: Optional[int] = None,\n",
    "         dim: Optional[int] = 4,\n",
    "         share_axis: bool = False\n",
    "         ):\n",
    "    \"\"\"project the input x onto the axes dimension\"\"\"\n",
    "    \"\"\"G=number of cones, S=dim of cones\"\"\"\n",
    "    \"\"\"output dimension = S = axes + cone sections = [len=(G or 1)] + G * [len=(S-1)]\"\"\"\n",
    "    \"\"\"jnp.moveaxis is avoided to optimize speed on TPU\"\"\"\n",
    "    shape = input.shape\n",
    "    if len(shape) == 0:\n",
    "        return input # edge case\n",
    "    assert (dim is not None) ^ (num_groups is not None) # specify one of both, infer the other\n",
    "\n",
    "    if share_axis:\n",
    "        if dim is None:\n",
    "            assert (shape[channel_axis] - 1) % num_groups == 0\n",
    "            dim = (shape[channel_axis] - 1) // num_groups + 1\n",
    "        if num_groups is None:\n",
    "            assert (shape[channel_axis] - 1) % (dim - 1) == 0\n",
    "            num_groups = (shape[channel_axis] - 1) // (dim - 1)\n",
    "    else:\n",
    "        if dim is None:\n",
    "            assert shape[channel_axis] % num_groups == 0\n",
    "            dim = shape[channel_axis] // num_groups\n",
    "        if num_groups is None:\n",
    "            assert shape[channel_axis] % dim == 0\n",
    "            num_groups = shape[channel_axis] // dim\n",
    "\n",
    "    if dim == 2: # pointwise case\n",
    "        return nn.silu(input) if variant == \"soft\" else nn.relu(input)\n",
    "\n",
    "    # y = axes, x = cone sections\n",
    "    if share_axis:\n",
    "        y, x = jnp.split(input, [1], axis=channel_axis)\n",
    "    else:\n",
    "        y, x = jnp.split(input, [num_groups], axis=channel_axis)\n",
    "\n",
    "    assert channel_axis < 0, \"channel_axis must be negative\" # Comply with broadcasting on first dimensions\n",
    "    x_old_shape = x.shape\n",
    "    y_old_shape = y.shape\n",
    "    x_shape = x.shape[:channel_axis] + (num_groups, dim - 1) # NG(S-1)\n",
    "    if share_axis:\n",
    "        y_shape = y.shape[:channel_axis] + (1, 1) # N11\n",
    "    else:\n",
    "        y_shape = y.shape[:channel_axis] + (num_groups, 1) # NG1\n",
    "    if channel_axis < -1:\n",
    "        x_shape += x.shape[(channel_axis+1):] # NGSHW if channel_axis = -3\n",
    "        y_shape += y.shape[(channel_axis+1):] # NG1HW\n",
    "    x = x.reshape(x_shape)\n",
    "    y = y.reshape(y_shape)\n",
    "\n",
    "    xn = jnp.linalg.norm(x,axis=channel_axis,keepdims=True) # NG1HW\n",
    "\n",
    "    mask = y / (xn + eps) # NG1HW\n",
    "    if variant == \"sqrt\":\n",
    "        mask = jnp.sqrt(mask)\n",
    "    elif variant == \"log\":\n",
    "        mask = jnp.log(jnp.max(mask,0)+1)\n",
    "    elif variant == \"soft\":\n",
    "        mask = nn.sigmoid(mask - .5)\n",
    "    elif variant == \"hard\":\n",
    "        mask = mask.clip(0,1)\n",
    "    else:\n",
    "        raise NotImplementedError(\"variant must be soft or hard.\")\n",
    "\n",
    "    x = mask * x # NGSHW\n",
    "    x = x.reshape(x_old_shape)\n",
    "    y = y.reshape(y_old_shape)\n",
    "    output = jnp.concatenate([y,x],axis=channel_axis)\n",
    "\n",
    "    return output\n",
    "\n",
    "@functools.partial(jax.jit, static_argnames=['scaling','eps'])\n",
    "def rcolu_(x, scaling=\"constant\",eps=1e-8):\n",
    "    \"\"\"x = w + v, v || e\"\"\"\n",
    "    C = x.shape[-1]\n",
    "    # e = jnp.ones(C) / jnp.sqrt(C)\n",
    "    vn = jnp.sum(x,axis=-1,keepdims=True) / jnp.sqrt(C) # dot(x, e)\n",
    "    v = jnp.repeat(vn,C,axis=-1) / jnp.sqrt(C) # outer(v, e)\n",
    "    w = x - v\n",
    "    wn = jnp.linalg.norm(w, x=-1, keepdims=True)\n",
    "    if scaling == 'constant':\n",
    "        m = jnp.maximum(vn, 0.) / (wn + eps)\n",
    "        m = jnp.minimum(m, 1.)\n",
    "    else:\n",
    "        m = nn.sigmoid(vn - .5)\n",
    "    w_ = w * m # project onto cone\n",
    "    x = v + w_\n",
    "\n",
    "    return x\n",
    "\n",
    "@functools.partial(jax.jit, static_argnames=['dim','num_groups','axis','scaling','eps'])\n",
    "def rcolu(x,\n",
    "          dim=4,\n",
    "          num_groups=None,\n",
    "          scaling='constant',\n",
    "          axis=-1,\n",
    "          eps=1e-7\n",
    "          ):\n",
    "    \"\"\"dim=S, num_groups=S\"\"\"\n",
    "    if len(x.shape) == 0:\n",
    "        return x\n",
    "    assert (dim is not None) ^ (num_groups is not None) # specify one of both\n",
    "    shape = x.shape\n",
    "    if dim is None:\n",
    "        assert shape[-1] % num_groups == 0\n",
    "        dim = shape[-1] // num_groups\n",
    "    if num_groups is None:\n",
    "        assert shape[-1] % dim == 0\n",
    "        num_groups = shape[-1] // dim\n",
    "    if axis != -1:\n",
    "        x = jnp.moveaxis(x, axis, -1)\n",
    "    new_shape = x.shape[:-1] + (num_groups, dim)\n",
    "    x = x.reshape(new_shape)\n",
    "    x = rcolu_(x,scaling,eps)\n",
    "    x = x.reshape(shape)\n",
    "    if axis != -1:\n",
    "        x = jnp.moveaxis(x, -1, axis)\n",
    "    return x\n",
    "\n",
    "# # some test\n",
    "# x = jnp.zeros(6).at[0].set(1)\n",
    "# y = rcolu(x,dim=3)\n",
    "# y.shape\n",
    "# # some assertion\n",
    "# y1 = jnp.sum(y,axis=-1,keepdims=True) / jnp.sqrt(3) # dot(y, e)\n",
    "# jnp.linalg.norm(y) / y1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import flax\n",
    "import flax.linen as nn\n",
    "import optax\n",
    "import tensorflow_datasets as tfds\n",
    "from flax.training.train_state import TrainState\n",
    "from flax.training.common_utils import shard\n",
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "# import sys\n",
    "# sys.path.append('.')\n",
    "from tqdm.auto import tqdm\n",
    "import time\n",
    "import functools\n",
    "import wandb\n",
    "from jax_smi import initialise_tracking\n",
    "initialise_tracking()\n",
    "\n",
    "def cross_entropy_loss(logits, labels):\n",
    "    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)\n",
    "    return optax.softmax_cross_entropy(logits, one_hot_labels).mean()\n",
    "\n",
    "def accuracy(logits, labels):\n",
    "    predictions = jnp.argmax(logits, axis=-1)\n",
    "    return jnp.mean(predictions == labels)\n",
    "\n",
    "@jax.pmap\n",
    "def train_step(state, batch):\n",
    "    def loss_fn(params):\n",
    "        logits = state.apply_fn({'params': params}, batch['image'])\n",
    "        loss = cross_entropy_loss(logits, batch['label'])\n",
    "        return loss, logits\n",
    "\n",
    "    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
    "    (loss, logits), grads = grad_fn(state.params)\n",
    "    state = state.apply_gradients(grads=grads)\n",
    "    acc = accuracy(logits, batch['label'])\n",
    "    return state, loss, acc\n",
    "\n",
    "@jax.pmap\n",
    "def eval_step(state, batch):\n",
    "    logits = state.apply_fn({'params': state.params}, batch['image'])\n",
    "    loss = cross_entropy_loss(logits, batch['label'])\n",
    "    acc = accuracy(logits, batch['label'])\n",
    "    return loss, acc\n",
    "\n",
    "def prepare_data(data_name):\n",
    "    ds_builder = tfds.builder(data_name)\n",
    "    ds_builder.download_and_prepare()\n",
    "    train_ds = tfds.as_numpy(tfds.load(data_name, split='train', batch_size=128, shuffle_files=True))\n",
    "    test_ds = tfds.as_numpy(tfds.load(data_name, split='test', batch_size=128))\n",
    "    return train_ds, test_ds\n",
    "\n",
    "def create_train_state(model, rng, learning_rate, weight_decay=.01):\n",
    "    params = model.init(rng, jnp.ones([1, 28, 28, 1]))['params']\n",
    "\n",
    "    def weight_decay_mask(params):\n",
    "        return {k: 'bias' not in k and 'scale' not in k for k in params.keys()}\n",
    "\n",
    "    tx = optax.adamw(learning_rate=learning_rate, weight_decay=weight_decay, mask=weight_decay_mask(params))\n",
    "    return TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n",
    "\n",
    "def train_and_evaluate(model, num_epochs, learning_rate):\n",
    "    rng = jax.random.PRNGKey(int(time.time()))\n",
    "    # rng = jax.random.PRNGKey(0)\n",
    "    rngs = {'params': rng}\n",
    "    train_ds, test_ds = prepare_data('mnist')\n",
    "    state = create_train_state(model, rngs, learning_rate)\n",
    "    # print(jax.tree.map(lambda x:x.shape,state.params))\n",
    "\n",
    "    # Replicate state across devices\n",
    "    state = jax.device_put_replicated(state, jax.local_devices())\n",
    "\n",
    "    # bar = tqdm(range(num_epochs),leave=0)\n",
    "    bar = range(num_epochs)\n",
    "    # pbar = tqdm(range(len(train_ds)),leave=1)\n",
    "    best_test_acc = list()\n",
    "    for _ in bar:\n",
    "        # pbar.reset()\n",
    "        # Training loop\n",
    "        for batch in train_ds:\n",
    "            # pbar.update(1)\n",
    "            batch = shard(batch)\n",
    "            state, train_loss, train_acc = train_step(state, batch)\n",
    "            train_loss, train_acc = jax.device_get(flax.jax_utils.unreplicate((train_loss, train_acc)))\n",
    "            # pbar.set_postfix(dict(train_loss=train_loss,train_acc=train_acc))\n",
    "\n",
    "        # Evaluation loop\n",
    "        test_loss, test_acc = 0, 0\n",
    "        for batch in test_ds:\n",
    "            batch = shard(batch)\n",
    "            loss, acc = eval_step(state, batch)\n",
    "            test_loss += loss.mean()\n",
    "            test_acc += acc.mean()\n",
    "\n",
    "        test_loss /= len(test_ds)\n",
    "        test_acc /= len(test_ds)\n",
    "        best_test_acc.append(test_acc.item())\n",
    "\n",
    "        # bar.set_postfix({'Test Loss':test_loss,'Test Acc':test_acc})\n",
    "\n",
    "    # pbar.close()\n",
    "    # bar.close()\n",
    "    best_test_acc = max(best_test_acc)\n",
    "    print('Train Loss',train_loss,'Test Loss',test_loss,'Test Acc',test_acc,'Best Test Acc',best_test_acc)\n",
    "    return best_test_acc, train_loss\n",
    "\n",
    "METHOD = 'extrapolate'\n",
    "# METHOD = 'split'\n",
    "DIM = 49\n",
    "# validate different activation functions\n",
    "def validate_activation(fn,C=512,method='default',num_epochs=10):\n",
    "    class MNISTModel(nn.Module):\n",
    "        @nn.compact\n",
    "        def __call__(self, x):\n",
    "            if method == 'extrapolate-aggregate':\n",
    "                x = jnp.repeat(x.reshape(-1, 1, 28 * 28),DIM,axis=1)\n",
    "                x = nn.Conv(kernel_size=(3,),padding=\"CIRCULAR\",features=C)(x)\n",
    "                x = jnp.mean(x,axis=1,keepdims=False)\n",
    "            elif method == 'extrapolate':\n",
    "                x = jnp.repeat(x.reshape(-1, 1, 28 * 28),DIM,axis=1)\n",
    "                x = nn.Conv(kernel_size=(3,),padding=\"CIRCULAR\",features=C//DIM)(x)\n",
    "                x = x.reshape(-1, C)\n",
    "            elif method == 'split':\n",
    "                x = x.reshape(-1, DIM, 28 * 28//DIM)\n",
    "                x = nn.Conv(kernel_size=(3,),padding=\"CIRCULAR\",features=C//DIM)(x)\n",
    "                x = x.reshape(-1, C)\n",
    "            elif method == 'split-dense':\n",
    "                x = x.reshape(-1, DIM, 28 * 28//DIM)\n",
    "                x = nn.Dense(features=C//DIM)(x)\n",
    "                x = x.reshape(-1, C)\n",
    "            elif method == \"id\":\n",
    "                x = x.reshape(-1, 28 * 28)\n",
    "            else:\n",
    "                x = x.reshape(-1, 28 * 28)  # Flatten the input\n",
    "                x = nn.Dense(features=C)(x)\n",
    "            x = fn(x)\n",
    "            # if CONV3D:\n",
    "            #     x = nn.Conv(kernel_size=(3,),padding=\"CIRCULAR\",features=10)(x)\n",
    "            #     x = jnp.mean(x,axis=1,keepdims=False)\n",
    "            # else:\n",
    "            x = nn.Dense(features=10)(x)\n",
    "            return x\n",
    "\n",
    "    model = MNISTModel()\n",
    "    best_test_acc = []\n",
    "    train_loss = []\n",
    "    for _ in range(3):\n",
    "        acc, loss = train_and_evaluate(model, num_epochs=num_epochs, learning_rate=1e-3)\n",
    "        best_test_acc.append(acc)\n",
    "        train_loss.append(loss)\n",
    "        \n",
    "\n",
    "    print(f'Test Accuracy: {np.mean(best_test_acc)} ± {np.std(best_test_acc)}')\n",
    "    print(f'Train Loss: {np.mean(train_loss)} ± {np.std(train_loss)}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss -0.0 Test Loss 1.7510715 Test Acc 0.95450944 Best Test Acc 0.956289529800415\n",
      "Train Loss 4.9668824e-08 Test Loss 1.5883534 Test Acc 0.9496637 Best Test Acc 0.9564872980117798\n",
      "Train Loss -0.0 Test Loss 1.7090598 Test Acc 0.9536194 Best Test Acc 0.9599485397338867\n",
      "Test Accuracy: 0.9575751225153605 ± 0.0016802003920671886\n",
      "Train Loss: 1.6556274573531482e-08 ± 2.341410798578636e-08\n"
     ]
    }
   ],
   "source": [
    "validate_activation(nn.relu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss -0.0 Test Loss 12.320416 Test Acc 0.9615308 Best Test Acc 0.9631130695343018\n",
      "Train Loss -0.0 Test Loss 13.119491 Test Acc 0.9610363 Best Test Acc 0.9654865264892578\n",
      "Train Loss -0.0 Test Loss 12.393973 Test Acc 0.96123415 Best Test Acc 0.9647942781448364\n",
      "Test Accuracy: 0.9644646247227987 ± 0.0009966035698823015\n",
      "Train Loss: 0.0 ± 0.0\n"
     ]
    }
   ],
   "source": [
    "validate_activation(colu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss -0.0 Test Loss 18.212978 Test Acc 0.9589596 Best Test Acc 0.9594540596008301\n",
      "Train Loss -0.0 Test Loss 16.731853 Test Acc 0.95826733 Best Test Acc 0.9642997980117798\n",
      "Train Loss -0.0 Test Loss 18.379313 Test Acc 0.9590585 Best Test Acc 0.9614319205284119\n",
      "Test Accuracy: 0.961728592713674 ± 0.001989356005674444\n",
      "Train Loss: 0.0 ± 0.0\n"
     ]
    }
   ],
   "source": [
    "validate_activation(rcolu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss 0.4819425 Test Acc 0.9277096 Best Test Acc 0.9277095794677734\n",
      "Test Loss 0.47985438 Test Acc 0.9186115 Best Test Acc 0.9289951920509338\n",
      "Test Loss 0.43648016 Test Acc 0.92325944 Best Test Acc 0.9259295463562012\n",
      "0.9275447726249695 ± 0.0012569584593996436\n"
     ]
    }
   ],
   "source": [
    "colu4=functools.partial(colu,num_groups=4,dim=None)\n",
    "validate_activation(colu4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss 0.00059222995 Test Loss 0.65389204 Test Acc 0.9252373 Best Test Acc 0.9283029437065125\n",
      "Train Loss 0.005728652 Test Loss 0.68661606 Test Acc 0.9225672 Best Test Acc 0.9286985397338867\n",
      "Train Loss 0.00043734856 Test Loss 0.6848236 Test Acc 0.9267207 Best Test Acc 0.9292919039726257\n",
      "Test Accuracy: 0.9287644624710083 ± 0.00040642338961556984\n",
      "Train Loss: 0.002252743346616626 ± 0.002458651550114155\n"
     ]
    }
   ],
   "source": [
    "scolu=functools.partial(colu,share_axis=True,num_groups=4,dim=None)\n",
    "validate_activation(scolu,C=513,method='default')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss -0.0 Test Loss 2.6070163 Test Acc 0.9620253 Best Test Acc 0.9634097814559937\n",
      "Train Loss -0.0 Test Loss 2.3222847 Test Acc 0.96350867 Best Test Acc 0.9666731953620911\n",
      "Train Loss -0.0 Test Loss 2.495196 Test Acc 0.960443 Best Test Acc 0.9653875827789307\n",
      "Test Accuracy: 0.9651568531990051 ± 0.0013422356188878696\n",
      "Train Loss: 0.0 ± 0.0\n"
     ]
    }
   ],
   "source": [
    "scolu=functools.partial(colu,share_axis=True,variant='soft')\n",
    "validate_activation(scolu,C=511,method='default',num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss -0.0 Test Loss 2.272504 Test Acc 0.9606408 Best Test Acc 0.962915301322937\n",
      "Train Loss -0.0 Test Loss 2.2598255 Test Acc 0.9639042 Best Test Acc 0.9639042019844055\n",
      "Train Loss -0.0 Test Loss 2.3103473 Test Acc 0.9590585 Best Test Acc 0.9628164172172546\n",
      "Test Accuracy: 0.9632119735081991 ± 0.0004911413333875091\n",
      "Train Loss: 0.0 ± 0.0\n"
     ]
    }
   ],
   "source": [
    "scolu=functools.partial(colu,share_axis=True,variant='soft')\n",
    "validate_activation(scolu,C=385,method='default',num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss 5.3826485 Test Loss 7.214199 Test Acc 0.9444224 Best Test Acc 0.9444224238395691\n",
      "Train Loss -0.0 Test Loss 8.155147 Test Acc 0.9409612 Best Test Acc 0.9409611821174622\n",
      "Train Loss 0.36790058 Test Loss 7.405613 Test Acc 0.94254345 Best Test Acc 0.9452135562896729\n",
      "Test Accuracy: 0.943532387415568 ± 0.0018465815537830235\n",
      "Train Loss: 1.9168496131896973 ± 2.4552879333496094\n"
     ]
    }
   ],
   "source": [
    "validate_activation(rcolu,C=384)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Loss -0.0 Test Loss 16.218866 Test Acc 0.9536194 Best Test Acc 0.9576740264892578\n",
      "Train Loss -0.0 Test Loss 16.243488 Test Acc 0.9515427 Best Test Acc 0.9570806622505188\n",
      "Train Loss -0.0 Test Loss 16.407934 Test Acc 0.95401496 Best Test Acc 0.9589595794677734\n",
      "Test Accuracy: 0.9579047560691833 ± 0.0007842234297330661\n",
      "Train Loss: 0.0 ± 0.0\n"
     ]
    }
   ],
   "source": [
    "scolu=functools.partial(colu,variant='soft')\n",
    "validate_activation(scolu,C=512,method='default',num_epochs=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss 17.766293 Test Acc 0.9137658 Best Test Acc 0.9198970794677734\n",
      "Test Loss 19.66462 Test Acc 0.90486544 Best Test Acc 0.914853572845459\n",
      "Test Loss 17.440067 Test Acc 0.9134691 Best Test Acc 0.9134690761566162\n",
      "0.9160732428232828 ± 0.0027623061359626994\n"
     ]
    }
   ],
   "source": [
    "scolu=functools.partial(rcolu,scaling='soft')\n",
    "validate_activation(scolu,C=512,method='default')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss 0.9899926 Test Acc 0.9398734 Best Test Acc 0.9414556622505188\n",
      "Test Loss 0.9778322 Test Acc 0.932852 Best Test Acc 0.9367088079452515\n",
      "Test Loss 0.8928901 Test Acc 0.94115895 Best Test Acc 0.9411589503288269\n",
      "0.9397744735081991 ± 0.002171134649440334\n"
     ]
    }
   ],
   "source": [
    "validate_activation(nn.silu,C=512,method='default')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
