<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
<meta name="generator" content="pdoc 0.9.2" />
<title>laplace.baselaplace API documentation</title>
<meta name="description" content="" />
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css" integrity="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs=" crossorigin>
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css" integrity="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg=" crossorigin>
<link rel="stylesheet preload" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css" crossorigin>
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
<script async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML" integrity="sha256-kZafAc6mZvK3W3v1pHOcUix30OHQN6pU/NO2oFkqZVw=" crossorigin></script>
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js" integrity="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8=" crossorigin></script>
<script>window.addEventListener('DOMContentLoaded', () => hljs.initHighlighting())</script>
</head>
<body>
<main>
<article id="content">
<header>
<h1 class="title">Module <code>laplace.baselaplace</code></h1>
</header>
<section id="section-intro">
</section>
<section>
</section>
<section>
</section>
<section>
</section>
<section>
<h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="laplace.baselaplace.BaseLaplace"><code class="flex name class">
<span>class <span class="ident">BaseLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Baseclass for all Laplace approximations in this library.
Subclasses need to specify how the Hessian approximation is initialized,
how to add up curvature over training data, how to sample from the
Laplace approximation, and how to compute the functional variance.</p>
<p>A Laplace approximation is represented by a MAP which is given by the
<code>model</code> parameter and a posterior precision or covariance specifying
a Gaussian distribution <span><span class="MathJax_Preview">\mathcal{N}(\theta_{MAP}, P^{-1})</span><script type="math/tex">\mathcal{N}(\theta_{MAP}, P^{-1})</script></span>.
The goal of this class is to compute the posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>
which sums as
<span><span class="MathJax_Preview">
P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
</span><script type="math/tex; mode=display">
P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta)
\vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}.
</script></span>
Every subclass implements different approximations to the log likelihood Hessians,
for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have
a simple form for <span><span class="MathJax_Preview">\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 </span><script type="math/tex">\nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 </script></span>.
In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in
all cases <span><span class="MathJax_Preview">P_0 = \textrm{diag}(p_0)</span><script type="math/tex">P_0 = \textrm{diag}(p_0)</script></span> and the structure of <span><span class="MathJax_Preview">p_0</span><script type="math/tex">p_0</script></span> can be varied.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>model</code></strong> :&ensp;<code>torch.nn.Module</code></dt>
<dd>&nbsp;</dd>
<dt><strong><code>likelihood</code></strong> :&ensp;<code>{'classification', 'regression'}</code></dt>
<dd>determines the log likelihood Hessian approximation</dd>
<dt><strong><code>sigma_noise</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>1</code></dt>
<dd>observation noise for the regression setting; must be 1 for classification</dd>
<dt><strong><code>prior_precision</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>1</code></dt>
<dd>prior precision of a Gaussian prior (= weight decay);
can be scalar, per-layer, or diagonal in the most general case</dd>
<dt><strong><code>prior_mean</code></strong> :&ensp;<code>torch.Tensor</code> or <code>float</code>, default=<code>0</code></dt>
<dd>prior mean of a Gaussian prior, useful for continual learning</dd>
<dt><strong><code>temperature</code></strong> :&ensp;<code>float</code>, default=<code>1</code></dt>
<dd>temperature of the likelihood; lower temperature leads to more
concentrated posterior and vice versa.</dd>
<dt><strong><code>backend</code></strong> :&ensp;<code>subclasses</code> of <code><a title="laplace.curvature.CurvatureInterface" href="curvature/index.html#laplace.curvature.CurvatureInterface">CurvatureInterface</a></code></dt>
<dd>backend for access to curvature/Hessian approximations</dd>
<dt><strong><code>backend_kwargs</code></strong> :&ensp;<code>dict</code>, default=<code>None</code></dt>
<dd>arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.</dd>
</dl></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.DiagLaplace" href="#laplace.baselaplace.DiagLaplace">DiagLaplace</a></li>
<li><a title="laplace.baselaplace.FullLaplace" href="#laplace.baselaplace.FullLaplace">FullLaplace</a></li>
<li><a title="laplace.baselaplace.KronLaplace" href="#laplace.baselaplace.KronLaplace">KronLaplace</a></li>
<li>laplace.lllaplace.LLLaplace</li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.baselaplace.BaseLaplace.backend"><code class="name">var <span class="ident">backend</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.log_likelihood"><code class="name">var <span class="ident">log_likelihood</span></code></dt>
<dd>
<div class="desc"><p>Compute log likelihood on the training data after <code>.fit()</code> has been called.
The log likelihood is computed on-demand based on the loss and, for example,
the observation noise which makes it differentiable in the latter for
iterative updates.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_likelihood</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.scatter"><code class="name">var <span class="ident">scatter</span></code></dt>
<dd>
<div class="desc"><p>Computes the <em>scatter</em>, a term of the log marginal likelihood that
corresponds to L-2 regularization:
<code>scatter</code> = <span><span class="MathJax_Preview">(\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) </span><script type="math/tex">(\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) </script></span>.</p>
<h2 id="returns">Returns</h2>
<p>[type]
[description]</p></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.log_det_prior_precision"><code class="name">var <span class="ident">log_det_prior_precision</span></code></dt>
<dd>
<div class="desc"><p>Compute log determinant of the prior precision
<span><span class="MathJax_Preview">\log \det P_0</span><script type="math/tex">\log \det P_0</script></span></p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_det</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.log_det_posterior_precision"><code class="name">var <span class="ident">log_det_posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Compute log determinant of the posterior precision
<span><span class="MathJax_Preview">\log \det P</span><script type="math/tex">\log \det P</script></span> which depends on the subclasses structure
used for the Hessian approximation.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_det</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.log_det_ratio"><code class="name">var <span class="ident">log_det_ratio</span></code></dt>
<dd>
<div class="desc"><p>Compute the log determinant ratio, a part of the log marginal likelihood.
<span><span class="MathJax_Preview">
\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
</span><script type="math/tex; mode=display">
\log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
</script></span></p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_det_ratio</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.prior_precision_diag"><code class="name">var <span class="ident">prior_precision_diag</span></code></dt>
<dd>
<div class="desc"><p>Obtain the diagonal prior precision <span><span class="MathJax_Preview">p_0</span><script type="math/tex">p_0</script></span> constructed from either
a scalar, layer-wise, or diagonal prior precision.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>prior_precision_diag</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.prior_mean"><code class="name">var <span class="ident">prior_mean</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.prior_precision"><code class="name">var <span class="ident">prior_precision</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.sigma_noise"><code class="name">var <span class="ident">sigma_noise</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Compute or return the posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>posterior_prec</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
</dl>
<h3>Methods</h3>
<dl>
<dt id="laplace.baselaplace.BaseLaplace.fit"><code class="name flex">
<span>def <span class="ident">fit</span></span>(<span>self, train_loader)</span>
</code></dt>
<dd>
<div class="desc"><p>Fit the local Laplace approximation at the parameters of the model.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>train_loader</code></strong> :&ensp;<code>torch.data.utils.DataLoader</code></dt>
<dd>each iterate is a training batch (X, y);
<code>train_loader.dataset</code> needs to be set to access <span><span class="MathJax_Preview">N</span><script type="math/tex">N</script></span>, size of the data set</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.log_marginal_likelihood"><code class="name flex">
<span>def <span class="ident">log_marginal_likelihood</span></span>(<span>self, prior_precision=None, sigma_noise=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Compute the Laplace approximation to the log marginal likelihood subject
to specific Hessian approximations that subclasses implement.
Requires that the Laplace approximation has been fit before.
The resulting torch.Tensor is differentiable in <code>prior_precision</code> and
<code>sigma_noise</code> if these have gradients enabled.
By passing <code>prior_precision</code> or <code>sigma_noise</code>, the current value is
overwritten. This is useful for iterating on the log marginal likelihood.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>prior_precision</code></strong> :&ensp;<code>torch.Tensor</code>, optional</dt>
<dd>prior precision if should be changed from current <code>prior_precision</code> value</dd>
<dt><strong><code>sigma_noise</code></strong> :&ensp;<code>[type]</code>, optional</dt>
<dd>observation noise standard deviation if should be changed</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>log_marglik</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.predictive"><code class="name flex">
<span>def <span class="ident">predictive</span></span>(<span>self, x, pred_type='glm', link_approx='mc', n_samples=100)</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.predictive_samples"><code class="name flex">
<span>def <span class="ident">predictive_samples</span></span>(<span>self, x, pred_type='glm', n_samples=100)</span>
</code></dt>
<dd>
<div class="desc"><p>Sample from the posterior predictive on input data <code>x</code>.
Can be used, for example, for Thompson sampling.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>x</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>input data <code>(batch_size, input_shape)</code></dd>
<dt><strong><code>pred_type</code></strong> :&ensp;<code>{'glm', 'nn'}</code>, default=<code>'glm'</code></dt>
<dd>type of posterior predictive, linearized GLM predictive or neural
network sampling predictive. The GLM predictive is consistent with
the curvature approximations used here.</dd>
<dt><strong><code>n_samples</code></strong> :&ensp;<code>int</code></dt>
<dd>number of samples</dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>samples</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>samples <code>(n_samples, batch_size, output_shape)</code></dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.functional_variance"><code class="name flex">
<span>def <span class="ident">functional_variance</span></span>(<span>self, Jacs)</span>
</code></dt>
<dd>
<div class="desc"><p>Compute functional variance for the <code>'glm'</code> predictive:
<code>f_var[i] = Jacs[i] @ P.inv() @ Jacs[i].T</code>, which is a output x output
predictive covariance matrix.
Mathematically, we have for a single Jacobian
<span><span class="MathJax_Preview">\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}</span><script type="math/tex">\mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}}</script></span>
the output covariance matrix
<span><span class="MathJax_Preview"> \mathcal{J} P^{-1} \mathcal{J}^T </span><script type="math/tex"> \mathcal{J} P^{-1} \mathcal{J}^T </script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>Jacs</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>Jacobians of model output wrt parameters
<code>(batch, outputs, parameters)</code></dd>
</dl>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>f_var</code></strong> :&ensp;<code>torch.Tensor</code></dt>
<dd>output covariance <code>(batch, outputs, outputs)</code></dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.sample"><code class="name flex">
<span>def <span class="ident">sample</span></span>(<span>self, n_samples=100)</span>
</code></dt>
<dd>
<div class="desc"><p>Sample from the Laplace posterior approximation, i.e.,
<span><span class="MathJax_Preview"> \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1})</span><script type="math/tex"> \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1})</script></span>.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>n_samples</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>number of samples</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.BaseLaplace.optimize_prior_precision"><code class="name flex">
<span>def <span class="ident">optimize_prior_precision</span></span>(<span>self, method='marglik', n_steps=100, lr=0.1, init_prior_prec=1.0, val_loader=None, loss=&lt;function get_nll&gt;, log_prior_prec_min=-4, log_prior_prec_max=4, grid_size=100, pred_type='glm', link_approx='probit', n_samples=100, verbose=False)</span>
</code></dt>
<dd>
<div class="desc"><p>Optimize the prior precision post-hoc using the <code>method</code>
specified by the user.</p>
<h2 id="parameters">Parameters</h2>
<dl>
<dt><strong><code>method</code></strong> :&ensp;<code>{'marglik', 'CV'}</code>, default=<code>'marglik'</code></dt>
<dd>specifies how the prior precision should be optimized.</dd>
<dt><strong><code>n_steps</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>the number of gradient descent steps to take.</dd>
<dt><strong><code>lr</code></strong> :&ensp;<code>float</code>, default=<code>1e-1</code></dt>
<dd>the learning rate to use for gradient descent.</dd>
<dt><strong><code>init_prior_prec</code></strong> :&ensp;<code>float</code>, default=<code>1.0</code></dt>
<dd>initial prior precision before the first optimization step.</dd>
<dt><strong><code>val_loader</code></strong> :&ensp;<code>torch.data.utils.DataLoader</code>, default=<code>None</code></dt>
<dd>DataLoader for the validation set; each iterate is a training batch (X, y).</dd>
<dt><strong><code>loss</code></strong> :&ensp;<code>callable</code>, default=<code>get_nll</code></dt>
<dd>loss function to use for CV.</dd>
<dt><strong><code>log_prior_prec_min</code></strong> :&ensp;<code>float</code>, default=<code>-4</code></dt>
<dd>lower bound of gridsearch interval for CV.</dd>
<dt><strong><code>log_prior_prec_max</code></strong> :&ensp;<code>float</code>, default=<code>4</code></dt>
<dd>upper bound of gridsearch interval for CV.</dd>
<dt><strong><code>grid_size</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>number of values to consider inside the gridsearch interval for CV.</dd>
<dt><strong><code>pred_type</code></strong> :&ensp;<code>{'glm', 'nn'}</code>, default=<code>'glm'</code></dt>
<dd>type of posterior predictive, linearized GLM predictive or neural
network sampling predictive. The GLM predictive is consistent with
the curvature approximations used here.</dd>
<dt><strong><code>link_approx</code></strong> :&ensp;<code>{'mc', 'probit', 'bridge'}</code>, default=<code>'probit'</code></dt>
<dd>how to approximate the classification link function for the <code>'glm'</code>.
For <code>pred_type='nn'</code>, only <code>'mc'</code> is possible.</dd>
<dt><strong><code>n_samples</code></strong> :&ensp;<code>int</code>, default=<code>100</code></dt>
<dd>number of samples for <code>link_approx='mc'</code>.</dd>
<dt><strong><code>verbose</code></strong> :&ensp;<code>bool</code>, default=<code>False</code></dt>
<dd>if true, the optimized prior precision will be printed
(can be a large tensor if the prior has a diagonal covariance).</dd>
</dl></div>
</dd>
</dl>
</dd>
<dt id="laplace.baselaplace.FullLaplace"><code class="flex name class">
<span>class <span class="ident">FullLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Laplace approximation with full, i.e., dense, log likelihood Hessian approximation
and hence posterior precision. Based on the chosen <code>backend</code> parameter, the full
approximation can be, for example, a generalized Gauss-Newton matrix.
Mathematically, we have <span><span class="MathJax_Preview">P \in \mathbb{R}^{P \times P}</span><script type="math/tex">P \in \mathbb{R}^{P \times P}</script></span>.
See <code><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></code> for the full interface.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.lllaplace.FullLLLaplace" href="lllaplace.html#laplace.lllaplace.FullLLLaplace">FullLLLaplace</a></li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.baselaplace.FullLaplace.posterior_scale"><code class="name">var <span class="ident">posterior_scale</span></code></dt>
<dd>
<div class="desc"><p>Posterior scale (square root of the covariance), i.e.,
<span><span class="MathJax_Preview">P^{-\frac{1}{2}}</span><script type="math/tex">P^{-\frac{1}{2}}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>scale</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters, parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.FullLaplace.posterior_covariance"><code class="name">var <span class="ident">posterior_covariance</span></code></dt>
<dd>
<div class="desc"><p>Posterior covariance, i.e., <span><span class="MathJax_Preview">P^{-1}</span><script type="math/tex">P^{-1}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>covariance</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters, parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.FullLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters, parameters)</code></dd>
</dl></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_posterior_precision" href="#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_prior_precision" href="#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_ratio" href="#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_likelihood" href="#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.prior_precision_diag" href="#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.scatter" href="#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.baselaplace.KronLaplace"><code class="flex name class">
<span>class <span class="ident">KronLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, damping=False, **backend_kwargs)</span>
</code></dt>
<dd>
<div class="desc"><p>Laplace approximation with Kronecker factored log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have for each parameter group, e.g., torch.nn.Module,
that \P\approx Q \otimes H.
See <code><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></code> for the full interface and see
<code><a title="laplace.matrix.Kron" href="matrix.html#laplace.matrix.Kron">Kron</a></code> and <code><a title="laplace.matrix.KronDecomposed" href="matrix.html#laplace.matrix.KronDecomposed">KronDecomposed</a></code> for the structure of
the Kronecker factors. <code>Kron</code> is used to aggregate factors by summing up and
<code>KronDecomposed</code> is used to add the prior, a Hessian factor (e.g. temperature),
and computing posterior covariances, marginal likelihood, etc.
Damping can be enabled by setting <code>damping=True</code>.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.lllaplace.KronLLLaplace" href="lllaplace.html#laplace.lllaplace.KronLLLaplace">KronLLLaplace</a></li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.baselaplace.KronLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Kronecker factored Posterior precision <span><span class="MathJax_Preview">P</span><script type="math/tex">P</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code><a title="laplace.matrix.KronDecomposed" href="matrix.html#laplace.matrix.KronDecomposed">KronDecomposed</a></code></dt>
<dd>&nbsp;</dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.KronLaplace.prior_precision"><code class="name">var <span class="ident">prior_precision</span></code></dt>
<dd>
<div class="desc"></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_posterior_precision" href="#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_prior_precision" href="#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_ratio" href="#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_likelihood" href="#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.prior_precision_diag" href="#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.scatter" href="#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
<dt id="laplace.baselaplace.DiagLaplace"><code class="flex name class">
<span>class <span class="ident">DiagLaplace</span></span>
<span>(</span><span>model, likelihood, sigma_noise=1.0, prior_precision=1.0, prior_mean=0.0, temperature=1.0, backend=laplace.curvature.backpack.BackPackGGN, backend_kwargs=None)</span>
</code></dt>
<dd>
<div class="desc"><p>Laplace approximation with diagonal log likelihood Hessian approximation
and hence posterior precision.
Mathematically, we have <span><span class="MathJax_Preview">P \approx \textrm{diag}(P)</span><script type="math/tex">P \approx \textrm{diag}(P)</script></span>.
See <code><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></code> for the full interface.</p></div>
<h3>Ancestors</h3>
<ul class="hlist">
<li><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></li>
<li>abc.ABC</li>
</ul>
<h3>Subclasses</h3>
<ul class="hlist">
<li><a title="laplace.lllaplace.DiagLLLaplace" href="lllaplace.html#laplace.lllaplace.DiagLLLaplace">DiagLLLaplace</a></li>
</ul>
<h3>Instance variables</h3>
<dl>
<dt id="laplace.baselaplace.DiagLaplace.posterior_precision"><code class="name">var <span class="ident">posterior_precision</span></code></dt>
<dd>
<div class="desc"><p>Diagonal posterior precision <span><span class="MathJax_Preview">p</span><script type="math/tex">p</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.DiagLaplace.posterior_scale"><code class="name">var <span class="ident">posterior_scale</span></code></dt>
<dd>
<div class="desc"><p>Diagonal posterior scale <span><span class="MathJax_Preview">\sqrt{p^{-1}}</span><script type="math/tex">\sqrt{p^{-1}}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters)</code></dd>
</dl></div>
</dd>
<dt id="laplace.baselaplace.DiagLaplace.posterior_variance"><code class="name">var <span class="ident">posterior_variance</span></code></dt>
<dd>
<div class="desc"><p>Diagonal posterior variance <span><span class="MathJax_Preview">p^{-1}</span><script type="math/tex">p^{-1}</script></span>.</p>
<h2 id="returns">Returns</h2>
<dl>
<dt><strong><code>precision</code></strong> :&ensp;<code>torch.tensor</code></dt>
<dd><code>(parameters)</code></dd>
</dl></div>
</dd>
</dl>
<h3>Inherited members</h3>
<ul class="hlist">
<li><code><b><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></b></code>:
<ul class="hlist">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_posterior_precision" href="#laplace.baselaplace.BaseLaplace.log_det_posterior_precision">log_det_posterior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_prior_precision" href="#laplace.baselaplace.BaseLaplace.log_det_prior_precision">log_det_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_det_ratio" href="#laplace.baselaplace.BaseLaplace.log_det_ratio">log_det_ratio</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_likelihood" href="#laplace.baselaplace.BaseLaplace.log_likelihood">log_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.prior_precision_diag" href="#laplace.baselaplace.BaseLaplace.prior_precision_diag">prior_precision_diag</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.scatter" href="#laplace.baselaplace.BaseLaplace.scatter">scatter</a></code></li>
</ul>
</li>
</ul>
</dd>
</dl>
</section>
</article>
<nav id="sidebar">
<h1>Index</h1>
<div class="toc">
<ul></ul>
</div>
<ul id="index">
<li><h3>Super-module</h3>
<ul>
<li><code><a title="laplace" href="index.html">laplace</a></code></li>
</ul>
</li>
<li><h3><a href="#header-classes">Classes</a></h3>
<ul>
<li>
<h4><code><a title="laplace.baselaplace.BaseLaplace" href="#laplace.baselaplace.BaseLaplace">BaseLaplace</a></code></h4>
<ul class="">
<li><code><a title="laplace.baselaplace.BaseLaplace.fit" href="#laplace.baselaplace.BaseLaplace.fit">fit</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.log_marginal_likelihood" href="#laplace.baselaplace.BaseLaplace.log_marginal_likelihood">log_marginal_likelihood</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive" href="#laplace.baselaplace.BaseLaplace.predictive">predictive</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.predictive_samples" href="#laplace.baselaplace.BaseLaplace.predictive_samples">predictive_samples</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.functional_variance" href="#laplace.baselaplace.BaseLaplace.functional_variance">functional_variance</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.sample" href="#laplace.baselaplace.BaseLaplace.sample">sample</a></code></li>
<li><code><a title="laplace.baselaplace.BaseLaplace.optimize_prior_precision" href="#laplace.baselaplace.BaseLaplace.optimize_prior_precision">optimize_prior_precision</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="laplace.baselaplace.FullLaplace" href="#laplace.baselaplace.FullLaplace">FullLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.baselaplace.KronLaplace" href="#laplace.baselaplace.KronLaplace">KronLaplace</a></code></h4>
</li>
<li>
<h4><code><a title="laplace.baselaplace.DiagLaplace" href="#laplace.baselaplace.DiagLaplace">DiagLaplace</a></code></h4>
</li>
</ul>
</li>
</ul>
</nav>
</main>
<footer id="footer">
<p>Generated by <a href="https://pdoc3.github.io/pdoc"><cite>pdoc</cite> 0.9.2</a>.</p>
</footer>
</body>
</html>