<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>FishLeg package &mdash; FishLeg 1.0 documentation</title>
      <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
        <script src="_static/jquery.js"></script>
        <script src="_static/underscore.js"></script>
        <script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
        <script src="_static/doctools.js"></script>
        <script src="_static/sphinx_highlight.js"></script>
        <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <script src="_static/js/theme.js"></script>
    <link rel="index" title="Index" href="genindex.html" />
    <link rel="search" title="Search" href="search.html" />
    <link rel="prev" title="optim" href="modules.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
            <a href="index.html" class="icon icon-home"> FishLeg
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="modules.html">optim</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">FishLeg package</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#submodules">Submodules</a></li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.fishleg">FishLeg.fishleg module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg.FishLeg"><code class="docutils literal notranslate"><span class="pre">FishLeg</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.fishleg_layers">FishLeg.fishleg_layers module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_layers.FishBatchNorm2d"><code class="docutils literal notranslate"><span class="pre">FishBatchNorm2d</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_layers.FishConv2d"><code class="docutils literal notranslate"><span class="pre">FishConv2d</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_layers.FishLayerNorm"><code class="docutils literal notranslate"><span class="pre">FishLayerNorm</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_layers.FishLinear"><code class="docutils literal notranslate"><span class="pre">FishLinear</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.fishleg_likelihood">FishLeg.fishleg_likelihood module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood"><code class="docutils literal notranslate"><span class="pre">BernoulliLikelihood</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood"><code class="docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood"><code class="docutils literal notranslate"><span class="pre">FixedGaussianLikelihood</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.GaussianLikelihood"><code class="docutils literal notranslate"><span class="pre">GaussianLikelihood</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood"><code class="docutils literal notranslate"><span class="pre">SoftMaxLikelihood</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.utils">FishLeg.utils module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.utils.get_named_layers_by_regex"><code class="docutils literal notranslate"><span class="pre">get_named_layers_by_regex()</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.utils.recursive_getattr"><code class="docutils literal notranslate"><span class="pre">recursive_getattr()</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.utils.recursive_setattr"><code class="docutils literal notranslate"><span class="pre">recursive_setattr()</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.utils.update_dict"><code class="docutils literal notranslate"><span class="pre">update_dict()</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg">Module contents</a></li>
</ul>
</li>
</ul>
</li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="index.html">FishLeg</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="index.html" class="icon icon-home"></a></li>
          <li class="breadcrumb-item"><a href="modules.html">optim</a></li>
      <li class="breadcrumb-item active">FishLeg package</li>
      <li class="wy-breadcrumbs-aside">
            <a href="_sources/FishLeg.rst.txt" rel="nofollow"> View page source</a>
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <section id="fishleg-package">
<h1>FishLeg package<a class="headerlink" href="#fishleg-package" title="Permalink to this heading"></a></h1>
<section id="submodules">
<h2>Submodules<a class="headerlink" href="#submodules" title="Permalink to this heading"></a></h2>
</section>
<section id="module-FishLeg.fishleg">
<span id="fishleg-fishleg-module"></span><h2>FishLeg.fishleg module<a class="headerlink" href="#module-FishLeg.fishleg" title="Permalink to this heading"></a></h2>
<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg.</span></span><span class="sig-name descname"><span class="pre">FishLeg</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">model</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">draw</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Callable</span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><span class="pre">Module</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nll</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Callable</span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><span class="pre">Module</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aux_dataloader</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">DataLoader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">likelihood</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><span class="pre">FishLikelihood</span></a><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fish_lr</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.05</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">damping</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.5</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">weight_decay</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1e-05</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">beta</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.9</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">update_aux_every</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">10</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aux_lr</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aux_betas</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">float</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">float</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">(0.9,</span> <span class="pre">0.999)</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aux_eps</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1e-08</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_steps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_speedup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">full</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">normalization</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fine_tune</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">[]</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">skip_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">[]</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">initialization</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'uniform'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">warmup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">warmup_data</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">DataLoader</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">warmup_loss</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">Callable</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'cpu'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">config</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg.FishLeg" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Optimizer</span></code></p>
<p>Implement FishLeg algorithm.</p>
<p>As described in <a class="reference external" href="https://openreview.net/forum?id=c9lAOPvQHS">https://openreview.net/forum?id=c9lAOPvQHS</a>.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>model</strong> (<em>torch.nn.Module</em>) – a pytorch neural network module,
can be nested in a tree structure</p></li>
<li><p><strong>draw</strong> (<em>Callable</em><em>[</em><em>[</em><em>nn.Module</em><em>, </em><em>torch.Tensor</em><em>]</em><em>, </em><em>Tuple</em><em>[</em><em>torch.Tensor</em><em>, </em><em>torch.Tensor</em><em>]</em><em>]</em>) – Sampling function that takes a model <span class="math notranslate nohighlight">\(f\)</span> and input data <span class="math notranslate nohighlight">\(\mathbf X\)</span>,
and returns <span class="math notranslate nohighlight">\((\mathbf X, \mathbf y)\)</span>,
where <span class="math notranslate nohighlight">\(\mathbf y\)</span> is sampled from
the conditional distribution <span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf X))\)</span></p></li>
<li><p><strong>nll</strong> (<em>Callable</em><em>[</em><em>[</em><em>nn.Module</em><em>, </em><em>Tuple</em><em>[</em><em>torch.Tensor</em><em>, </em><em>torch.Tensor</em><em>]</em><em>]</em><em>, </em><em>torch.Tensor</em><em>]</em>) – A function that takes a model and data, and evaluate the negative
log-likelihood.</p></li>
<li><p><strong>aux_dataloader</strong> (<em>torch.utiles.data.DataLoader</em>) – A function that takes a batch size as input and output dataset
with corresponding size.</p></li>
</ul>
</dd>
</dl>
<dl class="simple">
<dt>:param FishLikelihood likelihood<span class="classifier">a FishLeg likelihood, with Qv method if</span></dt><dd><p>any parameters are learnable.</p>
</dd>
</dl>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>fish_lr</strong> (<em>float</em>) – Learning rate,
for the parameters of the input model using FishLeg (default: 1e-2)</p></li>
<li><p><strong>damping</strong> (<em>float</em>) – Static damping applied to Fisher matrix, <span class="math notranslate nohighlight">\(\gamma\)</span>,
for stability when FIM becomes near-singular. (default: 5e-1)</p></li>
<li><p><strong>weight_decay</strong> (<em>float</em>) – L2 penalty on weights (default: 1e-5)</p></li>
<li><p><strong>beta</strong> (<em>float</em>) – coefficient for running averages of gradient (default: 0.9)</p></li>
<li><p><strong>update_aux_every</strong> (<em>int</em>) – Number of iteration after which an auxiliary
update is executed, if negative, then run -update_aux_every auxiliary
updates in each outer iteration. (default: 10)</p></li>
<li><p><strong>aux_lr</strong> (<em>float</em>) – learning rate for the auxiliary parameters,
using Adam (default: 1e-3)</p></li>
<li><p><strong>aux_betas</strong> (<em>Tuple</em><em>[</em><em>float</em><em>, </em><em>float</em><em>]</em>) – Coefficients used for computing
running averages of gradient and its square for auxiliary parameters
(default: (0.9, 0.999))</p></li>
<li><p><strong>aux_eps</strong> (<em>float</em>) – Term added to the denominator to improve
numerical stability for auxiliary parameters (default: 1e-8)</p></li>
<li><p><strong>batch_speedup</strong> (<em>bool</em>) – Whether to use speed-up Qv product (default: False)</p></li>
<li><p><strong>full</strong> (<em>bool</em>) – Whether to use full inner and outer diagonal rescalling
for block Kronecker approximation of Q. (default: True)</p></li>
<li><p><strong>normalization</strong> (<em>bool</em>) – Whether to use normalization on gradients when calculating
the auxiliary loss, this is important to learn about curvature even when
gradients are small (default: False)</p></li>
<li><p><strong>fine_tune</strong> (<em>bool</em>) – Whether to use Fisher as preconditioner of pretrained tasks,
and fine-tune on a downstream task. If True, Q will be fixed and
continual learning will be performed (default: False)</p></li>
<li><p><strong>module_names</strong> (<em>List</em>) – A List of module names wished to be optimized/pruned by FishLeg. 
(default: [], meaning all modules optimized/pruned by FishLeg)</p></li>
<li><p><strong>initialization</strong> (<em>string</em>) – Initialization of weights (default: uniform)</p></li>
<li><p><strong>warmup</strong> (<em>int</em>) – If warmup is zero, the default SGD warmup will be used, where Q is
initialized as a scaled identity matrix. If warmup is positive, the diagonal
of Q will be initialized as <span class="math notranslate nohighlight">\(\frac{1}{g^2 + \gamma}\)</span>; and in this case,
warmup_data and warmup_loss should be provided for sampling of gradients.</p></li>
<li><p><strong>scale</strong> (<em>float</em>) – Help specify initial scale of the inverse Fisher Information matrix
approximation. If using SGD warmup we suggest, <span class="math notranslate nohighlight">\(\eta=\gamma^{-1}\)</span>. If
warmup is positive, scale should be 1. (default: 1)</p></li>
<li><p><strong>device</strong> (<em>str</em>) – The device where calculations will be performed using PyTorch Tensors.</p></li>
</ul>
</dd>
</dl>
<dl>
<dt>Example:</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">aux_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">train_data</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">train_data</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">likelihood</span> <span class="o">=</span> <span class="n">FixedGaussianLikelihood</span><span class="p">(</span><span class="n">sigma</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="k">def</span> <span class="nf">nll</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span><span class="p">):</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">pred_y</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="k">return</span> <span class="n">likelihood</span><span class="o">.</span><span class="n">nll</span><span class="p">(</span><span class="n">data_y</span><span class="p">,</span> <span class="n">pred_y</span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="k">def</span> <span class="nf">draw</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">data_x</span><span class="p">):</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">pred_y</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="k">return</span> <span class="n">likelihood</span><span class="o">.</span><span class="n">draw</span><span class="p">(</span><span class="n">pred_y</span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="gp">&gt;&gt;&gt; </span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">opt</span> <span class="o">=</span> <span class="n">FishLeg</span><span class="p">(</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">model</span><span class="p">,</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">draw</span><span class="p">,</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nll</span><span class="p">,</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">aux_loader</span>
<span class="gp">&gt;&gt;&gt; </span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="k">for</span> <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">opt</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">pred_y</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MSELoss</span><span class="p">()(</span><span class="n">data_y</span><span class="p">,</span> <span class="n">pred_y</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="k">if</span> <span class="n">iteration</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="gp">&gt;&gt;&gt; </span>        <span class="nb">print</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
</pre></div>
</div>
</dd>
</dl>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.init_model_aux">
<span class="sig-name descname"><span class="pre">init_model_aux</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">model</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">skip_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">config</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Union</span><span class="p"><span class="pre">[</span></span><span class="pre">Module</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">List</span><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.init_model_aux" title="Permalink to this definition"></a></dt>
<dd><p>Given a model to optimize, parameters can be devided to</p>
<ol class="arabic simple">
<li><p>those fixed as pre-trained.</p></li>
<li><p>those required to optimize using FishLeg.</p></li>
</ol>
<p>Replace modules in the second group with FishLeg modules.</p>
<dl class="simple">
<dt>Args:</dt><dd><dl class="simple">
<dt>model (<code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.Module</span></code>, required):</dt><dd><p>A model containing modules to replace with FishLeg modules
containing extra functionality related to FishLeg algorithm.</p>
</dd>
</dl>
</dd>
<dt>Returns:</dt><dd><p><code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.Module</span></code>, the replaced model.</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.pretrain_fish">
<span class="sig-name descname"><span class="pre">pretrain_fish</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">DataLoader</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Callable</span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><span class="pre">Module</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">iterations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">10000</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">difference</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">verbose</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">testloader</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">DataLoader</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_size</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">500</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fisher</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.pretrain_fish" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.step">
<span class="sig-name descname"><span class="pre">step</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">closure</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.step" title="Permalink to this definition"></a></dt>
<dd><p>Performes a single optimization step of FishLeg.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.update_aux">
<span class="sig-name descname"><span class="pre">update_aux</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">train</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">fisher</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.update_aux" title="Permalink to this definition"></a></dt>
<dd><p>Performs a single auxliarary parameter update
using Adam. By minimizing the following objective:</p>
<div class="math notranslate nohighlight">
\[nll(model, \theta + \epsilon Q(\lambda)g) + nll(model, \theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g\]</div>
<p>where <span class="math notranslate nohighlight">\(\theta\)</span> is the parameters of model, <span class="math notranslate nohighlight">\(\lambda\)</span> is the
auxliarary parameters.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.warmup_aux">
<span class="sig-name descname"><span class="pre">warmup_aux</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">dataloader</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">DataLoader</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">loss</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">Callable</span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><span class="pre">Module</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1.0</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.warmup_aux" title="Permalink to this definition"></a></dt>
<dd><p>Warm up auxilirary parameters,
if warmup is larger zero, follow approxiamte Adam,
if warmup is zero, follow SGD</p>
</dd></dl>

</dd></dl>

</section>
<section id="module-FishLeg.fishleg_layers">
<span id="fishleg-fishleg-layers-module"></span><h2>FishLeg.fishleg_layers module<a class="headerlink" href="#module-FishLeg.fishleg_layers" title="Permalink to this heading"></a></h2>
<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_layers.</span></span><span class="sig-name descname"><span class="pre">FishBatchNorm2d</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">num_features</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eps</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1e-05</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">momentum</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">affine</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">track_running_stats</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_scale</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">BatchNorm2d</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">FishModule</span></code></p>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d.Qv">
<span class="sig-name descname"><span class="pre">Qv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">full</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d.Qv" title="Permalink to this definition"></a></dt>
<dd><p><span class="math notranslate nohighlight">\(Q(\lambda)\)</span> is a positive definite matrix which will effectively
estimate the inverse damped Fisher Information Matrix. Appropriate choices
for <span class="math notranslate nohighlight">\(Q\)</span> should take into account the architecture of the model/module.
It is usually parameterized as a positive definite Kronecker-factored
block-diagonal matrix, with block sizes reflecting the layer structure of
the neural networks.</p>
<dl>
<dt>Args:</dt><dd><dl class="simple">
<dt>aux: (Dict, required): auxiliary parameters,</dt><dd><p><span class="math notranslate nohighlight">\(\lambda\)</span>, a dictionary with keys, the name
of the auxiliary parameters, and values, the auxiliary parameters
of the module. These auxiliaray parameters will form <span class="math notranslate nohighlight">\(Q(\lambda)\)</span>.</p>
</dd>
<dt>v: (Tuple[Tensor, …], required): Values of the original parameters,</dt><dd><p>in an order that align with <cite>self.order</cite>, to multiply with
<span class="math notranslate nohighlight">\(Q(\lambda)\)</span>.</p>
</dd>
</dl>
<p>full: (bool, optional), whether to use full inner and outer re-scaling</p>
</dd>
<dt>Returns:</dt><dd><dl class="simple">
<dt>Tuple[Tensor, …]: The calculated <span class="math notranslate nohighlight">\(Q(\lambda)v\)</span> products,</dt><dd><p>in same order with <cite>self.order</cite>.</p>
</dd>
</dl>
</dd>
</dl>
</dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d.affine">
<span class="sig-name descname"><span class="pre">affine</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">bool</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d.affine" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d.diagQ">
<span class="sig-name descname"><span class="pre">diagQ</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d.diagQ" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d.eps">
<span class="sig-name descname"><span class="pre">eps</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">float</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d.eps" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d.momentum">
<span class="sig-name descname"><span class="pre">momentum</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">float</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d.momentum" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d.num_features">
<span class="sig-name descname"><span class="pre">num_features</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d.num_features" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishBatchNorm2d.track_running_stats">
<span class="sig-name descname"><span class="pre">track_running_stats</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">bool</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishBatchNorm2d.track_running_stats" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_layers.</span></span><span class="sig-name descname"><span class="pre">FishConv2d</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_channels</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_channels</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">stride</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dilation</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">groups</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding_mode</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'zeros'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Conv2d</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">FishModule</span></code></p>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.Qv">
<span class="sig-name descname"><span class="pre">Qv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">full</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.Qv" title="Permalink to this definition"></a></dt>
<dd><p>Inspired by KFAC’s conv2D layer by Grosse and Martens: Kronecker product of sizes (out_channels ⊗  (in_channels_eff * k_size))</p>
</dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.bias">
<span class="sig-name descname"><span class="pre">bias</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.bias" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.diagQ">
<span class="sig-name descname"><span class="pre">diagQ</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.diagQ" title="Permalink to this definition"></a></dt>
<dd><p>Similar maths as the Linear layer</p>
</dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.dilation">
<span class="sig-name descname"><span class="pre">dilation</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">int</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.dilation" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.groups">
<span class="sig-name descname"><span class="pre">groups</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.groups" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.in_channels">
<span class="sig-name descname"><span class="pre">in_channels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.in_channels" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.kernel_size">
<span class="sig-name descname"><span class="pre">kernel_size</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">int</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.kernel_size" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.out_channels">
<span class="sig-name descname"><span class="pre">out_channels</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.out_channels" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.output_padding">
<span class="sig-name descname"><span class="pre">output_padding</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">int</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.output_padding" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.padding">
<span class="sig-name descname"><span class="pre">padding</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Union</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">int</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.padding" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.padding_mode">
<span class="sig-name descname"><span class="pre">padding_mode</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.padding_mode" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.stride">
<span class="sig-name descname"><span class="pre">stride</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">int</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.stride" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.transposed">
<span class="sig-name descname"><span class="pre">transposed</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">bool</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.transposed" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.warmup">
<span class="sig-name descname"><span class="pre">warmup</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_scale</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1.0</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.warmup" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishConv2d.weight">
<span class="sig-name descname"><span class="pre">weight</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tensor</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishConv2d.weight" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLayerNorm">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_layers.</span></span><span class="sig-name descname"><span class="pre">FishLayerNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">normalized_shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eps</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1e-05</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">elementwise_affine</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_scale</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLayerNorm" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">LayerNorm</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">FishModule</span></code></p>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLayerNorm.Qv">
<span class="sig-name descname"><span class="pre">Qv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">full</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLayerNorm.Qv" title="Permalink to this definition"></a></dt>
<dd><p><span class="math notranslate nohighlight">\(Q(\lambda)\)</span> is a positive definite matrix which will effectively
estimate the inverse damped Fisher Information Matrix. Appropriate choices
for <span class="math notranslate nohighlight">\(Q\)</span> should take into account the architecture of the model/module.
It is usually parameterized as a positive definite Kronecker-factored
block-diagonal matrix, with block sizes reflecting the layer structure of
the neural networks.</p>
<dl>
<dt>Args:</dt><dd><dl class="simple">
<dt>aux: (Dict, required): auxiliary parameters,</dt><dd><p><span class="math notranslate nohighlight">\(\lambda\)</span>, a dictionary with keys, the name
of the auxiliary parameters, and values, the auxiliary parameters
of the module. These auxiliaray parameters will form <span class="math notranslate nohighlight">\(Q(\lambda)\)</span>.</p>
</dd>
<dt>v: (Tuple[Tensor, …], required): Values of the original parameters,</dt><dd><p>in an order that align with <cite>self.order</cite>, to multiply with
<span class="math notranslate nohighlight">\(Q(\lambda)\)</span>.</p>
</dd>
</dl>
<p>full: (bool, optional), whether to use full inner and outer re-scaling</p>
</dd>
<dt>Returns:</dt><dd><dl class="simple">
<dt>Tuple[Tensor, …]: The calculated <span class="math notranslate nohighlight">\(Q(\lambda)v\)</span> products,</dt><dd><p>in same order with <cite>self.order</cite>.</p>
</dd>
</dl>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLayerNorm.diagQ">
<span class="sig-name descname"><span class="pre">diagQ</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLayerNorm.diagQ" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLayerNorm.elementwise_affine">
<span class="sig-name descname"><span class="pre">elementwise_affine</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">bool</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLayerNorm.elementwise_affine" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLayerNorm.eps">
<span class="sig-name descname"><span class="pre">eps</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">float</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLayerNorm.eps" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLayerNorm.normalized_shape">
<span class="sig-name descname"><span class="pre">normalized_shape</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">int</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLayerNorm.normalized_shape" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_layers.</span></span><span class="sig-name descname"><span class="pre">FishLinear</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_features</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_features</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Linear</span></code>, <code class="xref py py-class docutils literal notranslate"><span class="pre">FishModule</span></code></p>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.Qg">
<span class="sig-name descname"><span class="pre">Qg</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.Qg" title="Permalink to this definition"></a></dt>
<dd><p>Speed up Qg product, when batch size is smaller than parameter size.
By chain rule:</p>
<div class="math notranslate nohighlight">
\[DW_i = g_i\hat{a}^T_{i-1}\]</div>
<p>where <span class="math notranslate nohighlight">\(DW_i\)</span> is gradient of parameter of the ith layer, <span class="math notranslate nohighlight">\(g_i\)</span> is
gradient w.r.t output of ith layer and <span class="math notranslate nohighlight">\(\hat{a}_i\)</span> is input to ith layer,
and output of (i-1)th layer.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.Qv">
<span class="sig-name descname"><span class="pre">Qv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">full</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.Qv" title="Permalink to this definition"></a></dt>
<dd><p>For fully-connected layers, the default structure of <span class="math notranslate nohighlight">\(Q\)</span> as a
block-diaglonal matrix is,
.. math:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Q_l</span> <span class="o">=</span> <span class="p">(</span><span class="n">R_lR_l</span><span class="o">^</span><span class="n">T</span> \<span class="n">otimes</span> <span class="n">L_lL_l</span><span class="o">^</span><span class="n">T</span><span class="p">)</span>
</pre></div>
</div>
<p>where <span class="math notranslate nohighlight">\(l\)</span> denotes the l-th layer. The matrix <span class="math notranslate nohighlight">\(R_l\)</span> has size
<span class="math notranslate nohighlight">\((N_{l-1} + 1) \times (N_{l-1} + 1)\)</span> while the matrix <span class="math notranslate nohighlight">\(L_l\)</span> has
size <span class="math notranslate nohighlight">\(N_l \times N_l\)</span>. The auxiliarary parameters <span class="math notranslate nohighlight">\(\lambda\)</span>
are represented by the matrices <span class="math notranslate nohighlight">\(L_l, R_l\)</span>. For a Kronecker form that
introduces full inner and outer diagonal rescaling structure is,</p>
<div class="math notranslate nohighlight">
\[Q_l = A_l(L_l \otimes R_l^T) D_l^2 (L_l^T \otimes R_l) A_l\]</div>
<p>where <span class="math notranslate nohighlight">\(A_l\)</span> and <span class="math notranslate nohighlight">\(D_l\)</span> are two diagonal matrices of the
appropriate size.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.diagQ">
<span class="sig-name descname"><span class="pre">diagQ</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tuple</span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.diagQ" title="Permalink to this definition"></a></dt>
<dd><p>The Q matrix defines the inverse fisher approximation as below:</p>
<div class="math notranslate nohighlight">
\[Q_l = (R_lR_l^T \otimes L_lL_l^T)\]</div>
<p>where <span class="math notranslate nohighlight">\(l\)</span> denotes the l-th layer. The matrix <span class="math notranslate nohighlight">\(R_l\)</span> has size
<span class="math notranslate nohighlight">\((N_{l-1} + 1) \times (N_{l-1} + 1)\)</span> while the matrix <span class="math notranslate nohighlight">\(L_l\)</span> has
size <span class="math notranslate nohighlight">\(N_l \times N_l\)</span>. The auxiliarary parameters <span class="math notranslate nohighlight">\(\lambda\)</span>
are represented by the matrices <span class="math notranslate nohighlight">\(L_l, R_l\)</span>.</p>
<p>The diagonal of this matrix is therefore calculated by</p>
<div class="math notranslate nohighlight">
\[\text{diag}(Q_l) = \text{diag}(R_l R_l^T) \otimes \text{diag}(L_l L_l^T)\]</div>
<p>where <span class="math notranslate nohighlight">\(\text{diag}\)</span> involves summing over the columns of the and <span class="math notranslate nohighlight">\(\otimes\)</span> remains as
the Kronecker product.</p>
</dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.in_features">
<span class="sig-name descname"><span class="pre">in_features</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.in_features" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.out_features">
<span class="sig-name descname"><span class="pre">out_features</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.out_features" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.save_layer_grad_output">
<span class="sig-name descname"><span class="pre">save_layer_grad_output</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">grad_output</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.save_layer_grad_output" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.save_layer_input">
<span class="sig-name descname"><span class="pre">save_layer_input</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">input_</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.save_layer_input" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.warmup">
<span class="sig-name descname"><span class="pre">warmup</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Optional</span><span class="p"><span class="pre">[</span></span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_speedup</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_scale</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1.0</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.warmup" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.weight">
<span class="sig-name descname"><span class="pre">weight</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tensor</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.weight" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</dd></dl>

</section>
<section id="module-FishLeg.fishleg_likelihood">
<span id="fishleg-fishleg-likelihood-module"></span><h2>FishLeg.fishleg_likelihood module<a class="headerlink" href="#module-FishLeg.fishleg_likelihood" title="Permalink to this heading"></a></h2>
<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.BernoulliLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">BernoulliLikelihood</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'cpu'</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></p>
<p>The Bernoulli likelihood used for classification.
Using the standard Normal CDF <span class="math notranslate nohighlight">\(\Phi(x)\)</span>) and the identity
<span class="math notranslate nohighlight">\(\Phi(-x) = 1-\Phi(x)\)</span>, we can write the likelihood as:</p>
<div class="math notranslate nohighlight">
\[p(y|f(x))=\Phi(yf(x))\]</div>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.BernoulliLikelihood.draw">
<span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.BernoulliLikelihood.nll">
<span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">observations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FishLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">FishLikelihood</span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">object</span></code></p>
<p>A Likelihood in FishLeg specifies a probablistic modeling, which attributes
the mapping from latent function values 
<span class="math notranslate nohighlight">\(f(\mathbf X)\)</span> to observed labels <span class="math notranslate nohighlight">\(y\)</span>.</p>
<p>For example, in the case of regression, 
a Gaussian likelihood can be chosen, as</p>
<div class="math notranslate nohighlight">
\[y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I)\]</div>
<p>As for the case of classification, 
a Bernoulli distribution can be chosen</p>
<div class="math notranslate nohighlight">
\[\begin{split}y(\mathbf x) = \begin{cases}
    1 &amp; \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\
    0 &amp; \text{w/ probability} \:\: 1-\sigma(f(\mathbf x))
\end{cases}\end{split}\]</div>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FishLikelihood.draw">
<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FishLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FishLikelihood.get_parameters">
<span class="sig-name descname"><span class="pre">get_parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FishLikelihood.get_parameters" title="Permalink to this definition"></a></dt>
<dd><p>return a list of learnable parameter.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FishLikelihood.nll">
<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">observations</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FishLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">FixedGaussianLikelihood</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">sigma</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'cpu'</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></p>
<p>The standard likelihood for regression,
but assuming fixed heteroscedastic noise.</p>
<div class="math notranslate nohighlight">
\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]</div>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>sigma</strong> (<em>torch.Tensor</em>) – Known observation
standard deviation for each example.</p>
</dd>
</dl>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood.draw">
<span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py property">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood.get_variance">
<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">get_variance</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tensor</span></em><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood.get_variance" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood.nll">
<span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">observations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.GaussianLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">GaussianLikelihood</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">sigma</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'cpu'</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.GaussianLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></p>
<p>The standard likelihood for regression,
but assuming fixed heteroscedastic noise.</p>
<div class="math notranslate nohighlight">
\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]</div>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>sigma</strong> (<em>torch.Tensor</em>) – standard deviation for each example;
also to be learned during training.</p>
</dd>
</dl>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.GaussianLikelihood.Qv">
<span class="sig-name descname"><span class="pre">Qv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.GaussianLikelihood.Qv" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.GaussianLikelihood.draw">
<span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.GaussianLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.GaussianLikelihood.get_aux_parameters">
<span class="sig-name descname"><span class="pre">get_aux_parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.GaussianLikelihood.get_aux_parameters" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.GaussianLikelihood.get_parameters">
<span class="sig-name descname"><span class="pre">get_parameters</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.GaussianLikelihood.get_parameters" title="Permalink to this definition"></a></dt>
<dd><p>return a list of learnable parameter.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.GaussianLikelihood.init_aux">
<span class="sig-name descname"><span class="pre">init_aux</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">init_scale</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.GaussianLikelihood.init_aux" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.GaussianLikelihood.nll">
<span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">observations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.GaussianLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.SoftMaxLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">SoftMaxLikelihood</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">str</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">'cpu'</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></p>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.SoftMaxLikelihood.draw">
<span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.SoftMaxLikelihood.nll">
<span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">observations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</section>
<section id="module-FishLeg.utils">
<span id="fishleg-utils-module"></span><h2>FishLeg.utils module<a class="headerlink" href="#module-FishLeg.utils" title="Permalink to this heading"></a></h2>
<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.utils.get_named_layers_by_regex">
<span class="sig-prename descclassname"><span class="pre">FishLeg.utils.</span></span><span class="sig-name descname"><span class="pre">get_named_layers_by_regex</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">module</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_names</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">str</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">params_strict</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">List</span><span class="p"><span class="pre">[</span></span><span class="pre">NamedLayerParam</span><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#FishLeg.utils.get_named_layers_by_regex" title="Permalink to this definition"></a></dt>
<dd><dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>module</strong> – the module to get the matching layers and params from</p></li>
<li><p><strong>param_names</strong> – a list of names or regex patterns to match with full parameter
paths. Regex patterns must be specified with the prefix ‘re:’</p></li>
<li><p><strong>params_strict</strong> – if True, this function will raise an exception if there a
parameter is not found to match every name or regex in param_names</p></li>
</ul>
</dd>
<dt class="field-even">Returns</dt>
<dd class="field-even"><p>a list of NamedLayerParam tuples whose full parameter names in the given
module match one of the given regex patterns or parameter names</p>
</dd>
</dl>
</dd></dl>

<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.utils.recursive_getattr">
<span class="sig-prename descclassname"><span class="pre">FishLeg.utils.</span></span><span class="sig-name descname"><span class="pre">recursive_getattr</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attr</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.utils.recursive_getattr" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.utils.recursive_setattr">
<span class="sig-prename descclassname"><span class="pre">FishLeg.utils.</span></span><span class="sig-name descname"><span class="pre">recursive_setattr</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attr</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.utils.recursive_setattr" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.utils.update_dict">
<span class="sig-prename descclassname"><span class="pre">FishLeg.utils.</span></span><span class="sig-name descname"><span class="pre">update_dict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">replace</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Module</span></span></span><a class="headerlink" href="#FishLeg.utils.update_dict" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</section>
<section id="module-FishLeg">
<span id="module-contents"></span><h2>Module contents<a class="headerlink" href="#module-FishLeg" title="Permalink to this heading"></a></h2>
</section>
</section>


           </div>
          </div>
          <footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
        <a href="modules.html" class="btn btn-neutral float-left" title="optim" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
    </div>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2023, MTK.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>