from src.generator.base_generator import BaseGenerator
from src.generator.claude_generator import ClaudeGenerator
from src.generator.llama_generator import LlamaGenerator
from src.generator.openai_compat_generator import (
    FireworksGenerator,
    GeminiGenerator,
    OllamaGenerator,
)
from src.generator.openai_generator import OpenAIGenerator


__all__ = [
    "generator_factory",
    "BaseGenerator",
]


def generator_factory(generator_name: str, **kwargs) -> BaseGenerator:
    # OpenAIGenerator
    if generator_name == "gpt-3.5-turbo":
        return OpenAIGenerator("gpt-3.5-turbo", **kwargs)
    elif generator_name == "gpt-4o-mini":
        return OpenAIGenerator("gpt-4o-mini", **kwargs)
    elif generator_name == "gpt-4o":
        return OpenAIGenerator("gpt-4o", **kwargs)
    elif generator_name == "o1-preview":
        return OpenAIGenerator("o1-preview", **kwargs)
    # LlamaGenerator
    elif generator_name == "llama3.2:1b-local":
        return LlamaGenerator(model_name="meta-llama/Llama-3.2-1B-Instruct", **kwargs)
    elif generator_name == "llama3.2:3b-local":
        return LlamaGenerator(model_name="meta-llama/Llama-3.2-3B-Instruct", **kwargs)
    elif generator_name == "llama3.1:8b-local":
        return LlamaGenerator(model_name="meta-llama/Llama-3.1-8B-Instruct", **kwargs)
    elif generator_name == "llama3.3:70b-local":
        return LlamaGenerator(model_name="meta-llama/Llama-3.3-70B-Instruct", **kwargs)
    # OllamaGenerator
    elif generator_name in ("llama3.2:1b", "llama3.2:3b"):
        return OllamaGenerator(generator_name, **kwargs)
    # FireworksGenerator
    elif generator_name.startswith("accounts/fireworks/"):
        return FireworksGenerator(generator_name, **kwargs)
    # GeminiGenerator
    elif generator_name.startswith("gemini-"):
        return GeminiGenerator(generator_name, **kwargs)
    # ClaudeGenerator
    elif generator_name.startswith("claude-"):
        return ClaudeGenerator(generator_name, **kwargs)
    else:
        raise ValueError(f"Unknown generator_name: {generator_name}")
