#!/usr/bin/env python3
"""
Probe which HF Inference API providers support a given model.

Usage:
    python scripts/probe_hf_providers.py deepseek-ai/DeepSeek-Prover-V2-671B
    python scripts/probe_hf_providers.py meta-llama/Llama-3.1-8B-Instruct --stop-first
    python scripts/probe_hf_providers.py moonshotai/Kimi-K2.6 --providers novita nscale
    python scripts/probe_hf_providers.py <model> --prompt "What is 1+1?" --max-tokens 16

Auth: HF_TOKEN env var (or HF_API_TOKEN / HUGGINGFACE_TOKEN as fallbacks).
"""
from __future__ import annotations

import argparse
import os
import sys
import time

# All providers the HF router knows about (as of mid-2025).
_ALL_PROVIDERS = [
    "auto",
    "cerebras",
    "cohere",
    "fal-ai",
    "featherless-ai",
    "fireworks-ai",
    "groq",
    "hf-inference",
    "hyperbolic",
    "nebius",
    "novita",
    "nscale",
    "perplexity-ai",
    "replicate",
    "sambanova",
    "together",
]

_TEST_PROMPT = "Reply with exactly one word: HELLO."


def _resolve_token() -> str | None:
    return (
        os.environ.get("HF_TOKEN")
        or os.environ.get("HF_API_TOKEN")
        or os.environ.get("HUGGINGFACE_TOKEN")
    )


def probe_provider(
    provider: str,
    model_id: str,
    prompt: str,
    max_tokens: int,
    token: str | None,
) -> tuple[bool, str, float]:
    """Return (ok, detail, elapsed_seconds)."""
    try:
        from huggingface_hub import InferenceClient
    except ImportError:
        print("ERROR: pip install huggingface_hub", file=sys.stderr)
        sys.exit(1)

    t0 = time.time()
    try:
        client = InferenceClient(provider=provider, token=token)
        resp = client.chat_completion(
            model=model_id,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
        )
        elapsed = time.time() - t0
        content = (resp.choices[0].message.content or "").strip()
        return True, repr(content), elapsed
    except Exception as exc:
        elapsed = time.time() - t0
        return False, f"{type(exc).__name__}: {str(exc)[:120]}", elapsed


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Probe HF Inference API providers for a model.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    parser.add_argument("model_id", help="HuggingFace model ID, e.g. meta-llama/Llama-3.1-8B-Instruct")
    parser.add_argument(
        "--providers", nargs="+", metavar="PROVIDER",
        help=f"Providers to probe (default: all {len(_ALL_PROVIDERS)}). "
             f"Known: {', '.join(_ALL_PROVIDERS)}",
    )
    parser.add_argument("--prompt", default=_TEST_PROMPT, help="Test prompt to send")
    parser.add_argument("--max-tokens", type=int, default=16, metavar="N",
                        help="Max tokens for the test response (default: 16)")
    parser.add_argument("--stop-first", action="store_true",
                        help="Stop after the first successful provider")
    args = parser.parse_args()

    providers = args.providers or _ALL_PROVIDERS
    token = _resolve_token()
    if not token:
        print("WARNING: No HF token found. Set HF_TOKEN for gated models.", file=sys.stderr)

    print(f"Model : {args.model_id}")
    print(f"Prompt: {args.prompt!r}")
    print(f"Probing {len(providers)} provider(s)...\n")

    ok_providers: list[str] = []
    fail_providers: list[str] = []

    col = 14  # width of provider column
    for provider in providers:
        ok, detail, elapsed = probe_provider(
            provider, args.model_id, args.prompt, args.max_tokens, token
        )
        tag = "OK  " if ok else "FAIL"
        print(f"  [{tag}] {provider:<{col}}  {elapsed:5.1f}s  {detail}")
        if ok:
            ok_providers.append(provider)
            if args.stop_first:
                break
        else:
            fail_providers.append(provider)

    print()
    if ok_providers:
        print(f"Working providers ({len(ok_providers)}): {', '.join(ok_providers)}")
        print(f"\nRecommended config snippet:")
        best = ok_providers[0]
        # prefer a pinned provider over 'auto' for reproducibility
        pinned = next((p for p in ok_providers if p != "auto"), best)
        print(f"  provider: {pinned}")
    else:
        print("No working providers found.")
        sys.exit(1)


if __name__ == "__main__":
    main()
