from __future__ import annotations

import math
from typing import Any, Dict, List, Optional


def safe_float(value: Any) -> Optional[float]:
    try:
        parsed = float(value)
    except (TypeError, ValueError):
        return None
    if not math.isfinite(parsed):
        return None
    return float(parsed)


def safe_int(value: Any, default: int = 0) -> int:
    try:
        return int(value)
    except (TypeError, ValueError):
        return int(default)


def metric_threshold(spec: Optional[Dict[str, Any]]) -> Optional[float]:
    if not isinstance(spec, dict):
        return None
    metric = spec.get("metric")
    if not isinstance(metric, dict):
        return None
    threshold = safe_float(metric.get("threshold"))
    if threshold is None or threshold <= 0:
        return None
    return threshold


def piece_interval_index(spec: Optional[Dict[str, Any]]) -> Dict[str, Dict[str, float]]:
    if not isinstance(spec, dict):
        return {}
    pieces = ((spec.get("domain") or {}).get("pieces")) if isinstance(spec.get("domain"), dict) else None
    if not isinstance(pieces, list):
        return {}
    out: Dict[str, Dict[str, float]] = {}
    for piece in pieces:
        if not isinstance(piece, dict):
            continue
        piece_id = piece.get("piece_id")
        interval = piece.get("interval")
        if piece_id is None or not isinstance(interval, dict):
            continue
        start = safe_float(interval.get("start"))
        end = safe_float(interval.get("end"))
        if start is None or end is None or end < start:
            continue
        out[str(piece_id)] = {
            "start": start,
            "end": end,
            "width": max(0.0, end - start),
        }
    return out


def _status_term(status: Optional[str]) -> float:
    token = str(status or "").strip().lower()
    if token in ("failed", "stopped"):
        return 0.5
    if token in ("running", "pending", "queued"):
        return 0.2
    return 0.0


def compute_piece_pressure(
    *,
    error: Optional[float],
    threshold: Optional[float],
    stagnation_count: int,
    health_score: float,
    status: Optional[str],
    throughput_tps: float = 0.0,
    last_rel_improve: float = 0.0,
    width: Optional[float] = None,
) -> Dict[str, float | None]:
    health = max(0.05, float(health_score))
    throughput = max(0.0, float(throughput_tps))
    rel_improve = max(0.0, float(last_rel_improve))
    stagnation = max(0, int(stagnation_count))

    error_ratio: Optional[float] = None
    if error is None:
        error_term = 1.0
    elif threshold is not None:
        error_ratio = max(error / threshold, 1e-12)
        error_term = math.log1p(error_ratio)
    else:
        error_term = math.log1p(max(error, 1e-12))

    throughput_term = 1.0 / (1.0 + throughput)
    improvement_term = 1.0 - min(1.0, rel_improve * 50.0)
    width_term = math.log1p(max(float(width or 0.0), 0.0))
    health_term = 1.0 / health
    status_term = _status_term(status)

    score = (
        error_term
        + 0.35 * float(stagnation)
        + 0.25 * health_term
        + 0.15 * throughput_term
        + 0.15 * improvement_term
        + 0.05 * width_term
        + status_term
    )
    split_score = error_term + 0.45 * float(stagnation) + 0.30 * health_term + 0.10 * width_term
    return {
        "score": float(score),
        "split_score": float(split_score),
        "error_ratio": float(error_ratio) if error_ratio is not None else None,
    }


def _suggested_action(
    *,
    error_ratio: Optional[float],
    stagnation_count: int,
    throughput_tps: float,
    width: Optional[float],
    min_piece_width: float,
) -> str:
    if error_ratio is not None and error_ratio <= 1.0:
        return "verify"
    if width is not None and width >= max(min_piece_width * 2.0, 1e-12) and stagnation_count >= 20:
        return "split"
    if stagnation_count >= 10 and throughput_tps <= 0.05:
        return "adjust_sampling"
    return "continue_search"


def build_repartition_focus(
    *,
    spec: Optional[Dict[str, Any]],
    rows: List[Dict[str, Any]],
    worker_plan: Optional[Dict[str, Any]] = None,
    limit: int = 8,
) -> Dict[str, Any]:
    threshold = metric_threshold(spec)
    interval_map = piece_interval_index(spec)
    stop_criteria = (spec.get("stop_criteria") or {}) if isinstance(spec, dict) else {}
    min_piece_width = safe_float(stop_criteria.get("min_piece_width")) or 1e-6
    per_piece_workers = (worker_plan or {}).get("per_piece") if isinstance(worker_plan, dict) else {}
    if not isinstance(per_piece_workers, dict):
        per_piece_workers = {}

    ranked: List[Dict[str, Any]] = []
    for row in rows:
        piece_id = row.get("piece_id")
        piece_key = str(piece_id) if piece_id is not None else None
        interval = interval_map.get(piece_key or "")
        error = safe_float(row.get("error"))
        if error is None:
            error = safe_float(row.get("best_error"))
        stagnation_count = max(0, safe_int(row.get("stagnation_count"), 0))
        health_score = safe_float(row.get("health_score"))
        throughput_tps = safe_float(row.get("throughput_tps")) or 0.0
        last_rel_improve = safe_float(row.get("last_rel_improve")) or 0.0
        metrics = compute_piece_pressure(
            error=error,
            threshold=threshold,
            stagnation_count=stagnation_count,
            health_score=health_score if health_score is not None else 1.0,
            status=str(row.get("status") or ""),
            throughput_tps=throughput_tps,
            last_rel_improve=last_rel_improve,
            width=interval.get("width") if isinstance(interval, dict) else None,
        )
        width = interval.get("width") if isinstance(interval, dict) else None
        ranked.append(
            {
                "piece_id": piece_key,
                "task_tag": row.get("task_tag"),
                "status": row.get("status"),
                "error": error,
                "error_ratio": metrics.get("error_ratio"),
                "stagnation_count": stagnation_count,
                "health_score": health_score,
                "throughput_tps": throughput_tps,
                "last_rel_improve": last_rel_improve,
                "score": metrics.get("score"),
                "split_score": metrics.get("split_score"),
                "interval": None
                if not isinstance(interval, dict)
                else {"start": interval["start"], "end": interval["end"]},
                "width": width,
                "worker_num": safe_int(per_piece_workers.get(piece_key), 0) if piece_key else 0,
                "suggested_action": _suggested_action(
                    error_ratio=metrics.get("error_ratio"),
                    stagnation_count=stagnation_count,
                    throughput_tps=throughput_tps,
                    width=width,
                    min_piece_width=min_piece_width,
                ),
            }
        )

    ranked.sort(
        key=lambda item: (
            -float(item.get("score") or 0.0),
            -float(item.get("split_score") or 0.0),
            -float(item.get("error") or 0.0),
            str(item.get("piece_id") or ""),
        )
    )
    limited = ranked[: max(1, int(limit))] if ranked else []
    primary = limited[0] if limited else {}
    return {
        "primary_piece_id": primary.get("piece_id"),
        "primary_task_tag": primary.get("task_tag"),
        "focus_piece_ids": [str(item["piece_id"]) for item in limited if item.get("piece_id") is not None],
        "ranked_pieces": limited,
    }
