from pathlib import Path

from .claude import Claude, ClaudeModelType
from .GPT import GPT, GPTModelType
from .hf_classifier import HFClassifier, HFClassifierType
from .hf_completion_model import HFCompletionModel, HFModelType
from .model import ChatModel, Classifier, CompletionModel, ModelType

parent_dir = Path(__file__).parent


def chat_model_from_config(model_str: str, config: dict[str, str]) -> ChatModel:
    model_type = GPTModelType(model_str)
    chat_model = create_chat_model(model_type)
    chat_model.set_config(config)
    return chat_model


def create_chat_model(model_type: ModelType, *args, **kwargs) -> ChatModel:
    """
    Factory method to create a LLM instance
    """

    if isinstance(model_type, GPTModelType):
        return GPT(model_type, *args, **kwargs)
    elif isinstance(model_type, ClaudeModelType):
        return Claude(model_type, *args, **kwargs)
    raise ValueError(f"ERROR: Unsupported provider {model_type}")


def create_completions_model(model_type: ModelType, *args, **kwargs) -> CompletionModel:
    if isinstance(model_type, HFModelType):
        return HFCompletionModel(model_type, *args, **kwargs)
    raise ValueError(f"ERROR: Unsupported provider {model_type}")


def create_classifier(model_type: ModelType, *args, **kwargs) -> Classifier:
    if isinstance(model_type, HFClassifierType):
        return HFClassifier(model_type, *args, **kwargs)
    raise ValueError(f"ERROR: Unsupported provider {model_type}")
