from __future__ import annotations
import os, time, warnings, logging, json, sys, subprocess, shlex
from pathlib import Path
import typer
from typing import List, Dict, Any

from .cfg import load_cfg, summary, ModelConfig
from .analysis.runner import run_task as run_analysis
from .repair.runner import run_task as run_repair
from .refactor.runner import run_task as run_refactor
from .transform.runner import run_task as run_transform

# Quiet noisy deps
os.environ.setdefault("GRPC_VERBOSITY", "ERROR")
os.environ.setdefault("GRPC_CPP_ENABLE_LOGGING", "0")
os.environ.setdefault("GLOG_minloglevel", "3")
os.environ.setdefault("ABSL_LOG_SEVERITY", "fatal")
warnings.filterwarnings("ignore", message=r".*Field name .* shadows an attribute.*", category=UserWarning)

from .reporting.html import render_html
from .reporting.csv import write_task_csvs
from .utils.logging_utils import setup_logging, get_logger, debug_on
from .reporting.aggregate import write_all_aggregates
app = typer.Typer(add_completion=False)

# ---------------- helpers ----------------

def _split_into_shards(items: List[str], n: int) -> List[List[str]]:
    n = max(1, int(n))
    if n == 1: return [items]
    shards: List[List[str]] = [[] for _ in range(n)]
    for i, it in enumerate(items):
        shards[i % n].append(it)
    return shards

def _discover_cases_ids(task_dir: Path) -> List[str]:
    from .registry import discover_cases
    cases = discover_cases(task_dir)
    return [c.id for c in cases]

def _apply_select(ids: List[str], select: str | None) -> List[str]:
    if not select: return ids
    os.environ["CIRBENCH_SELECT"] = select
    from .utils.case_select import select_cases
    class _C:
        def __init__(self, id): self.id=id
    selected = select_cases([_C(i) for i in ids])
    return [c.id for c in selected]

def _model_override(cfg, model: str | None, logger):
    if not model:
        return cfg
    try:
        provider, name = model.split(":", 1)
        provider = provider.strip().lower()
        name = name.strip()
    except ValueError:
        raise SystemExit("--model must be PROVIDER:NAME (e.g., gemini:gemini-2.5-flash)")

    base_params: Dict[str, Any] = {}
    try:
        models_list = list(getattr(cfg, "models", []) or [])
        chosen = None
        for m in models_list:
            if str(getattr(m, "kind", "")).lower() == provider:
                chosen = m; break
        if chosen is None:
            for m in models_list:
                if str(getattr(m, "name", "")).lower() == name.lower():
                    chosen = m; provider = str(getattr(m, "kind", provider) or provider).lower(); break
        if chosen is not None:
            base_params = dict(getattr(chosen, "params", {}) or {})
    except Exception:
        base_params = {}

    base_params["model"] = name
    cfg.models = [ModelConfig(kind=provider, name=name, params=base_params)]

    if debug_on():
        logger.info(f"[cli.model_override] provider={provider} name={name} param_keys={sorted(list(base_params.keys()))}")
    return cfg

def _model_slug(cfg) -> str:
    try:
        m = cfg.models[0] if getattr(cfg, "models", None) else None
        kind = (getattr(m, "kind", "") or "").lower()
        name = (getattr(m, "name", "") or getattr(getattr(m, "params", {}), "get", lambda *_: "")("model") or "").lower()
    except Exception:
        kind, name = "", ""
    import re as _re
    raw = (f"{kind}.{name}" if name else kind) or "model"
    s = _re.sub(r'[^a-z0-9._-]+', '-', raw).strip('-')
    return s or "model"

def _run_shard_inproc(task: str, cfg_path: Path, proj_root: Path, shard_ids: List[str]):
    os.environ["CIRBENCH_SHARDS"] = os.getenv("CIRBENCH_SHARDS") or "1"
    os.environ["CIRBENCH_SHARD_ID"] = os.getenv("CIRBENCH_SHARD_ID") or "1"
    os.environ["CIRBENCH_TASK"] = task
    prev_sel = os.getenv("CIRBENCH_SELECT")
    os.environ["CIRBENCH_SELECT"] = ",".join(shard_ids)

    c = load_cfg(cfg_path)
    if task == "analysis":   run_analysis(c, proj_root)
    elif task == "repair":   run_repair(c, proj_root)
    elif task == "refactor": run_refactor(c, proj_root)
    elif task == "transform":run_transform(c, proj_root)
    else:
        raise SystemExit("unknown task")

    if prev_sel is not None:
        os.environ["CIRBENCH_SELECT"] = prev_sel
    else:
        os.environ.pop("CIRBENCH_SELECT", None)

# ---------------- commands ----------------

@app.command()
def doctor(cfg: Path = typer.Option(..., exists=True)):
    c = load_cfg(cfg)
    setup_logging(Path(c.logging.dir))
    print(summary(c))

@app.command("list")
def cmd_list(
    task: str = typer.Option("analysis"),
    cfg: Path = typer.Option(..., exists=True),
    root: Path = typer.Option(Path(".")),
    select: str | None = typer.Option(None, help="Filter cases by id list / glob / re:regex"),
):
    from .registry import discover_cases
    from .utils.case_select import select_cases

    c = load_cfg(cfg)
    setup_logging(Path(c.logging.dir))
    tdir = root/("cirbench/"+task)
    cases = discover_cases(tdir)
    if select:
        os.environ["CIRBENCH_SELECT"] = select
        cases = select_cases(cases)
    for cse in cases:
        print(cse.id, "-", cse.root.name)
    print(f"Total: {len(cases)}")

@app.command()
def run(
    task: str = typer.Option("analysis"),
    cfg: Path = typer.Option(..., exists=True),
    root: Path = typer.Option(Path(".")),
    debug: bool = typer.Option(False, help="Print per-prompt debug lines"),
    debug_full: bool = typer.Option(False, help="Print full model outputs (no ellipsis)"),
    select: str | None = typer.Option(None, help="Run only selected cases: id1,id2 | glob | re:regex"),
    repair_mode: str | None = typer.Option(None, help="Repair mode: normal | hard"),
    refactor_mode: str | None = typer.Option(None, help="Refactor mode: normal | reverse"),
    transform_mode: str | None = typer.Option(None, help="Tranform mode: normal | copilot"),
    model: str | None = typer.Option(None, help="Override model as PROVIDER:NAME"),
    use_api: bool = typer.Option(False, "--use-api", help="Use API backend for the selected model (sets CIRBENCH_USE_API=1)"),
    api_pricing: Path | None = typer.Option(None, "--api-pricing", help="YAML pricing table path"),
    concurrency: int = typer.Option(1, "--concurrency", min=1),
    report: bool = typer.Option(True, "--report"),
    resume: bool = typer.Option(False, "--resume"),
    retry_model: int = typer.Option(0, "--retry-model", min=0, max=10),
    prompt_only: bool = typer.Option(False, "--prompt-only", help=""),
    from_files: Path | None = typer.Option(None, "--from-files", help=""),
    from_kind: str = typer.Option("resp", "--from-kind", help=""),
):
    c = load_cfg(cfg)
    if prompt_only and from_files:
        raise SystemExit("--prompt-only & --from-files cannot be used together")
    fk = (from_kind or "resp").strip().lower()
    if fk not in ("resp", "pred"):
        raise SystemExit("--from-kind only support resp | pred")
    logger = setup_logging(
        log_dir=Path(c.logging.dir),
        level=logging.DEBUG if debug else getattr(logging, c.logging.level),
        log_json=c.logging.json_mode,
        console_rich=c.logging.console_rich,
        log_name="cirbench",
    )
    if debug:      os.environ["CIRBENCH_DEBUG"] = "1"
    if debug_full: os.environ["CIRBENCH_DEBUG_FULL"] = "1"
    os.environ["CIRBENCH_USE_API"] = "1" if use_api else "0"
    os.environ["CIRBENCH_TASK"] = task
    if report: os.environ["CIRBENCH_REPORT_RUN"] = "1"
    if resume: os.environ["CIRBENCH_RESUME"] = "1"
    if api_pricing: os.environ["CIRBENCH_API_PRICING"] = str(api_pricing)
    os.environ["CIRBENCH_MODEL_RETRIES"] = str(int(retry_model))
    if prompt_only:
        os.environ["CIRBENCH_PROMPT_ONLY"] = "1"
    if from_files:
        os.environ["CIRBENCH_FROM_FILES"] = str(from_files.resolve())
    os.environ["CIRBENCH_FROM_KIND"] = fk
    try:
        c.prompt_only = bool(prompt_only) or bool(getattr(c, "prompt_only", False))
        c.from_files = from_files or getattr(c, "from_files", None)
        c.from_kind = fk or getattr(c, "from_kind", "resp")
    except Exception:
        pass

    if repair_mode:
        rm = repair_mode.strip().lower()
        if rm not in ("normal","hard"):
            raise SystemExit("--repair-mode supports only normal | hard")
        os.environ["CIRBENCH_REPAIR_MODE"] = rm
    if refactor_mode:
        fm = refactor_mode.strip().lower()
        if fm not in ("normal","reverse"):
            raise SystemExit("--refactor-mode supports only normal | reverse")
        os.environ["CIRBENCH_REFACTOR_MODE"] = fm
    if transform_mode:
        tm = transform_mode.strip().lower()
        if tm not in ("normal","copilot"):
            raise SystemExit("--transform-mode supports only normal | copilot")
        os.environ["CIRBENCH_TRANSFORM_MODE"] = tm

    c = _model_override(c, model, logger)
    try:
        m0 = c.models[0] if getattr(c, "models", None) else None
        if m0:
            kind = (getattr(m0, "kind", None) or (m0.get("kind") if isinstance(m0, dict) else "") or "").strip()
            name = (getattr(m0, "name", None) or (m0.get("name") if isinstance(m0, dict) else "") or "").strip()
            if kind and name:
                os.environ["CIRBENCH_MODEL_KIND"] = kind
                os.environ["CIRBENCH_MODEL_NAME"] = name
    except Exception:
        pass
    env_rid = os.getenv("CIRBENCH_RUN_ID")
    if env_rid:
        run_id = env_rid
    else:
        ts = time.strftime("%Y-%m-%dT%H-%M-%SZ", time.gmtime())
        if model:  # e.g. "qwen:qwen-plus"
            slug = model.replace(":", ".").lower()
        else:
            slug = _model_slug(c)
        run_id = f"{ts}__{slug}"
    os.environ["CIRBENCH_RUN_ID"] = run_id
    os.environ["CIRBENCH_RUN_MODEL"] = _model_slug(c)

    task_dir = root / ("cirbench/" + task)
    ids = _discover_cases_ids(task_dir)
    ids = _apply_select(ids, select)
    if not ids:
        print("[run] No cases.")
        return

    N = max(1, int(concurrency))
    shards = _split_into_shards(ids, N)
    run_dir = root / "runs" / run_id
    run_dir.mkdir(parents=True, exist_ok=True)

    if debug_on():
        logger.info(f"[cli.run] task={task} ids={ids} shards={len(shards)} run_id={run_id} model={c.models[0].kind}:{c.models[0].name}")

    print(f"[run] Task={task}  Cases={len(ids)}  Concurrency(requested)={concurrency}  Shards(actual)={len(shards)}  RunID={run_id}")

    if len(shards) == 1:
        os.environ["CIRBENCH_SHARDS"] = "1"
        os.environ["CIRBENCH_SHARD_ID"] = "1"
        _run_shard_inproc(task, cfg, root, shards[0])
        print("[run] Finished.")
        return

    procs = []
    for idx, shard in enumerate(shards, 1):
        if not shard: continue
        child_env = os.environ.copy()
        child_env["CIRBENCH_SHARDS"] = str(len(shards))
        child_env["CIRBENCH_SHARD_ID"] = str(idx)
        child_env["CIRBENCH_TASK"] = task
        shard_select = ",".join(shard)
        argv = [sys.executable, "-m", "cirbench.cli", "run",
                "--task", task, "--cfg", str(cfg), "--root", str(root),
                "--select", shard_select, "--concurrency", "1"]
        if debug:      argv += ["--debug"]
        if debug_full: argv += ["--debug-full"]
        if report:     argv += ["--report"]
        if repair_mode:    argv += ["--repair-mode", repair_mode]
        if refactor_mode:  argv += ["--refactor-mode", refactor_mode]
        if transform_mode: argv += ["--transform-mode", transform_mode]
        if model:          argv += ["--model", model]
        if use_api:        argv += ["--use-api"]
        if api_pricing:    argv += ["--api-pricing", str(api_pricing)]
        if resume:         argv += ["--resume"]
        if prompt_only:   argv += ["--prompt-only"]
        if from_files:    argv += ["--from-files", str(from_files)]
        if from_kind:     argv += ["--from-kind", fk]
        argv += ["--retry-model", str(retry_model)]

        print(f"[run] spawn worker#{idx}: {len(shard)} cases")
        p = subprocess.Popen(argv, env=child_env)
        procs.append(p)

    rc = 0
    for i, p in enumerate(procs, 1):
        r = p.wait()
        if r != 0:
            print(f"[run] worker#{i} exited with code {r}")
            rc = rc or r

    if rc != 0:
        raise SystemExit(rc)
    print("[run] All workers finished.")

@app.command()
def report(
    cfg: Path = typer.Option(..., exists=True),
    root: Path = typer.Option(Path(".")),
    run_id: str | None = typer.Option(None),
    out_html: Path | None = typer.Option(None),
):
    c = load_cfg(cfg)
    run_id = run_id or os.getenv("CIRBENCH_RUN_ID") or time.strftime("%Y-%m-%dT%H-%M-%SZ", time.gmtime())
    run_dir = root/"runs"/run_id
    run_dir.mkdir(parents=True, exist_ok=True)
    model = (c.models[0].model_dump() if c.models else {"name":"rule"})
    write_task_csvs(run_dir)
    tpl_dir = root/"env"/"templates"
    if tpl_dir.exists():
        render_html(report, tpl_dir, run_dir/"report.html")
    print(f"Report written to: {run_dir}")

@app.command()
def aggregate(
    cfg: Path = typer.Option(..., exists=True),
    root: Path = typer.Option(Path(".")),
    run_id: str | None = typer.Option(None, help="use current time for CIRBENCH_RUN_ID"),
):
    c = load_cfg(cfg)
    run_id = run_id or os.getenv("CIRBENCH_RUN_ID") or time.strftime("%Y-%m-%dT%H-%M-%SZ", time.gmtime())
    run_dir = root / "runs" / run_id
    run_dir.mkdir(parents=True, exist_ok=True)

    write_all_aggregates(run_dir)
    print(f"[aggregate] Aggregated CSVs written under: {run_dir / 'tables'}")
    
def main():
    app()

if __name__ == "__main__":
    main()