from pathlib import Path

import tiktoken
import tokenizers


def load_model_tokenizers_baseline(name):
    p = Path(__file__).parent.parent / "models" / name / "tokenizer.json"
    assert p.exists(), name
    return tokenizers.Tokenizer.from_file(str(p))


def load_model_tokenizers_inc_bpe(name):
    t = load_model_tokenizers_baseline(name)
    t.model = tokenizers.models.IncrementalBpe.from_bpe(t.model)
    return t


def load_model_tokenizers_eager_bpe(name):
    t = load_model_tokenizers_baseline(name)
    t.model = tokenizers.models.EagerBpe.from_bpe(t.model)
    return t


def load_model_tokenizers_no_cache_baseline(name):
    t = load_model_tokenizers_baseline(name)
    t.model._resize_cache(0)
    return t


def load_model_tokenizers_no_cache_inc_bpe(name):
    t = load_model_tokenizers_inc_bpe(name)
    t.model._resize_cache(0)
    return t


def load_model_tokenizers_no_cache_eager_bpe(name):
    t = load_model_tokenizers_eager_bpe(name)
    t.model._resize_cache(0)
    return t


def func_tokenizers(t):
    return lambda s: t.encode(s).ids


def load_model_tiktoken_baseline(name):
    return tiktoken.get_encoding(name)._core_bpe


def load_model_tiktoken_inc_bpe(name):
    t = tiktoken.get_encoding(name)
    t = type(t._core_bpe)(t._mergeable_ranks, t._special_tokens, t._pat_str)
    return t.as_inc_bpe()


def load_model_tiktoken_eager_bpe(name):
    t = tiktoken.get_encoding(name)
    t = type(t._core_bpe)(t._mergeable_ranks, t._special_tokens, t._pat_str)
    return t.as_eager_bpe()


def func_tiktoken(t):
    return t.encode_ordinary
