import argparse
import asyncio
from dataclasses import asdict
from pathlib import Path

import torch
from monarch.actor import this_host

from actors.behaviour_score_actor import BehaviourActor, BehaviourConfig
from actors.steering_plot_actor import model_slug


def discover_jobs(steer_dir: Path, models: list[str]) -> list[tuple[str, str, str]]:
    """Return list of (model_name, concept_slug, concept_label)."""
    jobs: list[tuple[str, str, str]] = []
    for model_name in models:
        mslug = model_slug(model_name)
        base = steer_dir / mslug
        if not base.exists():
            continue
        for concept_dir in sorted([p for p in base.iterdir() if p.is_dir()]):
            slug = concept_dir.name
            label = slug.replace("_", " ")
            if not any(concept_dir.glob("layer_*.pt")):
                continue
            jobs.append((model_name, slug, label))
    return jobs


async def main_async(args):
    cfg = BehaviourConfig(
        seed=args.seed,
        generator_dtype=args.generator_dtype,
        judge_dtype=args.judge_dtype,
        judge_model_name=args.judge_model,
        alpha_start=args.alpha_start,
        alpha_end=args.alpha_end,
        alpha_steps=args.alpha_steps,
        normalize=args.normalize,
        apply_last_token_only=args.apply_last_token_only,
        n_samples_per_context=args.n_samples_per_context,
        gen_context_batch_size=args.gen_context_batch_size,
        max_prompt_length=args.max_prompt_length,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        generator_prompt_suffix=args.generator_prompt_suffix,
        judge_use_chat_template=not args.judge_no_chat_template,
        progress_every=args.progress_every,
    )

    steer_dir = Path(args.steer_dir)
    out_dir = Path(args.out_dir)
    contexts_file = Path(args.contexts_file)

    if not steer_dir.exists():
        raise RuntimeError(f"--steer_dir '{steer_dir}' does not exist")
    if not contexts_file.exists():
        raise RuntimeError(f"--contexts_file '{contexts_file}' does not exist")

    models = list(args.models)
    jobs = discover_jobs(steer_dir, models)
    if not jobs:
        raise RuntimeError(
            f"No (model, concept) pairs discovered under {steer_dir} for given models."
        )

    visible = torch.cuda.device_count()
    if visible < 1:
        raise RuntimeError("No CUDA devices visible.")

    use_gpus = min(visible, len(jobs))
    if args.max_gpus and args.max_gpus > 0:
        use_gpus = min(use_gpus, args.max_gpus)

    mesh = this_host().spawn_procs(per_host={args.dim: use_gpus})
    print(mesh.to_table(), flush=True)

    workers = mesh.spawn("behaviour", BehaviourActor)

    def actor_for(rank: int):
        return workers.slice(**{args.dim: rank})

    async def run_one(rank: int, model_name: str, concept_slug: str, concept_label: str):
        return await actor_for(rank).compute_behaviour_curves.call_one(
            model_name=model_name,
            concept_slug=concept_slug,
            concept_label=concept_label,
            block_idx_to_steer=([None] if args.layers == [None] else [int(i) for i in args.layers]),
            contexts_file=str(contexts_file),
            steer_dir=str(steer_dir),
            save_dir=str(out_dir),
            layer_path=args.layer_path,
            cfg_dict=asdict(cfg),
            rank_hint=rank,
        )

    next_idx = 0
    in_flight: dict[asyncio.Task, int] = {}

    for r in range(min(use_gpus, len(jobs))):
        m, slug, label = jobs[next_idx]
        next_idx += 1
        print(f"→ [gpu {r}] start model='{m}' concept='{label}' (slug={slug})", flush=True)
        task = asyncio.create_task(run_one(r, m, slug, label))
        in_flight[task] = r

    while in_flight:
        done, _ = await asyncio.wait(in_flight.keys(), return_when=asyncio.FIRST_COMPLETED)
        for t in done:
            rank = in_flight.pop(t)
            try:
                res = await t
                if isinstance(res, dict) and res.get("ok"):
                    files = [rinfo.get("file") for rinfo in (res.get("results") or []) if rinfo.get("file")]
                    msg = files[0] if files else "(no files)"
                    print(f"[gpu {rank}] finished -> {msg}", flush=True)
                else:
                    print(f"[gpu {rank}] unexpected result: {res}", flush=True)
            except Exception as e:
                print(f"[gpu {rank}] EXCEPTION: {e}", flush=True)
                raise

            if next_idx < len(jobs):
                m, slug, label = jobs[next_idx]
                next_idx += 1
                print(f"→ [gpu {rank}] start model='{m}' concept='{label}' (slug={slug})", flush=True)
                task = asyncio.create_task(run_one(rank, m, slug, label))
                in_flight[task] = rank


def parse_args():
    p = argparse.ArgumentParser()

    p.add_argument("--models", nargs="+", required=True, help="HF model ids/paths for which steer vectors exist.")
    p.add_argument("--judge_model", required=True, help="HF model id/path for the binary judge model.")

    p.add_argument("--steer_dir", default="steering_vectors", help="Root: model_slug/concept_slug/layer_*.pt")
    p.add_argument("--contexts_file", default="contexts.jsonl", help="JSONL contexts (negatives + per-concept positives)")
    p.add_argument("--out_dir", default="behaviour_data", help="Where to write .npz behaviour curve files")

    p.add_argument(
        "--layers",
        nargs="+",
        default=[None],
        help="Layer indices (e.g., 5 10 15) or all if kept to None.",
    )
    p.add_argument(
        "--layer_path",
        default=None,
        help="Override path to transformer block list (e.g., 'model.layers').",
    )

    p.add_argument("--generator_dtype", default="float32", choices=["float32", "float16", "bfloat16"])
    p.add_argument("--judge_dtype", default="bfloat16", choices=["float32", "float16", "bfloat16"])

    p.add_argument("--alpha_start", type=float, default=-20.0)
    p.add_argument("--alpha_end", type=float, default=20.0)
    p.add_argument("--alpha_steps", type=int, default=41)

    p.add_argument("--normalize", action="store_true", default=False)
    p.add_argument("--apply_last_token_only", action="store_true", default=False)

    p.add_argument("--n_samples_per_context", type=int, default=8)
    p.add_argument("--gen_context_batch_size", type=int, default=8)
    p.add_argument("--max_prompt_length", type=int, default=512)
    p.add_argument("--max_new_tokens", type=int, default=200)
    p.add_argument("--temperature", type=float, default=0.9)
    p.add_argument("--top_p", type=float, default=0.95)
    p.add_argument("--top_k", type=int, default=50)
    p.add_argument(
        "--generator_prompt_suffix",
        default="",
        help="Optional suffix appended to every context prompt before generation.",
    )

    p.add_argument("--judge_no_chat_template", action="store_true", default=False)
    p.add_argument("--judge_system_prompt", default="You are a strict binary classifier. Output exactly one character: 1 or 0. No spaces, no newline, no explanation.")
    p.add_argument(
        "--judge_question_template",
        default=(
            "Text:\\n{completion}\\n\\n"
            "Does the text above express the concept {concept} (any word linked to the concept is sufficient for a positive reply)? Reply with 1 for yes and 0 for no.\\n"
            "Answer:"
        ),
        help="Template with placeholders {concept} and {completion}.",
    )
    p.add_argument("--judge_max_prompt_length", type=int, default=4000)
    p.add_argument("--judge_batch_size", type=int, default=64)
    p.add_argument("--judge_max_completion_chars", type=int, default=4000)

    p.add_argument("--dim", default="gpu", help="Mesh dimension name (use 'gpu' if your env uses that).")
    p.add_argument("--max_gpus", type=int, default=0, help="Limit number of GPUs (0=auto)")

    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--progress_every", type=int, default=1)

    return p.parse_args()


if __name__ == "__main__":
    args = parse_args()
    asyncio.run(main_async(args))
