from src.generator.base_generator import BaseGenerator
from src.generator.claude_generator import ClaudeGenerator
from src.generator.openai_compat_generator import (
    FireworksGenerator,
    GeminiGenerator,
    OllamaGenerator,
)
from src.generator.openai_generator import OpenAIGenerator


__all__ = [
    "generator_factory",
    "BaseGenerator",
]


def generator_factory(generator_name: str, **kwargs) -> BaseGenerator:
    # OpenAIGenerator
    if generator_name.startswith("gpt"):
        return OpenAIGenerator(generator_name, **kwargs)
    # GeminiGenerator
    elif generator_name.startswith("gemini-"):
        return GeminiGenerator(generator_name, **kwargs)
    # ClaudeGenerator
    elif generator_name.startswith("anthropic"):
        return ClaudeGenerator(generator_name, **kwargs)
    # OllamaGenerator
    elif generator_name in (
        "llama3.3:70b-instruct-q4_K_M",
        "mistral:7b-instruct-v0.3-q4_0",
        "gemma:7b-instruct-v1.1-q4_0",
    ):
        return OllamaGenerator(generator_name, **kwargs)
    # FireworksGenerator
    elif generator_name.startswith("accounts/fireworks/"):
        return FireworksGenerator(generator_name, **kwargs)
    else:
        raise ValueError(f"Unknown generator_name: {generator_name}")
