import os
import time
from typing import List, Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed

from llm.prompts import PER_STEP_SYSTEM, build_per_step_user
from llm.openai import chat_complete
from llm.utils import (
    read_json, write_json, iter_feature_files,
    pred_path_for_features, label_path_for_features,
    extract_json_block, filter_pred_names_to_step, augment_with_parents
)

def predict_one(features_path: str, *, model: str, temperature: float, max_tokens: int, verbose: bool) -> dict:
    t0 = time.time()
    features = read_json(features_path)

    txt = chat_complete(
        model=model,
        system=PER_STEP_SYSTEM,
        user=build_per_step_user(features),
        temperature=temperature,
        max_tokens=max_tokens,
        extra={"response_format": {"type": "json_object"}}
    )

    raw_obj = extract_json_block(txt) or {}
    step_anom = bool(raw_obj.get("step_anomaly", False))
    op_err = filter_pred_names_to_step(features, raw_obj.get("op_error_names", []))
    op_err = augment_with_parents(features, op_err)

    pred = {
        "features_path": features_path,
        "model": model,
        "params": {"temperature": temperature, "max_tokens": max_tokens},
        "created_at": time.strftime("%Y-%m-%d %H:%M:%S"),
        "prediction": {
            "step_anomaly": step_anom,
            "op_error_names": op_err
        },
        "llm_raw": {
            "step_anomaly": step_anom,
            "op_error_names": op_err
        }
    }

    label_path = label_path_for_features(features_path)
    if os.path.exists(label_path):
        try:
            label = read_json(label_path)
            pred["label"] = {
                "step_error": bool(label.get("label_step_anomaly", False)),
                "op_error_names": [x["name"] for x in label.get("per_name", []) if x.get("error")]
            }
        except Exception:
            pass

    if verbose:
        print(f"[OK] Predicted {features_path} in {time.time() - t0:.2f}s")
    return pred


def run_single(features_path: str, *, model: str, temperature: float, max_tokens: int, verbose: bool, save: bool = True) -> dict:
    pred = predict_one(
        features_path,
        model=model,
        temperature=temperature,
        max_tokens=max_tokens,
        verbose=verbose
    )
    if save:
        write_json(pred_path_for_features(features_path), pred)
    return pred


def run_batch(
    root_dir: str,
    *,
    model: str,
    temperature: float,
    max_tokens: int,
    recompute: bool,
    verbose: bool,
    concurrent: bool = False,
    workers: int = 4,
) -> int:
    all_features: List[str] = iter_feature_files(root_dir)
    todo: List[str] = []
    for fpath in all_features:
        outp = pred_path_for_features(fpath)
        if (not recompute) and os.path.exists(outp):
            if verbose:
                print(f"[SKIP] Exists: {outp}")
            continue
        todo.append(fpath)

    if verbose:
        print(f"[INFO] total_features={len(all_features)}, to_predict={len(todo)}, concurrent={concurrent}, workers={workers}")

    wrote = 0

    if not concurrent or workers <= 1:
        for fpath in todo:
            try:
                pred = predict_one(
                    fpath,
                    model=model,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    verbose=verbose
                )
                write_json(pred_path_for_features(fpath), pred)
                wrote += 1
            except Exception as e:
                print(f"[ERR] {fpath} -> {e}")
        return wrote

    def _job(path: str) -> Optional[str]:
        try:
            pred = predict_one(
                path,
                model=model,
                temperature=temperature,
                max_tokens=max_tokens,
                verbose=verbose
            )
            write_json(pred_path_for_features(path), pred)
            return path
        except Exception as e:
            print(f"[ERR] {path} -> {e}")
            return None

    with ThreadPoolExecutor(max_workers=int(workers)) as ex:
        futures = {ex.submit(_job, p): p for p in todo}
        for fut in as_completed(futures):
            res = fut.result()
            if res is not None:
                wrote += 1

    return wrote
