<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
<meta name="generator" content="pdoc 0.9.2" />
<title>laplace API documentation</title>
<meta name="description" content="&lt;div align=&#34;center&#34;&gt;
&lt;img src=&#34;https://raw.githubusercontent.com/AlexImmer/Laplace/main/logo/laplace_logo.png&#34; alt=&#34;Laplace&#34; width=&#34;300&#34;/&gt;
&lt;/div&gt; …" />
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css" integrity="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs=" crossorigin>
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css" integrity="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg=" crossorigin>
<link rel="stylesheet preload" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css" crossorigin>
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js" integrity="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8=" crossorigin></script>
<script>window.addEventListener('DOMContentLoaded', () => hljs.initHighlighting())</script>
</head>
<body>
<main>
<article id="content">
<header>
<h1 class="title">Package <code>laplace</code></h1>
</header>
<section id="section-intro">
<div align="center">
<img src="https://raw.githubusercontent.com/AlexImmer/Laplace/main/logo/laplace_logo.png" alt="Laplace" width="300"/>
</div>
<p><a href="https://travis-ci.com/AlexImmer/Laplace"><img alt="Main" src="https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&amp;branch=main"></a></p>
<p>The laplace package facilitates the application of Laplace approximations for entire neural networks or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at <a href="https://aleximmer.github.io/Laplace">https://aleximmer.github.io/Laplace</a>.</p>
<p>There is also a corresponding paper, <a href="https://arxiv.org/abs/2106.14806"><em>Laplace Redux — Effortless Bayesian Deep Learning</em></a>, which introduces the library, provides an introduction to the Laplace approximation, reviews its use in deep learning, and empirically demonstrates its versatility and competitiveness. Please consider referring to the paper when using our library:</p>
<pre><code class="language-bibtex">@article{daxberger2021laplace,
  title={Laplace Redux--Effortless Bayesian Deep Learning},
  author={Daxberger, Erik and Kristiadi, Agustinus and Immer, Alexander
          and Eschenhagen, Runa and Bauer, Matthias and Hennig, Philipp},
  journal={arXiv preprint arXiv:2106.14806},
  year={2021}
}
</code></pre>
<h2 id="setup">Setup</h2>
<p>We assume <code>python3.8</code> since the package was developed with that version.
To install laplace with <code>pip</code>, run the following:</p>
<pre><code class="language-bash">pip install laplace-torch
</code></pre>
<p>For development purposes, clone the repository and then install:</p>
<pre><code class="language-bash"># or after cloning the repository for development
pip install -e .
# run tests
pip install -e .[tests]
pytest tests/
</code></pre>
<h2 id="structure">Structure</h2>
<p>The laplace package consists of two main components:</p>
<ol>
<li>The subclasses of <a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py"><code>laplace.BaseLaplace</code></a> that implement different sparsity structures: different subsets of weights (<code>'all'</code> and <code>'last_layer'</code>) and different structures of the Hessian approximation (<code>'full'</code>, <code>'kron'</code>, and <code>'diag'</code>). This results in six currently available options: <code><a title="laplace.FullLaplace" href="#laplace.FullLaplace">FullLaplace</a></code>, <code><a title="laplace.KronLaplace" href="#laplace.KronLaplace">KronLaplace</a></code>, <code><a title="laplace.DiagLaplace" href="#laplace.DiagLaplace">DiagLaplace</a></code>, and the corresponding last-layer variations <code><a title="laplace.FullLLLaplace" href="#laplace.FullLLLaplace">FullLLLaplace</a></code>, <code><a title="laplace.KronLLLaplace" href="#laplace.KronLLLaplace">KronLLLaplace</a></code>,
and <code><a title="laplace.DiagLLLaplace" href="#laplace.DiagLLLaplace">DiagLLLaplace</a></code>, which are all subclasses of <a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py"><code>laplace.LLLaplace</code></a>. All of these can be conveniently accessed via the <a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py"><code>laplace.Laplace</code></a> function.</li>
<li>The backends in <a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/"><code>laplace.curvature</code></a> which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.</li>
</ol>
<p>Additionally, the package provides utilities for
decomposing a neural network into feature extractor and last layer for <code><a title="laplace.LLLaplace" href="#laplace.LLLaplace">LLLaplace</a></code> subclasses (<a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/feature_extractor.py"><code>laplace.feature_extractor</code></a>)
and
effectively dealing with Kronecker factors (<a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/matrix.py"><code>laplace.matrix</code></a>).</p>
<h2 id="extendability">Extendability</h2>
<p>To extend the laplace package, new <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> subclasses can be designed, for example,
a block-diagonal structure or subset-of-weights Laplace.
Alternatively, extending or integrating backends (subclasses of <a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py"><code>curvature.curvature</code></a>) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the <a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py"><code>curvature.BackPackInterface</code></a> based on <a href="https://github.com/f-dangel/backpack/">BackPACK</a> and <a href="https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py"><code>curvature.AsdlInterface</code></a> based on <a href="https://github.com/kazukiosawa/asdfghjkl">ASDL</a> are available.
The <code><a title="laplace.curvature.AsdlInterface" href="curvature/index.html#laplace.curvature.AsdlInterface">AsdlInterface</a></code> provides a Kronecker factored empirical Fisher while the <code><a title="laplace.curvature.BackPackInterface" href="curvature/index.html#laplace.curvature.BackPackInterface">BackPackInterface</a></code>
does not, and only the <code><a title="laplace.curvature.BackPackInterface" href="curvature/index.html#laplace.curvature.BackPackInterface">BackPackInterface</a></code> provides access to Hessian approximations
for a regression (MSELoss) loss function.</p>
<h2 id="example-usage">Example usage</h2>
<h3 id="post-hoc-prior-precision-tuning-of-last-layer-la"><em>Post-hoc</em> prior precision tuning of last-layer LA</h3>
<p>In the following example, a pre-trained model is loaded,
then the Laplace approximation is fit to the training data,
and the prior precision is optimized with cross-validation <code>'CV'</code>.
After that, the resulting LA is used for prediction with
the <code>'probit'</code> predictive for classification.</p>
<pre><code class="language-python">from laplace import Laplace

# pre-trained model
model = load_map_model()  

# User-specified LA flavor
la = Laplace(model, 'classification',
             subset_of_weights='all',
             hessian_structure='diag')
la.fit(train_loader)
la.optimize_prior_precision(method='CV', val_loader=val_loader)

# User-specified predictive approx.
pred = la(x, link_approx='probit')
</code></pre>
<h3 id="differentiating-the-log-marginal-likelihood-wrt-hyperparameters">Differentiating the log marginal likelihood w.r.t. hyperparameters</h3>
<p>The marginal likelihood can be used for model selection and is differentiable
for continuous hyperparameters like the prior precision or observation noise.
Here, we fit the library default, KFAC last-layer LA and differentiate
the log marginal likelihood.</p>
<pre><code class="language-python">from laplace import Laplace

# Un- or pre-trained model
model = load_model()  

# Default to recommended last-layer KFAC LA:
la = Laplace(model, likelihood='regression')
la.fit(train_loader)

# ML w.r.t. prior precision and observation noise
ml = la.log_marginal_likelihood(prior_prec, obs_noise)
ml.backward()
</code></pre>
<h2 id="documentation">Documentation</h2>
<p>The documentation is available <a href="https://aleximmer.github.io/Laplace">here</a> or can be generated and/or viewed locally:</p>
<pre><code class="language-bash"># assuming the repository was cloned
pip install -e .[docs]
# create docs and write to html
bash update_docs.sh
# .. or serve the docs directly
pdoc --http 0.0.0.0:8080 laplace --template-dir template
</code></pre>
<h2 id="references">References</h2>
<p>This package relies on various improvements to the Laplace approximation for neural networks, which was originally due to MacKay [1].</p>
<ul>
<li>[1] MacKay, DJC. <a href="https://authors.library.caltech.edu/13793/"><em>A Practical Bayesian Framework for Backpropagation Networks</em></a>. Neural Computation 1992.</li>
<li>[2] Gibbs, M. N. <a href="https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.147.1130&amp;rep=rep1&amp;type=pdf"><em>Bayesian Gaussian Processes for Regression and Classification</em></a>. PhD Thesis 1997.</li>
<li>[3] Snoek, J., Rippel, O., Swersky, K., Kiros, R., Satish, N., Sundaram, N., Patwary, M., Prabhat, M., Adams, R. <a href="https://arxiv.org/abs/1502.05700"><em>Scalable Bayesian Optimization Using Deep Neural Networks</em></a>. ICML 2015.</li>
<li>[4] Ritter, H., Botev, A., Barber, D. <a href="https://openreview.net/forum?id=Skdvd2xAZ"><em>A Scalable Laplace Approximation for Neural Networks</em></a>. ICLR 2018.</li>
<li>[5] Foong, A. Y., Li, Y., Hernández-Lobato, J. M., Turner, R. E. <a href="https://arxiv.org/abs/1906.11537"><em>'In-Between' Uncertainty in Bayesian Neural Networks</em></a>. ICML UDL Workshop 2019.</li>
<li>[6] Khan, M. E., Immer, A., Abedi, E., Korzepa, M. <a href="https://arxiv.org/abs/1906.01930"><em>Approximate Inference Turns Deep Networks into Gaussian Processes</em></a>. NeurIPS 2019.</li>
<li>[7] Kristiadi, A., Hein, M., Hennig, P. <a href="https://arxiv.org/abs/2002.10118"><em>Being Bayesian, Even Just a Bit, Fixes Overconfidence in ReLU Networks</em></a>. ICML 2020.</li>
<li>[8] Immer, A., Korzepa, M., Bauer, M. <a href="https://arxiv.org/abs/2008.08400"><em>Improving predictions of Bayesian neural nets via local linearization</em></a>. AISTATS 2021.</li>
<li>[9] Immer, A., Bauer, M., Fortuin, V., Rätsch, G., Khan, EM. <a href="https://arxiv.org/abs/2104.04975"><em>Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning</em></a>. ICML 2021.</li>
</ul>
<h2 id="full-example-post-hoc-optimization-of-the-marginal-likelihood-and-prediction">Full example: <em>post-hoc</em> optimization of the marginal likelihood and prediction</h2>
<h4 id="sinusoidal-toy-data">Sinusoidal toy data</h4>
<p>We show how the marginal likelihood can be used after training a MAP network on a simple sinusoidal regression task.
Subsequently, we use the optimized LA to predict which provides uncertainty on top of the MAP prediction.
First, we set up the training data for the problem with observation noise \(\sigma=0.3\):</p>
<pre><code class="language-python">import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from laplace import Laplace

n_epochs = 1000
batch_size = 150  # full batch
true_sigma_noise = 0.3

# create simple sinusoid data set
X_train = (torch.rand(150) * 8).unsqueeze(-1)
y_train = torch.sin(X_train) + torch.randn_like(X_train) * true_sigma_noise
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size)
X_test = torch.linspace(-5, 13, 500).unsqueeze(-1)  # +-5 on top of the training X-range
</code></pre>
<h4 id="training-a-map">Training a MAP</h4>
<p>We now use <code>pytorch</code> to train a neural network with single hidden layer and Tanh activation.
This is standard so nothing new here, yet:</p>
<pre><code class="language-python"># create and train MAP model
model = torch.nn.Sequential(torch.nn.Linear(1, 50),
                            torch.nn.Tanh(),
                            torch.nn.Linear(50, 1))

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4, lr=1e-2)
for i in range(n_epochs):
    for X, y in train_loader:
        optimizer.zero_grad()
        loss = criterion(model(X), y)
        loss.backward()
        optimizer.step()
</code></pre>
<h4 id="fitting-and-optimizing-the-laplace-approximation-using-empirical-bayes">Fitting and optimizing the Laplace approximation using empirical Bayes</h4>
<p>With the MAP-trained model at hand, we can estimate the prior precision and observation noise
using empirical Bayes after training.
The <code><a title="laplace.Laplace" href="#laplace.Laplace">Laplace()</a></code> method is called to construct a LA for <code>'regression'</code> with <code>'all'</code> weights.
As default <code><a title="laplace.Laplace" href="#laplace.Laplace">Laplace()</a></code> returns a Kronecker factored LA.
We fit the LA to the training data and initialize <code>log_prior</code> and <code>log_sigma</code>.
Using Adam, we minimize the negative log marginal likelihood for <code>n_epochs</code>.</p>
<pre><code class="language-python">la = Laplace(model, 'regression', subset_of_weights='all', hessian_structure='full')
la.fit(train_loader)
log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1)
for i in range(n_epochs):
    hyper_optimizer.zero_grad()
    neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
    neg_marglik.backward()
    hyper_optimizer.step()
</code></pre>
<p>The obtained observation noise is close to the ground truth with a value of \(\sigma \approx 0.28\)
without the need for any validation data.
The resulting prior precision is \(\delta \approx 0.18\).</p>
<h4 id="bayesian-predictive">Bayesian predictive</h4>
<p>Lastly, we compare the MAP prediction to the obtained LA prediction.
For LA, we have a closed-form predictive distribution on the output \(f\) which is a Gaussian
\(\mathcal{N}(f(x;\theta_{MAP}), \mathbb{V}[f] + \sigma^2)\):</p>
<pre><code class="language-python">x = X_test.flatten().cpu().numpy()
f_mu, f_var = la(X_test)
f_mu = f_mu.squeeze().detach().cpu().numpy()
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
pred_std = np.sqrt(f_sigma**2 + la.sigma_noise.item()**2)
</code></pre>
<p>In comparison to the MAP, the predictive shows useful uncertainties:</p>
<pre><code class="language-python">import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharey=True,
                               figsize=(4.5, 2.8))
ax1.set_title('MAP')
ax1.scatter(X_train.flatten(), y_train.flatten(), alpha=0.7, color='tab:orange')
ax1.plot(x, f_mu, color='black', label='$f_{MAP}$')
ax1.legend()

ax2.set_title('LA')
ax2.scatter(X_train.flatten(), y_train.flatten(), alpha=0.7, color='tab:orange')
ax2.plot(x, f_mu, label='$\mathbb{E}[f]$')
ax2.fill_between(x, f_mu-pred_std*2, f_mu+pred_std*2, 
                 alpha=0.3, color='tab:blue', label='$2\sqrt{\mathbb{V}\,[f]}$')
ax2.legend()
ax1.set_ylim([-4, 6])
ax1.set_xlim([x.min(), x.max()])
ax2.set_xlim([x.min(), x.max()])
ax1.set_ylabel('$y$')
ax1.set_xlabel('$x$')
ax2.set_xlabel('$x$')
plt.tight_layout()
plt.show()
</code></pre>
<p><img alt=":align: center" src="regression_example.png"></p>
</section>
<section>
<h2 class="section-title" id="header-submodules">Sub-modules</h2>
<dl>
<dt><code class="name"><a title="laplace.baselaplace" href="baselaplace.html">laplace.baselaplace</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="laplace.curvature" href="curvature/index.html">laplace.curvature</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="laplace.feature_extractor" href="feature_extractor.html">laplace.feature_extractor</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="laplace.laplace" href="laplace.html">laplace.laplace</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="laplace.lllaplace" href="lllaplace.html">laplace.lllaplace</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="laplace.matrix" href="matrix.html">laplace.matrix</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="laplace.utils" href="utils.html">laplace.utils</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
</section>
<section>
</section>
<section>
<h2 class="section-title" id="header-functions">Functions</h2>
<dl>
<dt id="laplace.Laplace"><code class="name flex">
<span>def <span class="ident">Laplace</span></span>(<span>model, likelihood, subset_of_weights='last_layer', hessian_structure='kron', *args, **kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Simplified Laplace access using strings instead of different classes.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>model</code></strong> :&ensp;<code>torch.nn.Module</code></dt>
<dd>&nbsp;</dd>
<dt><strong><code>likelihood</code></strong> :&ensp;<code>{'classification', 'regression'}</code></dt>
<dd>&nbsp;</dd>
<dt><strong><code>subset_of_weights</code></strong> :&ensp;<code>{'last_layer', 'all'}</code>, default=<code>'last_layer'</code></dt>
<dd>subset of weights to consider for inference</dd>
<dt><strong><code>hessian_structure</code></strong> :&ensp;<code>{'diag', 'kron', 'full'}</code>, default=<code>'kron'</code></dt>
<dd>structure of the Hessian approximation</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>laplace</code></strong> :&ensp;<code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code></dt>
<dd>chosen subclass of BaseLaplace instantiated with additional arguments</dd>
</dl></div>
</dd>
</dl>
</section>
<section>
<h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="laplace.BaseLaplace"><code class="flex name class">
<span>class <span class="ident">BaseLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Baseclass for all Laplace approximations in this library.
Subclasses need to specify how the Hessian approximation is initialized,
how to add up curvature over training data, how to sample from the
Laplace approximation, and how to compute the functional variance.</p>
<p>A Laplace approximation is represented by a MAP which is given by the
<code>model</code> parameter and a posterior precision or covariance specifying
a Gaussian distribution <span><span class="MathJax_Preview">\mathcal{N}(\theta_{MAP}, P^{-1})</span><script type="math/tex">\mathcal{N}(\theta_{MAP}, P^{-1})</script></span>.
The goal of this class is to compute the posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>
which sums as
<span><span class="MathJax_Preview">
P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
</span><script type="math/tex; mode=display">
P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
</script></span>
Every subclass implements different approximations to the log likelihood Hessians,
for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
a simple form for <span><span class="MathJax_Preview">\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 </span><script type="math/tex">\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 </script></span>.
In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
all cases <span><span class="MathJax_Preview">P_0 = \textrm{diag}(p_0)</span><script type="math/tex">P_0 = \textrm{diag}(p_0)</script></span> and the structure of <span><span class="MathJax_Preview">p_0</span><script type="math/tex">p_0</script></span> can be varied.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>model</code></strong> :&ensp;<code>torch.nn.Module</code></dt>
<dd>&nbsp;</dd>
<dt><strong><code>likelihood</code></strong> :&ensp;<code>{'classification', 'regression'}</code></dt>
<dd>determines the log likelihood Hessian approximation</dd>
<dt><strong><code>sigma_noise</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>1</code></dt>
<dd>observation noise for the regression setting; must be 1 for classification</dd>
<dt><strong><code>prior_precision</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>1</code></dt>
<dd>prior precision of a Gaussian prior (= weight decay);
can be scalar, per-layer, or diagonal in the most general case</dd>
<dt><strong><code>prior_mean</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>0</code></dt>
<dd>prior mean of a Gaussian prior, useful for continual learning</dd>
<dt><strong><code>temperature</code></strong> :&ensp;<code>float</code>, default=<code>1</code></dt>
<dd>temperature of the likelihood; lower temperature leads to more
concentrated posterior and vice versa.</dd>
<dt><strong><code>backend</code></strong> :&ensp;<code>subclasses</code> of <code><a title="laplace.curvature.CurvatureInterface" href="curvature/index.html#laplace.curvature.CurvatureInterface">CurvatureInterface</a></code></dt>
<dd>backend for access to curvature/Hessian approximations</dd>
<dt><strong><code>backend_kwargs</code></strong> :&ensp;<code>dict</code>, default=<code>None</code></dt>
<dd>arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.</dd>
</dl></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.DiagLaplace" href="baselaplace.html#laplace.baselaplace.DiagLaplace">DiagLaplace</a></li>
<li><a title="laplace.baselaplace.FullLaplace" href="baselaplace.html#laplace.baselaplace.FullLaplace">FullLaplace</a></li>
<li><a title="laplace.baselaplace.KronLaplace" href="baselaplace.html#laplace.baselaplace.KronLaplace">KronLaplace</a></li>
<li>laplace.lllaplace.LLLaplace</li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.BaseLaplace.backend"><code class="name">var <span class="ident">backend</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.BaseLaplace.log_likelihood"><code class="name">var <span class="ident">log_likelihood</span></code></dt>
<dd>
<div class="desc"><p>Compute log likelihood on the training data after <code>.fit()</code> has been called.
The log likelihood is computed on-demand based on the loss and, for example,
the observation noise which makes it differentiable in the latter for
iterative updates.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_likelihood</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.scatter"><code class="name">var <span class="ident">scatter</span></code></dt>
<dd>
<div class="desc"><p>Computes the <em>scatter</em>, a term of the log marginal likelihood that
corresponds to L-2 regularization:
<code>scatter</code> = <span><span class="MathJax_Preview">(\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) </span><script type="math/tex">(\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) </script></span>.</p>
<h2 id="returns">Returns</h2>
<p>[type]
[description]</p></div>
</dd>
<dt id="laplace.BaseLaplace.log_det_prior_precision"><code class="name">var <span class="ident">log_det_prior_precision</span></code></dt>
<dd>
<div class="desc"><p>Compute log determinant of the prior precision
<span><span class="MathJax_Preview">\log \det P_0</span><script type="math/tex">\log \det P_0</script></span></p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_det</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.log_det_posterior_precision"><code class="name">var <span class="ident">log_det_posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Compute log determinant of the posterior precision
<span><span class="MathJax_Preview">\log \det P</span><script type="math/tex">\log \det P</script></span> which depends on the subclasses structure
used for the Hessian approximation.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_det</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.log_det_ratio"><code class="name">var <span class="ident">log_det_ratio</span></code></dt>
<dd>
<div class="desc"><p>Compute the log determinant ratio, a part of the log marginal likelihood.
<span><span class="MathJax_Preview">
\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
</span><script type="math/tex; mode=display">
\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
</script></span></p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_det_ratio</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.prior_precision_diag"><code class="name">var <span class="ident">prior_precision_diag</span></code></dt>
<dd>
<div class="desc"><p>Obtain the diagonal prior precision <span><span class="MathJax_Preview">p_0</span><script type="math/tex">p_0</script></span> constructed from either
a scalar, layer-wise, or diagonal prior precision.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>prior_precision_diag</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.prior_mean"><code class="name">var <span class="ident">prior_mean</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.BaseLaplace.prior_precision"><code class="name">var <span class="ident">prior_precision</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.BaseLaplace.sigma_noise"><code class="name">var <span class="ident">sigma_noise</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.BaseLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Compute or return the posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>posterior_prec</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
</dl>
<h3>Methods</h3>
<dl>
<dt id="laplace.BaseLaplace.fit"><code class="name flex">
<span>def <span class="ident">fit</span></span>(<span>self, train_loader)</span>
</code></dt>
<dd>
<div class="desc"><p>Fit the local Laplace approximation at the parameters of the model.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>train_loader</code></strong> :&ensp;<code>torch.data.utils.DataLoader</code></dt>
<dd>each iterate is a training batch (X, y);
<code>train_loader.dataset</code> needs to be set to access <span><span class="MathJax_Preview">N</span><script type="math/tex">N</script></span>, size of the data set</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.log_marginal_likelihood"><code class="name flex">
<span>def <span class="ident">log_marginal_likelihood</span></span>(<span>self, prior_precision=None, sigma_noise=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Compute the Laplace approximation to the log marginal likelihood subject
to specific Hessian approximations that subclasses implement.
Requires that the Laplace approximation has been fit before.
The resulting torch.Tensor is differentiable in <code>prior_precision</code> and
<code>sigma_noise</code> if these have gradients enabled.
By passing <code>prior_precision</code> or <code>sigma_noise</code>, the current value is
overwritten. This is useful for iterating on the log marginal likelihood.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>prior_precision</code></strong> :&ensp;<code>torch.Tensor</code>, optional</dt>
<dd>prior precision if should be changed from current <code>prior_precision</code> value</dd>
<dt><strong><code>sigma_noise</code></strong> :&ensp;<code>[type]</code>, optional</dt>
<dd>observation noise standard deviation if should be changed</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_marglik</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.predictive"><code class="name flex">
<span>def <span class="ident">predictive</span></span>(<span>self, x, pred_type='glm', link_approx='mc', n_samples=100)</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.BaseLaplace.predictive_samples"><code class="name flex">
<span>def <span class="ident">predictive_samples</span></span>(<span>self, x, pred_type='glm', n_samples=100)</span>
</code></dt>
<dd>
<div class="desc"><p>Sample from the posterior predictive on input data <code>x</code>.
Can be used, for example, for Thompson sampling.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>input data <code>(batch_size, input_shape)</code></dd>
<dt><strong><code>pred_type</code></strong> :&ensp;<code>{'glm', 'nn'}</code>, default=<code>'glm'</code></dt>
<dd>type of posterior predictive, linearized GLM predictive or neural
network sampling predictive. The GLM predictive is consistent with
the curvature approximations used here.</dd>
<dt><strong><code>n_samples</code></strong> :&ensp;<code>int</code></dt>
<dd>number of samples</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>samples</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>samples <code>(n_samples, batch_size, output_shape)</code></dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.functional_variance"><code class="name flex">
<span>def <span class="ident">functional_variance</span></span>(<span>self, Jacs)</span>
</code></dt>
<dd>
<div class="desc"><p>Compute functional variance for the <code>'glm'</code> predictive:
<code>f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T</code>, which is a output x output
predictive covariance matrix.
Mathematically, we have for a single Jacobian
<span><span class="MathJax_Preview">\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}</span><script type="math/tex">\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}</script></span>
the output covariance matrix
<span><span class="MathJax_Preview"> \mathcal{J} P^{-1} \mathcal{J}^T </span><script type="math/tex"> \mathcal{J} P^{-1} \mathcal{J}^T </script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>Jacs</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>Jacobians of model output wrt parameters
<code>(batch, outputs, parameters)</code></dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>f_var</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>output covariance <code>(batch, outputs, outputs)</code></dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.sample"><code class="name flex">
<span>def <span class="ident">sample</span></span>(<span>self, n_samples=100)</span>
</code></dt>
<dd>
<div class="desc"><p>Sample from the Laplace posterior approximation, i.e.,
<span><span class="MathJax_Preview"> \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1})</span><script type="math/tex"> \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1})</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>n_samples</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>number of samples</dd>
</dl></div>
</dd>
<dt id="laplace.BaseLaplace.optimize_prior_precision"><code class="name flex">
<span>def <span class="ident">optimize_prior_precision</span></span>(<span>self, method='marglik', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=&lt;function get_nll&gt;, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, pred_type='glm', link_approx='probit', n_samples=100, verbose=False)</span>
</code></dt>
<dd>
<div class="desc"><p>Optimize the prior precision post-hoc using the <code>method</code>
specified by the user.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>method</code></strong> :&ensp;<code>{'marglik', 'CV'}</code>, default=<code>'marglik'</code></dt>
<dd>specifies how the prior precision should be optimized.</dd>
<dt><strong><code>n_steps</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>the number of gradient descent steps to take.</dd>
<dt><strong><code>lr</code></strong> :&ensp;<code>float</code>, default=<code>1e-1</code></dt>
<dd>the learning rate to use for gradient descent.</dd>
<dt><strong><code>init_prior_prec</code></strong> :&ensp;<code>float</code>, default=<code>1.0</code></dt>
<dd>initial prior precision before the first optimization step.</dd>
<dt><strong><code>val_loader</code></strong> :&ensp;<code>torch.data.utils.DataLoader</code>, default=<code>None</code></dt>
<dd>DataLoader for the validation set; each iterate is a training batch (X, y).</dd>
<dt><strong><code>loss</code></strong> :&ensp;<code>callable</code>, default=<code>get_nll</code></dt>
<dd>loss function to use for CV.</dd>
<dt><strong><code>log_prior_prec_min</code></strong> :&ensp;<code>float</code>, default=<code>-4</code></dt>
<dd>lower bound of gridsearch interval for CV.</dd>
<dt><strong><code>log_prior_prec_max</code></strong> :&ensp;<code>float</code>, default=<code>4</code></dt>
<dd>upper bound of gridsearch interval for CV.</dd>
<dt><strong><code>grid_size</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>number of values to consider inside the gridsearch interval for CV.</dd>
<dt><strong><code>pred_type</code></strong> :&ensp;<code>{'glm', 'nn'}</code>, default=<code>'glm'</code></dt>
<dd>type of posterior predictive, linearized GLM predictive or neural
network sampling predictive. The GLM predictive is consistent with
the curvature approximations used here.</dd>
<dt><strong><code>link_approx</code></strong> :&ensp;<code>{'mc', 'probit', 'bridge'}</code>, default=<code>'probit'</code></dt>
<dd>how to approximate the classification link function for the <code>'glm'</code>.
For <code>pred_type='nn'</code>, only <code>'mc'</code> is possible.</dd>
<dt><strong><code>n_samples</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>number of samples for <code>link_approx='mc'</code>.</dd>
<dt><strong><code>verbose</code></strong> :&ensp;<code>bool</code>, default=<code>False</code></dt>
<dd>if true, the optimized prior precision will be printed
(can be a large tensor if the prior has a diagonal covariance).</dd>
</dl></div>
</dd>
</dl>
</dd>
<dt id="laplace.FullLaplace"><code class="flex name class">
<span>class <span class="ident">FullLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
and hence posterior precision. Based on the chosen <code>backend</code> parameter, the full
approximation can be, for example, a generalized Gauss-Newton matrix.
Mathematically, we have <span><span class="MathJax_Preview">P \in \mathbb{R}^{P \times P}</span><script type="math/tex">P \in \mathbb{R}^{P \times P}</script></span>.
See <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> for the full interface.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.lllaplace.FullLLLaplace" href="lllaplace.html#laplace.lllaplace.FullLLLaplace">FullLLLaplace</a></li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.FullLaplace.posterior_scale"><code class="name">var <span class="ident">posterior_scale</span></code></dt>
<dd>
<div class="desc"><p>Posterior scale (square root of the covariance), i.e.,
<span><span class="MathJax_Preview">P^{-\frac{1}{2}}</span><script type="math/tex">P^{-\frac{1}{2}}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>scale</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters, parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.FullLaplace.posterior_covariance"><code class="name">var <span class="ident">posterior_covariance</span></code></dt>
<dd>
<div class="desc"><p>Posterior covariance, i.e., <span><span class="MathJax_Preview">P^{-1}</span><script type="math/tex">P^{-1}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>covariance</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters, parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.FullLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters, parameters)</code></dd>
</dl></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="baselaplace.html#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="baselaplace.html#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_ratio" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="baselaplace.html#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.prior_precision_diag" href="baselaplace.html#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="baselaplace.html#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.scatter" href="baselaplace.html#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.KronLaplace"><code class="flex name class">
<span>class <span class="ident">KronLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, damping=False, **backend_kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Laplace approximation with Kronecker factored log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \P\approx Q \otimes H.
See <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> for the full interface and see
<code><a title="laplace.matrix.Kron" href="matrix.html#laplace.matrix.Kron">Kron</a></code> and <code><a title="laplace.matrix.KronDecomposed" href="matrix.html#laplace.matrix.KronDecomposed">KronDecomposed</a></code> for the structure of
the Kronecker factors. <code>Kron</code> is used to aggregate factors by summing up and
<code>KronDecomposed</code> is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Damping can be enabled by setting <code>damping=True</code>.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.lllaplace.KronLLLaplace" href="lllaplace.html#laplace.lllaplace.KronLLLaplace">KronLLLaplace</a></li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.KronLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Kronecker factored Posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code><a title="laplace.matrix.KronDecomposed" href="matrix.html#laplace.matrix.KronDecomposed">KronDecomposed</a></code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.KronLaplace.prior_precision"><code class="name">var <span class="ident">prior_precision</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="baselaplace.html#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="baselaplace.html#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_ratio" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="baselaplace.html#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.prior_precision_diag" href="baselaplace.html#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="baselaplace.html#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.scatter" href="baselaplace.html#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.DiagLaplace"><code class="flex name class">
<span>class <span class="ident">DiagLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Laplace approximation with diagonal log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have <span><span class="MathJax_Preview">P \approx \textrm{diag}(P)</span><script type="math/tex">P \approx \textrm{diag}(P)</script></span>.
See <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> for the full interface.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.lllaplace.DiagLLLaplace" href="lllaplace.html#laplace.lllaplace.DiagLLLaplace">DiagLLLaplace</a></li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.DiagLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Diagonal posterior precision <span><span class="MathJax_Preview">p</span><script type="math/tex">p</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.DiagLaplace.posterior_scale"><code class="name">var <span class="ident">posterior_scale</span></code></dt>
<dd>
<div class="desc"><p>Diagonal posterior scale <span><span class="MathJax_Preview">\sqrt{p^{-1}}</span><script type="math/tex">\sqrt{p^{-1}}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.DiagLaplace.posterior_variance"><code class="name">var <span class="ident">posterior_variance</span></code></dt>
<dd>
<div class="desc"><p>Diagonal posterior variance <span><span class="MathJax_Preview">p^{-1}</span><script type="math/tex">p^{-1}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters)</code></dd>
</dl></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="baselaplace.html#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="baselaplace.html#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_ratio" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="baselaplace.html#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.prior_precision_diag" href="baselaplace.html#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="baselaplace.html#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.scatter" href="baselaplace.html#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.LLLaplace"><code class="flex name class">
<span>class <span class="ident">LLLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Baseclass for all last-layer Laplace approximations in this library.
Subclasses specify the structure of the Hessian approximation.
See <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> for the full interface.</p>
<p>A Laplace approximation is represented by a MAP which is given by the
<code>model</code> parameter and a posterior precision or covariance specifying
a Gaussian distribution <span><span class="MathJax_Preview">\mathcal{N}(\theta_{MAP}, P^{-1})</span><script type="math/tex">\mathcal{N}(\theta_{MAP}, P^{-1})</script></span>.
Here, only the parameters of the last layer of the neural network
are treated probabilistically.
The goal of this class is to compute the posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>
which sums as
<span><span class="MathJax_Preview">
P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
</span><script type="math/tex; mode=display">
P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
</script></span>
Every subclass implements different approximations to the log likelihood Hessians,
for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
a simple form for <span><span class="MathJax_Preview">\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 </span><script type="math/tex">\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 </script></span>.
In particular, we assume a scalar or diagonal prior precision so that in
all cases <span><span class="MathJax_Preview">P_0 = \textrm{diag}(p_0)</span><script type="math/tex">P_0 = \textrm{diag}(p_0)</script></span> and the structure of <span><span class="MathJax_Preview">p_0</span><script type="math/tex">p_0</script></span> can be varied.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>model</code></strong> :&ensp;<code>torch.nn.Module</code> or <code><a title="laplace.feature_extractor.FeatureExtractor" href="feature_extractor.html#laplace.feature_extractor.FeatureExtractor">FeatureExtractor</a></code></dt>
<dd>&nbsp;</dd>
<dt><strong><code>likelihood</code></strong> :&ensp;<code>{'classification', 'regression'}</code></dt>
<dd>determines the log likelihood Hessian approximation</dd>
<dt><strong><code>sigma_noise</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>1</code></dt>
<dd>observation noise for the regression setting; must be 1 for classification</dd>
<dt><strong><code>prior_precision</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>1</code></dt>
<dd>prior precision of a Gaussian prior (= weight decay);
can be scalar, per-layer, or diagonal in the most general case</dd>
<dt><strong><code>prior_mean</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>0</code></dt>
<dd>prior mean of a Gaussian prior, useful for continual learning</dd>
<dt><strong><code>temperature</code></strong> :&ensp;<code>float</code>, default=<code>1</code></dt>
<dd>temperature of the likelihood; lower temperature leads to more
concentrated posterior and vice versa.</dd>
<dt><strong><code>backend</code></strong> :&ensp;<code>subclasses</code> of <code><a title="laplace.curvature.CurvatureInterface" href="curvature/index.html#laplace.curvature.CurvatureInterface">CurvatureInterface</a></code></dt>
<dd>backend for access to curvature/Hessian approximations</dd>
<dt><strong><code>last_layer_name</code></strong> :&ensp;<code>str</code>, default=<code>None</code></dt>
<dd>name of the model's last layer, if None it will be determined automatically</dd>
<dt><strong><code>backend_kwargs</code></strong> :&ensp;<code>dict</code>, default=<code>None</code></dt>
<dd>arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.</dd>
</dl></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.lllaplace.DiagLLLaplace" href="lllaplace.html#laplace.lllaplace.DiagLLLaplace">DiagLLLaplace</a></li>
<li><a title="laplace.lllaplace.FullLLLaplace" href="lllaplace.html#laplace.lllaplace.FullLLLaplace">FullLLLaplace</a></li>
<li><a title="laplace.lllaplace.KronLLLaplace" href="lllaplace.html#laplace.lllaplace.KronLLLaplace">KronLLLaplace</a></li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.LLLaplace.prior_precision_diag"><code class="name">var <span class="ident">prior_precision_diag</span></code></dt>
<dd>
<div class="desc"><p>Obtain the diagonal prior precision <span><span class="MathJax_Preview">p_0</span><script type="math/tex">p_0</script></span> constructed from either
a scalar or diagonal prior precision.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>prior_precision_diag</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="baselaplace.html#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="baselaplace.html#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_ratio" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.posterior_precision">posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="baselaplace.html#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="baselaplace.html#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.scatter" href="baselaplace.html#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.FullLLLaplace"><code class="flex name class">
<span>class <span class="ident">FullLLLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
and hence posterior precision. Based on the chosen <code>backend</code> parameter, the full
approximation can be, for example, a generalized Gauss-Newton matrix.
Mathematically, we have <span><span class="MathJax_Preview">P \in \mathbb{R}^{P \times P}</span><script type="math/tex">P \in \mathbb{R}^{P \times P}</script></span>.
See <code><a title="laplace.FullLaplace" href="#laplace.FullLaplace">FullLaplace</a></code>, <code><a title="laplace.LLLaplace" href="#laplace.LLLaplace">LLLaplace</a></code>, and <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> for the full interface.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li>laplace.lllaplace.LLLaplace</li>
<li><a title="laplace.baselaplace.FullLaplace" href="baselaplace.html#laplace.baselaplace.FullLaplace">FullLaplace</a></li>
<li><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.FullLaplace" href="baselaplace.html#laplace.baselaplace.FullLaplace">FullLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.FullLaplace.fit" href="baselaplace.html#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.functional_variance" href="baselaplace.html#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.log_det_posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.log_det_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.log_det_ratio" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.log_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.log_marginal_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.optimize_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.posterior_covariance" href="baselaplace.html#laplace.baselaplace.FullLaplace.posterior_covariance">posterior_covariance</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.posterior_precision" href="baselaplace.html#laplace.baselaplace.FullLaplace.posterior_precision">posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.posterior_scale" href="baselaplace.html#laplace.baselaplace.FullLaplace.posterior_scale">posterior_scale</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.predictive_samples" href="baselaplace.html#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.prior_precision_diag" href="baselaplace.html#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.sample" href="baselaplace.html#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.FullLaplace.scatter" href="baselaplace.html#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.KronLLLaplace"><code class="flex name class">
<span>class <span class="ident">KronLLLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, damping=False, **backend_kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have for the last parameter group, i.e., torch.nn.Linear,
that \P\approx Q \otimes H.
See <code><a title="laplace.KronLaplace" href="#laplace.KronLaplace">KronLaplace</a></code>, <code><a title="laplace.LLLaplace" href="#laplace.LLLaplace">LLLaplace</a></code>, and <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> for the full interface and see
<code><a title="laplace.matrix.Kron" href="matrix.html#laplace.matrix.Kron">Kron</a></code> and <code><a title="laplace.matrix.KronDecomposed" href="matrix.html#laplace.matrix.KronDecomposed">KronDecomposed</a></code> for the structure of
the Kronecker factors. <code>Kron</code> is used to aggregate factors by summing up and
<code>KronDecomposed</code> is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Use of <code>damping</code> is possible by initializing or setting <code>damping=True</code>.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li>laplace.lllaplace.LLLaplace</li>
<li><a title="laplace.baselaplace.KronLaplace" href="baselaplace.html#laplace.baselaplace.KronLaplace">KronLaplace</a></li>
<li><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.KronLaplace" href="baselaplace.html#laplace.baselaplace.KronLaplace">KronLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.KronLaplace.fit" href="baselaplace.html#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.functional_variance" href="baselaplace.html#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.log_det_posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.log_det_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.log_det_ratio" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.log_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.log_marginal_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.optimize_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.posterior_precision" href="baselaplace.html#laplace.baselaplace.KronLaplace.posterior_precision">posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.predictive_samples" href="baselaplace.html#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.prior_precision_diag" href="baselaplace.html#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.sample" href="baselaplace.html#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.KronLaplace.scatter" href="baselaplace.html#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.DiagLLLaplace"><code class="flex name class">
<span>class <span class="ident">DiagLLLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, last_layer_name=None, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Last-layer Laplace approximation with diagonal log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have <span><span class="MathJax_Preview">P \approx \textrm{diag}(P)</span><script type="math/tex">P \approx \textrm{diag}(P)</script></span>.
See <code><a title="laplace.DiagLaplace" href="#laplace.DiagLaplace">DiagLaplace</a></code>, <code><a title="laplace.LLLaplace" href="#laplace.LLLaplace">LLLaplace</a></code>, and <code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code> for the full interface.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li>laplace.lllaplace.LLLaplace</li>
<li><a title="laplace.baselaplace.DiagLaplace" href="baselaplace.html#laplace.baselaplace.DiagLaplace">DiagLaplace</a></li>
<li><a title="laplace.baselaplace.BaseLaplace" href="baselaplace.html#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.DiagLaplace" href="baselaplace.html#laplace.baselaplace.DiagLaplace">DiagLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.DiagLaplace.fit" href="baselaplace.html#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.functional_variance" href="baselaplace.html#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.log_det_posterior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.log_det_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.log_det_ratio" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.log_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.log_marginal_likelihood" href="baselaplace.html#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.optimize_prior_precision" href="baselaplace.html#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.posterior_precision" href="baselaplace.html#laplace.baselaplace.DiagLaplace.posterior_precision">posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.posterior_scale" href="baselaplace.html#laplace.baselaplace.DiagLaplace.posterior_scale">posterior_scale</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.posterior_variance" href="baselaplace.html#laplace.baselaplace.DiagLaplace.posterior_variance">posterior_variance</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.predictive_samples" href="baselaplace.html#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.prior_precision_diag" href="baselaplace.html#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.sample" href="baselaplace.html#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.DiagLaplace.scatter" href="baselaplace.html#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
</dl>
</section>
</article>
<nav id="sidebar">
<h1>Index</h1>
<div class="toc">
<ul>
<li><a href="#setup">Setup</a></li>
<li><a href="#structure">Structure</a></li>
<li><a href="#extendability">Extendability</a></li>
<li><a href="#example-usage">Example usage</a><ul>
<li><a href="#post-hoc-prior-precision-tuning-of-last-layer-la">Post-hoc prior precision tuning of last-layer LA</a></li>
<li><a href="#differentiating-the-log-marginal-likelihood-wrt-hyperparameters">Differentiating the log marginal likelihood w.r.t. hyperparameters</a></li>
</ul>
</li>
<li><a href="#documentation">Documentation</a></li>
<li><a href="#references">References</a></li>
<li><a href="#full-example-post-hoc-optimization-of-the-marginal-likelihood-and-prediction">Full example: post-hoc optimization of the marginal likelihood and prediction</a><ul>
<li><a href="#sinusoidal-toy-data">Sinusoidal toy data</a></li>
<li><a href="#training-a-map">Training a MAP</a></li>
<li><a href="#fitting-and-optimizing-the-laplace-approximation-using-empirical-bayes">Fitting and optimizing the Laplace approximation using empirical Bayes</a></li>
<li><a href="#bayesian-predictive">Bayesian predictive</a></li>
</ul>
</li>
</ul>
</div>
<ul id="index">
<li><h3><a href="#header-submodules">Sub-modules</a></h3>
<ul>
<li><code><a title="laplace.baselaplace" href="baselaplace.html">laplace.baselaplace</a></code></li>
<li><code><a title="laplace.curvature" href="curvature/index.html">laplace.curvature</a></code></li>
<li><code><a title="laplace.feature_extractor" href="feature_extractor.html">laplace.feature_extractor</a></code></li>
<li><code><a title="laplace.laplace" href="laplace.html">laplace.laplace</a></code></li>
<li><code><a title="laplace.lllaplace" href="lllaplace.html">laplace.lllaplace</a></code></li>
<li><code><a title="laplace.matrix" href="matrix.html">laplace.matrix</a></code></li>
<li><code><a title="laplace.utils" href="utils.html">laplace.utils</a></code></li>
</ul>
</li>
<li><h3><a href="#header-functions">Functions</a></h3>
<ul class="">
<li><code><a title="laplace.Laplace" href="#laplace.Laplace">Laplace</a></code></li>
</ul>
</li>
<li><h3><a href="#header-classes">Classes</a></h3>
<ul>
<li>
<h4><code><a title="laplace.BaseLaplace" href="#laplace.BaseLaplace">BaseLaplace</a></code></h4>
<ul class="">
<li><code><a title="laplace.BaseLaplace.fit" href="#laplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.BaseLaplace.log_marginal_likelihood" href="#laplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.BaseLaplace.predictive" href="#laplace.BaseLaplace.predictive">predictive</a></code></li>
<li><code><a title="laplace.BaseLaplace.predictive_samples" href="#laplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.BaseLaplace.functional_variance" href="#laplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.BaseLaplace.sample" href="#laplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.BaseLaplace.optimize_prior_precision" href="#laplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="laplace.FullLaplace" href="#laplace.FullLaplace">FullLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.KronLaplace" href="#laplace.KronLaplace">KronLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.DiagLaplace" href="#laplace.DiagLaplace">DiagLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.LLLaplace" href="#laplace.LLLaplace">LLLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.FullLLLaplace" href="#laplace.FullLLLaplace">FullLLLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.KronLLLaplace" href="#laplace.KronLLLaplace">KronLLLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.DiagLLLaplace" href="#laplace.DiagLLLaplace">DiagLLLaplace</a></code></h4>
</li>
</ul>
</li>
</ul>
</nav>
</main>
<footer id="footer">
<p>Generated by <a href="https://pdoc3.github.io/pdoc"><cite>pdoc</cite> 0.9.2</a>.</p>
</footer>
</body>
</html>