{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def _get_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, num_training_steps, num_cycles, min_lr_ratio):\n",
    "    if current_step < num_warmup_steps:\n",
    "        return float(current_step) / float(max(1, num_warmup_steps))\n",
    "    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n",
    "    cosine_decay = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n",
    "    return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay\n",
    "\n",
    "\n",
    "num_training_steps = 1000\n",
    "num_warmup_steps = 200\n",
    "min_lr_ratio = 0\n",
    "num_cycles = 0.5\n",
    "\n",
    "lr_values = [_get_cosine_schedule_with_min_lr_lambda(step, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, min_lr_ratio=min_lr_ratio) for step in range(num_training_steps)]\n",
    "\n",
    "plt.plot(lr_values)\n",
    "plt.xlabel(\"Training steps\")\n",
    "plt.ylabel(\"Learning rate\")\n",
    "plt.title(\"Cosine schedule with min LR\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, num_training_steps, num_cycles, min_lr_ratio):\n",
    "    if current_step < num_warmup_steps:\n",
    "        return float(current_step) / float(max(1, num_warmup_steps))\n",
    "    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n",
    "    cosine_decay = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n",
    "    return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "\n",
    "\n",
    "def get_cyclical_cosine_schedule_with_min_lr(optimizer, num_warmup_steps, cycle_length, num_cycles=0.5, min_lr_ratio=0.1, last_epoch=-1):\n",
    "    lr_lambda = partial(\n",
    "        _get_cyclical_cosine_schedule_with_min_lr_lambda,\n",
    "        num_warmup_steps=num_warmup_steps,\n",
    "        cycle_length=cycle_length,\n",
    "        num_cycles=num_cycles,\n",
    "        min_lr_ratio=min_lr_ratio,\n",
    "    )\n",
    "    return LambdaLR(optimizer, lr_lambda, last_epoch)\n",
    "\n",
    "\n",
    "def _get_cyclical_cosine_schedule_with_min_lr_lambda(current_step, *, num_warmup_steps, cycle_length, min_lr_ratio):\n",
    "    assert 0 < min_lr_ratio <= 1.0, \"min_lr_ratio must be in (0,1]\"\n",
    "\n",
    "    # compute where we are in the current cycle\n",
    "    cycle_step = current_step % cycle_length\n",
    "    \n",
    "    if cycle_step < num_warmup_steps:\n",
    "        return float(cycle_step) / float(max(1, num_warmup_steps))\n",
    "    \n",
    "    progress = float(cycle_step - num_warmup_steps) / float(max(1, cycle_length - num_warmup_steps))\n",
    "    cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))\n",
    "    \n",
    "    return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math\n",
    "from functools import partial\n",
    "\n",
    "num_training_steps = 10000\n",
    "num_warmup_steps = 1000\n",
    "cycle_length = 2000\n",
    "num_cycles=num_training_steps / cycle_length\n",
    "min_lr_ratio=0.1\n",
    "\n",
    "lr_lambda = partial(\n",
    "    _get_cyclical_cosine_schedule_with_min_lr_lambda,\n",
    "    num_warmup_steps=num_warmup_steps,\n",
    "    cycle_length=cycle_length,\n",
    "    min_lr_ratio=min_lr_ratio,\n",
    ")\n",
    "\n",
    "# Generate a range of step values\n",
    "steps = np.arange(0, num_training_steps, 100)\n",
    "\n",
    "# Compute learning rate for each step\n",
    "lr_values = [lr_lambda(step) for step in steps]\n",
    "\n",
    "# Plot the learning rate schedule\n",
    "plt.plot(steps, lr_values)\n",
    "plt.title('Cyclical Cosine Schedule with Min LR')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Learning Rate')\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from torch.optim.lr_scheduler import LambdaLR\n",
    "\n",
    "\n",
    "def get_cosine_schedule_with_restarts(optimizer, num_warmup_steps, cycle_length, num_cycles=0.5, min_lr_ratio=0.1, last_epoch=-1):\n",
    "    lr_lambda = partial(\n",
    "        _get_cosine_schedule_with_restarts_lambda,\n",
    "        num_warmup_steps=num_warmup_steps,\n",
    "        cycle_length=cycle_length,\n",
    "        num_cycles=num_cycles,\n",
    "        min_lr_ratio=min_lr_ratio,\n",
    "    )\n",
    "    return LambdaLR(optimizer, lr_lambda, last_epoch)\n",
    "\n",
    "\n",
    "def _get_cosine_schedule_with_restarts_lambda(current_step, *, num_warmup_steps, cycle_length, min_lr_ratio):\n",
    "    assert 0 < min_lr_ratio <= 1.0, \"min_lr_ratio must be in (0,1]\"\n",
    "\n",
    "    # compute where we are in the current cycle\n",
    "    cycle_step = current_step % cycle_length\n",
    "    \n",
    "    if cycle_step < num_warmup_steps:\n",
    "        return float(cycle_step) / float(max(1, num_warmup_steps))\n",
    "    \n",
    "    progress = float(cycle_step - num_warmup_steps) / float(max(1, cycle_length - num_warmup_steps))\n",
    "    cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))\n",
    "    \n",
    "    return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math\n",
    "from functools import partial\n",
    "\n",
    "num_training_steps = 10000\n",
    "num_warmup_steps = 1000\n",
    "cycle_length = 10000\n",
    "num_cycles=num_training_steps / cycle_length\n",
    "min_lr_ratio=0.1\n",
    "\n",
    "lr_lambda = partial(\n",
    "    _get_cyclical_cosine_schedule_with_min_lr_lambda,\n",
    "    num_warmup_steps=num_warmup_steps,\n",
    "    cycle_length=cycle_length,\n",
    "    min_lr_ratio=min_lr_ratio,\n",
    ")\n",
    "\n",
    "# Generate a range of step values\n",
    "steps = np.arange(0, num_training_steps, 100)\n",
    "\n",
    "# Compute learning rate for each step\n",
    "lr_values = [lr_lambda(step) for step in steps]\n",
    "\n",
    "# Plot the learning rate schedule\n",
    "plt.plot(steps, lr_values)\n",
    "plt.title('Cyclical Cosine Schedule with Min LR')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Learning Rate')\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "# --warmup_steps 500 --restart_warmup_steps 400 --cycle_length 1000 --num_training_steps 10000 --min_lr_ratio 0.1\n",
    "def jagged_cosine_schedule(step, *, first_warmup_steps, restart_warmup_steps, restart_every, min_lr_ratio):\n",
    "    assert 0 < min_lr_ratio <= 1.0, \"min_lr_ratio must be in (0,1]\"\n",
    "    assert restart_every > 0, \"restart_every must be positive\"\n",
    "\n",
    "    if step < first_warmup_steps:\n",
    "        return float(step) / float(max(1, first_warmup_steps))\n",
    "\n",
    "    restart_step = step % restart_every\n",
    "    restart_number = step // restart_every\n",
    "\n",
    "    if restart_step < restart_warmup_steps:\n",
    "        # get expected lr multipler at the end of the warmup\n",
    "        end_of_warmup_progress = (\n",
    "            float(restart_number * restart_every) /\n",
    "            float(max(1, num_training_steps - first_warmup_steps))\n",
    "        )\n",
    "\n",
    "        _cosine_decay = 0.5 * (1.0 + math.cos(math.pi * end_of_warmup_progress))\n",
    "        warmup_lr_multiplier = min_lr_ratio + (1.0 - min_lr_ratio) * _cosine_decay\n",
    "    \n",
    "        return float(restart_step) / float(max(1, restart_warmup_steps)) * warmup_lr_multiplier\n",
    "\n",
    "    progress = float(step - first_warmup_steps) / float(max(1, num_training_steps - first_warmup_steps))\n",
    "    cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))\n",
    "\n",
    "    return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay\n",
    "\n",
    "\n",
    "# plot\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "num_training_steps = 10000\n",
    "first_warmup_steps = 1000\n",
    "restart_warmup_steps = 100\n",
    "restart_every = 2000\n",
    "min_lr_ratio=0.1\n",
    "\n",
    "# Generate a range of step values\n",
    "steps = np.arange(0, num_training_steps, 10)\n",
    "\n",
    "# Compute learning rate for each step\n",
    "lr_values = [\n",
    "    jagged_cosine_schedule(\n",
    "        step,\n",
    "        first_warmup_steps=first_warmup_steps,\n",
    "        restart_warmup_steps=restart_warmup_steps,\n",
    "        restart_every=restart_every,\n",
    "        min_lr_ratio=min_lr_ratio,\n",
    "    )\n",
    "    for step in steps\n",
    "]\n",
    "\n",
    "# Plot the learning rate schedule\n",
    "# figure\n",
    "fig = plt.figure(figsize=(5, 3), dpi=150)\n",
    "plt.plot(steps, lr_values)\n",
    "# plt.title('Jagged Cosine Schedule')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Learning Rate')\n",
    "plt.grid(True)\n",
    "\n",
    "# tight\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"jagged_cosine_schedule.pdf\")\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
