from typing import Optional, Tuple
import torch
import torch.nn as nn
import copy
from torch.optim import Optimizer, Adam
from .fishleg_layers import FISH_LAYERS
[docs]class FishLeg(Optimizer):
"""Implement FishLeg algorithm.
:param torch.nn.Module model: a pytorch neural network module,
can be nested in a tree structure
:param float lr: learning rate,
for the parameters of the input model using FishLeg (default: 1e-2)
:param float eps: a small scalar, to evaluate the auxiliary loss
in the direction of gradient of model parameters (default: 1e-4)
:param int aux_K: number of sample to evaluate the entropy (default: 5)
:param int update_aux_every: number of iteration after which an auxiliary
update is executed, if negative, then run -update_aux_every auxiliary
updates in each outer iteration. (default: -3)
:param float aux_lr: learning rate for the auxiliary parameters,
using Adam (default: 1e-3)
:param Tuple[float, float] aux_betas: coefficients used for computing
running averages of gradient and its square for auxiliary parameters
(default: (0.9, 0.999))
:param float aux_eps: term added to the denominator to improve
numerical stability for auxiliary parameters (default: 1e-8)
"""
def __init__(
self,
model: nn.Module,
lr: float = 1e-2,
eps: float = 1e-4,
aux_K: int = 5,
update_aux_every: int = -3,
aux_lr: float = 1e-3,
aux_betas: Tuple[float, float] = (0.9, 0.999),
aux_eps: float = 1e-8
) -> None:
self.model = model
self.plus_model = copy.deepcopy(self.model)
self.minus_model = copy.deepcopy(self.model)
self.model = self.init_model_aux(model)
# partition by modules
self.aux_param = [
param
for name, param in self.model.named_parameters()
if "fishleg_aux" in name
]
param_groups = []
for module_name, module in self.model.named_modules():
if hasattr(module, "fishleg_aux"):
params = {
name: param
for name, param in self.model._modules[
module_name
].named_parameters()
if "fishleg_aux" not in name
}
g = {
"params": [params[name] for name in module.order],
"aux_params": {
name: param
for name, param in module.named_parameters()
if "fishleg_aux" in name
},
"Qv": module.Qv,
"order": module.order,
"name": module_name,
}
param_groups.append(g)
# TODO: add param_group for modules without aux
defaults = dict(lr=lr)
super(FishLeg, self).__init__(param_groups, defaults)
self.aux_opt = Adam(self.aux_param, lr=aux_lr, betas=aux_betas, eps=aux_eps)
self.eps = eps
self.aux_K = aux_K
self.update_aux_every = update_aux_every
self.aux_lr = aux_lr
self.aux_betas = aux_betas
self.aux_eps = aux_eps
self.step_t = 0
[docs] def init_model_aux(self, model: nn.Module) -> nn.Module:
"""Given a model to optimize, parameters can be devided to
#. those fixed as pre-trained.
#. those required to optimize using FishLeg.
Replace modules in the second group with FishLeg modules.
Args:
model (:class:`torch.nn.Module`, required):
A model containing modules to replace with FishLeg modules
containing extra functionality related to FishLeg algorithm.
Returns:
:class:`torch.nn.Module`, the replaced model.
"""
for name, module in model.named_modules():
try:
replace = FISH_LAYERS[type(module).__name__.lower()](
module.in_features,
module.out_features,
module.bias is not None
)
replace = update_dict(replace, module)
model._modules[name] = replace
except KeyError:
pass
# TODO: The above may not be a very "correct" way to do this, so please feel free to change, for example, we may want to check the name is in the fish_layer keys before attempting what is in the try statement.
# TODO: Error checking to check that model includes some auxiliary arguments.
return model
[docs] def update_aux(self) -> None:
"""Performs a single auxliarary parameter update
using Adam. By minimizing the following objective:
.. math::
nll(model, \\theta + \epsilon Q(\lambda)g) + nll(model, \\theta - \epsilon Q(\lambda)g) - 2\epsilon^2g^T Q(\lambda)g
where :math:`\\theta` is the parameters of model, :math:`\lambda` is the
auxliarary parameters.
"""
self.aux_opt.zero_grad()
with torch.no_grad():
data = self.model.sample(self.aux_K)
aux_loss = 0.0
for group in self.param_groups:
name = group["name"]
grad = [p.grad.data for p in group["params"]]
qg = group["Qv"](group["aux_params"], grad)
for g, d_p, para_name in zip(grad, qg, group["order"]):
param_plus = self.plus_model._modules[name]._parameters[para_name]
param_plus = param_plus.detach()
param_minus = self.minus_model._modules[name]._parameters[para_name]
param_minus = param_minus.detach()
param_plus.add_(d_p, alpha=self.eps)
param_minus.add_(d_p, alpha=-self.eps)
aux_loss -= 2 * torch.sum(g * d_p)
h_plus = self.plus_model.nll(data)
h_minus = self.minus_model.nll(data)
aux_loss += (h_plus + h_minus) / (self.eps**2)
aux_loss.backward()
self.aux_opt.step()
for group in self.param_groups:
for p, para_name in zip(group["params"], group["order"]):
self.plus_model._modules[name]._parameters[para_name].data = p.data
self.minus_model._modules[name]._parameters[para_name].data = p.data
[docs] def step(self) -> None:
"""Performes a single optimization step of FishLeg.
"""
self.step_t += 1
if self.update_aux_every > 0:
if self.step_t % self.update_aux_every == 0:
self.update_aux()
elif self.update_aux_every < 0:
for _ in range(-self.update_aux_every):
self.update_aux()
for group in self.param_groups:
lr = group["lr"]
order = group["order"]
name = group["name"]
if "aux_params" in group.keys():
grad = grad = [p.grad.data for p in group["params"]]
qg = group["Qv"](group["aux_params"], grad)
for p, d_p, para_name in zip(group["params"], qg, order):
p.data.add_(d_p, alpha=-lr)
self.plus_model._modules[name]._parameters[para_name].data = p.data
self.minus_model._modules[name]._parameters[para_name].data = p.data
[docs]def update_dict(replace: nn.Module, module: nn.Module) -> nn.Module:
replace_dict = replace.state_dict()
pretrained_dict = {
k: v for k, v in module.state_dict().items() if k in replace_dict
}
replace_dict.update(pretrained_dict)
replace.load_state_dict(replace_dict)
return replace