Module pastalib.pasta
PASTA Implementation
Expand source code
"""PASTA Implementation"""
import torch
import abc, json
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Literal, Optional, Sequence, cast, overload, Tuple
import transformers
from pastalib.utils import tokenizer_utils
from pastalib.utils.typing import (
Model,
Dataset,
Device,
ModelInput,
ModelOutput,
StrSequence,
Tokenizer,
TokenizerOffsetMapping,
)
class PASTA(abc.ABC):
ATTN_MODULE_NAME = {
"gptj": "transformer.h.{}.attn",
"llama": "model.layers.{}.self_attn",
}
ATTENTION_MASK_ARGIDX = {
"gptj": 2,
"llama": 1,
}
def __init__(
self,
model: Model,
tokenizer: Tokenizer,
head_config: dict|list|None = None,
alpha: float = 0.01,
scale_position: str = "exclude",
):
self.model = model
self.tokenizer = tokenizer
self.setup_model(model)
self.alpha = alpha
self.scale_position = scale_position
self.setup_head_config(head_config)
assert self.scale_position in ['include', 'exclude']
assert self.alpha > 0
def setup_model(self, model):
if isinstance(model, transformers.LlamaForCausalLM):
self.model_name = "llama"
self.num_attn_head = model.config.num_attention_heads
elif isinstance(model, transformers.GPTJForCausalLM):
self.model_name = "gptj"
self.num_attn_head = model.config.n_head
else:
raise ValueError("Unimplemented Model Type.")
def setup_head_config(self, head_config):
if isinstance(head_config, dict):
self.head_config = {int(k):v for k,v in head_config.items()}
self.all_layers_idx = [int(key) for key in head_config]
elif isinstance(head_config, list):
self.all_layers_idx = [int(v) for v in head_config]
self.head_config = {
idx:list(range(self.num_attn_head)) for idx in self.all_layers_idx
}
else:
raise ValueError(f"Incorrect head config: {head_config}")
def _maybe_batch(self, text: str | StrSequence) -> StrSequence:
"""Batch the text if it is not already batched."""
if isinstance(text, str):
return [text]
return text
def token_ranges_from_batch(
self,
strings: str | StrSequence,
substrings: str | StrSequence,
offsets_mapping: Sequence[TokenizerOffsetMapping],
occurrence: int = 0,
) -> torch.Tensor:
"""Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs."""
strings = self._maybe_batch(strings)
substrings = self._maybe_batch(substrings)
if len(strings) != len(substrings):
raise ValueError(
f"got {len(strings)} strings but only {len(substrings)} substrings"
)
return torch.tensor(
[
tokenizer_utils.find_token_range(
string, substring, offset_mapping=offset_mapping, occurrence=occurrence
)
for string, substring, offset_mapping in zip(
strings, substrings, offsets_mapping
)
]
)
def edit_attention_mask(
self,
module: torch.nn.Module,
input_args: tuple,
input_kwargs: dict,
head_idx: list[int],
token_range: torch.Tensor,
input_len: int,
):
if "attention_mask" in input_kwargs:
attention_mask = input_kwargs['attention_mask'].clone()
elif input_args is not None:
arg_idx = self.ATTENTION_MASK_ARGIDX[self.model_name]
attention_mask = input_args[arg_idx].clone()
else:
raise ValueError(f"Not found attention masks in {str(module)}")
bsz, head_dim, tgt_len, src_len = attention_mask.size()
dtype, device = attention_mask.dtype, attention_mask.device
if head_dim != self.num_attn_head:
attention_mask = attention_mask.expand(
bsz, self.num_attn_head, tgt_len, src_len
).clone()
scale_constant = torch.Tensor([self.alpha]).to(dtype).to(device).log()
for bi, (ti,tj) in enumerate(token_range.tolist()):
if self.scale_position == "include":
attention_mask[bi, head_idx, :, ti:tj] += scale_constant
else:
attention_mask[bi, head_idx, :, :ti] += scale_constant
attention_mask[bi, head_idx, :, tj:input_len] += scale_constant
if self.model_name == "llama":
attention_mask.old_size = attention_mask.size
attention_mask.size = lambda:(bsz, 1, tgt_len, src_len)
if "attention_mask" in input_kwargs:
input_kwargs['attention_mask'] = attention_mask
return input_args, input_kwargs
else:
return (input_args[:arg_idx], attention_mask, *input_args[arg_idx+1:]), input_kwargs
@contextmanager
def apply_steering(
self,
model: Model,
strings: list,
substrings: list,
model_input: ModelInput,
offsets_mapping: Sequence[TokenizerOffsetMapping],
):
token_range = self.token_ranges_from_batch(
strings, substrings, offsets_mapping
)
registered_hooks = []
for layer_idx in self.all_layers_idx:
name = self.ATTN_MODULE_NAME[self.model_name].format(layer_idx)
module = model.get_submodule(name)
hook_func = partial(
self.edit_attention_mask,
head_idx = self.head_config[layer_idx],
token_range = token_range,
input_len = model_input['input_ids'].size(-1)
)
registered_hook = module.register_forward_pre_hook(hook_func, with_kwargs=True)
registered_hooks.append(registered_hook)
try:
yield model
except Exception as error:
raise error
finally:
for registered_hook in registered_hooks:
registered_hook.remove()
def inputs_from_batch(
self,
text: str | StrSequence,
tokenizer: Tokenizer|None = None,
device: Optional[Device] = None,
) -> tuple[ModelInput, Sequence[TokenizerOffsetMapping]]:
"""Precompute model inputs."""
if tokenizer is None:
tokenizer = self.tokenizer
with tokenizer_utils.set_padding_side(tokenizer, padding_side="left"):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="longest",
return_offsets_mapping=True,
)
offset_mapping = inputs.pop("offset_mapping")
if device is not None:
inputs = inputs.to(device)
return inputs, offset_mapping
@classmethod
def load_head_config(cls, file:str|Path):
with open(file, "r") as f:
head_config = json.load(f)
return head_config
Classes
class PASTA (model: transformers.models.gptj.modeling_gptj.GPTJForCausalLM | transformers.models.llama.modeling_llama.LlamaForCausalLM | transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel | transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM, tokenizer: transformers.tokenization_utils_fast.PreTrainedTokenizerFast, head_config: dict | list | None = None, alpha: float = 0.01, scale_position: str = 'exclude')-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code
class PASTA(abc.ABC): ATTN_MODULE_NAME = { "gptj": "transformer.h.{}.attn", "llama": "model.layers.{}.self_attn", } ATTENTION_MASK_ARGIDX = { "gptj": 2, "llama": 1, } def __init__( self, model: Model, tokenizer: Tokenizer, head_config: dict|list|None = None, alpha: float = 0.01, scale_position: str = "exclude", ): self.model = model self.tokenizer = tokenizer self.setup_model(model) self.alpha = alpha self.scale_position = scale_position self.setup_head_config(head_config) assert self.scale_position in ['include', 'exclude'] assert self.alpha > 0 def setup_model(self, model): if isinstance(model, transformers.LlamaForCausalLM): self.model_name = "llama" self.num_attn_head = model.config.num_attention_heads elif isinstance(model, transformers.GPTJForCausalLM): self.model_name = "gptj" self.num_attn_head = model.config.n_head else: raise ValueError("Unimplemented Model Type.") def setup_head_config(self, head_config): if isinstance(head_config, dict): self.head_config = {int(k):v for k,v in head_config.items()} self.all_layers_idx = [int(key) for key in head_config] elif isinstance(head_config, list): self.all_layers_idx = [int(v) for v in head_config] self.head_config = { idx:list(range(self.num_attn_head)) for idx in self.all_layers_idx } else: raise ValueError(f"Incorrect head config: {head_config}") def _maybe_batch(self, text: str | StrSequence) -> StrSequence: """Batch the text if it is not already batched.""" if isinstance(text, str): return [text] return text def token_ranges_from_batch( self, strings: str | StrSequence, substrings: str | StrSequence, offsets_mapping: Sequence[TokenizerOffsetMapping], occurrence: int = 0, ) -> torch.Tensor: """Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs.""" strings = self._maybe_batch(strings) substrings = self._maybe_batch(substrings) if len(strings) != len(substrings): raise ValueError( f"got {len(strings)} strings but only {len(substrings)} substrings" ) return torch.tensor( [ tokenizer_utils.find_token_range( string, substring, offset_mapping=offset_mapping, occurrence=occurrence ) for string, substring, offset_mapping in zip( strings, substrings, offsets_mapping ) ] ) def edit_attention_mask( self, module: torch.nn.Module, input_args: tuple, input_kwargs: dict, head_idx: list[int], token_range: torch.Tensor, input_len: int, ): if "attention_mask" in input_kwargs: attention_mask = input_kwargs['attention_mask'].clone() elif input_args is not None: arg_idx = self.ATTENTION_MASK_ARGIDX[self.model_name] attention_mask = input_args[arg_idx].clone() else: raise ValueError(f"Not found attention masks in {str(module)}") bsz, head_dim, tgt_len, src_len = attention_mask.size() dtype, device = attention_mask.dtype, attention_mask.device if head_dim != self.num_attn_head: attention_mask = attention_mask.expand( bsz, self.num_attn_head, tgt_len, src_len ).clone() scale_constant = torch.Tensor([self.alpha]).to(dtype).to(device).log() for bi, (ti,tj) in enumerate(token_range.tolist()): if self.scale_position == "include": attention_mask[bi, head_idx, :, ti:tj] += scale_constant else: attention_mask[bi, head_idx, :, :ti] += scale_constant attention_mask[bi, head_idx, :, tj:input_len] += scale_constant if self.model_name == "llama": attention_mask.old_size = attention_mask.size attention_mask.size = lambda:(bsz, 1, tgt_len, src_len) if "attention_mask" in input_kwargs: input_kwargs['attention_mask'] = attention_mask return input_args, input_kwargs else: return (input_args[:arg_idx], attention_mask, *input_args[arg_idx+1:]), input_kwargs @contextmanager def apply_steering( self, model: Model, strings: list, substrings: list, model_input: ModelInput, offsets_mapping: Sequence[TokenizerOffsetMapping], ): token_range = self.token_ranges_from_batch( strings, substrings, offsets_mapping ) registered_hooks = [] for layer_idx in self.all_layers_idx: name = self.ATTN_MODULE_NAME[self.model_name].format(layer_idx) module = model.get_submodule(name) hook_func = partial( self.edit_attention_mask, head_idx = self.head_config[layer_idx], token_range = token_range, input_len = model_input['input_ids'].size(-1) ) registered_hook = module.register_forward_pre_hook(hook_func, with_kwargs=True) registered_hooks.append(registered_hook) try: yield model except Exception as error: raise error finally: for registered_hook in registered_hooks: registered_hook.remove() def inputs_from_batch( self, text: str | StrSequence, tokenizer: Tokenizer|None = None, device: Optional[Device] = None, ) -> tuple[ModelInput, Sequence[TokenizerOffsetMapping]]: """Precompute model inputs.""" if tokenizer is None: tokenizer = self.tokenizer with tokenizer_utils.set_padding_side(tokenizer, padding_side="left"): inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="longest", return_offsets_mapping=True, ) offset_mapping = inputs.pop("offset_mapping") if device is not None: inputs = inputs.to(device) return inputs, offset_mapping @classmethod def load_head_config(cls, file:str|Path): with open(file, "r") as f: head_config = json.load(f) return head_configAncestors
- abc.ABC
Class variables
var ATTENTION_MASK_ARGIDXvar ATTN_MODULE_NAME
Static methods
def load_head_config(file: str | pathlib.Path)-
Expand source code
@classmethod def load_head_config(cls, file:str|Path): with open(file, "r") as f: head_config = json.load(f) return head_config
Methods
def apply_steering(self, model: transformers.models.gptj.modeling_gptj.GPTJForCausalLM | transformers.models.llama.modeling_llama.LlamaForCausalLM | transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel | transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM, strings: list, substrings: list, model_input: transformers.tokenization_utils_base.BatchEncoding, offsets_mapping: Sequence[Sequence[tuple[int, int]]])-
Expand source code
@contextmanager def apply_steering( self, model: Model, strings: list, substrings: list, model_input: ModelInput, offsets_mapping: Sequence[TokenizerOffsetMapping], ): token_range = self.token_ranges_from_batch( strings, substrings, offsets_mapping ) registered_hooks = [] for layer_idx in self.all_layers_idx: name = self.ATTN_MODULE_NAME[self.model_name].format(layer_idx) module = model.get_submodule(name) hook_func = partial( self.edit_attention_mask, head_idx = self.head_config[layer_idx], token_range = token_range, input_len = model_input['input_ids'].size(-1) ) registered_hook = module.register_forward_pre_hook(hook_func, with_kwargs=True) registered_hooks.append(registered_hook) try: yield model except Exception as error: raise error finally: for registered_hook in registered_hooks: registered_hook.remove() def edit_attention_mask(self, module: torch.nn.modules.module.Module, input_args: tuple, input_kwargs: dict, head_idx: list[int], token_range: torch.Tensor, input_len: int)-
Expand source code
def edit_attention_mask( self, module: torch.nn.Module, input_args: tuple, input_kwargs: dict, head_idx: list[int], token_range: torch.Tensor, input_len: int, ): if "attention_mask" in input_kwargs: attention_mask = input_kwargs['attention_mask'].clone() elif input_args is not None: arg_idx = self.ATTENTION_MASK_ARGIDX[self.model_name] attention_mask = input_args[arg_idx].clone() else: raise ValueError(f"Not found attention masks in {str(module)}") bsz, head_dim, tgt_len, src_len = attention_mask.size() dtype, device = attention_mask.dtype, attention_mask.device if head_dim != self.num_attn_head: attention_mask = attention_mask.expand( bsz, self.num_attn_head, tgt_len, src_len ).clone() scale_constant = torch.Tensor([self.alpha]).to(dtype).to(device).log() for bi, (ti,tj) in enumerate(token_range.tolist()): if self.scale_position == "include": attention_mask[bi, head_idx, :, ti:tj] += scale_constant else: attention_mask[bi, head_idx, :, :ti] += scale_constant attention_mask[bi, head_idx, :, tj:input_len] += scale_constant if self.model_name == "llama": attention_mask.old_size = attention_mask.size attention_mask.size = lambda:(bsz, 1, tgt_len, src_len) if "attention_mask" in input_kwargs: input_kwargs['attention_mask'] = attention_mask return input_args, input_kwargs else: return (input_args[:arg_idx], attention_mask, *input_args[arg_idx+1:]), input_kwargs def inputs_from_batch(self, text: str | list[str] | tuple[str, ...], tokenizer: transformers.tokenization_utils_fast.PreTrainedTokenizerFast | None = None, device: Union[str, torch.device, ForwardRef(None)] = None) ‑> tuple[transformers.tokenization_utils_base.BatchEncoding, typing.Sequence[typing.Sequence[tuple[int, int]]]]-
Precompute model inputs.
Expand source code
def inputs_from_batch( self, text: str | StrSequence, tokenizer: Tokenizer|None = None, device: Optional[Device] = None, ) -> tuple[ModelInput, Sequence[TokenizerOffsetMapping]]: """Precompute model inputs.""" if tokenizer is None: tokenizer = self.tokenizer with tokenizer_utils.set_padding_side(tokenizer, padding_side="left"): inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="longest", return_offsets_mapping=True, ) offset_mapping = inputs.pop("offset_mapping") if device is not None: inputs = inputs.to(device) return inputs, offset_mapping def setup_head_config(self, head_config)-
Expand source code
def setup_head_config(self, head_config): if isinstance(head_config, dict): self.head_config = {int(k):v for k,v in head_config.items()} self.all_layers_idx = [int(key) for key in head_config] elif isinstance(head_config, list): self.all_layers_idx = [int(v) for v in head_config] self.head_config = { idx:list(range(self.num_attn_head)) for idx in self.all_layers_idx } else: raise ValueError(f"Incorrect head config: {head_config}") def setup_model(self, model)-
Expand source code
def setup_model(self, model): if isinstance(model, transformers.LlamaForCausalLM): self.model_name = "llama" self.num_attn_head = model.config.num_attention_heads elif isinstance(model, transformers.GPTJForCausalLM): self.model_name = "gptj" self.num_attn_head = model.config.n_head else: raise ValueError("Unimplemented Model Type.") def token_ranges_from_batch(self, strings: str | list[str] | tuple[str, ...], substrings: str | list[str] | tuple[str, ...], offsets_mapping: Sequence[Sequence[tuple[int, int]]], occurrence: int = 0) ‑> torch.Tensor-
Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs.
Expand source code
def token_ranges_from_batch( self, strings: str | StrSequence, substrings: str | StrSequence, offsets_mapping: Sequence[TokenizerOffsetMapping], occurrence: int = 0, ) -> torch.Tensor: """Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs.""" strings = self._maybe_batch(strings) substrings = self._maybe_batch(substrings) if len(strings) != len(substrings): raise ValueError( f"got {len(strings)} strings but only {len(substrings)} substrings" ) return torch.tensor( [ tokenizer_utils.find_token_range( string, substring, offset_mapping=offset_mapping, occurrence=occurrence ) for string, substring, offset_mapping in zip( strings, substrings, offsets_mapping ) ] )