from pathlib import Path
from typing import List, Optional

from llm_mcts.llm_generation_interface import Model
from llm_mcts.models.aggregated_model import AggregatedModel
from llm_mcts.models.claude_api import PRICING as CLAUDE_PRICING
from llm_mcts.models.claude_api import ClaudeBedrockAPIModel
from llm_mcts.models.gemini_api import PRICING as GEMINI_PRICING
from llm_mcts.models.gemini_api import GeminiAPIModel
from llm_mcts.models.openai_api import PRICING as OPENAI_PRICING
from llm_mcts.models.openai_api import OpenAIAPIModel


def parse_model_name(value: str) -> List[str]:
    return value.split(",")


def parse_model_temperature(value: str) -> List[float]:
    return list(map(float, value.split(",")))


def parse_model_prob(value: Optional[str]) -> List[float] | None:
    if value is not None:
        return list(map(float, value.split(",")))
    else:
        return None


def build_model(
    model_names: str,
    model_probs: Optional[str],
    temperatures: str,
    logging_dir: Optional[Path],
) -> Model:
    model_names = parse_model_name(model_names)
    model_probs = parse_model_prob(model_probs)
    temperatures = parse_model_temperature(temperatures)

    if len(temperatures) > 1 and len(temperatures) != len(model_names):
        raise ValueError("Number of temperatures must match number of models or be 1")

    if len(temperatures) == 1 and len(model_names) > 1:
        temperatures = temperatures * len(model_names)

    individual_models = []
    for model_name, temperature in zip(model_names, temperatures):
        if model_name in CLAUDE_PRICING:
            model_cls = ClaudeBedrockAPIModel
        elif model_name in GEMINI_PRICING:
            model_cls = GeminiAPIModel
        elif model_name in OPENAI_PRICING:
            model_cls = OpenAIAPIModel
        else:
            raise ValueError(f"Unsupported model {model_name}")
        individual_models.append(
            model_cls(
                model=model_name, temperature=temperature, logging_dir=logging_dir
            )
        )

    model = AggregatedModel(
        models=individual_models,
        model_prob=model_probs,
        logging_dir=logging_dir,
    )
    return model
