FishLeg package

Submodules

FishLeg.fishleg module

class FishLeg.fishleg.FishLeg(model: Module, draw: Callable[[Module, Tensor], Tuple[Tensor, Tensor]], nll: Callable[[Module, Tuple[Tensor, Tensor]], Tensor], dataloader: Callable[[], Tuple[Tensor, Tensor]], lr: float = 0.01, eps: float = 0.0001, weight_decay: float = 1e-05, beta: float = 0.9, update_aux_every: int = -3, aux_lr: float = 0.001, aux_betas: Tuple[float, float] = (0.9, 0.999), aux_eps: float = 1e-08, damping: float = 1e-05, pre_aux_training: int = 10, differentiable: bool = False, sgd_lr: float = 0.01)

Bases: Optimizer

Implement FishLeg algorithm.

Parameters
  • model (torch.nn.Module) – a pytorch neural network module, can be nested in a tree structure

  • draw (Callable[[nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]) – Sampling function that takes a model \(f\) and input data \(\mathbf X\), and returns \((\mathbf X, \mathbf y)\), where \(\mathbf y\) is sampled from the conditional distribution \(p(\mathbf y|f(\mathbf X))\)

  • nll (Callable[[nn.Module, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]) – A function that takes a model and data, and evaluate the negative log-likelihood.

  • dataloader (Callable[[int], Tuple[torch.Tensor, torch.Tensor]]) – A function that takes a batch size as input and output dataset with corresponding size.

  • lr (float) – learning rate, for the parameters of the input model using FishLeg (default: 1e-2)

  • eps (float) – a small scalar, to evaluate the auxiliary loss in the direction of gradient of model parameters (default: 1e-4)

  • update_aux_every (int) – number of iteration after which an auxiliary update is executed, if negative, then run -update_aux_every auxiliary updates in each outer iteration. (default: -3)

  • aux_lr (float) – learning rate for the auxiliary parameters, using Adam (default: 1e-3)

  • aux_betas (Tuple[float, float]) – coefficients used for computing running averages of gradient and its square for auxiliary parameters (default: (0.9, 0.999))

  • aux_eps (float) – term added to the denominator to improve numerical stability for auxiliary parameters (default: 1e-8)

  • pre_aux_training (int) – number of auxiliary updates to make before any update of the original parameter. This process intends to approximate the correct Fisher Information matrix during initialization, which is espectially important for fine-tuning of models with pretraining

  • differentiable (bool) – whether the fused implementation (CUDA only) is used

  • sgd_lr (float) –

    help specify initial scale of the inverse Fisher Information matrix approximation, \(\eta\). Make sure that

    \[- \eta_{init} Q(\lambda) grad = - \eta_{sgd} grad\]

    is hold in the beginning of the optimization. And here \(\eta_{init}=\eta_{sgd}/\eta_{fl}\).

Example:
>>> auxloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100)
>>> trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100)
>>>
>>> likelihood = FixedGaussianLikelihood(sigma_fixed=1.0)
>>>
>>> def nll(model, data):
>>>     data_x, data_y = data
>>>     pred_y = model.forward(data_x)
>>>     return likelihood.nll(data_y, pred_y)
>>>
>>> def draw(model, data_x):
>>>     pred_y = model.forward(data_x)
>>>     return (data_x, likelihood.draw(pred_y))
>>>
>>> def dataloader():
>>>     data_x, _ = next(iter(auxloader))
>>>     return data_x
>>>
>>> model = nn.Sequential(
>>>     nn.Linear(2, 5),
>>>     nn.ReLU(),
>>>     nn.Linear(5, 1),
>>> )
>>>
>>> opt = FishLeg(
>>>     model,
>>>     draw,
>>>     nll,
>>>     dataloader
>>> )
>>>
>>> for iteration in range(100):
>>>     data_x, data_y = next(iter(trainloader))
>>>     opt.zero_grad()
>>>     pred_y = model(data_x)
>>>     loss = nn.MSELoss()(data_y, pred_y)
>>>     loss.backward()
>>>     opt.step()
>>>     if iteration % 10 == 0:
>>>         print(loss.detach())
init_model_aux(model: Module) Module

Given a model to optimize, parameters can be devided to

  1. those fixed as pre-trained.

  2. those required to optimize using FishLeg.

Replace modules in the second group with FishLeg modules.

Args:
model (torch.nn.Module, required):

A model containing modules to replace with FishLeg modules containing extra functionality related to FishLeg algorithm.

Returns:

torch.nn.Module, the replaced model.

step() None

Performes a single optimization step of FishLeg.

update_aux() None

Performs a single auxliarary parameter update using Adam. By minimizing the following objective:

\[nll(model, \theta + \epsilon Q(\lambda)g) + nll(model, \theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g\]

where \(\theta\) is the parameters of model, \(\lambda\) is the auxliarary parameters.

FishLeg.fishleg.update_dict(replace: Module, module: Module) Module

FishLeg.fishleg_layers module

class FishLeg.fishleg_layers.FishLinear(in_features: int, out_features: int, bias: bool = True, init_scale: float = 1.0, device=None, dtype=None)

Bases: Linear, FishModule

Qv(v: Tuple[Tensor, Tensor]) Tuple[Tensor, Tensor]

For fully-connected layers, the default structure of \(Q\) as a block-diaglonal matrix is,

\[Q_l = (R_lR_l^T \otimes L_lL_l^T)\]

where \(l\) denotes the l-th layer. The matrix \(R_l\) has size \((N_{l-1} + 1) \times (N_{l-1} + 1)\) while the matrix \(L_l\) has size \(N_l \times N_l\). The auxiliarary parameters \(\lambda\) are represented by the matrices \(L_l, R_l\).

in_features: int
out_features: int
weight: Tensor
class FishLeg.fishleg_layers.FishModule(*args: Any, **kwargs: Any)

Bases: Module

Base class for all neural network modules in FishLeg to

  1. Initialize auxiliary parameters, \(\lambda\) and its forms, \(Q(\lambda)\).

  2. Specify quick calculation of \(Q(\lambda)v\) products.

Parameters
  • fishleg_aux (torch.nn.ParameterDict) –

    auxiliary parameters with their initialization, including an additional parameter, scale, \(\eta\). Make sure that

    \[- \eta_{init} Q(\lambda) grad = - \eta_{sgd} grad\]

    is hold in the beginning of the optimization

  • order (List) – specify a name order of original parameter

abstract Qv(aux: Dict, v: Tuple[Tensor, ...]) Tuple[Tensor, ...]

\(Q(\lambda)\) is a positive definite matrix which will effectively estimate the inverse damped Fisher Information Matrix. Appropriate choices for \(Q\) should take into account the architecture of the model/module. It is usually parameterized as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of the neural networks.

Args:
aux: (Dict, required): auxiliary parameters,

\(\lambda\), a dictionary with keys, the name of the auxiliary parameters, and values, the auxiliary parameters of the module. These auxiliaray parameters will form \(Q(\lambda)\).

v: (Tuple[Tensor, …], required): Values of the original parameters,

in an order that align with self.order, to multiply with \(Q(\lambda)\).

Returns:
Tuple[Tensor, …]: The calculated \(Q(\lambda)v\) products,

in same order with self.order.

cuda(device) None

Moves all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Args:
device (int, optional): if specified, all parameters will be

copied to that device

Returns:

Module: self

property name: str
training: bool
FishLeg.fishleg_layers.get_zero_grad_hook(mask)

FishLeg.fishleg_likelihood module

class FishLeg.fishleg_likelihood.BernoulliLikelihood

Bases: FishLikelihood

The Bernoulli likelihood used for classification. Using the standard Normal CDF \(\Phi(x)\)) and the identity \(\Phi(-x) = 1-\Phi(x)\), we can write the likelihood as:

\[p(y|f(x))=\Phi(yf(x))\]
draw(preds: Tensor) Tensor

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

nll(observations: Tensor, preds: Tensor) Tensor

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

class FishLeg.fishleg_likelihood.FishLikelihood

Bases: object

A Likelihood in FishLeg specifies a probablistic modeling, which attributes the mapping from latent function values \(f(\mathbf X)\) to observed labels \(y\).

For example, in the case of regression, a Gaussian likelihood can be chosen, as

\[y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I)\]

As for the case of classification, a Bernoulli distribution can be chosen

\[\begin{split}y(\mathbf x) = \begin{cases} 1 & \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\ 0 & \text{w/ probability} \:\: 1-\sigma(f(\mathbf x)) \end{cases}\end{split}\]
abstract draw(preds, **kwargs)

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

abstract nll(observations, preds, **kwargs)

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

class FishLeg.fishleg_likelihood.FixedGaussianLikelihood(sigma)

Bases: FishLikelihood

The standard likelihood for regression, but assuming fixed heteroscedastic noise.

\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]
Parameters

sigma_fixed (torch.Tensor) – Known observation standard deviation for each example.

draw(preds: Tensor) Tensor

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

property get_variance
nll(observations: Tensor, preds: Tensor) Tensor

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

class FishLeg.fishleg_likelihood.SoftMaxLikelihood

Bases: FishLikelihood

draw(preds: Tensor) Tensor

Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)

Parameters

preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

nll(observations: Tensor, preds: Tensor) Tensor

Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)

Parameters
  • observations (torch.Tensor) – Values of \(y\).

  • preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)

Return type

torch.Tensor

FishLeg.utils module

FishLeg.utils.recursive_getattr(obj, attr)
FishLeg.utils.recursive_setattr(obj, attr, value)

Module contents