from abc import ABC, abstractmethod
from typing import Literal, cast

ModelT = Literal[
    # open source models
    "llama-3-1b",
    "llama-3-3b",
    "llama-3-8b",
    "llama-3-70b",
    "llama-3-405b",
    #
    "olmo-2-1b",
    "olmo-2-7b",
    "olmo-2-13b",
    "olmo-2-32b",
    #
    "deepseek-v3",
    "deepseek-r1",  # reasoning
    #
    "gemini-2.5-flash",  # reasoning
    #
    "gpt-4o",
    "o4-mini",  # reasoning
]


class TextGenerationModel(ABC):
    @abstractmethod
    def generate_continuation(self, text: str, max_new_tokens: int | None = None, stop_seq: str | None = None) -> str:
        raise NotImplementedError


def get_model(model_name: ModelT, use_nnsight: bool = False) -> TextGenerationModel:
    match model_name:
        case (
            "llama-3-1b"
            | "llama-3-3b"
            | "llama-3-8b"
            | "llama-3-70b"
            | "olmo-2-1b"
            | "olmo-2-7b"
            | "olmo-2-13b"
            | "olmo-2-32b"
        ):
            from .transformers import NNsightModel, TransformersModel, TransformersModelT

            model_name = cast(TransformersModelT, model_name)
            if use_nnsight:
                return NNsightModel(model_name=model_name)
            return TransformersModel(model_name=model_name)
        case (
            "llama-3-405b"
            | "deepseek-v3"
            | "deepseek-r1"
            | "gpt-4o"
            | "o4-mini"
            | "gemini-2.5-flash"
            | "gemini-2.5-flash-thinking"
        ):
            from .api import APIModel, APIModelT

            model_name = cast(APIModelT, model_name)
            return APIModel(model_name=model_name)


__all__ = ["TextGenerationModel", "get_model", "ModelT"]
