from __future__ import annotations
from typing import Dict, Any, List, DefaultDict, Tuple
from collections import defaultdict
from .common import detect_time_unit, to_ms, normalize_op_name


def _walk_nodes(root: Dict[str, Any], unit: str) -> List[Dict[str, Any]]:
    stack = [root]
    order: List[Dict[str, Any]] = []
    while stack:
        n = stack.pop()
        order.append(n)
        for ch in n.get("children", []) or []:
            stack.append(ch)

    child_inc_map = {}
    nodes: List[Dict[str, Any]] = []
    start0 = float(root.get("start", 0.0))

    for n in reversed(order):
        start = float(n.get("start", 0.0)); end = float(n.get("end", 0.0))
        inc_ms = to_ms(end - start, unit)
        ch_sum = 0.0
        for ch in n.get("children", []) or []:
            ch_sum += child_inc_map.get(id(ch), 0.0)
        self_ms = max(0.0, inc_ms - ch_sum)
        child_inc_map[id(n)] = inc_ms
        nodes.append({
            "name": normalize_op_name(n.get("name", "unknown")),
            "layer": n.get("layer", "unknown"),
            "inc_ms": max(0.0, inc_ms),
            "self_ms": self_ms,
            "is_error": bool(n.get("is_error", False)),
            "start_ms": to_ms(start - start0, unit),
            "end_ms": to_ms(end - start0, unit),
        })
    return nodes

def collect_nodes(step_tree: Dict[str, Any]) -> List[Dict[str, Any]]:
    unit = detect_time_unit(float(step_tree.get("end", 0.0)))
    return _walk_nodes(step_tree, unit)

def collect_step_op_samples(step_tree: Dict[str, Any], metric: str = "inc") -> Dict[str, List[float]]:
    nodes = collect_nodes(step_tree)
    key = "inc_ms" if metric == "inc" else "self_ms"
    mp: DefaultDict[str, List[float]] = defaultdict(list)
    for n in nodes:
        mp[n["name"]].append(float(n[key]))
    return dict(mp)

def collect_parent_pairs(step_tree: Dict[str, Any]) -> List[Tuple[str, str]]:
    pairs: List[Tuple[str, str]] = []
    stack: List[Tuple[Dict[str, Any], Dict[str, Any] | None]] = [(step_tree, None)]

    while stack:
        node, parent = stack.pop()
        name = normalize_op_name(node.get("name", "unknown"))
        if parent is not None:
            p_name = normalize_op_name(parent.get("name", "unknown"))
            pairs.append((name, p_name))

        for ch in node.get("children", []) or []:
            stack.append((ch, node))

    return pairs

def summarize_step_time(step_tree: Dict[str, Any]) -> Dict[str, Any]:
    unit = detect_time_unit(float(step_tree.get("end", 0.0)))
    dur_ms = to_ms(float(step_tree.get("end",0.0)) - float(step_tree.get("start",0.0)), unit)
    nodes = _walk_nodes(step_tree, unit)
    return {"dur_ms": dur_ms, "node_count": len(nodes)}

def extract_step_labels(step_tree: Dict[str, Any]) -> Dict[str, Any]:
    nodes = collect_nodes(step_tree)
    step_error = any(n.get("is_error") for n in nodes)
    from collections import defaultdict
    per_name_map: DefaultDict[str, Dict[str, int]] = defaultdict(lambda: {"count": 0, "error_count": 0})
    for n in nodes:
        per_name_map[n["name"]]["count"] += 1
        if n.get("is_error"):
            per_name_map[n["name"]]["error_count"] += 1
    per_name = []
    for name, v in per_name_map.items():
        per_name.append({
            "name": name,
            "count": v["count"],
            "error_count": v["error_count"],
            "error": v["error_count"] > 0
        })
    per_name = sorted(per_name, key=lambda x: (-int(x["error"]), -x["error_count"], -x["count"], x["name"]))
    return {"step_error": step_error, "per_name": per_name}
