"""
LLM Fine-tuning Runner Implementation

This module provides a specialized runner for LLM fine-tuning that executes
LLaMA-Factory configuration files generated by the coder.
"""

from ftagent.app.finetune.llm.conf import FT_RD_SETTING
from ftagent.components.coder.CoSTEER import CoSTEER
from ftagent.components.coder.CoSTEER.evaluators import (
    CoSTEERMultiEvaluator,
    CoSTEERSingleFeedback,
)
from ftagent.components.coder.CoSTEER.evolving_strategy import (
    MultiProcessEvolvingStrategy,
)
from ftagent.components.coder.CoSTEER.knowledge_management import (
    CoSTEERQueriedKnowledge,
)
from ftagent.components.coder.finetune.conf import (
    FT_YAML_FILE_NAME,
    FTCoderCoSTEERSettings,
)
from ftagent.components.coder.finetune.eval import FTDataEvaluator
from ftagent.core.experiment import FBWorkspace, Task
from ftagent.core.scenario import Scenario
from ftagent.log import ftagent_logger as logger
from ftagent.scenarios.finetune.train.eval import FTRunnerEvaluator


class FTRunnerSettings(FTCoderCoSTEERSettings):
    """LLM Fine-tuning specific runner settings."""

    class Config:
        env_prefix = "LLM_FT_Runner_"


class FTRunnerEvolvingStrategy(MultiProcessEvolvingStrategy):
    """Evolving strategy for LLM fine-tuning runner.

    Runner directly executes the yaml from coder without modification.
    The coder generates full training config, and its validator tests with micro-batch.
    """

    def implement_one_task(
        self,
        target_task: Task,
        queried_knowledge: CoSTEERQueriedKnowledge | None = None,
        workspace: FBWorkspace | None = None,
        prev_task_feedback: CoSTEERSingleFeedback | None = None,
    ) -> dict[str, str]:
        """No modification needed - directly use coder's full training config."""
        # TODO: detect error during training automatically, and fix it here
        if not workspace or FT_YAML_FILE_NAME not in workspace.file_dict:
            logger.error(f"No {FT_YAML_FILE_NAME} found in workspace")
            return {}

        # Coder already generated full training config, no modification needed
        # Return empty dict to indicate no changes
        return {}


class LLMFinetuneRunner(CoSTEER):
    """LLM Fine-tuning specific runner that executes LLaMA-Factory configurations."""

    def __init__(
        self,
        scen: Scenario,
        *args,
        **kwargs,
    ) -> None:
        eval_l = [
            FTRunnerEvaluator(scen=scen),  # Training validation
        ]

        eva = CoSTEERMultiEvaluator(single_evaluator=eval_l, scen=scen)
        settings = FTRunnerSettings()

        # Use runner-specific evolving strategy for full dataset training
        es = FTRunnerEvolvingStrategy(scen=scen, settings=settings, improve_mode=True)

        # Initialize with LLM-specific configuration
        super().__init__(
            *args,
            settings=settings,
            eva=eva,
            es=es,
            evolving_version=2,
            scen=scen,
            max_loop=getattr(FT_RD_SETTING, "runner_max_loop", 1),  # Default to 1 loop for running
            stop_eval_chain_on_fail=True,  # finetune involve partial implementation.
            **kwargs,
        )

    def develop(self, exp):
        """Execute LLaMA-Factory fine-tuning on full dataset.

        Runner directly executes the full training config generated by coder.
        The actual training execution and basic validation are handled by LLMFinetuneEvaluator.
        Benchmark evaluation should be done as a separate step after training.
        """
        logger.info("Starting full dataset LLM fine-tuning with LLaMA-Factory")

        # Run the standard CoSTEER develop process:
        # 1. Execute training using coder's full training config (no modification)
        # 2. Validate execution using LLMFinetuneEvaluator
        exp = super().develop(exp)
        return exp

    def get_develop_max_seconds(self) -> int | None:
        """Get maximum seconds for development using FT settings."""
        return int(self.scen.real_full_timeout() * self.settings.max_seconds_multiplier)

    def compare_and_pick_fb(self, base_fb, new_fb) -> bool:
        """Compare feedback for LLM fine-tuning results."""
        if base_fb is None:
            return True

        base_fb = base_fb[0]
        new_fb = new_fb[0]

        def compare_scores(s1, s2) -> bool:
            if s2 is None:
                return False
            if s1 is None:
                return True
            return (s2 > s1) == self.scen.metric_direction

        return compare_scores(getattr(base_fb, "score", None), getattr(new_fb, "score", None))
