<!DOCTYPE html>
<html class="writer-html5" lang="en" >
<head>
  <meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

  <meta name="viewport" content="width=device-width, initial-scale=1.0" />
  <title>FishLeg package &mdash; FishLeg 1.0 documentation</title>
      <link rel="stylesheet" href="_static/pygments.css" type="text/css" />
      <link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
  <!--[if lt IE 9]>
    <script src="_static/js/html5shiv.min.js"></script>
  <![endif]-->
  
        <script data-url_root="./" id="documentation_options" src="_static/documentation_options.js"></script>
        <script src="_static/jquery.js"></script>
        <script src="_static/underscore.js"></script>
        <script src="_static/_sphinx_javascript_frameworks_compat.js"></script>
        <script src="_static/doctools.js"></script>
        <script src="_static/sphinx_highlight.js"></script>
        <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <script src="_static/js/theme.js"></script>
    <link rel="index" title="Index" href="genindex.html" />
    <link rel="search" title="Search" href="search.html" />
    <link rel="prev" title="optim" href="modules.html" /> 
</head>

<body class="wy-body-for-nav"> 
  <div class="wy-grid-for-nav">
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
            <a href="index.html" class="icon icon-home"> FishLeg
          </a>
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>
        </div><div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="Navigation menu">
              <p class="caption" role="heading"><span class="caption-text">Contents:</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="modules.html">optim</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">FishLeg package</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#submodules">Submodules</a></li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.fishleg">FishLeg.fishleg module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg.FishLeg"><code class="docutils literal notranslate"><span class="pre">FishLeg</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg.update_dict"><code class="docutils literal notranslate"><span class="pre">update_dict()</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.fishleg_layers">FishLeg.fishleg_layers module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_layers.FishLinear"><code class="docutils literal notranslate"><span class="pre">FishLinear</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_layers.FishModule"><code class="docutils literal notranslate"><span class="pre">FishModule</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_layers.get_zero_grad_hook"><code class="docutils literal notranslate"><span class="pre">get_zero_grad_hook()</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.fishleg_likelihood">FishLeg.fishleg_likelihood module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood"><code class="docutils literal notranslate"><span class="pre">BernoulliLikelihood</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood"><code class="docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood"><code class="docutils literal notranslate"><span class="pre">FixedGaussianLikelihood</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood"><code class="docutils literal notranslate"><span class="pre">SoftMaxLikelihood</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg.utils">FishLeg.utils module</a><ul>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.utils.recursive_getattr"><code class="docutils literal notranslate"><span class="pre">recursive_getattr()</span></code></a></li>
<li class="toctree-l4"><a class="reference internal" href="#FishLeg.utils.recursive_setattr"><code class="docutils literal notranslate"><span class="pre">recursive_setattr()</span></code></a></li>
</ul>
</li>
<li class="toctree-l3"><a class="reference internal" href="#module-FishLeg">Module contents</a></li>
</ul>
</li>
</ul>
</li>
</ul>

        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap"><nav class="wy-nav-top" aria-label="Mobile navigation menu" >
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="index.html">FishLeg</a>
      </nav>

      <div class="wy-nav-content">
        <div class="rst-content">
          <div role="navigation" aria-label="Page navigation">
  <ul class="wy-breadcrumbs">
      <li><a href="index.html" class="icon icon-home"></a></li>
          <li class="breadcrumb-item"><a href="modules.html">optim</a></li>
      <li class="breadcrumb-item active">FishLeg package</li>
      <li class="wy-breadcrumbs-aside">
            <a href="_sources/FishLeg.rst.txt" rel="nofollow"> View page source</a>
      </li>
  </ul>
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
             
  <section id="fishleg-package">
<h1>FishLeg package<a class="headerlink" href="#fishleg-package" title="Permalink to this heading"></a></h1>
<section id="submodules">
<h2>Submodules<a class="headerlink" href="#submodules" title="Permalink to this heading"></a></h2>
</section>
<section id="module-FishLeg.fishleg">
<span id="fishleg-fishleg-module"></span><h2>FishLeg.fishleg module<a class="headerlink" href="#module-FishLeg.fishleg" title="Permalink to this heading"></a></h2>
<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg.</span></span><span class="sig-name descname"><span class="pre">FishLeg</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">model</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">draw</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Callable</span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><span class="pre">Module</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">nll</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Callable</span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><span class="pre">Module</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dataloader</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Callable</span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">[</span></span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span><span class="p"><span class="pre">]</span></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lr</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.01</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">eps</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">weight_decay</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1e-05</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">beta</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.9</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">update_aux_every</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">-3</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aux_lr</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aux_betas</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">float</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">float</span><span class="p"><span class="pre">]</span></span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">(0.9,</span> <span class="pre">0.999)</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">aux_eps</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1e-08</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">damping</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1e-05</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pre_aux_training</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">10</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">differentiable</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sgd_lr</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">0.01</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg.FishLeg" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Optimizer</span></code></p>
<p>Implement FishLeg algorithm.</p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>model</strong> (<em>torch.nn.Module</em>) – a pytorch neural network module,
can be nested in a tree structure</p></li>
<li><p><strong>draw</strong> (<em>Callable</em><em>[</em><em>[</em><em>nn.Module</em><em>, </em><em>torch.Tensor</em><em>]</em><em>, </em><em>Tuple</em><em>[</em><em>torch.Tensor</em><em>, </em><em>torch.Tensor</em><em>]</em><em>]</em>) – Sampling function that takes a model <span class="math notranslate nohighlight">\(f\)</span> and input data <span class="math notranslate nohighlight">\(\mathbf X\)</span>, 
and returns <span class="math notranslate nohighlight">\((\mathbf X, \mathbf y)\)</span>, 
where <span class="math notranslate nohighlight">\(\mathbf y\)</span> is sampled from 
the conditional distribution <span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf X))\)</span></p></li>
<li><p><strong>nll</strong> (<em>Callable</em><em>[</em><em>[</em><em>nn.Module</em><em>, </em><em>Tuple</em><em>[</em><em>torch.Tensor</em><em>, </em><em>torch.Tensor</em><em>]</em><em>]</em><em>, </em><em>torch.Tensor</em><em>]</em>) – A function that takes a model and data, and evaluate the negative
log-likelihood.</p></li>
<li><p><strong>dataloader</strong> (<em>Callable</em><em>[</em><em>[</em><em>int</em><em>]</em><em>, </em><em>Tuple</em><em>[</em><em>torch.Tensor</em><em>, </em><em>torch.Tensor</em><em>]</em><em>]</em>) – A function that takes a batch size as input and output dataset 
with corresponding size.</p></li>
<li><p><strong>lr</strong> (<em>float</em>) – learning rate,
for the parameters of the input model using FishLeg (default: 1e-2)</p></li>
<li><p><strong>eps</strong> (<em>float</em>) – a small scalar, to evaluate the auxiliary loss
in the direction of gradient of model parameters (default: 1e-4)</p></li>
<li><p><strong>update_aux_every</strong> (<em>int</em>) – number of iteration after which an auxiliary
update is executed, if negative, then run -update_aux_every auxiliary
updates in each outer iteration. (default: -3)</p></li>
<li><p><strong>aux_lr</strong> (<em>float</em>) – learning rate for the auxiliary parameters,
using Adam (default: 1e-3)</p></li>
<li><p><strong>aux_betas</strong> (<em>Tuple</em><em>[</em><em>float</em><em>, </em><em>float</em><em>]</em>) – coefficients used for computing
running averages of gradient and its square for auxiliary parameters
(default: (0.9, 0.999))</p></li>
<li><p><strong>aux_eps</strong> (<em>float</em>) – term added to the denominator to improve
numerical stability for auxiliary parameters (default: 1e-8)</p></li>
<li><p><strong>pre_aux_training</strong> (<em>int</em>) – number of auxiliary updates to make before
any update of the original parameter. This process intends to approximate
the correct Fisher Information matrix during initialization,
which is espectially important for fine-tuning of models with pretraining</p></li>
<li><p><strong>differentiable</strong> (<em>bool</em>) – whether the fused implementation (CUDA only) is used</p></li>
<li><p><strong>sgd_lr</strong> (<em>float</em>) – <p>help specify initial scale of the inverse Fisher Information matrix
approximation, <span class="math notranslate nohighlight">\(\eta\)</span>. Make sure that</p>
<div class="math notranslate nohighlight">
\[- \eta_{init} Q(\lambda) grad = - \eta_{sgd} grad\]</div>
<p>is hold in the beginning of the optimization. 
And here <span class="math notranslate nohighlight">\(\eta_{init}=\eta_{sgd}/\eta_{fl}\)</span>.</p>
</p></li>
</ul>
</dd>
</dl>
<dl>
<dt>Example:</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="n">auxloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">train_data</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">trainloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="n">train_data</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">likelihood</span> <span class="o">=</span> <span class="n">FixedGaussianLikelihood</span><span class="p">(</span><span class="n">sigma_fixed</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="k">def</span> <span class="nf">nll</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="o">=</span> <span class="n">data</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">pred_y</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="k">return</span> <span class="n">likelihood</span><span class="o">.</span><span class="n">nll</span><span class="p">(</span><span class="n">data_y</span><span class="p">,</span> <span class="n">pred_y</span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="k">def</span> <span class="nf">draw</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">data_x</span><span class="p">):</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">pred_y</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="k">return</span> <span class="p">(</span><span class="n">data_x</span><span class="p">,</span> <span class="n">likelihood</span><span class="o">.</span><span class="n">draw</span><span class="p">(</span><span class="n">pred_y</span><span class="p">))</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="k">def</span> <span class="nf">dataloader</span><span class="p">():</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">data_x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">auxloader</span><span class="p">))</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="k">return</span> <span class="n">data_x</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="gp">&gt;&gt;&gt; </span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">opt</span> <span class="o">=</span> <span class="n">FishLeg</span><span class="p">(</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">model</span><span class="p">,</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">draw</span><span class="p">,</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">nll</span><span class="p">,</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">dataloader</span>
<span class="gp">&gt;&gt;&gt; </span><span class="p">)</span>
<span class="go">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="k">for</span> <span class="n">iteration</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">data_x</span><span class="p">,</span> <span class="n">data_y</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">trainloader</span><span class="p">))</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">opt</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">pred_y</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data_x</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">loss</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MSELoss</span><span class="p">()(</span><span class="n">data_y</span><span class="p">,</span> <span class="n">pred_y</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="n">opt</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span>    <span class="k">if</span> <span class="n">iteration</span> <span class="o">%</span> <span class="mi">10</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="gp">&gt;&gt;&gt; </span>        <span class="nb">print</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
</pre></div>
</div>
</dd>
</dl>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.init_model_aux">
<span class="sig-name descname"><span class="pre">init_model_aux</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">model</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Module</span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.init_model_aux" title="Permalink to this definition"></a></dt>
<dd><p>Given a model to optimize, parameters can be devided to</p>
<ol class="arabic simple">
<li><p>those fixed as pre-trained.</p></li>
<li><p>those required to optimize using FishLeg.</p></li>
</ol>
<p>Replace modules in the second group with FishLeg modules.</p>
<dl class="simple">
<dt>Args:</dt><dd><dl class="simple">
<dt>model (<code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.Module</span></code>, required):</dt><dd><p>A model containing modules to replace with FishLeg modules
containing extra functionality related to FishLeg algorithm.</p>
</dd>
</dl>
</dd>
<dt>Returns:</dt><dd><p><code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.Module</span></code>, the replaced model.</p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.step">
<span class="sig-name descname"><span class="pre">step</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.step" title="Permalink to this definition"></a></dt>
<dd><p>Performes a single optimization step of FishLeg.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg.FishLeg.update_aux">
<span class="sig-name descname"><span class="pre">update_aux</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg.FishLeg.update_aux" title="Permalink to this definition"></a></dt>
<dd><p>Performs a single auxliarary parameter update
using Adam. By minimizing the following objective:</p>
<div class="math notranslate nohighlight">
\[nll(model, \theta + \epsilon Q(\lambda)g) + nll(model, \theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g\]</div>
<p>where <span class="math notranslate nohighlight">\(\theta\)</span> is the parameters of model, <span class="math notranslate nohighlight">\(\lambda\)</span> is the
auxliarary parameters.</p>
</dd></dl>

</dd></dl>

<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.fishleg.update_dict">
<span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg.</span></span><span class="sig-name descname"><span class="pre">update_dict</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">replace</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Module</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Module</span></span></span><a class="headerlink" href="#FishLeg.fishleg.update_dict" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</section>
<section id="module-FishLeg.fishleg_layers">
<span id="fishleg-fishleg-layers-module"></span><h2>FishLeg.fishleg_layers module<a class="headerlink" href="#module-FishLeg.fishleg_layers" title="Permalink to this heading"></a></h2>
<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_layers.</span></span><span class="sig-name descname"><span class="pre">FishLinear</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_features</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_features</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">int</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">bool</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_scale</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">float</span></span><span class="w"> </span><span class="o"><span class="pre">=</span></span><span class="w"> </span><span class="default_value"><span class="pre">1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">device</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Linear</span></code>, <a class="reference internal" href="#FishLeg.fishleg_layers.FishModule" title="FishLeg.fishleg_layers.FishModule"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishModule</span></code></a></p>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.Qv">
<span class="sig-name descname"><span class="pre">Qv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="pre">Tensor</span><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.Qv" title="Permalink to this definition"></a></dt>
<dd><p>For fully-connected layers, the default structure of <span class="math notranslate nohighlight">\(Q\)</span> as a
block-diaglonal matrix is,</p>
<div class="math notranslate nohighlight">
\[Q_l = (R_lR_l^T \otimes L_lL_l^T)\]</div>
<p>where <span class="math notranslate nohighlight">\(l\)</span> denotes the l-th layer. The matrix <span class="math notranslate nohighlight">\(R_l\)</span> has size 
<span class="math notranslate nohighlight">\((N_{l-1} + 1) \times (N_{l-1} + 1)\)</span> while the matrix <span class="math notranslate nohighlight">\(L_l\)</span> has 
size <span class="math notranslate nohighlight">\(N_l \times N_l\)</span>. The auxiliarary parameters <span class="math notranslate nohighlight">\(\lambda\)</span> 
are represented by the matrices <span class="math notranslate nohighlight">\(L_l, R_l\)</span>.</p>
</dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.in_features">
<span class="sig-name descname"><span class="pre">in_features</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.in_features" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.out_features">
<span class="sig-name descname"><span class="pre">out_features</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">int</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.out_features" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishLinear.weight">
<span class="sig-name descname"><span class="pre">weight</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">Tensor</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishLinear.weight" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishModule">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_layers.</span></span><span class="sig-name descname"><span class="pre">FishModule</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Any</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Any</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.FishModule" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code></p>
<p>Base class for all neural network modules in FishLeg to</p>
<ol class="arabic simple">
<li><p>Initialize auxiliary parameters, <span class="math notranslate nohighlight">\(\lambda\)</span> and its forms, <span class="math notranslate nohighlight">\(Q(\lambda)\)</span>.</p></li>
<li><p>Specify quick calculation of <span class="math notranslate nohighlight">\(Q(\lambda)v\)</span> products.</p></li>
</ol>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>fishleg_aux</strong> (<em>torch.nn.ParameterDict</em>) – <p>auxiliary parameters 
with their initialization, including an additional parameter, scale, 
<span class="math notranslate nohighlight">\(\eta\)</span>. Make sure that</p>
<div class="math notranslate nohighlight">
\[- \eta_{init} Q(\lambda) grad = - \eta_{sgd} grad\]</div>
<p>is hold in the beginning of the optimization</p>
</p></li>
<li><p><strong>order</strong> (<em>List</em>) – specify a name order of original parameter</p></li>
</ul>
</dd>
</dl>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishModule.Qv">
<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">Qv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">aux</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Dict</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">v</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tuple</span><span class="p"><span class="pre">[</span></span><span class="pre">Tensor</span><span class="p"><span class="pre">,</span></span><span class="w"> </span><span class="p"><span class="pre">...</span></span><span class="p"><span class="pre">]</span></span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishModule.Qv" title="Permalink to this definition"></a></dt>
<dd><p><span class="math notranslate nohighlight">\(Q(\lambda)\)</span> is a positive definite matrix which will effectively 
estimate the inverse damped Fisher Information Matrix. Appropriate choices 
for <span class="math notranslate nohighlight">\(Q\)</span> should take into account the architecture of the model/module.
It is usually parameterized as a positive definite Kronecker-factored 
block-diagonal matrix, with block sizes reflecting the layer structure of 
the neural networks.</p>
<dl class="simple">
<dt>Args:</dt><dd><dl class="simple">
<dt>aux: (Dict, required): auxiliary parameters,</dt><dd><p><span class="math notranslate nohighlight">\(\lambda\)</span>, a dictionary with keys, the name 
of the auxiliary parameters, and values, the auxiliary parameters 
of the module. These auxiliaray parameters will form <span class="math notranslate nohighlight">\(Q(\lambda)\)</span>.</p>
</dd>
<dt>v: (Tuple[Tensor, …], required): Values of the original parameters, </dt><dd><p>in an order that align with <cite>self.order</cite>, to multiply with 
<span class="math notranslate nohighlight">\(Q(\lambda)\)</span>.</p>
</dd>
</dl>
</dd>
<dt>Returns:</dt><dd><dl class="simple">
<dt>Tuple[Tensor, …]: The calculated <span class="math notranslate nohighlight">\(Q(\lambda)v\)</span> products, </dt><dd><p>in same order with <cite>self.order</cite>.</p>
</dd>
</dl>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishModule.cuda">
<span class="sig-name descname"><span class="pre">cuda</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">device</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">None</span></span></span><a class="headerlink" href="#FishLeg.fishleg_layers.FishModule.cuda" title="Permalink to this definition"></a></dt>
<dd><p>Moves all model parameters and buffers to the GPU.</p>
<p>This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on GPU while being optimized.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>This method modifies the module in-place.</p>
</div>
<dl class="simple">
<dt>Args:</dt><dd><dl class="simple">
<dt>device (int, optional): if specified, all parameters will be</dt><dd><p>copied to that device</p>
</dd>
</dl>
</dd>
<dt>Returns:</dt><dd><p>Module: self</p>
</dd>
</dl>
</dd></dl>

<dl class="py property">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishModule.name">
<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">name</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">str</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishModule.name" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py attribute">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.FishModule.training">
<span class="sig-name descname"><span class="pre">training</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="pre">bool</span></em><a class="headerlink" href="#FishLeg.fishleg_layers.FishModule.training" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</dd></dl>

<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.fishleg_layers.get_zero_grad_hook">
<span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_layers.</span></span><span class="sig-name descname"><span class="pre">get_zero_grad_hook</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">mask</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_layers.get_zero_grad_hook" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</section>
<section id="module-FishLeg.fishleg_likelihood">
<span id="fishleg-fishleg-likelihood-module"></span><h2>FishLeg.fishleg_likelihood module<a class="headerlink" href="#module-FishLeg.fishleg_likelihood" title="Permalink to this heading"></a></h2>
<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.BernoulliLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">BernoulliLikelihood</span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></p>
<p>The Bernoulli likelihood used for classification. 
Using the standard Normal CDF <span class="math notranslate nohighlight">\(\Phi(x)\)</span>) and the identity
<span class="math notranslate nohighlight">\(\Phi(-x) = 1-\Phi(x)\)</span>, we can write the likelihood as:</p>
<div class="math notranslate nohighlight">
\[p(y|f(x))=\Phi(yf(x))\]</div>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.BernoulliLikelihood.draw">
<span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution 
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.BernoulliLikelihood.nll">
<span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">observations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.BernoulliLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood 
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FishLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">FishLikelihood</span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">object</span></code></p>
<p>A Likelihood in FishLeg specifies a probablistic modeling, which attributes
the mapping from latent function values 
<span class="math notranslate nohighlight">\(f(\mathbf X)\)</span> to observed labels <span class="math notranslate nohighlight">\(y\)</span>.</p>
<p>For example, in the case of regression, 
a Gaussian likelihood can be chosen, as</p>
<div class="math notranslate nohighlight">
\[y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I)\]</div>
<p>As for the case of classification, 
a Bernoulli distribution can be chosen</p>
<div class="math notranslate nohighlight">
\[\begin{split}y(\mathbf x) = \begin{cases}
    1 &amp; \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\
    0 &amp; \text{w/ probability} \:\: 1-\sigma(f(\mathbf x))
\end{cases}\end{split}\]</div>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FishLikelihood.draw">
<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FishLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution 
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FishLikelihood.nll">
<em class="property"><span class="pre">abstract</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">observations</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">preds</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FishLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood 
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">FixedGaussianLikelihood</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">sigma</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></p>
<p>The standard likelihood for regression, 
but assuming fixed heteroscedastic noise.</p>
<div class="math notranslate nohighlight">
\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]</div>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>sigma_fixed</strong> (<em>torch.Tensor</em>) – Known observation 
standard deviation for each example.</p>
</dd>
</dl>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood.draw">
<span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution 
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py property">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood.get_variance">
<em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">get_variance</span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood.get_variance" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.FixedGaussianLikelihood.nll">
<span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">observations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.FixedGaussianLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood 
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.SoftMaxLikelihood">
<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">FishLeg.fishleg_likelihood.</span></span><span class="sig-name descname"><span class="pre">SoftMaxLikelihood</span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood" title="Permalink to this definition"></a></dt>
<dd><p>Bases: <a class="reference internal" href="#FishLeg.fishleg_likelihood.FishLikelihood" title="FishLeg.fishleg_likelihood.FishLikelihood"><code class="xref py py-class docutils literal notranslate"><span class="pre">FishLikelihood</span></code></a></p>
<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.SoftMaxLikelihood.draw">
<span class="sig-name descname"><span class="pre">draw</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood.draw" title="Permalink to this definition"></a></dt>
<dd><p>Draw samples from the conditional distribution 
<span class="math notranslate nohighlight">\(p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p>
</dd>
</dl>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="FishLeg.fishleg_likelihood.SoftMaxLikelihood.nll">
<span class="sig-name descname"><span class="pre">nll</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">observations</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">preds</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><span class="pre">Tensor</span></span></span><a class="headerlink" href="#FishLeg.fishleg_likelihood.SoftMaxLikelihood.nll" title="Permalink to this definition"></a></dt>
<dd><p>Computes the negative log-likelihood 
<span class="math notranslate nohighlight">\(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)</span></p>
<dl class="field-list simple">
<dt class="field-odd">Parameters</dt>
<dd class="field-odd"><ul class="simple">
<li><p><strong>observations</strong> (<em>torch.Tensor</em>) – Values of <span class="math notranslate nohighlight">\(y\)</span>.</p></li>
<li><p><strong>preds</strong> (<em>torch.Tensor</em>) – Predictions from model <span class="math notranslate nohighlight">\(f(\mathbf x)\)</span></p></li>
</ul>
</dd>
<dt class="field-even">Return type</dt>
<dd class="field-even"><p><cite>torch.Tensor</cite></p>
</dd>
</dl>
</dd></dl>

</dd></dl>

</section>
<section id="module-FishLeg.utils">
<span id="fishleg-utils-module"></span><h2>FishLeg.utils module<a class="headerlink" href="#module-FishLeg.utils" title="Permalink to this heading"></a></h2>
<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.utils.recursive_getattr">
<span class="sig-prename descclassname"><span class="pre">FishLeg.utils.</span></span><span class="sig-name descname"><span class="pre">recursive_getattr</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attr</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.utils.recursive_getattr" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

<dl class="py function">
<dt class="sig sig-object py" id="FishLeg.utils.recursive_setattr">
<span class="sig-prename descclassname"><span class="pre">FishLeg.utils.</span></span><span class="sig-name descname"><span class="pre">recursive_setattr</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">obj</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attr</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#FishLeg.utils.recursive_setattr" title="Permalink to this definition"></a></dt>
<dd></dd></dl>

</section>
<section id="module-FishLeg">
<span id="module-contents"></span><h2>Module contents<a class="headerlink" href="#module-FishLeg" title="Permalink to this heading"></a></h2>
</section>
</section>


           </div>
          </div>
          <footer><div class="rst-footer-buttons" role="navigation" aria-label="Footer">
        <a href="modules.html" class="btn btn-neutral float-left" title="optim" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left" aria-hidden="true"></span> Previous</a>
    </div>

  <hr/>

  <div role="contentinfo">
    <p>&#169; Copyright 2023, MTK.</p>
  </div>

  Built with <a href="https://www.sphinx-doc.org/">Sphinx</a> using a
    <a href="https://github.com/readthedocs/sphinx_rtd_theme">theme</a>
    provided by <a href="https://readthedocs.org">Read the Docs</a>.
   

</footer>
        </div>
      </div>
    </section>
  </div>
  <script>
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script> 

</body>
</html>