from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Sequence, Tuple

import time

from .workflow import Workflow
from .counterexamples import CounterexamplePool, ExecutionTrace
from .failures import FailureClustering, FailureMode
from .operators import OperatorLibrary


@dataclass
class RefinementReport:
    iteration: int
    base_score: float
    candidate_score: float
    accepted: bool
    chosen_mode: Optional[FailureMode]
    chosen_operator: Optional[str]
    notes: Dict[str, str]


def refine_workflow(
    workflow: Workflow,
    runner: Callable[[Workflow], Tuple[float, List[ExecutionTrace]]],
    operator_lib: OperatorLibrary,
    pool: Optional[CounterexamplePool] = None,
    clustering: Optional[FailureClustering] = None,
    max_iters: int = 5,
    topk_ops: int = 5,
    accept_if_improves: bool = True,
) -> Tuple[Workflow, List[RefinementReport]]:
    """Failure-driven workflow refinement loop.

    Args:
        workflow: initial workflow W_0.
        runner: a function that executes `workflow` on an evaluation set and returns
            (score, traces). `score` should be higher-is-better.
        operator_lib: library O of operator-constrained edits.
        pool: counterexample pool (failed traces). If None, one will be created.
        clustering: failure clustering model. If None, default is used.
        max_iters: number of refinement rounds.
        topk_ops: number of operator candidates to try per chosen failure mode.
        accept_if_improves: if True, accept only if candidate_score > base_score.

    Returns:
        (best_workflow, reports)
    """

    pool = pool or CounterexamplePool(capacity=5000)
    clustering = clustering or FailureClustering()
    reports: List[RefinementReport] = []

    # Evaluate the initial workflow.
    base_score, traces = runner(workflow)
    pool.extend(traces)

    for it in range(1, max_iters + 1):
        failed = pool.all()
        modes, labels = clustering.cluster(failed)
        if not modes:
            reports.append(
                RefinementReport(
                    iteration=it,
                    base_score=base_score,
                    candidate_score=base_score,
                    accepted=False,
                    chosen_mode=None,
                    chosen_operator=None,
                    notes={"reason": "no_failure_modes"},
                )
            )
            break

        # Choose the dominant failure mode (largest cluster).
        modes = sorted(modes, key=lambda m: m.size, reverse=True)
        chosen_mode = modes[0]

        # Propose edits under operator constraints.
        ops = operator_lib.propose(workflow, chosen_mode, top_k=topk_ops)
        if not ops:
            reports.append(
                RefinementReport(
                    iteration=it,
                    base_score=base_score,
                    candidate_score=base_score,
                    accepted=False,
                    chosen_mode=chosen_mode,
                    chosen_operator=None,
                    notes={"reason": "no_applicable_operator"},
                )
            )
            # Still rerun with same workflow? usually not useful.
            break

        best_cand = None
        best_cand_score = float("-inf")
        best_op_name = None
        best_cand_traces: List[ExecutionTrace] = []

        for op in ops:
            cand = op.apply(workflow, chosen_mode)
            cand_score, cand_traces = runner(cand)
            if cand_score > best_cand_score:
                best_cand = cand
                best_cand_score = cand_score
                best_op_name = op.name
                best_cand_traces = cand_traces

        accepted = (best_cand_score > base_score) if accept_if_improves else True
        if accepted and best_cand is not None:
            workflow = best_cand
            base_score = best_cand_score

        pool.extend(best_cand_traces)

        reports.append(
            RefinementReport(
                iteration=it,
                base_score=base_score if accepted else base_score,
                candidate_score=best_cand_score,
                accepted=accepted,
                chosen_mode=chosen_mode,
                chosen_operator=best_op_name,
                notes={"keywords": ",".join(chosen_mode.keywords)},
            )
        )

    return workflow, reports
