FishLeg package
Submodules
FishLeg.fishleg module
- class FishLeg.fishleg.FishLeg(model: Module, draw: Callable[[Module, Tensor], Tuple[Tensor, Tensor]], nll: Callable[[Module, Tuple[Tensor, Tensor]], Tensor], dataloader: Callable[[], Tuple[Tensor, Tensor]], lr: float = 0.01, eps: float = 0.0001, weight_decay: float = 1e-05, beta: float = 0.9, update_aux_every: int = -3, aux_lr: float = 0.001, aux_betas: Tuple[float, float] = (0.9, 0.999), aux_eps: float = 1e-08, damping: float = 1e-05, pre_aux_training: int = 10, differentiable: bool = False, sgd_lr: float = 0.01)
Bases:
Optimizer
Implement FishLeg algorithm.
- Parameters
model (torch.nn.Module) – a pytorch neural network module, can be nested in a tree structure
draw (Callable[[nn.Module, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]) – Sampling function that takes a model \(f\) and input data \(\mathbf X\), and returns \((\mathbf X, \mathbf y)\), where \(\mathbf y\) is sampled from the conditional distribution \(p(\mathbf y|f(\mathbf X))\)
nll (Callable[[nn.Module, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]) – A function that takes a model and data, and evaluate the negative log-likelihood.
dataloader (Callable[[int], Tuple[torch.Tensor, torch.Tensor]]) – A function that takes a batch size as input and output dataset with corresponding size.
lr (float) – learning rate, for the parameters of the input model using FishLeg (default: 1e-2)
eps (float) – a small scalar, to evaluate the auxiliary loss in the direction of gradient of model parameters (default: 1e-4)
update_aux_every (int) – number of iteration after which an auxiliary update is executed, if negative, then run -update_aux_every auxiliary updates in each outer iteration. (default: -3)
aux_lr (float) – learning rate for the auxiliary parameters, using Adam (default: 1e-3)
aux_betas (Tuple[float, float]) – coefficients used for computing running averages of gradient and its square for auxiliary parameters (default: (0.9, 0.999))
aux_eps (float) – term added to the denominator to improve numerical stability for auxiliary parameters (default: 1e-8)
pre_aux_training (int) – number of auxiliary updates to make before any update of the original parameter. This process intends to approximate the correct Fisher Information matrix during initialization, which is espectially important for fine-tuning of models with pretraining
differentiable (bool) – whether the fused implementation (CUDA only) is used
sgd_lr (float) –
help specify initial scale of the inverse Fisher Information matrix approximation, \(\eta\). Make sure that
\[- \eta_{init} Q(\lambda) grad = - \eta_{sgd} grad\]is hold in the beginning of the optimization. And here \(\eta_{init}=\eta_{sgd}/\eta_{fl}\).
- Example:
>>> auxloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100) >>> trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=100) >>> >>> likelihood = FixedGaussianLikelihood(sigma_fixed=1.0) >>> >>> def nll(model, data): >>> data_x, data_y = data >>> pred_y = model.forward(data_x) >>> return likelihood.nll(data_y, pred_y) >>> >>> def draw(model, data_x): >>> pred_y = model.forward(data_x) >>> return (data_x, likelihood.draw(pred_y)) >>> >>> def dataloader(): >>> data_x, _ = next(iter(auxloader)) >>> return data_x >>> >>> model = nn.Sequential( >>> nn.Linear(2, 5), >>> nn.ReLU(), >>> nn.Linear(5, 1), >>> ) >>> >>> opt = FishLeg( >>> model, >>> draw, >>> nll, >>> dataloader >>> ) >>> >>> for iteration in range(100): >>> data_x, data_y = next(iter(trainloader)) >>> opt.zero_grad() >>> pred_y = model(data_x) >>> loss = nn.MSELoss()(data_y, pred_y) >>> loss.backward() >>> opt.step() >>> if iteration % 10 == 0: >>> print(loss.detach())
- init_model_aux(model: Module) Module
Given a model to optimize, parameters can be devided to
those fixed as pre-trained.
those required to optimize using FishLeg.
Replace modules in the second group with FishLeg modules.
- Args:
- model (
torch.nn.Module
, required): A model containing modules to replace with FishLeg modules containing extra functionality related to FishLeg algorithm.
- model (
- Returns:
torch.nn.Module
, the replaced model.
- step() None
Performes a single optimization step of FishLeg.
- update_aux() None
Performs a single auxliarary parameter update using Adam. By minimizing the following objective:
\[nll(model, \theta + \epsilon Q(\lambda)g) + nll(model, \theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g\]where \(\theta\) is the parameters of model, \(\lambda\) is the auxliarary parameters.
- FishLeg.fishleg.update_dict(replace: Module, module: Module) Module
FishLeg.fishleg_layers module
- class FishLeg.fishleg_layers.FishLinear(in_features: int, out_features: int, bias: bool = True, init_scale: float = 1.0, device=None, dtype=None)
Bases:
Linear
,FishModule
- Qv(v: Tuple[Tensor, Tensor]) Tuple[Tensor, Tensor]
For fully-connected layers, the default structure of \(Q\) as a block-diaglonal matrix is,
\[Q_l = (R_lR_l^T \otimes L_lL_l^T)\]where \(l\) denotes the l-th layer. The matrix \(R_l\) has size \((N_{l-1} + 1) \times (N_{l-1} + 1)\) while the matrix \(L_l\) has size \(N_l \times N_l\). The auxiliarary parameters \(\lambda\) are represented by the matrices \(L_l, R_l\).
- in_features: int
- out_features: int
- weight: Tensor
- class FishLeg.fishleg_layers.FishModule(*args: Any, **kwargs: Any)
Bases:
Module
Base class for all neural network modules in FishLeg to
Initialize auxiliary parameters, \(\lambda\) and its forms, \(Q(\lambda)\).
Specify quick calculation of \(Q(\lambda)v\) products.
- Parameters
fishleg_aux (torch.nn.ParameterDict) –
auxiliary parameters with their initialization, including an additional parameter, scale, \(\eta\). Make sure that
\[- \eta_{init} Q(\lambda) grad = - \eta_{sgd} grad\]is hold in the beginning of the optimization
order (List) – specify a name order of original parameter
- abstract Qv(aux: Dict, v: Tuple[Tensor, ...]) Tuple[Tensor, ...]
\(Q(\lambda)\) is a positive definite matrix which will effectively estimate the inverse damped Fisher Information Matrix. Appropriate choices for \(Q\) should take into account the architecture of the model/module. It is usually parameterized as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of the neural networks.
- Args:
- aux: (Dict, required): auxiliary parameters,
\(\lambda\), a dictionary with keys, the name of the auxiliary parameters, and values, the auxiliary parameters of the module. These auxiliaray parameters will form \(Q(\lambda)\).
- v: (Tuple[Tensor, …], required): Values of the original parameters,
in an order that align with self.order, to multiply with \(Q(\lambda)\).
- Returns:
- Tuple[Tensor, …]: The calculated \(Q(\lambda)v\) products,
in same order with self.order.
- cuda(device) None
Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.
Note
This method modifies the module in-place.
- Args:
- device (int, optional): if specified, all parameters will be
copied to that device
- Returns:
Module: self
- property name: str
- training: bool
- FishLeg.fishleg_layers.get_zero_grad_hook(mask)
FishLeg.fishleg_likelihood module
- class FishLeg.fishleg_likelihood.BernoulliLikelihood
Bases:
FishLikelihood
The Bernoulli likelihood used for classification. Using the standard Normal CDF \(\Phi(x)\)) and the identity \(\Phi(-x) = 1-\Phi(x)\), we can write the likelihood as:
\[p(y|f(x))=\Phi(yf(x))\]- draw(preds: Tensor) Tensor
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- nll(observations: Tensor, preds: Tensor) Tensor
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
- class FishLeg.fishleg_likelihood.FishLikelihood
Bases:
object
A Likelihood in FishLeg specifies a probablistic modeling, which attributes the mapping from latent function values \(f(\mathbf X)\) to observed labels \(y\).
For example, in the case of regression, a Gaussian likelihood can be chosen, as
\[y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I)\]As for the case of classification, a Bernoulli distribution can be chosen
\[\begin{split}y(\mathbf x) = \begin{cases} 1 & \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\ 0 & \text{w/ probability} \:\: 1-\sigma(f(\mathbf x)) \end{cases}\end{split}\]- abstract draw(preds, **kwargs)
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- abstract nll(observations, preds, **kwargs)
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
- class FishLeg.fishleg_likelihood.FixedGaussianLikelihood(sigma)
Bases:
FishLikelihood
The standard likelihood for regression, but assuming fixed heteroscedastic noise.
\[p(y | f(x)) = f(x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2})\]- Parameters
sigma_fixed (torch.Tensor) – Known observation standard deviation for each example.
- draw(preds: Tensor) Tensor
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- property get_variance
- nll(observations: Tensor, preds: Tensor) Tensor
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
- class FishLeg.fishleg_likelihood.SoftMaxLikelihood
Bases:
FishLikelihood
- draw(preds: Tensor) Tensor
Draw samples from the conditional distribution \(p(\mathbf y|f(\mathbf x))\)
- Parameters
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- nll(observations: Tensor, preds: Tensor) Tensor
Computes the negative log-likelihood \(\ell(\theta, \mathcal D)=-\log p(\mathbf y|f(\mathbf x))\)
- Parameters
observations (torch.Tensor) – Values of \(y\).
preds (torch.Tensor) – Predictions from model \(f(\mathbf x)\)
- Return type
torch.Tensor
FishLeg.utils module
- FishLeg.utils.recursive_getattr(obj, attr)
- FishLeg.utils.recursive_setattr(obj, attr, value)