from typing import Dict
from src.loggers.setup_logging import setup_logging
from src.reliability_eval.common.models.scores import QuestionAggregateScores
from src.reliability_eval.evaluator.model_evaluator import ModelGenerationEvaluator
from src.reliability_eval.pipeline.context import EvaluationContext
from src.reliability_eval.pipeline.processor.batch import Batch

logger = setup_logging()

class BatchProcessor:
    """Handles processing of individual batches."""
    
    def __init__(self, context: EvaluationContext):
        """Initialize with evaluation context."""
        self.context = context
        self.evaluator = ModelGenerationEvaluator(
            model=context.model,
            tokenizer=context.tokenizer,
            model_name=str(context.model_identifier),
            device=context.device
        )

    def process_batch(self, batch: Batch) -> Dict[str, QuestionAggregateScores]:
        """Process a single batch and return scores."""
        logger.debug(f"Processing batch with {len(batch.queries)} items")
        return self.evaluator.evaluate_batch(
            queries=batch.queries,
            true_answers=batch.answers,
            generation_experiment_config=self.context.generation_config,
            evaluation_pipeline_dict=self.context.evaluation_config
        )
