{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configuration Showcase\n",
    "\n",
    "This notebook serves as a showcase how the different configurations available in\n",
    "`Trainax` can be depicted schematically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "import trainax as tx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Supervised\n",
    "\n",
    "A supervised configuration is special because all data can be pre-computed. No\n",
    "`ref_stepper` or `residuum_fn` is needed on the fly (and hence also does not\n",
    "have to be differentiable)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### One-Step supervised\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-4\" width=\"400\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Supervised(\n",
       "  num_rollout_steps=1,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[1]\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The default is one-step supervised learning\n",
    "tx.configuration.Supervised()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Step supervised (rollout) Training\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-4\" width=\"400\">\n",
    "</p>\n",
    "\n",
    "We roll out the neural emulator for two autoregressive steps. Its parameters are\n",
    "shared between the two predictions. Similarly, the `ref_stepper` is used to\n",
    "create the reference trajectory; the loss is aggregated as a sum over the two\n",
    "time levels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Supervised(\n",
       "  num_rollout_steps=2,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Supervised(num_rollout_steps=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Step supervised (rollout) Training\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-4\" width=\"400\">\n",
    "</p>\n",
    "\n",
    "Same idead as above but with an additional rollout step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Supervised(\n",
       "  num_rollout_steps=3,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Supervised(num_rollout_steps=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Step supervised (rollout) Training with loss only at final state\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-4\" width=\"400\">\n",
    "</p>\n",
    "\n",
    "The loss is only taken from the last step. Essentially, this corresponds to\n",
    "weighting the time levels with $[0, 0, 1]$, respectively. (More weighting\n",
    "options are possible, of course.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Supervised(\n",
       "  num_rollout_steps=3,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Supervised(\n",
    "    num_rollout_steps=3, time_level_weights=jnp.array([0.0, 0.0, 1.0])\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-Step supervised (rollout) Training with no backpropagation through time\n",
    "\n",
    "(Displays the primal evaluation together with the cotangent flow; grey dashed\n",
    "line indicates a cutted gradient.)\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"400\">\n",
    "</p>\n",
    "\n",
    "This interrupts a gradient flow backward over the autoregressive network\n",
    "execution. Gradients can still flow into the parameter space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Supervised(\n",
       "  num_rollout_steps=3,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=True,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Supervised(num_rollout_steps=3, cut_bptt=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Four Steps supervised (rollout) Training with sparse backpropagation through time\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"700\">\n",
    "</p>\n",
    "\n",
    "Only every second backpropagation step is allowed to flow through the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Supervised(\n",
       "  num_rollout_steps=4,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=True,\n",
       "  cut_bptt_every=2,\n",
       "  time_level_weights=f32[4]\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Supervised(num_rollout_steps=4, cut_bptt=True, cut_bptt_every=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Diverted Chain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Steps with branch length one\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"500\">\n",
    "</p>\n",
    "\n",
    "The `ref_stepper` is not run autoregressively for two steps from the initial\n",
    "condition but rather for one step, branching off from the main chain created by\n",
    "the emulator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChainBranchOne(\n",
       "  num_rollout_steps=2,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=False,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# `num_rollout_steps` referse to the number of autoregressive steps performed by\n",
    "# the neural emulator\n",
    "tx.configuration.DivertedChainBranchOne(num_rollout_steps=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChain(\n",
       "  num_rollout_steps=2,\n",
       "  num_branch_steps=1,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=False,\n",
       "  time_level_weights=f32[2],\n",
       "  branch_level_weights=f32[1]\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Alternatively, the general interface can be used\n",
    "tx.configuration.DivertedChain(num_rollout_steps=2, num_branch_steps=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-steps with branch length one\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"600\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChainBranchOne(\n",
       "  num_rollout_steps=3,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=False,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.DivertedChainBranchOne(num_rollout_steps=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Four-steps with branch length one\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"700\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChainBranchOne(\n",
       "  num_rollout_steps=4,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=False,\n",
       "  time_level_weights=f32[4]\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.DivertedChainBranchOne(num_rollout_steps=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three-steps with branch length two\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"600\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChain(\n",
       "  num_rollout_steps=3,\n",
       "  num_branch_steps=2,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=False,\n",
       "  time_level_weights=f32[3],\n",
       "  branch_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Can only be done with the general interface\n",
    "tx.configuration.DivertedChain(num_rollout_steps=3, num_branch_steps=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Steps with no differentiable physics\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"500\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChainBranchOne(\n",
       "  num_rollout_steps=2,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=True,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.DivertedChainBranchOne(\n",
    "    num_rollout_steps=2,\n",
    "    cut_div_chain=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Steps with no backpropagation through time\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"500\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChainBranchOne(\n",
       "  num_rollout_steps=2,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=True,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=False,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.DivertedChainBranchOne(\n",
    "    num_rollout_steps=2,\n",
    "    cut_bptt=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Steps with no backpropagation through time and no differentiable physics\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"500\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DivertedChainBranchOne(\n",
       "  num_rollout_steps=2,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=True,\n",
       "  cut_bptt_every=1,\n",
       "  cut_div_chain=True,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.DivertedChainBranchOne(\n",
    "    num_rollout_steps=2,\n",
    "    cut_bptt=True,\n",
    "    cut_div_chain=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mix-Chain\n",
    "\n",
    "So far, `Trainax` only supports \"post-physics\" mixing, meaning that the main\n",
    "chain is built by first performing a specified number of autoregressive network\n",
    "steps, and then a specified number of `ref_stepper` steps.\n",
    "\n",
    "The reference trajectory is always built by autoregressively unrolling the\n",
    "`ref_stepper`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### One-Step Network with one Step Physics\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"500\">\n",
    "</p>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MixChainPostPhysics(\n",
       "  num_rollout_steps=1,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  num_post_physics_steps=1,\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.MixChainPostPhysics(\n",
    "    num_rollout_steps=1,\n",
    "    num_post_physics_steps=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### One-Step Network with one step physics and loss only at final state\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"500\">\n",
    "</p>\n",
    "\n",
    "Similar to the supervised setting, this is achieved by choosing proper\n",
    "`time_level_weights`. For `MixChainPostPhysics` the `time_level_weights` refer\n",
    "to the entire main chain, i.e., the trajectory created by the former network\n",
    "steps and the latter physics steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MixChainPostPhysics(\n",
       "  num_rollout_steps=1,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  num_post_physics_steps=1,\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.MixChainPostPhysics(\n",
    "    num_rollout_steps=1,\n",
    "    num_post_physics_steps=1,\n",
    "    time_level_weights=jnp.array([0.0, 1.0]),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Step Network with one step physics\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"600\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MixChainPostPhysics(\n",
       "  num_rollout_steps=1,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  num_post_physics_steps=2,\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.MixChainPostPhysics(\n",
    "    num_rollout_steps=1,\n",
    "    num_post_physics_steps=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two-Step Network with one step physics and no backpropagation through time\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"600\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MixChainPostPhysics(\n",
       "  num_rollout_steps=1,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  num_post_physics_steps=2,\n",
       "  cut_bptt=True,\n",
       "  cut_bptt_every=1,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.MixChainPostPhysics(\n",
    "    num_rollout_steps=1,\n",
    "    num_post_physics_steps=2,\n",
    "    cut_bptt=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Residuum\n",
    "\n",
    "Instead of having a `ref_stepper` that can be unrolled autoregressively, these\n",
    "configurations rely on a `residuum_fn` that defines a condition based on two\n",
    "consecutive time levels."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### One-Step Residuum\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"350\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Residuum(\n",
       "  num_rollout_steps=1,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_prev=False,\n",
       "  cut_next=False,\n",
       "  time_level_weights=f32[1]\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Residuum(\n",
    "    num_rollout_steps=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Two Steps Residuum Training\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"450\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Residuum(\n",
       "  num_rollout_steps=2,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_prev=False,\n",
       "  cut_next=False,\n",
       "  time_level_weights=f32[2]\n",
       ")"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Residuum(\n",
    "    num_rollout_steps=2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three Steps Residuum Training\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"550\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Residuum(\n",
       "  num_rollout_steps=3,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=False,\n",
       "  cut_bptt_every=1,\n",
       "  cut_prev=False,\n",
       "  cut_next=False,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Residuum(\n",
    "    num_rollout_steps=3,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three Steps Residuum with no backpropagation through time\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"550\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Residuum(\n",
       "  num_rollout_steps=3,\n",
       "  time_level_loss=MSELoss(batch_reduction=<function mean>),\n",
       "  cut_bptt=True,\n",
       "  cut_bptt_every=1,\n",
       "  cut_prev=False,\n",
       "  cut_next=False,\n",
       "  time_level_weights=f32[3]\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tx.configuration.Residuum(num_rollout_steps=3, cut_bptt=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Other Residuum Options\n",
    "\n",
    "It is possible to cut the `prev` and `next` contribution to the `residuum_fn`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Teacher Forcing\n",
    "\n",
    "Resets the main chain with information from the autoregressive reference chain.\n",
    "It is essentially the opposite of diverted chain learning.\n",
    "\n",
    "It has similarities as if one selected minibatches over the entire trajectories.\n",
    "However, this setup guarantees that within one gradient update, multiple\n",
    "consecutive time levels are considered without having the network to rollout\n",
    "autoregressively."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Three Steps teacher forcing with reset every step\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"550\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO Implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Four Steps teacher forcing with reset every second step\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"750\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Four Steps teacher forcing with reset every second step and no backpropagation through time\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"750\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How about correction learning?\n",
    "\n",
    "All the above mentioned setups are also usable for correction learning, i.e.,\n",
    "when the emulator is not just a pure network but has some (differentiable)\n",
    "(coarse) solver component. For example, in the case of sequential correction\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img src=\"XXXX-5\" width=\"350\">\n",
    "</p>\n",
    "\n",
    "See [this](XXXX-5) websites for options\n",
    "of potential corrector layouts and options to cut gradients within it.\n",
    "\n",
    "All these layouts are **not** provided by `Trainax`. This is just to showcase\n",
    "that the configurations of `Trainax` can be used in a more general context."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
