Source code for FishLeg.fishleg_layers

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import ParameterDict, Parameter

from typing import Any, List, Dict, Tuple


[docs]class FishModel:
[docs] def nll(self, data): data_x, data_y = data pred_y = self.forward(data_x) return self.likelihood.nll(None, pred_y, data_y)
[docs] def sample(self, K): data_x = self.data[0][np.random.randint(0, self.N, K)] pred_y = self.forward(data_x) return (data_x, self.likelihood.sample(None, pred_y))
[docs]class FishModule(nn.Module): """Base class for all neural network modules in FishLeg to #. Initialize auxiliary parameters, :math: `\lambda` and its forms, :math: `Q(\lambda)`. #. Specify quick calculation of :math: `Q(\lambda)v` products. :param torch.nn.ParameterDict fishleg_aux: auxiliary parameters with their initialization, including an additional parameter, scale, :math:`\eta`. Make sure that .. math:: - \eta Q(\lambda) grad = - \eta_{adam} grad is hold in the beginning of the optimization :param List order: specify a name order of original parameter """ def __init__(self, *args: Any, **kwargs: Any) -> None: super(FishModule, self).__init__(*args, **kwargs) self.__setattr__('fishleg_aux', ParameterDict()) self.__setattr__('order', List)
[docs] @staticmethod def Qv(aux: Dict, v: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]: """ :math:`Q(\lambda)` is a positive definite matrix which will effectively estimate the inverse damped Fisher Information Matrix. Appropriate choices for :math:`Q` should take into account the architecture of the model/module. It is usually parameterized as a positive definite Kronecker-factored block-diagonal matrix, with block sizes reflecting the layer structure of the neural networks. Args: aux: (Dict, required): auxiliary parameters, :math:`\lambda`, a dictionary with keys, the name of the auxiliary parameters, and values, the auxiliary parameters of the module. These auxiliaray parameters will form :math:`Q(\lambda)`. v: (Tuple[Tensor, ...], required): Values of the original parameters, in an order that align with `self.order`, to multiply with :math:`Q(\lambda)`. Returns: Tuple[Tensor, ...]: The calculated :math:`Q(\lambda)v` products, in same order with `self.order`. """ raise NotImplementedError(f"Module is missing the required \"Qv\" function")
[docs]class FishLinear(nn.Linear, FishModule): def __init__( self, in_features: int, out_features: int, bias: bool = True, device = None, dtype = None, ) -> None: super(FishLinear, self).__init__( in_features, out_features, bias, device=device, dtype=dtype ) self._layer_name = "Linear" self.fishleg_aux = ParameterDict( { "scale": Parameter(torch.ones(size=(1,))), "L": Parameter(torch.eye(in_features + 1)), "R": Parameter(torch.eye(out_features)), } ) self.order = ["weight", "bias"] @property def name(self) -> str: return self._layer_name
[docs] @staticmethod def Qv(aux: Dict, v: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: '''For fully-connected layers, the default structure of :math:`Q` as a block-diaglonal matrix is, .. math:: Q_l = (R_lR_l^T \otimes L_lL_l^T) where :math:`l` denotes the l-th layer. The matrix :math:`R_l` has size :math:`(N_l-1 + 1) \\times (N_l-1 + 1)` while the matrix :math:`L_l` has size :math:`N_l \\times N_l`. The auxiliarary parameters :math:`\lambda` are represented by the matrices :math:`L_l, R_l`. ''' L, R = aux["fishleg_aux.L"], aux["fishleg_aux.R"] u = torch.cat([v[0], v[1][:, None]], dim=-1) z = aux["fishleg_aux.scale"] * torch.linalg.multi_dot((R, R.T, u, L, L.T)) return (z[:, :-1], z[:, -1])
[docs] def cuda(self, device) -> None: super.cuda(device) for p in self.fishleg_aux.values: p.to(device)
FISH_LAYERS = { "linear": FishLinear } # Perhaps this would be better constructed inside the __init__.py file?