from .base import BaseRunner

class SimpleRunner(BaseRunner):
    def __init__(self, model, metrics):
        self.model = model
        self.metrics = metrics
        self.prompt_ar = []
    
    async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
        draft_logits = None
        model_output = await self.model.run(prompt_ids, max_length, end_id, request_id, turn_id)
        self.process_metrics_step(model_output, request_id, turn_id)
        output_ids = model_output['output_ids']
        flattened_output_ids = [[] for _ in range(len(output_ids))]
        for i, beam_output in enumerate(output_ids):
            for output_id_iter in beam_output:
                flattened_output_ids[i].extend(output_id_iter)
        
        return {
            "output_ids": flattened_output_ids,
            "output_logits": model_output.get("output_logits", None)
        }