Module bitsandbytes.optim.adam
Expand source code
import torch
from torch.optim import Optimizer
from bitsandbytes.optim.optimizer import Optimizer8bit, MockArgs
import bitsandbytes.functional as F
class Adam(Optimizer8bit):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, optim_bits=32, is_sparse=False, args=None, override_with_args=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, is_sparse=is_sparse)
super(Adam, self).__init__(params, defaults)
if args is None:
args = {}
args['optim_bits'] = optim_bits
args['adam8bits_offset'] = 1/512
args['percentile_clipping'] = 100
args['is_sparse'] = is_sparse
self.args = MockArgs(args)
else:
self.args = args
self.keep_32_bit = set()
def set_state_bits(self, model, keep32type=[torch.nn.Embedding], keep32smaller=4096):
for module, p in model.named_modules():
if any([isinstance(module, t) for t in keep32type]):
for p2 in module.parameters():
self.keep_32_bit.add(p2.data.storage().data_ptr())
if p.numel() < keep32smaller:
self.keep_32_bit.add(p.data.storage().data_ptr())
@torch.no_grad()
def init_state(self, group, p_id, p):
if self.args.optim_bits == 32:
dtype = torch.float32
elif self.args.optim_bits == 8:
dtype = torch.uint8
else: raise NotImplementedError('Amount of Adam bits not supported')
state = self.state[p]
state['step'] = 0
if p.numel() % 4 != 0:
raise ValueError(f'Parameter tensors need to have a multiple of 4: {p.shape}')
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
elif dtype == torch.uint8:
state['qtbl1'] = torch.zeros((256,), dtype=torch.float32, device=p.device)
state['qtbl2'] = torch.zeros((256,), dtype=torch.float32, device=p.device)
state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
if self.args.percentile_clipping < 100:
state['gnorm_vec'] = torch.zeros((100,), device=p.device)
def get_config(self, p, group):
config = {}
config['betas'] = group['betas']
config['eps'] = group['eps']
config['weight_decay'] = group['weight_decay']
config['lr'] = group['lr']
config['is_sparse'] = self.args.is_sparse
if id(p) in self.mng.p2config:
config.update(self.mng.p2config[id(p)])
return config
@torch.no_grad()
def update_step(self, group, p_id, p):
state = self.state[p]
grad = p.grad
config = self.get_config(p, group)
state['step'] += 1
step = state['step']
F.adam_update(grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
config['eps'], config['weight_decay'], step, config['lr'],
is_sparse=config['is_sparse'])
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
overflows = []
for group in self.param_groups:
for p_id, p in enumerate(group['params']):
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self.init_state(group, p_id, p)
self.update_step(group, p_id, p)
return loss
class Adam32bit(Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, args=None, override_with_args=False):
super(Adam32bit, self).__init__(params, lr, betas, eps, weight_decay, amsgrad, args, override_with_args)
self.args.optim_bits = 32
class Adam8bit(Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, args=None, override_with_args=False):
super(Adam32bit, self).__init__(params, lr, betas, eps, weight_decay, amsgrad, args, override_with_args)
self.args.optim_bits = 8
Classes
class Adam (params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, optim_bits=32, is_sparse=False, args=None, override_with_args=False)
-
Base class for all optimizers.
Warning
Parameters need to be specified as collections that have a deterministic ordering that is consistent between runs. Examples of objects that don't satisfy those properties are sets and iterators over values of dictionaries.
Args
params
:iterable
- an iterable of :class:
torch.Tensor
s or :class:dict
s. Specifies what Tensors should be optimized. defaults
- (dict): a dict containing default values of optimization options (used when a parameter group doesn't specify them).
Expand source code
class Adam(Optimizer8bit): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, is_sparse=False, args=None, override_with_args=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, is_sparse=is_sparse) super(Adam, self).__init__(params, defaults) if args is None: args = {} args['optim_bits'] = optim_bits args['adam8bits_offset'] = 1/512 args['percentile_clipping'] = 100 args['is_sparse'] = is_sparse self.args = MockArgs(args) else: self.args = args self.keep_32_bit = set() def set_state_bits(self, model, keep32type=[torch.nn.Embedding], keep32smaller=4096): for module, p in model.named_modules(): if any([isinstance(module, t) for t in keep32type]): for p2 in module.parameters(): self.keep_32_bit.add(p2.data.storage().data_ptr()) if p.numel() < keep32smaller: self.keep_32_bit.add(p.data.storage().data_ptr()) @torch.no_grad() def init_state(self, group, p_id, p): if self.args.optim_bits == 32: dtype = torch.float32 elif self.args.optim_bits == 8: dtype = torch.uint8 else: raise NotImplementedError('Amount of Adam bits not supported') state = self.state[p] state['step'] = 0 if p.numel() % 4 != 0: raise ValueError(f'Parameter tensors need to have a multiple of 4: {p.shape}') if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) elif dtype == torch.uint8: state['qtbl1'] = torch.zeros((256,), dtype=torch.float32, device=p.device) state['qtbl2'] = torch.zeros((256,), dtype=torch.float32, device=p.device) state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) if self.args.percentile_clipping < 100: state['gnorm_vec'] = torch.zeros((100,), device=p.device) def get_config(self, p, group): config = {} config['betas'] = group['betas'] config['eps'] = group['eps'] config['weight_decay'] = group['weight_decay'] config['lr'] = group['lr'] config['is_sparse'] = self.args.is_sparse if id(p) in self.mng.p2config: config.update(self.mng.p2config[id(p)]) return config @torch.no_grad() def update_step(self, group, p_id, p): state = self.state[p] grad = p.grad config = self.get_config(p, group) state['step'] += 1 step = state['step'] F.adam_update(grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], config['eps'], config['weight_decay'], step, config['lr'], is_sparse=config['is_sparse']) @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() overflows = [] for group in self.param_groups: for p_id, p in enumerate(group['params']): if p.grad is None: continue state = self.state[p] if len(state) == 0: self.init_state(group, p_id, p) self.update_step(group, p_id, p) return loss
Ancestors
- bitsandbytes.optim.optimizer.Optimizer8bit
- torch.optim.optimizer.Optimizer
Subclasses
Methods
def get_config(self, p, group)
-
Expand source code
def get_config(self, p, group): config = {} config['betas'] = group['betas'] config['eps'] = group['eps'] config['weight_decay'] = group['weight_decay'] config['lr'] = group['lr'] config['is_sparse'] = self.args.is_sparse if id(p) in self.mng.p2config: config.update(self.mng.p2config[id(p)]) return config
def init_state(self, group, p_id, p)
-
Expand source code
@torch.no_grad() def init_state(self, group, p_id, p): if self.args.optim_bits == 32: dtype = torch.float32 elif self.args.optim_bits == 8: dtype = torch.uint8 else: raise NotImplementedError('Amount of Adam bits not supported') state = self.state[p] state['step'] = 0 if p.numel() % 4 != 0: raise ValueError(f'Parameter tensors need to have a multiple of 4: {p.shape}') if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) elif dtype == torch.uint8: state['qtbl1'] = torch.zeros((256,), dtype=torch.float32, device=p.device) state['qtbl2'] = torch.zeros((256,), dtype=torch.float32, device=p.device) state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) if self.args.percentile_clipping < 100: state['gnorm_vec'] = torch.zeros((100,), device=p.device)
def set_state_bits(self, model, keep32type=[<class 'torch.nn.modules.sparse.Embedding'>], keep32smaller=4096)
-
Expand source code
def set_state_bits(self, model, keep32type=[torch.nn.Embedding], keep32smaller=4096): for module, p in model.named_modules(): if any([isinstance(module, t) for t in keep32type]): for p2 in module.parameters(): self.keep_32_bit.add(p2.data.storage().data_ptr()) if p.numel() < keep32smaller: self.keep_32_bit.add(p.data.storage().data_ptr())
def step(self, closure=None)
-
Performs a single optimization step.
Arguments
closure (callable, optional): A closure that reevaluates the model and returns the loss.
Expand source code
@torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() overflows = [] for group in self.param_groups: for p_id, p in enumerate(group['params']): if p.grad is None: continue state = self.state[p] if len(state) == 0: self.init_state(group, p_id, p) self.update_step(group, p_id, p) return loss
def update_step(self, group, p_id, p)
-
Expand source code
@torch.no_grad() def update_step(self, group, p_id, p): state = self.state[p] grad = p.grad config = self.get_config(p, group) state['step'] += 1 step = state['step'] F.adam_update(grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], config['eps'], config['weight_decay'], step, config['lr'], is_sparse=config['is_sparse'])
class Adam32bit (params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, args=None, override_with_args=False)
-
Base class for all optimizers.
Warning
Parameters need to be specified as collections that have a deterministic ordering that is consistent between runs. Examples of objects that don't satisfy those properties are sets and iterators over values of dictionaries.
Args
params
:iterable
- an iterable of :class:
torch.Tensor
s or :class:dict
s. Specifies what Tensors should be optimized. defaults
- (dict): a dict containing default values of optimization options (used when a parameter group doesn't specify them).
Expand source code
class Adam32bit(Adam): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, args=None, override_with_args=False): super(Adam32bit, self).__init__(params, lr, betas, eps, weight_decay, amsgrad, args, override_with_args) self.args.optim_bits = 32
Ancestors
- Adam
- bitsandbytes.optim.optimizer.Optimizer8bit
- torch.optim.optimizer.Optimizer
Inherited members
class Adam8bit (params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, args=None, override_with_args=False)
-
Base class for all optimizers.
Warning
Parameters need to be specified as collections that have a deterministic ordering that is consistent between runs. Examples of objects that don't satisfy those properties are sets and iterators over values of dictionaries.
Args
params
:iterable
- an iterable of :class:
torch.Tensor
s or :class:dict
s. Specifies what Tensors should be optimized. defaults
- (dict): a dict containing default values of optimization options (used when a parameter group doesn't specify them).
Expand source code
class Adam8bit(Adam): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, args=None, override_with_args=False): super(Adam32bit, self).__init__(params, lr, betas, eps, weight_decay, amsgrad, args, override_with_args) self.args.optim_bits = 8
Ancestors
- Adam
- bitsandbytes.optim.optimizer.Optimizer8bit
- torch.optim.optimizer.Optimizer
Inherited members