import logging
import os
import random
from collections import defaultdict
from pathlib import Path
from typing import Iterable, List, Optional, Sequence, Tuple

from llm_mcts.file_logging import get_logging_dir
from llm_mcts.llm_generation_interface import GenerationRequest, GenerationResult, Model
from llm_mcts.models.openai_api import OpenAIAPIModel

logger = logging.getLogger(__name__)


class AggregatedModel(Model):
    """
    Aggregate the responses from all the models and return them
    """

    def __init__(
        self,
        models: List[Model],
        model_prob: Optional[list[float]] = None,
        logging_dir: Optional[Path] = None,
    ) -> None:
        self.models = models
        self.model_name = "_".join([model.model_name for model in models])
        self.model_prob = model_prob
        self.logging_dir = logging_dir

        self.logging_dir = (
            get_logging_dir(os.getpid()) if logging_dir is None else logging_dir
        )
        if not self.logging_dir.exists():
            self.logging_dir.mkdir()

    def generate(
        self, requests: Sequence[GenerationRequest]
    ) -> Iterable[GenerationResult]:
        results, _ = self.generate_with_llm_names(requests)
        return results

    def generate_with_llm_names(
        self, requests: Sequence[GenerationRequest]
    ) -> Tuple[Iterable[GenerationResult], List[str]]:
        """
        NOTE: We do not support different requests as of now; We expect a duplicated or single request.
        """
        # Assign models to requests either round-robin or weighted random
        assigned_models = (
            random.choices(self.models, weights=self.model_prob, k=len(requests))
            if self.model_prob is not None
            else [self.models[i % len(self.models)] for i in range(len(requests))]
        )

        if self.model_prob is None and len(requests) < len(self.models):
            logger.warning(
                f"Warning: Round-robin assignment with {len(requests)} requests and {len(self.models)} models "
                f"will only use the first {len(requests)} models, leaving {len(self.models) - len(requests)} models unused."
            )

        # Group requests by model using dictionary comprehension
        model_requests = defaultdict(list)
        for request, model in zip(requests, assigned_models):
            model_requests[model].append(request)

        results = []
        llm_names = []
        for model in self.models:
            llm_names += [model.model_name] * len(model_requests[model])
            if isinstance(model, OpenAIAPIModel):
                if not model.model.startswith("o1"):
                    results += list(
                        model.generate(
                            model_requests[model][:1],  # o1 doesn't support n > 1
                            num_samples=len(model_requests[model]),
                        )
                    )
                else:
                    # o1 doesn't support n > 1
                    results += list(model.generate(model_requests[model]))
            else:
                results += list(model.generate(model_requests[model]))

        return results, llm_names
