import asyncio
from typing import Any

from ftagent.app.finetune.llm.conf import LLMFinetunePropSetting
from ftagent.components.coder.finetune.conf import get_ft_env
from ftagent.components.workflow.rd_loop import RDLoop
from ftagent.core.conf import RD_AGENT_SETTINGS
from ftagent.core.exception import CoderError
from ftagent.core.proposal import HypothesisFeedback
from ftagent.log import ftagent_logger as logger
from ftagent.scenarios.finetune.proposal.trace import FTTrace


class LLMFinetuneRDLoop(RDLoop):
    """LLM fine-tuning loop using standard RDLoop workflow"""

    skip_loop_error = (CoderError,)
    skip_loop_error_stepname = "feedback"  # if `skip_loop_error` happens, we should skip and continue on feedback step
    withdraw_loop_error = ()

    def __init__(self, PROP_SETTING: LLMFinetunePropSetting):
        # Store finetune-specific settings
        self.ft_rd_setting = PROP_SETTING
        self.dataset = PROP_SETTING.dataset
        self.model = PROP_SETTING.base_model

        # Initialize using base class
        super().__init__(PROP_SETTING)

        # Replace generic Trace with FTTrace for SOTA tracking
        self.trace = FTTrace(scen=self.trace.scen)

    async def direct_exp_gen(self, prev_out: dict[str, Any]):
        """Generate LLM fine-tuning experiment"""
        exp = await self.hypothesis_gen.async_gen(self.trace, self)
        logger.log_object(exp.hypothesis, tag="hypothesis")
        logger.log_object(exp.sub_tasks, tag="experiment generation")
        return exp

    def coding(self, prev_out: dict[str, Any]):
        """Generate fine-tuning code"""
        exp = prev_out["direct_exp_gen"]
        exp = self.coder.develop(exp)
        logger.log_object(exp.sub_workspace_list, tag="coder result")
        return exp

    def feedback(self, prev_out: dict[str, Any]):
        """Generate feedback for LLM fine-tuning experiment - always call LLM"""

        # Get experiment from available sources
        exp = prev_out.get("running") or prev_out.get("coding") or prev_out.get("direct_exp_gen")
        e = prev_out.get(self.EXCEPTION_KEY, None)
        feedback = self.summarizer.generate_feedback(exp, self.trace, exception=e)

        logger.log_object(feedback, tag="feedback")
        return feedback

    def record(self, prev_out: dict[str, Any]):
        """Record the experiment and feedback into trace"""
        feedback = prev_out["feedback"]
        exp = prev_out.get("running") or prev_out.get("coding") or prev_out.get("direct_exp_gen")
        self.trace.sync_dag_parent_and_hist((exp, feedback), prev_out[self.LOOP_IDX_KEY])
