Module pastalib.pasta

PASTA Implementation

Expand source code
"""PASTA Implementation"""
import torch
import abc, json
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Literal, Optional, Sequence, cast, overload, Tuple

import transformers 
from pastalib.utils import tokenizer_utils
from pastalib.utils.typing import (
    Model,
    Dataset,
    Device,
    ModelInput,
    ModelOutput,
    StrSequence,
    Tokenizer,
    TokenizerOffsetMapping,
)


class PASTA(abc.ABC):

    ATTN_MODULE_NAME = {
        "gptj": "transformer.h.{}.attn",
        "llama": "model.layers.{}.self_attn",
    }
    ATTENTION_MASK_ARGIDX = {
        "gptj": 2, 
        "llama": 1, 
    }

    def __init__(
        self, 
        model: Model, 
        tokenizer: Tokenizer, 
        head_config: dict|list|None = None, 
        alpha: float = 0.01, 
        scale_position: str = "exclude", 
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.setup_model(model)

        self.alpha = alpha
        self.scale_position = scale_position
        self.setup_head_config(head_config)

        assert self.scale_position in ['include', 'exclude']
        assert self.alpha > 0

    def setup_model(self, model):
        if isinstance(model, transformers.LlamaForCausalLM):
            self.model_name = "llama"
            self.num_attn_head = model.config.num_attention_heads
        elif isinstance(model, transformers.GPTJForCausalLM):
            self.model_name = "gptj"
            self.num_attn_head = model.config.n_head
        else:
            raise ValueError("Unimplemented Model Type.")

    def setup_head_config(self, head_config):
        if isinstance(head_config, dict):
            self.head_config = {int(k):v for k,v in head_config.items()} 
            self.all_layers_idx = [int(key) for key in head_config]
        elif isinstance(head_config, list):
            self.all_layers_idx = [int(v) for v in head_config]
            self.head_config = {
                idx:list(range(self.num_attn_head)) for idx in self.all_layers_idx
            }
        else:
            raise ValueError(f"Incorrect head config: {head_config}")
    
    def _maybe_batch(self, text: str | StrSequence) -> StrSequence:
        """Batch the text if it is not already batched."""
        if isinstance(text, str):
            return [text]
        return text

    def token_ranges_from_batch(
        self,
        strings: str | StrSequence,
        substrings: str | StrSequence,
        offsets_mapping: Sequence[TokenizerOffsetMapping],
        occurrence: int = 0,
    ) -> torch.Tensor:
        """Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs."""
        strings = self._maybe_batch(strings)
        substrings = self._maybe_batch(substrings)
        if len(strings) != len(substrings):
            raise ValueError(
                f"got {len(strings)} strings but only {len(substrings)} substrings"
            )
        return torch.tensor(
            [
                tokenizer_utils.find_token_range(
                    string, substring, offset_mapping=offset_mapping, occurrence=occurrence
                )
                for string, substring, offset_mapping in zip(
                    strings, substrings, offsets_mapping
                )
            ]
        )

    def edit_attention_mask(
        self, 
        module: torch.nn.Module, 
        input_args: tuple,
        input_kwargs: dict, 
        head_idx: list[int],
        token_range: torch.Tensor, 
        input_len: int, 
    ):
        if "attention_mask" in input_kwargs:
            attention_mask = input_kwargs['attention_mask'].clone()
        elif input_args is not None:
            arg_idx = self.ATTENTION_MASK_ARGIDX[self.model_name]
            attention_mask = input_args[arg_idx].clone()
        else:
            raise ValueError(f"Not found attention masks in {str(module)}")
        
        bsz, head_dim, tgt_len, src_len = attention_mask.size()
        dtype, device = attention_mask.dtype, attention_mask.device
        if head_dim != self.num_attn_head:
            attention_mask = attention_mask.expand(
                bsz, self.num_attn_head, tgt_len, src_len
            ).clone()
        scale_constant = torch.Tensor([self.alpha]).to(dtype).to(device).log()
        
        for bi, (ti,tj) in enumerate(token_range.tolist()):
            if self.scale_position == "include":
                attention_mask[bi, head_idx, :, ti:tj] += scale_constant
            else:
                attention_mask[bi, head_idx, :, :ti] += scale_constant
                attention_mask[bi, head_idx, :, tj:input_len] += scale_constant
        
        if self.model_name == "llama":
            attention_mask.old_size = attention_mask.size 
            attention_mask.size = lambda:(bsz, 1, tgt_len, src_len)
        
        if "attention_mask" in input_kwargs:
            input_kwargs['attention_mask'] = attention_mask 
            return input_args, input_kwargs
        else:
            return (input_args[:arg_idx], attention_mask, *input_args[arg_idx+1:]), input_kwargs


    @contextmanager
    def apply_steering(
        self, 
        model: Model, 
        strings: list, 
        substrings: list, 
        model_input: ModelInput, 
        offsets_mapping: Sequence[TokenizerOffsetMapping], 
    ):
        token_range = self.token_ranges_from_batch(
            strings, substrings, offsets_mapping
        )

        registered_hooks = []
        for layer_idx in self.all_layers_idx:
            name = self.ATTN_MODULE_NAME[self.model_name].format(layer_idx)
            module = model.get_submodule(name)
            hook_func = partial(
                self.edit_attention_mask, 
                head_idx = self.head_config[layer_idx],
                token_range = token_range, 
                input_len = model_input['input_ids'].size(-1)
            )
            registered_hook = module.register_forward_pre_hook(hook_func, with_kwargs=True)
            registered_hooks.append(registered_hook)
        try:
            yield model
        except Exception as error:
            raise error
        finally:
            for registered_hook in registered_hooks:
                registered_hook.remove()
    

    def inputs_from_batch(
        self, 
        text: str | StrSequence,
        tokenizer: Tokenizer|None = None,
        device: Optional[Device] = None,
    ) -> tuple[ModelInput, Sequence[TokenizerOffsetMapping]]:
        """Precompute model inputs."""
        if tokenizer is None:
            tokenizer = self.tokenizer
        with tokenizer_utils.set_padding_side(tokenizer, padding_side="left"):
            inputs = tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                padding="longest",
                return_offsets_mapping=True,
            )
            offset_mapping = inputs.pop("offset_mapping")
        if device is not None:
            inputs = inputs.to(device)
        return inputs, offset_mapping

    @classmethod
    def load_head_config(cls, file:str|Path):
        with open(file, "r") as f:
            head_config = json.load(f)
        return head_config 

Classes

class PASTA (model: transformers.models.gptj.modeling_gptj.GPTJForCausalLM | transformers.models.llama.modeling_llama.LlamaForCausalLM | transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel | transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM, tokenizer: transformers.tokenization_utils_fast.PreTrainedTokenizerFast, head_config: dict | list | None = None, alpha: float = 0.01, scale_position: str = 'exclude')

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class PASTA(abc.ABC):

    ATTN_MODULE_NAME = {
        "gptj": "transformer.h.{}.attn",
        "llama": "model.layers.{}.self_attn",
    }
    ATTENTION_MASK_ARGIDX = {
        "gptj": 2, 
        "llama": 1, 
    }

    def __init__(
        self, 
        model: Model, 
        tokenizer: Tokenizer, 
        head_config: dict|list|None = None, 
        alpha: float = 0.01, 
        scale_position: str = "exclude", 
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.setup_model(model)

        self.alpha = alpha
        self.scale_position = scale_position
        self.setup_head_config(head_config)

        assert self.scale_position in ['include', 'exclude']
        assert self.alpha > 0

    def setup_model(self, model):
        if isinstance(model, transformers.LlamaForCausalLM):
            self.model_name = "llama"
            self.num_attn_head = model.config.num_attention_heads
        elif isinstance(model, transformers.GPTJForCausalLM):
            self.model_name = "gptj"
            self.num_attn_head = model.config.n_head
        else:
            raise ValueError("Unimplemented Model Type.")

    def setup_head_config(self, head_config):
        if isinstance(head_config, dict):
            self.head_config = {int(k):v for k,v in head_config.items()} 
            self.all_layers_idx = [int(key) for key in head_config]
        elif isinstance(head_config, list):
            self.all_layers_idx = [int(v) for v in head_config]
            self.head_config = {
                idx:list(range(self.num_attn_head)) for idx in self.all_layers_idx
            }
        else:
            raise ValueError(f"Incorrect head config: {head_config}")
    
    def _maybe_batch(self, text: str | StrSequence) -> StrSequence:
        """Batch the text if it is not already batched."""
        if isinstance(text, str):
            return [text]
        return text

    def token_ranges_from_batch(
        self,
        strings: str | StrSequence,
        substrings: str | StrSequence,
        offsets_mapping: Sequence[TokenizerOffsetMapping],
        occurrence: int = 0,
    ) -> torch.Tensor:
        """Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs."""
        strings = self._maybe_batch(strings)
        substrings = self._maybe_batch(substrings)
        if len(strings) != len(substrings):
            raise ValueError(
                f"got {len(strings)} strings but only {len(substrings)} substrings"
            )
        return torch.tensor(
            [
                tokenizer_utils.find_token_range(
                    string, substring, offset_mapping=offset_mapping, occurrence=occurrence
                )
                for string, substring, offset_mapping in zip(
                    strings, substrings, offsets_mapping
                )
            ]
        )

    def edit_attention_mask(
        self, 
        module: torch.nn.Module, 
        input_args: tuple,
        input_kwargs: dict, 
        head_idx: list[int],
        token_range: torch.Tensor, 
        input_len: int, 
    ):
        if "attention_mask" in input_kwargs:
            attention_mask = input_kwargs['attention_mask'].clone()
        elif input_args is not None:
            arg_idx = self.ATTENTION_MASK_ARGIDX[self.model_name]
            attention_mask = input_args[arg_idx].clone()
        else:
            raise ValueError(f"Not found attention masks in {str(module)}")
        
        bsz, head_dim, tgt_len, src_len = attention_mask.size()
        dtype, device = attention_mask.dtype, attention_mask.device
        if head_dim != self.num_attn_head:
            attention_mask = attention_mask.expand(
                bsz, self.num_attn_head, tgt_len, src_len
            ).clone()
        scale_constant = torch.Tensor([self.alpha]).to(dtype).to(device).log()
        
        for bi, (ti,tj) in enumerate(token_range.tolist()):
            if self.scale_position == "include":
                attention_mask[bi, head_idx, :, ti:tj] += scale_constant
            else:
                attention_mask[bi, head_idx, :, :ti] += scale_constant
                attention_mask[bi, head_idx, :, tj:input_len] += scale_constant
        
        if self.model_name == "llama":
            attention_mask.old_size = attention_mask.size 
            attention_mask.size = lambda:(bsz, 1, tgt_len, src_len)
        
        if "attention_mask" in input_kwargs:
            input_kwargs['attention_mask'] = attention_mask 
            return input_args, input_kwargs
        else:
            return (input_args[:arg_idx], attention_mask, *input_args[arg_idx+1:]), input_kwargs


    @contextmanager
    def apply_steering(
        self, 
        model: Model, 
        strings: list, 
        substrings: list, 
        model_input: ModelInput, 
        offsets_mapping: Sequence[TokenizerOffsetMapping], 
    ):
        token_range = self.token_ranges_from_batch(
            strings, substrings, offsets_mapping
        )

        registered_hooks = []
        for layer_idx in self.all_layers_idx:
            name = self.ATTN_MODULE_NAME[self.model_name].format(layer_idx)
            module = model.get_submodule(name)
            hook_func = partial(
                self.edit_attention_mask, 
                head_idx = self.head_config[layer_idx],
                token_range = token_range, 
                input_len = model_input['input_ids'].size(-1)
            )
            registered_hook = module.register_forward_pre_hook(hook_func, with_kwargs=True)
            registered_hooks.append(registered_hook)
        try:
            yield model
        except Exception as error:
            raise error
        finally:
            for registered_hook in registered_hooks:
                registered_hook.remove()
    

    def inputs_from_batch(
        self, 
        text: str | StrSequence,
        tokenizer: Tokenizer|None = None,
        device: Optional[Device] = None,
    ) -> tuple[ModelInput, Sequence[TokenizerOffsetMapping]]:
        """Precompute model inputs."""
        if tokenizer is None:
            tokenizer = self.tokenizer
        with tokenizer_utils.set_padding_side(tokenizer, padding_side="left"):
            inputs = tokenizer(
                text,
                return_tensors="pt",
                truncation=True,
                padding="longest",
                return_offsets_mapping=True,
            )
            offset_mapping = inputs.pop("offset_mapping")
        if device is not None:
            inputs = inputs.to(device)
        return inputs, offset_mapping

    @classmethod
    def load_head_config(cls, file:str|Path):
        with open(file, "r") as f:
            head_config = json.load(f)
        return head_config 

Ancestors

  • abc.ABC

Class variables

var ATTENTION_MASK_ARGIDX
var ATTN_MODULE_NAME

Static methods

def load_head_config(file: str | pathlib.Path)
Expand source code
@classmethod
def load_head_config(cls, file:str|Path):
    with open(file, "r") as f:
        head_config = json.load(f)
    return head_config 

Methods

def apply_steering(self, model: transformers.models.gptj.modeling_gptj.GPTJForCausalLM | transformers.models.llama.modeling_llama.LlamaForCausalLM | transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel | transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM, strings: list, substrings: list, model_input: transformers.tokenization_utils_base.BatchEncoding, offsets_mapping: Sequence[Sequence[tuple[int, int]]])
Expand source code
@contextmanager
def apply_steering(
    self, 
    model: Model, 
    strings: list, 
    substrings: list, 
    model_input: ModelInput, 
    offsets_mapping: Sequence[TokenizerOffsetMapping], 
):
    token_range = self.token_ranges_from_batch(
        strings, substrings, offsets_mapping
    )

    registered_hooks = []
    for layer_idx in self.all_layers_idx:
        name = self.ATTN_MODULE_NAME[self.model_name].format(layer_idx)
        module = model.get_submodule(name)
        hook_func = partial(
            self.edit_attention_mask, 
            head_idx = self.head_config[layer_idx],
            token_range = token_range, 
            input_len = model_input['input_ids'].size(-1)
        )
        registered_hook = module.register_forward_pre_hook(hook_func, with_kwargs=True)
        registered_hooks.append(registered_hook)
    try:
        yield model
    except Exception as error:
        raise error
    finally:
        for registered_hook in registered_hooks:
            registered_hook.remove()
def edit_attention_mask(self, module: torch.nn.modules.module.Module, input_args: tuple, input_kwargs: dict, head_idx: list[int], token_range: torch.Tensor, input_len: int)
Expand source code
def edit_attention_mask(
    self, 
    module: torch.nn.Module, 
    input_args: tuple,
    input_kwargs: dict, 
    head_idx: list[int],
    token_range: torch.Tensor, 
    input_len: int, 
):
    if "attention_mask" in input_kwargs:
        attention_mask = input_kwargs['attention_mask'].clone()
    elif input_args is not None:
        arg_idx = self.ATTENTION_MASK_ARGIDX[self.model_name]
        attention_mask = input_args[arg_idx].clone()
    else:
        raise ValueError(f"Not found attention masks in {str(module)}")
    
    bsz, head_dim, tgt_len, src_len = attention_mask.size()
    dtype, device = attention_mask.dtype, attention_mask.device
    if head_dim != self.num_attn_head:
        attention_mask = attention_mask.expand(
            bsz, self.num_attn_head, tgt_len, src_len
        ).clone()
    scale_constant = torch.Tensor([self.alpha]).to(dtype).to(device).log()
    
    for bi, (ti,tj) in enumerate(token_range.tolist()):
        if self.scale_position == "include":
            attention_mask[bi, head_idx, :, ti:tj] += scale_constant
        else:
            attention_mask[bi, head_idx, :, :ti] += scale_constant
            attention_mask[bi, head_idx, :, tj:input_len] += scale_constant
    
    if self.model_name == "llama":
        attention_mask.old_size = attention_mask.size 
        attention_mask.size = lambda:(bsz, 1, tgt_len, src_len)
    
    if "attention_mask" in input_kwargs:
        input_kwargs['attention_mask'] = attention_mask 
        return input_args, input_kwargs
    else:
        return (input_args[:arg_idx], attention_mask, *input_args[arg_idx+1:]), input_kwargs
def inputs_from_batch(self, text: str | list[str] | tuple[str, ...], tokenizer: transformers.tokenization_utils_fast.PreTrainedTokenizerFast | None = None, device: Union[str, torch.device, ForwardRef(None)] = None) ‑> tuple[transformers.tokenization_utils_base.BatchEncoding, typing.Sequence[typing.Sequence[tuple[int, int]]]]

Precompute model inputs.

Expand source code
def inputs_from_batch(
    self, 
    text: str | StrSequence,
    tokenizer: Tokenizer|None = None,
    device: Optional[Device] = None,
) -> tuple[ModelInput, Sequence[TokenizerOffsetMapping]]:
    """Precompute model inputs."""
    if tokenizer is None:
        tokenizer = self.tokenizer
    with tokenizer_utils.set_padding_side(tokenizer, padding_side="left"):
        inputs = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            padding="longest",
            return_offsets_mapping=True,
        )
        offset_mapping = inputs.pop("offset_mapping")
    if device is not None:
        inputs = inputs.to(device)
    return inputs, offset_mapping
def setup_head_config(self, head_config)
Expand source code
def setup_head_config(self, head_config):
    if isinstance(head_config, dict):
        self.head_config = {int(k):v for k,v in head_config.items()} 
        self.all_layers_idx = [int(key) for key in head_config]
    elif isinstance(head_config, list):
        self.all_layers_idx = [int(v) for v in head_config]
        self.head_config = {
            idx:list(range(self.num_attn_head)) for idx in self.all_layers_idx
        }
    else:
        raise ValueError(f"Incorrect head config: {head_config}")
def setup_model(self, model)
Expand source code
def setup_model(self, model):
    if isinstance(model, transformers.LlamaForCausalLM):
        self.model_name = "llama"
        self.num_attn_head = model.config.num_attention_heads
    elif isinstance(model, transformers.GPTJForCausalLM):
        self.model_name = "gptj"
        self.num_attn_head = model.config.n_head
    else:
        raise ValueError("Unimplemented Model Type.")
def token_ranges_from_batch(self, strings: str | list[str] | tuple[str, ...], substrings: str | list[str] | tuple[str, ...], offsets_mapping: Sequence[Sequence[tuple[int, int]]], occurrence: int = 0) ‑> torch.Tensor

Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs.

Expand source code
def token_ranges_from_batch(
    self,
    strings: str | StrSequence,
    substrings: str | StrSequence,
    offsets_mapping: Sequence[TokenizerOffsetMapping],
    occurrence: int = 0,
) -> torch.Tensor:
    """Return shape (batch_size, 2) tensor of token ranges for (str, substr) pairs."""
    strings = self._maybe_batch(strings)
    substrings = self._maybe_batch(substrings)
    if len(strings) != len(substrings):
        raise ValueError(
            f"got {len(strings)} strings but only {len(substrings)} substrings"
        )
    return torch.tensor(
        [
            tokenizer_utils.find_token_range(
                string, substring, offset_mapping=offset_mapping, occurrence=occurrence
            )
            for string, substring, offset_mapping in zip(
                strings, substrings, offsets_mapping
            )
        ]
    )