# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import functools
import re
import string
from typing import Sequence
from absl.testing import absltest
from absl.testing import parameterized
from corr_faith.experiments import classical_nlp
from corr_faith.experiments.models import tokenization
import numpy as np
import transformers

ALL_WHITESPACE_REGEX = r"^\s*$"


REPRESENTATIVE_MODELS = [
    "allenai/OLMo-7B-0724-hf",
    "allenai/OLMo-7B-0724-Instruct-hf",
    "EleutherAI/pythia-1b",
    "mistralai/Mistral-7B-v0.3",
    "mistralai/Mistral-7B-Instruct-v0.3",
    "mistralai/Mistral-Nemo-Base-2407",
    "mistralai/Mistral-Nemo-Instruct-2407",
    "01-ai/Yi-6B",
    "01-ai/Yi-6B-Chat",
    "01-ai/Yi-1.5-6B",
    "01-ai/Yi-1.5-6B-Chat",
    "Qwen/Qwen1.5-0.5B",
    "Qwen/Qwen1.5-0.5B-Chat",
    "Qwen/Qwen2-0.5B",
    "Qwen/Qwen2-0.5B-Instruct",
    "Qwen/Qwen2.5-0.5B",
    "Qwen/Qwen2.5-0.5B-Instruct",
    "google/gemma-2b",
    "google/gemma-2b-it",
    "google/gemma-2-2b",
    "google/gemma-2-2b-it",
]


@functools.cache
def tokenizer_from_model(model_name: str) -> transformers.PreTrainedTokenizer:
  return transformers.AutoTokenizer.from_pretrained(
      model_name, local_files_only=True
  )


@functools.cache
def get_words(**kwargs) -> Sequence[str]:
  words_by_part_of_speech = classical_nlp.get_wordnet_words_by_part_of_speech(
      **kwargs
  )
  # Flatten the words to a single sequence.
  return tuple(
      word  # pylint: disable=g-complex-comprehension
      for pos_words in words_by_part_of_speech.values()
      for word in pos_words
  )


class TokenizationTest(parameterized.TestCase):

  @parameterized.named_parameters([
      dict(testcase_name=model, model_name=model)
      for model in REPRESENTATIVE_MODELS
  ])
  def test_tokenizer_assumptions(self, model_name: str):
    """Tests assumptions about tokenizers that are used in other tests."""
    tokenizer = tokenizer_from_model(model_name)
    encode = functools.partial(tokenizer.encode, add_special_tokens=False)
    # Shouldn't add bos token.
    self.assertEmpty(encode(""))
    short_encoding = encode("a")
    self.assertLen(short_encoding, 1)
    self.assertNotIn(tokenizer.bos_token, short_encoding)
    # Test that return_offsets_mapping is implemented.
    # Note, return_offsets_mapping is only implemented for "fast" tokenizers:
    # https://huggingface.co/docs/transformers/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__.return_offsets_mapping
    encoding_and_offsets = tokenizer(
        "a", return_offsets_mapping=True, add_special_tokens=False
    )
    self.assertSequenceEqual(short_encoding, encoding_and_offsets["input_ids"])
    self.assertSequenceEqual(encoding_and_offsets["offset_mapping"], [(0, 1)])

  @parameterized.product(
      (
          dict(
              string_to_search="JUDGMENT: entailment\nQUESTION:",
              present_target="entailment",
              absent_target="neutral",
              preceding_char=" ",
          ),
          dict(
              string_to_search="ANSWER: 1",
              present_target="1",
              absent_target="2",
              preceding_char=" ",
          ),
          dict(
              string_to_search="ANSWER: 2\nQUESTION:",
              present_target="2",
              absent_target="1",
              preceding_char=" ",
          ),
          dict(
              string_to_search="The fastest way to get to Rome is to fly.",
              present_target="Rome",
              absent_target="Paris",
              preceding_char=" ",
          ),
          dict(
              string_to_search=(
                  "\nFALSE SENTENCE: 1\nEXPLANATION: Horses do not lay eggs,"
                  " chickens do.</s>"
              ),
              present_target="FALSE SENTENCE:",
              absent_target="TRUE SENTENCE:",
              preceding_char="\n",
          ),
      ),
      model_name=REPRESENTATIVE_MODELS,
  )
  def test_find_first_token(
      self,
      model_name: str,
      string_to_search: str,
      present_target: str,
      absent_target: str,
      preceding_char: str,
  ):
    tokenizer = tokenizer_from_model(model_name)
    encode = functools.partial(tokenizer.encode, add_special_tokens=False)

    present_token_info = tokenization.string_to_tokens(
        tokenizer,
        present_target,
        preceding_char=preceding_char,
    )
    encoded_string_to_search = encode(string_to_search)
    self.assertIn(
        present_token_info.first_token_id,
        encoded_string_to_search,
        msg=(
            f"First token {present_token_info.first_token_id} not found in"
            f" {string_to_search}"
        ),
    )
    absent_token_info = tokenization.string_to_tokens(
        tokenizer, absent_target, preceding_char=preceding_char
    )
    self.assertNotIn(
        absent_token_info.first_token_id,
        encoded_string_to_search,
        msg=(
            f"First token {absent_token_info.first_token_id} found unexpectedly"
            f" in {string_to_search}"
        ),
    )

  @parameterized.product(
      [
          dict(model_name=model, seed=idx)
          for idx, model in enumerate(REPRESENTATIVE_MODELS)
      ],
      sep_char=[" ", "\n"],
  )
  def test_string_to_tokens_with_random_words(
      self,
      model_name: str,
      seed: int,
      sep_char: str,
      sequences_per_tokenizer: int = 5,
      max_words_per_sequence: int = 10,
  ) -> None:
    tokenizer = tokenizer_from_model(model_name)
    encode = functools.partial(tokenizer.encode, add_special_tokens=False)
    rng = np.random.default_rng(seed)
    # Note that some non-alphanumeric words break tokenization assumptions,
    # e.g. ".22-calibre", due to its use of punctuation.
    words = get_words(only_alphanumeric=True)
    for _ in range(sequences_per_tokenizer):
      sequence = rng.choice(words, size=rng.integers(2, max_words_per_sequence))
      string_to_search = sep_char.join(sequence)
      encoded_string_to_search = encode(string_to_search)
      for target in sequence[1:]:
        token_info = tokenization.string_to_tokens(
            tokenizer, target, preceding_char=sep_char
        )
        self.assertEqual(token_info.target, target)
        self.assertEqual(token_info.preceding_char, sep_char)
        token_id = token_info.first_token_id
        self.assertIn(
            token_id,
            encoded_string_to_search,
            msg=(
                f"First token {token_id} not found in"
                f" {encoded_string_to_search}"
            ),
        )
        decoded_first_token = tokenizer.decode([token_id])
        self.assertNotEmpty(decoded_first_token)
        # The first token shouldn't consist entirely of whitespace.
        assert not re.fullmatch(target, ALL_WHITESPACE_REGEX)
        self.assertNotRegex(decoded_first_token, r"^\s*$")
        if token_info.token_includes_preceding_char:
          target_maybe_with_sep = sep_char + target
        else:
          target_maybe_with_sep = target
        self.assertIn(
            decoded_first_token,
            target_maybe_with_sep,
        )
        decoded_full_target = tokenizer.decode(token_info.target_token_ids)
        self.assertIn(target, decoded_full_target)

  @parameterized.named_parameters([
      dict(testcase_name=model, model_name=model, seed=idx)
      for idx, model in enumerate(REPRESENTATIVE_MODELS)
  ])
  def test_decode_with_spans(
      self,
      model_name: str,
      seed: int,
      sequences_per_tokenizer: int = 10,
      max_chars_per_sequence: int = 50,
      max_attempts: int = 100,
  ):
    tokenizer = tokenizer_from_model(model_name)
    rng = np.random.default_rng(seed)
    n_sequences = 0
    for _ in range(max_attempts):
      sequence = "".join(
          rng.choice(
              list(string.printable), size=rng.integers(max_chars_per_sequence)
          )
      )
      sequence = tokenizer.decode(tokenizer.encode(sequence))
      if tokenizer.decode(tokenizer.encode(sequence)) != sequence:
        continue  # Skip non-fixed point sequences.
        # Some strings are not fixed points, e.g. strings starting with a
        # space under sentencepiece tokenization; for example, " =/hcBa]p2e"
        # with the tokenizer for '01-ai/Yi-1.5-34B'. Or "8t" with the same
        # tokenizer. Encoding and decoding these strings results in the
        # tokenizer adding additional spaces on each pass.
      encoding = tokenizer(sequence)
      tokens = encoding["input_ids"]
      decoded, token_boundaries = tokenization.decode_with_spans(
          tokenizer, tokens
      )
      self.assertEqual(decoded, sequence)
      self.assertLen(token_boundaries, len(tokens) + 1)
      self.assertEqual(token_boundaries[0], 0)
      self.assertLen(sequence, token_boundaries[-1])
      for token_idx in range(len(tokens)):
        char_span = encoding.token_to_chars(token_idx)
        self.assertEqual(char_span.start, token_boundaries[token_idx])
        self.assertEqual(char_span.end, token_boundaries[token_idx + 1])

      n_sequences += 1
      if n_sequences >= sequences_per_tokenizer:
        return

    # Note: this output will only appear when running with python, not pytest.
    print(
        f"    Found {n_sequences} sequences after"
        f" {max_attempts} attempts for tokenizer {model_name}."
    )


if __name__ == "__main__":
  absltest.main()
