{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch.optim import Optimizer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Lamb(Optimizer):\n",
    "    r\"\"\"Implements Lamb algorithm.\n",
    "    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n",
    "    Arguments:\n",
    "        params (iterable): iterable of parameters to optimize or dicts defining\n",
    "            parameter groups\n",
    "        lr (float, optional): learning rate (default: 1e-3)\n",
    "        betas (Tuple[float, float], optional): coefficients used for computing\n",
    "            running averages of gradient and its square (default: (0.9, 0.999))\n",
    "        eps (float, optional): term added to the denominator to improve\n",
    "            numerical stability (default: 1e-8)\n",
    "        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n",
    "        adam (bool, optional): always use trust ratio = 1, which turns this into\n",
    "            Adam. Useful for comparison purposes.\n",
    "    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:\n",
    "        https://arxiv.org/abs/1904.00962\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,\n",
    "                 weight_decay=0.01, bias_correction=False):\n",
    "        if not 0.0 <= lr:\n",
    "            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n",
    "        if not 0.0 <= eps:\n",
    "            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n",
    "        if not 0.0 <= betas[0] < 1.0:\n",
    "            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n",
    "        if not 0.0 <= betas[1] < 1.0:\n",
    "            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n",
    "        if not 0.0 <= weight_decay:\n",
    "            raise ValueError(\"Invalid weight_decay value: {}\".format(weight_decay))\n",
    "        defaults = dict(lr=lr, betas=betas, eps=eps,\n",
    "                        weight_decay=weight_decay)\n",
    "        super().__init__(params, defaults)\n",
    "        self.bias_correction= bias_correction\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def step(self, closure=None):\n",
    "        \"\"\"Performs a single optimization step.\n",
    "        \n",
    "        Arguments:\n",
    "            closure (callable, optional): A closure that reevaluates the model\n",
    "                and returns the loss.\n",
    "        \"\"\"\n",
    "        loss = None\n",
    "        if closure is not None:\n",
    "            with torch.enable_grad():\n",
    "                loss = closure()\n",
    "\n",
    "        for group in self.param_groups:\n",
    "            for p in group['params']:\n",
    "                if p.grad is None:\n",
    "                    continue\n",
    "                    \n",
    "                # Perform stepweight decay\n",
    "                p.mul_(1 - group['lr'] * group['weight_decay'])\n",
    "                \n",
    "                # Perform optimization step\n",
    "                grad = p.grad\n",
    "                if grad.is_sparse:\n",
    "                    raise RuntimeError('AdamW does not support sparse gradients')            \n",
    "                \n",
    "                state = self.state[p]\n",
    "    \n",
    "                # State initialization\n",
    "                if len(state) == 0:\n",
    "                    state['step'] = 0\n",
    "                    # Exponential moving average of gradient values\n",
    "                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n",
    "                    # Exponential moving average of squared gradient values\n",
    "                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)\n",
    "\n",
    "                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n",
    "                beta1, beta2 = group['betas']\n",
    "\n",
    "                state['step'] += 1\n",
    "\n",
    "                # Paper v3 does not use debiasing.\n",
    "                if self.bias_correction:\n",
    "                    bias_correction1 = 1 - beta1 ** state['step']\n",
    "                    bias_correction2 = 1 - beta2 ** state['step']\n",
    "\n",
    "                # Decay the first and second moment running average coefficient\n",
    "                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)\n",
    "                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)                    \n",
    "\n",
    "                \n",
    "                if self.bias_correction:\n",
    "                    num = exp_avg / bias_correction1\n",
    "                    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])\n",
    "                    update = num / denom\n",
    "                else:\n",
    "                    num = exp_avg\n",
    "                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n",
    "                    update = num / denom\n",
    "                    \n",
    "                    \n",
    "\n",
    "                w_norm = torch.norm(p)\n",
    "                g_norm = torch.norm(update)\n",
    "                trust_ratio = w_norm / g_norm if w_norm.item() != 0.0 or g_norm.item() != 0.0 else 1.0\n",
    "\n",
    "                p.add_(-group['lr'] * trust_ratio * update)\n",
    "\n",
    "        return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.randn(1, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.norm(a) == 0:\n",
    "    print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(torch.norm(a) == 0).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.pow(2).sum().sqrt().clamp_(0.0, 1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = 1 if False else 2\n",
    "b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
