from src.batch_processor.base_batch_processor import BaseBatchProcessor
from src.batch_processor.claude_batch_processor import ClaudeBatchProcessor
from src.batch_processor.mock_batch_processor import (
    OpenAIMockBatchProcessor,
    GeminiMockBatchProcessor,
    ClaudeMockBatchProcessor,
    OllamaMockBatchProcessor,
    FireworksMockBatchProcessor,
)
from src.batch_processor.openai_batch_processor import OpenAIBatchProcessor
from src.batch_processor.openai_parallel_processor import OpenAIParallelProcessor


__all__ = [
    "batch_processor_factory",
    "BaseBatchProcessor",
]


def batch_processor_factory(generator_name: str, api_type: str, **kwargs) -> BaseBatchProcessor:
    # OpenAIGenerator
    if generator_name.startswith("gpt"):
        if api_type == "b":
            return OpenAIBatchProcessor(generator_name, **kwargs)
        elif api_type == "p":
            return OpenAIParallelProcessor(generator_name, **kwargs)
        elif api_type == "s":
            # return OpenAIParallelProcessor(generator_name, **kwargs)
            return OpenAIMockBatchProcessor(generator_name, **kwargs)
        else:
            raise ValueError(f"Unknown api_type: {api_type}")
    # GeminiGenerator
    elif generator_name.startswith("gemini"):
        return GeminiMockBatchProcessor(generator_name, **kwargs)
    # ClaudeGenerator
    elif generator_name.startswith("anthropic"):
        if api_type == "b":
            return ClaudeMockBatchProcessor(generator_name, **kwargs)  # tmp
        elif api_type in ["p", "s"]:
            return ClaudeMockBatchProcessor(generator_name, **kwargs)
        else:
            raise ValueError(f"Unknown api_type: {api_type}")
    # OllamaGenerator
    elif generator_name in (
        "llama3.3:70b-instruct-q4_K_M",
        "mistral:7b-instruct-v0.3-q4_0",
        "gemma:7b-instruct-v1.1-q4_0"
    ):
        return OllamaMockBatchProcessor(generator_name, **kwargs)
    # FireworksGenerator
    elif generator_name.startswith("accounts/fireworks/"):
        return FireworksMockBatchProcessor(generator_name, **kwargs)
    else:
        raise ValueError(f"Unknown generator_name: {generator_name}")
