"""Toy CE-Graph refinement demo (offline / deterministic).

This script is meant to validate that the failure-driven refinement loop is
wired correctly, without depending on any model APIs.

It constructs a small workflow, simulates executions to produce failures, then
applies CE-Graph to refine the workflow using a conservative operator library.
"""

from __future__ import annotations

# Allow running from repo root without installing the package.
import os
import sys

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)

import random
from typing import List, Tuple

from verl.ce_graph import (
    Workflow,
    Node,
    ExecutionTrace,
    CounterexamplePool,
    FailureClustering,
)
from verl.ce_graph.operators import default_operator_library
from verl.ce_graph.refine import refine_workflow


def build_initial_workflow() -> Workflow:
    # A minimal agent-like workflow: llm -> tool -> llm
    w = Workflow(metadata={"name": "toy", "version": "W0"})
    w.add_node(Node("n0", "llm", {"prompt_prefix": ""}))
    w.add_node(Node("n1", "tool", {"max_retries": 1}))
    w.add_node(Node("n2", "llm", {"prompt_prefix": ""}))
    w.add_edge("n0", "n1")
    w.add_edge("n1", "n2")
    return w


def toy_runner(workflow: Workflow) -> Tuple[float, List[ExecutionTrace]]:
    """Simulate execution.

    Heuristic behavior:
      - If there's a verifier node, fewer invalid_json failures.
      - If tool retries increased, fewer tool timeouts.
      - If constraints prefix added, slightly fewer semantic mismatches.
    """

    has_verifier = any(n.node_type == "verifier" for n in workflow.nodes)
    tool_retries = max([n.config.get("max_retries", 1) for n in workflow.nodes if n.node_type == "tool"] + [1])
    constraints = any(
        n.node_type == "llm" and "Follow the tool schema" in n.config.get("prompt_prefix", "")
        for n in workflow.nodes
    )

    traces: List[ExecutionTrace] = []
    successes = 0
    total = 200
    rng = random.Random(0)

    for i in range(total):
        r = rng.random()
        t = ExecutionTrace(instance_id=str(i), workflow_version=workflow.metadata.get("version", "W"))

        # Failure probabilities
        p_invalid = 0.18 * (0.35 if has_verifier else 1.0)
        p_tool = 0.22 * (0.55 if tool_retries >= 2 else 1.0)
        p_sem = 0.25 * (0.80 if constraints else 1.0)

        if r < p_invalid:
            t.is_success = False
            t.node_outputs.append({"invalid_json": True})
            t.error = "invalid_json"
            t.final_output = "{bad json"
        elif r < p_invalid + p_tool:
            t.is_success = False
            t.node_outputs.append({"tool_error": "timeout"})
            t.error = "tool_timeout"
            t.final_output = ""
        elif r < p_invalid + p_tool + p_sem:
            t.is_success = False
            t.error = "mismatch"
            t.final_output = "wrong answer"
        else:
            t.is_success = True
            t.final_output = "correct"
            successes += 1

        traces.append(t)

    score = successes / total
    return score, traces


def main() -> None:
    w0 = build_initial_workflow()
    lib = default_operator_library()
    pool = CounterexamplePool(capacity=5000)
    clustering = FailureClustering(max_modes=6, min_cluster_size=10)

    w_best, reports = refine_workflow(
        workflow=w0,
        runner=toy_runner,
        operator_lib=lib,
        pool=pool,
        clustering=clustering,
        max_iters=3,
        topk_ops=4,
    )

    print("Final workflow:")
    print(w_best.to_dict())
    print("\nReports:")
    for r in reports:
        print(
            f"iter={r.iteration} accepted={r.accepted} op={r.chosen_operator} "
            f"cand={r.candidate_score:.3f} base={r.base_score:.3f} keywords={r.notes.get('keywords','')}"
        )


if __name__ == "__main__":
    main()
