import logging

from transformers import PreTrainedModel


logger = logging.getLogger(__name__)


def get_model_max_length(model: PreTrainedModel) -> int:
    # Use the same max length that the model supports. Fall back to 1024
    # if the setting can't be found.
    # The configuraton for the length can be stored under different
    # names depending on the model.  Here we attempt a few possible names
    # we've encountered.
    model_config = getattr(model, "config")
    assert model_config is not None
    max_length = None
    for length_setting in [
        "n_positions",
        "max_position_embeddings",
        "seq_length",
    ]:
        max_length = getattr(model_config, length_setting, None)
        if max_length:
            logger.info(f"Found max lenth: {max_length}")
            break
    if not max_length:
        max_length = 1024
        logger.info(f"Using default max length: {max_length}")
    return max_length


def get_tokenizer_type(model_id: str) -> str:
    if model_id.startswith("pythia") or model_id.startswith(
        "EleutherAI/pythia"
    ):
        return "pythia"
    elif model_id.startswith("Llama-2"):
        return "llama2"
    elif model_id.startswith("gpt2"):
        return "gpt2"
    elif model_id.startswith("phi"):
        return "phi"
    elif model_id.startswith("opt"):
        return "opt"
    else:
        raise ValueError(f"Unknown model type: {model_id}")
