import types
import torch
import torch.nn as nn
import copy, os, random
from typing import List
[docs]
class Chicken(nn.Module):
"""
A Incremental learning class Module.
"""
# optional cache so we don't recreate the same subclass over and over
_cls_cache: dict[type[nn.Module], type[nn.Module]] = {}
# -------- object construction --------
def __new__(cls, model: nn.Module, *args, **kwargs):
if cls is Chicken: # only when user calls Chicken(…)
base = type(model)
# reuse cached subclass if it exists
Wrapped = cls._cls_cache.get(base)
if Wrapped is None:
Wrapped = types.new_class(
f"Conti{base.__name__}", # e.g. ContiVisionTransformer
(Chicken, base), # MRO: Chicken → base model
{}
)
cls._cls_cache[base] = Wrapped
# allocate instance of the *new* subclass
inst = super().__new__(Wrapped)
# copy every weight / buffer / attribute
inst.__dict__.update(model.__dict__)
return inst
# if somebody subclasses Chicken explicitly, honour normal behaviour
return super().__new__(cls)
[docs]
def __init__(
self,
model,
device: str="cpu",
init_val: float=0.1,
max_mult: float=1.0,
matching_texts: List[str]=("layernorm", "bias", "embeddings", "layrnorm", "layer_norm"),
rank=None, # optional truncation
):
"""
Parameters
----------
model: torch.nn.Module, required
device: string, optional
Initial Value (default cpu).
init_val: float, optional
Maximum initial value mask ~ U[0,init_val] (default 0.1).
max_mult: float, optional
Maximum possible value the mask can take [0,max_mult] (default 1.0).
matching_texts: List[str], optional
A list of matching layer names that should not perform the decomposition and reconstruction (default ("layernorm", "bias", "embeddings", "layrnorm", "layer_norm")).
Examples
--------
>>> from transformers import ViTModel
>>> model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
>>> model = Chicken(model, device="cuda", init_val=0.05, max_mult=1.0)
"""
# super().__init__() # DON'T: dynamic subclass already has base attrs
if not isinstance(model, nn.Module):
raise TypeError("model is not an torch.nn.Module")
self.init_val = float(init_val)
self.matching_texts = tuple(matching_texts)
self.device = device
self.max_mult = float(max_mult)
self.rank = rank # None = full SVD
# snapshot of base params (on current device/dtype)
self.base_params = copy.deepcopy(model.state_dict())
# precompute decomposition only for 2D weights we intend to adapt
self.decomposed_params = self.decompose(self.base_params, self.matching_texts, self.rank, self.device)
self.class_policy_map = {}
# register mask containers so they move with .to()
self.learnable_params = nn.ModuleDict() # key: str(mask_idx) -> ParameterDict
self._mask_param_lists = {} # mask_idx (int) -> list[Parameter]
self.num_params = 0
self.enable_mask = []
self.new_mask_idx = 0
self.selected_mask = -1 # no mask selected
# ---------- properties / helpers ----------
@property
def class_map(self):
"""
Returns a string of mask index and the classes associated with it
Returns:
string
Examples
--------
>>> print(model.class_map)
CLASS MAP
1: cat, dog, horse, cow
2: mouse, lion
"""
# inverse map
inverse_map = {}
for name in self.class_policy_map:
mask_idx = self.class_policy_map[name]
if mask_idx not in inverse_map:
inverse_map[mask_idx] = []
inverse_map[mask_idx].append(name)
string = "CLASS MAP\n"
string += "------------------\n"
for mask_idx in inverse_map:
string += f"{mask_idx}: {', '.join(inverse_map[mask_idx])}\n"
string += "------------------\n"
return string
@property
def latest_mask_idx(self):
"""
retruns the latest mask index
Returns
-------
int
"""
return self.new_mask_idx - 1
@staticmethod
def decompose(base_params, skip_match_texts=(), rank=None, device="cpu"):
decomposed_params = {}
for k, v in base_params.items():
if any(text in k for text in skip_match_texts):
continue # skip this param
W = v.detach().to(device=device, dtype=torch.float32)
# U: [m,r], S: [r], Vh: [r,n], r = min(m,n)
U, S, Vh = torch.linalg.svd(W, full_matrices=False)
decomposed_params[f"{k}::U"] = U
decomposed_params[f"{k}::S"] = S
decomposed_params[f"{k}::Vh"] = Vh
return decomposed_params
def add_mask(self):
"""
Call this to add a new mask (creates a new mask vector per decomposed matrix)
"""
mask_params = nn.ParameterDict()
for k, v in self.base_params.items():
if any(text in k for text in self.matching_texts):
continue
S = self.decomposed_params.get(f"{k}::S")
if S is None:
continue
# init small random so sigmoid ≈ 0.5 with small variance
m = nn.Parameter(torch.randn_like(S, dtype=torch.float32, device=self.device) * self.init_val)
mask_params[k.replace('.', '__')] = m
self.num_params += m.numel()
key = str(self.new_mask_idx)
self.learnable_params[key] = mask_params
self._mask_param_lists[self.new_mask_idx] = list(mask_params.parameters())
self.enable_mask.append(True)
self.new_mask_idx += 1
return True
[docs]
def add_class(self, class_names: List[str]):
"""
Call this to add a new set of classes (creates a new mask vector per decomposed matrix)
Parameters
----------
class_names: List[str], required
A list of class names
Returns
-------
bool
True if the classes were added successfully, False otherwise.
Examples
--------
>>> model.add_class(["cat", "dog"])
True
"""
for name in class_names:
self.class_policy_map[name] = self.new_mask_idx
mask_params = nn.ParameterDict()
for k, v in self.base_params.items():
if any(text in k for text in self.matching_texts):
continue
S = self.decomposed_params.get(f"{k}::S")
if S is None:
continue
# init small random so sigmoid ≈ 0.5 with small variance
m = nn.Parameter(torch.randn_like(S, dtype=torch.float32, device=self.device) * self.init_val)
mask_params[k.replace('.', '__')] = m
self.num_params += m.numel()
key = str(self.new_mask_idx)
self.learnable_params[key] = mask_params
self._mask_param_lists[self.new_mask_idx] = list(mask_params.parameters())
self.enable_mask.append(True)
self.new_mask_idx += 1
return True
[docs]
def set_mask(self, mask_idx: int = 0):
"""
Set the selected mask
Parameters
----------
mask_idx: int, optional
Set the selected mask to the mask_idx (default 0)
Returns
-------
boolean
True if selected mask set successfully
"""
if mask_idx == -1: # special: base weights
self.selected_mask = -1
return True
try:
self.enable_mask[mask_idx]
except IndexError:
raise IndexError("the mask number is out of range")
self.selected_mask = mask_idx
return True
[docs]
def get_mask(self, mask_idx: int = -1):
"""
Returns the state dictionary of the the selected mask
Parameters
----------
mask_idx: int, required
The mask index if not sepecified return the last mask (default -1).
Returns
-------
dict
state_dict: a state dict of the selected mask
"""
try:
self.enable_mask[mask_idx]
except IndexError:
raise IndexError("the mask number is out of range")
# set to latest mask if not specified
if mask_num == -1:
self.selected_mask = self.latest_mask_idx
# return ParameterDict for transparency
return self.learnable_params[str(mask_num)]
def get_trainable_parameters(self, mask_idx=None):
if mask_idx is None:
mask_idx = self.selected_mask
if mask_idx == -1:
return [] # nothing to train when using base weights
return self._mask_param_lists[mask_idx]
[docs]
def save_weights(self, path: str):
"""
Save the mask weights to the path
Parameters
----------
path: str, required
location to where the mask should be saved should be .pt file.
"""
payload = {
"learnable_params": {
idx: {n: p.detach().cpu() for n, p in self.learnable_params[idx].items()}
for idx in self.learnable_params.keys()
},
"enable_mask": self.enable_mask,
"new_mask_idx": self.new_mask_idx,
"class_policy_map": self.class_policy_map,
"rank": self.rank,
"matching_texts": self.matching_texts,
"init_val": self.init_val,
"max_mult": self.max_mult,
}
torch.save(payload, path)
[docs]
def load_weights(self, path: str):
"""
Load the mask
Parameters
----------
path: str, required
location to where the .pt for the mask is located.
"""
info = torch.load(path, map_location=self.device)
self.learnable_params = nn.ModuleDict()
self._mask_param_lists.clear()
self.enable_mask = list(info["enable_mask"])
self.new_mask_idx = int(info["new_mask_idx"])
self.class_policy_map = dict(info["class_policy_map"])
for idx, d in info["learnable_params"].items():
pd = nn.ParameterDict({n: nn.Parameter(t.to(self.device)) for n, t in d.items()})
self.learnable_params[idx] = pd
self._mask_param_lists[int(idx)] = list(pd.parameters())
# choose a mask (method unchanged per your request)
self.set_mask()
def activate_mask(self, p, mask_idx):
if mask_idx == -1:
return torch.ones_like(p, dtype=torch.float32)
if not self.enable_mask[mask_idx]:
return torch.ones_like(p, dtype=torch.float32)
return torch.sigmoid(p).to(torch.float32) * self.max_mult
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def compose_new_params(self, param_name, mask_idx):
U = self.decomposed_params[f"{param_name}::U"] # [m,r]
S = self.decomposed_params[f"{param_name}::S"] # [r]
Vh = self.decomposed_params[f"{param_name}::Vh"] # [r,n]
mparam = self.learnable_params[str(mask_idx)][param_name.replace('.', '__')] # [r]
mm = self.activate_mask(mparam, mask_idx) # [r]
S_scaled = S * mm
eps = torch.finfo(S.dtype).eps
scale = (S.sum() / (S_scaled.sum() + eps))
Wp = torch.einsum('mr,r->mr', U, S_scaled)
Wp = torch.einsum('mr,rn->mn', Wp, Vh) * scale
return Wp
[docs]
def toggle_mask(self, mask_value: bool = True, mask_idx: int = None):
"""
turn on or off the mask
Parameters
----------
mask_value: bool, optional
A boolean checking whether the mask should be on or off (default True)
mask_idx: int, optional
If None selected the last mask index (default None)
"""
if mask_idx is None:
mask_idx = self.latest_mask_idx
self.enable_mask[mask_idx] = mask_value
self.apply_policy_to_model(mask_idx)
[docs]
def update_backward(self, mask_idx: int = None):
"""
Backpropagate through the learnable mask parameters using VJP.
Requires that loss.backward() has populated dL/dW on base weights.
Parameters
----------
mask_idx: int, optional
If None use the selected mask from set_mask (default None)
"""
if mask_idx is None:
mask_idx = self.selected_mask
keys = [k for k in self.base_params
if all(text not in k for text in self.matching_texts)
and self.decomposed_params.get(f"{k}::S") is not None]
if not keys:
return
last_key = keys[-1]
for k in keys:
g = self.get_parameter(k).grad
if g is None:
raise RuntimeError(f"No grad for {k}; call set_train() and loss.backward() first.")
self.compose_new_params(k, mask_idx).backward(g, retain_graph=(k is not last_key))
[docs]
def set_train(self, mask_idx: int = None):
"""
Set the learnable parameters to training mode.
Parameters
---------
mask_idx: int, optional
If None use the mask index from set_mask
"""
if mask_idx is None:
mask_idx = self.selected_mask
# 1) freeze everything
for _, p in self.named_parameters():
p.requires_grad_(False)
# 2) enable grads on base weights we compose (so dL/dW is computed)
for k in self.base_params:
if any(s in k for s in self.matching_texts):
continue
if self.decomposed_params.get(f"{k}::S") is None:
continue
p = self.get_parameter(k)
p.requires_grad_(True)
p.retain_grad() # keep grad around for VJP
# 3) ensure masks are trainable
for p in self.get_trainable_parameters(mask_idx):
p.requires_grad_(True)
[docs]
def apply_policy_to_model(self, mask_idx: int = None):
"""
Compose & write weights into the live model (fast in-place copy).
Parameters
----------
mask_idx: int, required
index of the mask that should be applied to the model if None will choose based on set_mask or latest mask
Examples
--------
>>> model.apply_policy_to_model(1)
"""
if mask_idx is None:
mask_idx = self.selected_mask
with torch.no_grad():
for k, base in self.base_params.items():
param = self.get_parameter(k)
if any(skip in k for skip in self.matching_texts):
param.copy_(base.to(param.dtype).to(self.device))
continue
if self.decomposed_params.get(f"{k}::S") is None:
param.copy_(base.to(param.dtype).to(self.device))
continue
Wp = self.compose_new_params(k, mask_idx).to(param.dtype).to(self.device)
param.copy_(Wp)