"""Backward routing controller for TextResNet."""

from __future__ import annotations

from typing import Any, Dict, List, TypedDict

from core.utils.textresnet import STOP_GRADIENT, BackwardResponse, GradientSignal, merge_context


class TextResNetBackpropagator:
    """Route gradient signals through system components in reverse order."""

    CHAINS = {
        "bigcodebench": ["final_code_generator", "unit_test_generator", "code_generator"],
        "hotpotqa": ["answer_generator", "hint_generator", "info_extractor", "question_rewriter"],
    }

    def __init__(self, system, task: str):
        if task not in self.CHAINS:
            raise ValueError(f"Unsupported task '{task}'.")
        self.system = system
        self.task = task
        self.chain: List[str] = self.CHAINS[task]

    def run(self, signal: GradientSignal, traj: Dict) -> Dict[str, Dict[str, Any]]:
        """Propagate gradient signal through configured chain.

        Returns mapping: component_name -> {incoming_feedback, local_fix, upstream_grad}
        """

        def _is_objective_satisfied(feedback: str) -> bool:
            """Heuristic: detect 'no-op' feedback so STOP_GRADIENT is safe."""
            s = (feedback or "").strip().lower()
            if not s:
                return True
            if "no feedback" in s:
                return True
            if "answer correct" in s:
                return True
            if s == "success":
                return True
            return False

        responses: Dict[str, Dict[str, Any]] = {}
        current_signal = signal

        for component_name in self.chain:
            component = self.system.components.get(component_name)
            if component is None or not hasattr(component, "backward"):
                continue

            component_traj = traj.get(component_name, {})
            response: BackwardResponse = component.backward(current_signal, traj, component_traj)
            responses[component_name] = {
                "incoming_feedback": current_signal.feedback,
                "local_fix": response.local_fix,
                "upstream_grad": response.upstream_grad,
                "component_traj": component_traj,
                "backward_context": getattr(response, "debug", None),
            }

            if response.upstream_grad == STOP_GRADIENT:
                if (not (response.local_fix or "").strip()) and (
                    not _is_objective_satisfied(current_signal.feedback)
                ):
                    continue
                break

            merged = merge_context(
                current_signal.context, {"source": component_name, "previous": response.local_fix}
            )
            current_signal = GradientSignal(feedback=response.upstream_grad, context=merged)

        return responses
