import typer
from agenkit.agents import BaseAgent, BaseComponent
from datetime import datetime
from prompts import (
    SrGeneralizerPrompts,
)


class SrGeneralizerAgent(BaseAgent):

    def __init__(self, prompts: SrGeneralizerPrompts, agent_name: str, **kwargs):
        super().__init__(**kwargs)
        self.prompts = prompts
        self.agent_name = agent_name

    async def run(self, workflow_log: dict, **kwargs):
        context = kwargs.copy()
        typer.echo("Executing step: measure...")

        step_key = f"{self.agent_name}.measure"
        log_key = step_key
        i = 1
        while log_key in workflow_log:
            log_key = f"{step_key}_{i}"
            i += 1

        workflow_log[log_key] = {
            "status": "RUNNING",
            "start_time": datetime.now().isoformat(),
            "output": None,
            "error": None,
        }
        try:
            step_1_comp = BaseComponent(
                model_name=self.model_name,
                model_provider=self.model_provider,
                system_prompt=self.prompts.measure.system,
                user_prompt_template=self.prompts.measure.user,
            )
            step_1_result = await step_1_comp.execute(
                self.llm_router,
                prompt_format_kwargs={
                    "generalizer_output": context["generalizer_output"],
                    "original_abstract": context["original_abstract"],
                },
            )
            output = step_1_result["response"]
            context[f"{self.agent_name}.measure.output"] = output
            workflow_log[log_key].update({"status": "COMPLETED", "output": output})
        except Exception as e:
            error_message = f"{e.__class__.__name__}: {e}"
            typer.secho(
                f"    ❌ Error during step 'measure': {error_message}",
                fg=typer.colors.RED,
            )
            workflow_log[log_key].update({"status": "FAILED", "error": error_message})
            raise
        finally:
            workflow_log[log_key]["end_time"] = datetime.now().isoformat()

        context[f"{self.agent_name}.final.output"] = context.get(
            f"{self.agent_name}.measure.output"
        )
        return context
