from __future__ import annotations
from typing import Dict, Any, List, Tuple
import os, json, math

from .common import normpath, get_step_idx_from_path
from .tree_collect import (
    collect_step_op_samples,
    summarize_step_time,  
    extract_step_labels,
    collect_parent_pairs,
)
from .cluster import ClusterConfig, fit_pid_clusters, SoftClusterer


def iter_tasks(root_dir: str):
    for mode_dir in ["vllm_online", "vllm_offline"]:
        base = os.path.join(root_dir, mode_dir)
        if not os.path.isdir(base):
            continue
        for name in sorted(os.listdir(base)):
            if name.startswith("task_"):
                yield ("online" if mode_dir == "vllm_online" else "offline",
                       os.path.join(base, name))


def _load_json(p: str) -> Any:
    with open(p, "r", encoding="utf-8") as f:
        return json.load(f)


def _save_json(p: str, data: dict):
    os.makedirs(os.path.dirname(p), exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def _resolve_step_path(task_dir: str, raw_path: str) -> str:
    p = normpath(raw_path)
    if os.path.isabs(p) and os.path.exists(p):
        return p
    if "output" in p:
        idx = p.find("task_")
        if idx != -1:
            sub = p[idx:]
            parts = sub.split(os.sep)
            if len(parts) >= 2 and parts[0].startswith("task_"):
                sub = os.path.join(*parts[1:])
            return os.path.join(task_dir, sub)
    return os.path.join(task_dir, p)


def _list_all_step_paths_for_kind(root_dir: str, kind: str) -> List[str]:
    paths: List[str] = []
    info_key = "process_vertical_info" if kind == "vertical" else "process_horizontal_info"
    for _, task_dir in iter_tasks(root_dir):
        meta_path = os.path.join(task_dir, "meta_data.json")
        if not os.path.exists(meta_path):
            continue
        try:
            meta = _load_json(meta_path)
        except Exception:
            continue
        processes = meta.get("all_process_data", []) or []
        for proc in processes:
            for path_label in (proc.get(info_key) or []):
                raw = path_label[0]
                step_path = _resolve_step_path(task_dir, raw)
                if os.path.exists(step_path):
                    paths.append(step_path)
    return paths


def _collect_global_samples(
    root_dir: str,
    *,
    kind: str,             # "vertical" | "horizontal"
    metric: str,           # "self" | "inc"
    rebuild_cache: bool,
) -> Dict[str, List[float]]:
    assert kind in ("vertical", "horizontal")
    cache_path = os.path.join(root_dir, f"cache_op_samples_{metric}_{kind}.json")

    if (not rebuild_cache) and os.path.exists(cache_path):
        try:
            return _load_json(cache_path)
        except Exception:
            pass

    samples_by_op: Dict[str, List[float]] = {}
    all_paths = _list_all_step_paths_for_kind(root_dir, kind)
    for sp in all_paths:
        try:
            tree = _load_json(sp)
            one = collect_step_op_samples(tree, metric=metric)
            for name, arr in one.items():
                samples_by_op.setdefault(name, []).extend(arr)
        except Exception:
            continue

    try:
        _save_json(cache_path, samples_by_op)
    except Exception:
        pass
    return samples_by_op


def _fit_global_models(
    root_dir: str,
    *,
    kind: str,
    cluster_method: str,
    k: int,
    metric: str,
    log_transform: bool,
    center_mass_alpha: float,
    rebuild_cache: bool,
) -> Dict[str, SoftClusterer]:
    cfg = ClusterConfig(
        method=cluster_method,
        k=k,
        metric=metric,
        log_transform=log_transform,
        center_mass_alpha=center_mass_alpha,
    )
    global_samples = _collect_global_samples(
        root_dir,
        kind=kind,
        metric=metric,
        rebuild_cache=rebuild_cache,
    )
    models = fit_pid_clusters(global_samples, cfg)  
    return models


def _features_from_probs(
    step_tree: Dict[str, Any],
    models: Dict[str, SoftClusterer],
    *,
    metric: str,
    log_transform: bool,
    kind: str,  # "vertical" | "horizontal"
) -> Dict[str, Any]:
    samples   = collect_step_op_samples(step_tree, metric=metric)  # {op: [values]}
    pair_list = collect_parent_pairs(step_tree)  # List[(child_name, parent_name)]
    from collections import Counter, defaultdict
    parent_counter: Dict[str, Counter] = defaultdict(Counter)
    for child, parent in pair_list:
        parent_counter[child][parent] += 1

    ops_prob = []
    import numpy as np
    for name, arr in samples.items():
        X = np.array(arr, dtype=float).reshape(-1, 1)
        if log_transform:
            X = np.log1p(X)

        model = models.get(name)
        if model is None:
            prob = np.full((len(arr),), 0.5, dtype=float)  
        else:
            prob = model.prob_center(X)

        n = len(prob)
        if n <= 0:
            continue

        if kind == "vertical":
            kq = max(1, math.floor(n * 0.1))  
        elif kind == "horizontal":
            kq = max(1, math.floor(n * 0.25))  
        else:
            raise NotImplementedError

        prob_center = float(np.sort(prob)[:kq].mean())

        pc = parent_counter.get(name)
        top_parent = pc.most_common(1)[0][0] if pc and len(pc) > 0 else None

        ops_prob.append({"name": name, "prob_center": prob_center, "parent": top_parent})

    ops_prob.sort(key=lambda x: x["prob_center"])

    return {"kind": kind, "ops_count": len(ops_prob), "ops_prob": ops_prob}


def _attach_step_features_with_models(
    task_dir: str,
    *,
    v_models: Dict[str, SoftClusterer],
    h_models: Dict[str, SoftClusterer],
    metric: str,
    log_transform: bool,
    write_labels: bool,
    output_style: str,
    include_labels_in_meta: bool,
) -> Dict[str, Any]:
    meta_path = os.path.join(task_dir, "meta_data.json")
    meta = _load_json(meta_path)
    processes = meta.get("all_process_data", [])
    out_procs: List[Dict[str, Any]] = []

    index = {"vertical": [], "horizontal": []}

    for proc in processes:
        pid = proc.get("pid")
        is_master = bool(proc.get("is_master", False))
        is_worker = bool(proc.get("is_worker", False))

        v_steps, h_steps = [], []

        # vertical
        for path_label in proc.get("process_vertical_info", []) or []:
            raw, meta_lbl = path_label
            step_path = _resolve_step_path(task_dir, raw)
            step_idx = get_step_idx_from_path(step_path)
            try:
                tree = _load_json(step_path)
                feat = _features_from_probs(
                    tree, models=v_models, metric=metric, log_transform=log_transform, kind="vertical"
                )
                rec = {"step_idx": step_idx, "path": step_path}
                if include_labels_in_meta:
                    rec["label_anomaly"] = bool(meta_lbl)

                if output_style == "per_task":
                    rec["features"] = feat
                else:
                    feat_path = os.path.join(os.path.dirname(step_path), f"features_v_step_{step_idx}.json")
                    _save_json(feat_path, feat)
                    rec["feature_path"] = feat_path

                    if write_labels:
                        step_labels = extract_step_labels(tree)
                        label_obj = {
                            "step_idx": step_idx,
                            "path": step_path,
                            "label_step_anomaly": bool(meta_lbl),
                            "per_name": step_labels["per_name"],
                        }
                        label_path = os.path.join(os.path.dirname(step_path), f"label_v_step_{step_idx}.json")
                        _save_json(label_path, label_obj)
                        rec["label_path"] = label_path

                    index["vertical"].append({"pid": pid, **rec})

                v_steps.append(rec)
            except Exception as e:
                v_steps.append({"step_idx": step_idx, "path": step_path, "error": str(e)})

        # horizontal
        for path_label in proc.get("process_horizontal_info", []) or []:
            raw, meta_lbl = path_label
            step_path = _resolve_step_path(task_dir, raw)
            step_idx = get_step_idx_from_path(step_path)
            try:
                tree = _load_json(step_path)
                feat = _features_from_probs(
                    tree, models=h_models, metric=metric, log_transform=log_transform, kind="horizontal"
                )
                rec = {"step_idx": step_idx, "path": step_path}
                if include_labels_in_meta:
                    rec["label_anomaly"] = bool(meta_lbl)

                if output_style == "per_task":
                    rec["features"] = feat
                else:
                    feat_path = os.path.join(os.path.dirname(step_path), f"features_h_step_{step_idx}.json")
                    _save_json(feat_path, feat)
                    rec["feature_path"] = feat_path

                    if write_labels:
                        step_labels = extract_step_labels(tree)
                        label_obj = {
                            "step_idx": step_idx,
                            "path": step_path,
                            "label_step_anomaly": bool(meta_lbl),
                            "per_name": step_labels["per_name"],
                        }
                        label_path = os.path.join(os.path.dirname(step_path), f"label_h_step_{step_idx}.json")
                        _save_json(label_path, label_obj)
                        rec["label_path"] = label_path

                    index["horizontal"].append({"pid": pid, **rec})

                h_steps.append(rec)
            except Exception as e:
                h_steps.append({"step_idx": step_idx, "path": step_path, "error": str(e)})

        out_procs.append({
            "pid": pid,
            "role": {"is_master": is_master, "is_worker": is_worker},
            "vertical_steps": sorted(v_steps, key=lambda x: x.get("step_idx", -1)),
            "horizontal_steps": sorted(h_steps, key=lambda x: x.get("step_idx", -1)),
        })

    pack = {"processes": out_procs}
    if output_style == "per_step":
        idx_path = os.path.join(task_dir, "features_index.json")
        _save_json(idx_path, index)
        pack["features_index_path"] = idx_path
    return pack


def build_dataset(
    root_dir: str,
    *,
    cluster_method: str = "gmm",
    k_vertical: int = 2,
    k_horizontal: int = 2,
    metric: str = "self",             
    log_transform: bool = True,        
    rebuild_cache: bool = True,        
    write_labels: bool = False,        
    output_style: str = "per_step",
    include_labels_in_meta: bool = True,
    center_mass_alpha_v: float = 0.7,
    center_mass_alpha_h: float = 0.7,
) -> List[str]:
    v_models = _fit_global_models(
        root_dir,
        kind="vertical",
        cluster_method=cluster_method,
        k=k_vertical,
        metric=metric,
        log_transform=log_transform,
        center_mass_alpha=center_mass_alpha_v,
        rebuild_cache=rebuild_cache,
    )
    h_models = _fit_global_models(
        root_dir,
        kind="horizontal",
        cluster_method=cluster_method,
        k=k_horizontal,
        metric=metric,
        log_transform=log_transform,
        center_mass_alpha=center_mass_alpha_h,
        rebuild_cache=rebuild_cache,
    )

    outputs = []
    for mode, task_dir in iter_tasks(root_dir):
        pack = _attach_step_features_with_models(
            task_dir,
            v_models=v_models,
            h_models=h_models,
            metric=metric,
            log_transform=log_transform,
            write_labels=write_labels,
            output_style=output_style,
            include_labels_in_meta=include_labels_in_meta,
        )
        pack["mode"] = mode
        pack["task_id"] = os.path.basename(task_dir)

        out_path = os.path.join(task_dir, "features_task_meta.json" if output_style == "per_step" else "features_task.json")
        _save_json(out_path, pack)
        outputs.append(out_path)

    return outputs
