Source code for chicken.model

import types
import torch
import torch.nn as nn
import copy, os, random

from typing import List

[docs] class Chicken(nn.Module): """ A Incremental learning class Module. """ # optional cache so we don't recreate the same subclass over and over _cls_cache: dict[type[nn.Module], type[nn.Module]] = {} # -------- object construction -------- def __new__(cls, model: nn.Module, *args, **kwargs): if cls is Chicken: # only when user calls Chicken(…) base = type(model) # reuse cached subclass if it exists Wrapped = cls._cls_cache.get(base) if Wrapped is None: Wrapped = types.new_class( f"Conti{base.__name__}", # e.g. ContiVisionTransformer (Chicken, base), # MRO: Chicken → base model {} ) cls._cls_cache[base] = Wrapped # allocate instance of the *new* subclass inst = super().__new__(Wrapped) # copy every weight / buffer / attribute inst.__dict__.update(model.__dict__) return inst # if somebody subclasses Chicken explicitly, honour normal behaviour return super().__new__(cls)
[docs] def __init__( self, model, device: str="cpu", init_val: float=0.1, max_mult: float=1.0, matching_texts: List[str]=("layernorm", "bias", "embeddings", "layrnorm", "layer_norm"), rank=None, # optional truncation ): """ Parameters ---------- model: torch.nn.Module, required device: string, optional Initial Value (default cpu). init_val: float, optional Maximum initial value mask ~ U[0,init_val] (default 0.1). max_mult: float, optional Maximum possible value the mask can take [0,max_mult] (default 1.0). matching_texts: List[str], optional A list of matching layer names that should not perform the decomposition and reconstruction (default ("layernorm", "bias", "embeddings", "layrnorm", "layer_norm")). Examples -------- >>> from transformers import ViTModel >>> model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') >>> model = Chicken(model, device="cuda", init_val=0.05, max_mult=1.0) """ # super().__init__() # DON'T: dynamic subclass already has base attrs if not isinstance(model, nn.Module): raise TypeError("model is not an torch.nn.Module") self.init_val = float(init_val) self.matching_texts = tuple(matching_texts) self.device = device self.max_mult = float(max_mult) self.rank = rank # None = full SVD # snapshot of base params (on current device/dtype) self.base_params = copy.deepcopy(model.state_dict()) # precompute decomposition only for 2D weights we intend to adapt self.decomposed_params = self.decompose(self.base_params, self.matching_texts, self.rank, self.device) self.class_policy_map = {} # register mask containers so they move with .to() self.learnable_params = nn.ModuleDict() # key: str(mask_idx) -> ParameterDict self._mask_param_lists = {} # mask_idx (int) -> list[Parameter] self.num_params = 0 self.enable_mask = [] self.new_mask_idx = 0 self.selected_mask = -1 # no mask selected
# ---------- properties / helpers ---------- @property def class_map(self): """ Returns a string of mask index and the classes associated with it Returns: string Examples -------- >>> print(model.class_map) CLASS MAP 1: cat, dog, horse, cow 2: mouse, lion """ # inverse map inverse_map = {} for name in self.class_policy_map: mask_idx = self.class_policy_map[name] if mask_idx not in inverse_map: inverse_map[mask_idx] = [] inverse_map[mask_idx].append(name) string = "CLASS MAP\n" string += "------------------\n" for mask_idx in inverse_map: string += f"{mask_idx}: {', '.join(inverse_map[mask_idx])}\n" string += "------------------\n" return string @property def latest_mask_idx(self): """ retruns the latest mask index Returns ------- int """ return self.new_mask_idx - 1 @staticmethod def decompose(base_params, skip_match_texts=(), rank=None, device="cpu"): decomposed_params = {} for k, v in base_params.items(): if any(text in k for text in skip_match_texts): continue # skip this param W = v.detach().to(device=device, dtype=torch.float32) # U: [m,r], S: [r], Vh: [r,n], r = min(m,n) U, S, Vh = torch.linalg.svd(W, full_matrices=False) decomposed_params[f"{k}::U"] = U decomposed_params[f"{k}::S"] = S decomposed_params[f"{k}::Vh"] = Vh return decomposed_params def add_mask(self): """ Call this to add a new mask (creates a new mask vector per decomposed matrix) """ mask_params = nn.ParameterDict() for k, v in self.base_params.items(): if any(text in k for text in self.matching_texts): continue S = self.decomposed_params.get(f"{k}::S") if S is None: continue # init small random so sigmoid ≈ 0.5 with small variance m = nn.Parameter(torch.randn_like(S, dtype=torch.float32, device=self.device) * self.init_val) mask_params[k.replace('.', '__')] = m self.num_params += m.numel() key = str(self.new_mask_idx) self.learnable_params[key] = mask_params self._mask_param_lists[self.new_mask_idx] = list(mask_params.parameters()) self.enable_mask.append(True) self.new_mask_idx += 1 return True
[docs] def add_class(self, class_names: List[str]): """ Call this to add a new set of classes (creates a new mask vector per decomposed matrix) Parameters ---------- class_names: List[str], required A list of class names Returns ------- bool True if the classes were added successfully, False otherwise. Examples -------- >>> model.add_class(["cat", "dog"]) True """ for name in class_names: self.class_policy_map[name] = self.new_mask_idx mask_params = nn.ParameterDict() for k, v in self.base_params.items(): if any(text in k for text in self.matching_texts): continue S = self.decomposed_params.get(f"{k}::S") if S is None: continue # init small random so sigmoid ≈ 0.5 with small variance m = nn.Parameter(torch.randn_like(S, dtype=torch.float32, device=self.device) * self.init_val) mask_params[k.replace('.', '__')] = m self.num_params += m.numel() key = str(self.new_mask_idx) self.learnable_params[key] = mask_params self._mask_param_lists[self.new_mask_idx] = list(mask_params.parameters()) self.enable_mask.append(True) self.new_mask_idx += 1 return True
[docs] def set_mask(self, mask_idx: int = 0): """ Set the selected mask Parameters ---------- mask_idx: int, optional Set the selected mask to the mask_idx (default 0) Returns ------- boolean True if selected mask set successfully """ if mask_idx == -1: # special: base weights self.selected_mask = -1 return True try: self.enable_mask[mask_idx] except IndexError: raise IndexError("the mask number is out of range") self.selected_mask = mask_idx return True
[docs] def get_mask(self, mask_idx: int = -1): """ Returns the state dictionary of the the selected mask Parameters ---------- mask_idx: int, required The mask index if not sepecified return the last mask (default -1). Returns ------- dict state_dict: a state dict of the selected mask """ try: self.enable_mask[mask_idx] except IndexError: raise IndexError("the mask number is out of range") # set to latest mask if not specified if mask_num == -1: self.selected_mask = self.latest_mask_idx # return ParameterDict for transparency return self.learnable_params[str(mask_num)]
def get_trainable_parameters(self, mask_idx=None): if mask_idx is None: mask_idx = self.selected_mask if mask_idx == -1: return [] # nothing to train when using base weights return self._mask_param_lists[mask_idx]
[docs] def save_weights(self, path: str): """ Save the mask weights to the path Parameters ---------- path: str, required location to where the mask should be saved should be .pt file. """ payload = { "learnable_params": { idx: {n: p.detach().cpu() for n, p in self.learnable_params[idx].items()} for idx in self.learnable_params.keys() }, "enable_mask": self.enable_mask, "new_mask_idx": self.new_mask_idx, "class_policy_map": self.class_policy_map, "rank": self.rank, "matching_texts": self.matching_texts, "init_val": self.init_val, "max_mult": self.max_mult, } torch.save(payload, path)
[docs] def load_weights(self, path: str): """ Load the mask Parameters ---------- path: str, required location to where the .pt for the mask is located. """ info = torch.load(path, map_location=self.device) self.learnable_params = nn.ModuleDict() self._mask_param_lists.clear() self.enable_mask = list(info["enable_mask"]) self.new_mask_idx = int(info["new_mask_idx"]) self.class_policy_map = dict(info["class_policy_map"]) for idx, d in info["learnable_params"].items(): pd = nn.ParameterDict({n: nn.Parameter(t.to(self.device)) for n, t in d.items()}) self.learnable_params[idx] = pd self._mask_param_lists[int(idx)] = list(pd.parameters()) # choose a mask (method unchanged per your request) self.set_mask()
def activate_mask(self, p, mask_idx): if mask_idx == -1: return torch.ones_like(p, dtype=torch.float32) if not self.enable_mask[mask_idx]: return torch.ones_like(p, dtype=torch.float32) return torch.sigmoid(p).to(torch.float32) * self.max_mult def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) def compose_new_params(self, param_name, mask_idx): U = self.decomposed_params[f"{param_name}::U"] # [m,r] S = self.decomposed_params[f"{param_name}::S"] # [r] Vh = self.decomposed_params[f"{param_name}::Vh"] # [r,n] mparam = self.learnable_params[str(mask_idx)][param_name.replace('.', '__')] # [r] mm = self.activate_mask(mparam, mask_idx) # [r] S_scaled = S * mm eps = torch.finfo(S.dtype).eps scale = (S.sum() / (S_scaled.sum() + eps)) Wp = torch.einsum('mr,r->mr', U, S_scaled) Wp = torch.einsum('mr,rn->mn', Wp, Vh) * scale return Wp
[docs] def toggle_mask(self, mask_value: bool = True, mask_idx: int = None): """ turn on or off the mask Parameters ---------- mask_value: bool, optional A boolean checking whether the mask should be on or off (default True) mask_idx: int, optional If None selected the last mask index (default None) """ if mask_idx is None: mask_idx = self.latest_mask_idx self.enable_mask[mask_idx] = mask_value self.apply_policy_to_model(mask_idx)
[docs] def update_backward(self, mask_idx: int = None): """ Backpropagate through the learnable mask parameters using VJP. Requires that loss.backward() has populated dL/dW on base weights. Parameters ---------- mask_idx: int, optional If None use the selected mask from set_mask (default None) """ if mask_idx is None: mask_idx = self.selected_mask keys = [k for k in self.base_params if all(text not in k for text in self.matching_texts) and self.decomposed_params.get(f"{k}::S") is not None] if not keys: return last_key = keys[-1] for k in keys: g = self.get_parameter(k).grad if g is None: raise RuntimeError(f"No grad for {k}; call set_train() and loss.backward() first.") self.compose_new_params(k, mask_idx).backward(g, retain_graph=(k is not last_key))
[docs] def set_train(self, mask_idx: int = None): """ Set the learnable parameters to training mode. Parameters --------- mask_idx: int, optional If None use the mask index from set_mask """ if mask_idx is None: mask_idx = self.selected_mask # 1) freeze everything for _, p in self.named_parameters(): p.requires_grad_(False) # 2) enable grads on base weights we compose (so dL/dW is computed) for k in self.base_params: if any(s in k for s in self.matching_texts): continue if self.decomposed_params.get(f"{k}::S") is None: continue p = self.get_parameter(k) p.requires_grad_(True) p.retain_grad() # keep grad around for VJP # 3) ensure masks are trainable for p in self.get_trainable_parameters(mask_idx): p.requires_grad_(True)
[docs] def apply_policy_to_model(self, mask_idx: int = None): """ Compose & write weights into the live model (fast in-place copy). Parameters ---------- mask_idx: int, required index of the mask that should be applied to the model if None will choose based on set_mask or latest mask Examples -------- >>> model.apply_policy_to_model(1) """ if mask_idx is None: mask_idx = self.selected_mask with torch.no_grad(): for k, base in self.base_params.items(): param = self.get_parameter(k) if any(skip in k for skip in self.matching_texts): param.copy_(base.to(param.dtype).to(self.device)) continue if self.decomposed_params.get(f"{k}::S") is None: param.copy_(base.to(param.dtype).to(self.device)) continue Wp = self.compose_new_params(k, mask_idx).to(param.dtype).to(self.device) param.copy_(Wp)