# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for analyzing outputs from tokenizers used for LLMs."""

import dataclasses
import functools
from typing import Sequence
import bidict
import numpy as np
import pandas as pd
import scipy as sp
import transformers


@dataclasses.dataclass(frozen=True)
class TokenInfo:
  target: str
  first_token_id: int
  target_token_ids: int
  preceding_char: str
  token_includes_preceding_char: bool


def string_to_tokens(
    tokenizer: transformers.PreTrainedTokenizer,
    target: str,
    sacrificial_text: str = "to",
    preceding_char: str = " ",
) -> TokenInfo:
  """Finds the first token of the target string.

  Args:
    tokenizer: The tokenizer to use.
    target: The target string to find the first token of.
    sacrificial_text: A string that is prepended to the target string. This is
      necessary because e.g. the Llama 2 tokenizer adds a prepended token,
      SPIECE_UNDERLINE, at the start of the sequence. This token is sometimes
      merged with the first token in the sequence (e.g. the letter "J"), so we
      can't just remove it. Instead, we add sacrificial text before the target
      string.
      https://github.com/huggingface/transformers/blob/6f7d5db58c7c149c75642b5a4647b5cbc6c55643/src/transformers/models/llama/tokenization_llama.py#L248
    preceding_char: The character that is expected to precede the first token of
      the target string, e.g. space or newline.

  Returns:
    A TokenInfo object containing the ID of the first token, and whether
    the token includes a preceding space.
  """
  encode = functools.partial(tokenizer.encode, add_special_tokens=False)
  assert len(encode("")) == 0  # pylint: disable=g-explicit-length-test
  bos_token_id = tokenizer.bos_token_id

  sacrificial_tokens = encode(sacrificial_text)
  n_sacrificial_tokens = len(sacrificial_tokens)
  preceding_token_id = encode(sacrificial_text + preceding_char)
  assert len(preceding_token_id) == 1 + n_sacrificial_tokens
  preceding_token_id = preceding_token_id[-1]
  target_tokens_with_sacrifice_and_space = encode(
      sacrificial_text + preceding_char + target
  )
  assert bos_token_id not in target_tokens_with_sacrifice_and_space
  assert (
      target_tokens_with_sacrifice_and_space[:n_sacrificial_tokens]
      == sacrificial_tokens
  )
  target_tokens_with_space = target_tokens_with_sacrifice_and_space[
      n_sacrificial_tokens:
  ]
  if target_tokens_with_space[0] == preceding_token_id:
    first_token_idx = 1
    first_token = target_tokens_with_space[first_token_idx]
    assert preceding_char not in tokenizer.decode([first_token])
    assert preceding_char not in tokenizer.decode([first_token, first_token])
    token_includes_preceding_char = False
  else:
    first_token_idx = 0
    first_token = target_tokens_with_space[first_token_idx]
    # If the preceding char is a space, some sentencepiece tokenizers don't
    # include it before the first token, but it do include it before later
    # tokens.
    assert preceding_char in tokenizer.decode([first_token, first_token])
    token_includes_preceding_char = True
  return TokenInfo(
      target=target,
      first_token_id=first_token,
      target_token_ids=target_tokens_with_space[first_token_idx:],
      preceding_char=preceding_char,
      token_includes_preceding_char=token_includes_preceding_char,
  )


@functools.lru_cache(maxsize=1)
def get_vocab_tokens_in_order(
    tokenizer: transformers.PreTrainedTokenizer,
) -> Sequence[str]:
  """Return a tuple of the tokens in the tokenizer's vocabulary, in order."""
  vocab_bidict = bidict.bidict(tokenizer.vocab)
  tokens = []
  for i in range(len(vocab_bidict)):
    tokens.append(vocab_bidict.inv[i])
  return tuple(tokens)


def logit_vector_to_token_probs_df(
    tokenizer: transformers.PreTrainedTokenizer,
    logit_vector: np.ndarray,
) -> str:
  """Return a DataFrame allowing easier analysis of token probabilities."""
  if len(logit_vector.shape) != 1:
    raise ValueError(
        f"logit_vector must be a 1D array, but has shape {logit_vector.shape}."
    )
  token_probs = sp.special.softmax(logit_vector)
  vocab_tokens = get_vocab_tokens_in_order(tokenizer)
  tokens_not_in_vocab = len(logit_vector) - len(vocab_tokens)
  token_probs_df = pd.DataFrame(
      data=dict(
          tokens=vocab_tokens + ("<OUT OF RANGE>",) * tokens_not_in_vocab,
          probs=token_probs,
          logits=logit_vector,
      )
  )
  return token_probs_df.sort_values(by="probs", ascending=False)


def get_prefix_tokens_for_labels(
    tokenizer: transformers.PreTrainedTokenizer,
    label_options: Sequence[str],
) -> dict[int, str]:
  """Return all tokens that could be the first token of a class."""
  for label in label_options:
    if label != label.lower().lstrip():
      raise ValueError(
          f"Label {label} must be all lowercase and stripped of leading"
          " whitespace."
      )
  token_to_label = dict()
  for token in tokenizer.vocab.values():
    decoded_token = tokenizer.decode([token], skip_special_tokens=True)
    # Strip leading space.
    stripped_decoded_token = decoded_token.lower().lstrip()
    possible_labels = [
        label
        for label in label_options
        if label.startswith(stripped_decoded_token)
    ]
    if len(possible_labels) == 1:
      token_to_label[token] = possible_labels[0]
  for label in label_options:
    if label not in token_to_label.values():
      raise ValueError(f"No unambiguous token found for label {label}.")
  return token_to_label


def decode_with_spans(
    tokenizer: transformers.PreTrainedTokenizer,
    token_ids: Sequence[int],
    skip_special_tokens: bool = False,
) -> (str, Sequence[int]):
  """Decode a sequence of token IDs, returning the decoded string and spans."""
  token_boundaries = [0]
  decoded = ""
  for token_idx in range(len(token_ids)):
    decoded = tokenizer.decode(
        token_ids[: token_idx + 1],
        skip_special_tokens=skip_special_tokens,
    )
    token_boundaries.append(len(decoded))
  return decoded, token_boundaries
