import argparse
import asyncio
from dataclasses import asdict
from typing import List, Dict


import torch
from monarch.actor import this_host

from actors.prompts_actor import LLMActor, GenConfig


async def main_async(args):
    cfg = GenConfig(
        seed=args.seed,
        name_of_model_instruct=args.model_generating_concept,
    )

    concepts: List[str] = list(args.concepts)
    visible = torch.cuda.device_count()
    if visible < 1:
        raise RuntimeError("No CUDA devices visible.")

    use_gpus = min(visible, len(concepts))
    if args.max_gpus and args.max_gpus > 0:
        use_gpus = min(use_gpus, args.max_gpus)

    mesh = this_host().spawn_procs(per_host={args.dim: use_gpus})
    print(mesh.to_table(), flush=True)

    llm = mesh.spawn("llm", LLMActor, args.model_generating_concept)

    def actor_for(rank: int):
        return llm.slice(**{args.dim: rank})

    async def run_phase(mode: str, negative_model_tag=None):
        """
        Run one full pass over all concepts for a given mode ("related", "unrelated", or "both")
        using whatever model is currently loaded inside the actors.
        """
        next_idx = 0
        in_flight: Dict[asyncio.Task, int] = {}

        async def run_one(rank: int, concept: str):
            return await actor_for(rank).generate_for_concept.call_one(
                concept,
                asdict(cfg),
                args.out_dir,
                rank,
                mode,
                negative_model_tag,
            )

        for r in range(min(use_gpus, len(concepts))):
            c = concepts[next_idx]
            next_idx += 1
            print(f"→ [mode={mode}] [gpu {r}] started: '{c}'", flush=True)
            task = asyncio.create_task(run_one(r, c))
            in_flight[task] = r

        results_all = []
        while in_flight:
            done, _pending = await asyncio.wait(
                in_flight.keys(), return_when=asyncio.FIRST_COMPLETED
            )
            for t in done:
                rank = in_flight.pop(t)
                try:
                    res = await t  # already done
                    results_all.append(res)
                    print(
                        f"[mode={mode}] [gpu {res['rank']}] concept='{res['concept']}' "
                        f"related={res['related']} unrelated={res['unrelated']} -> {res['files']}",
                        flush=True,
                    )
                except Exception as e:
                    print(f"[mode={mode}] [gpu {rank}] ERROR: {e}", flush=True)
                    raise

                if next_idx < len(concepts):
                    c = concepts[next_idx]
                    next_idx += 1
                    print(f"→ [mode={mode}] [gpu {rank}] started: '{c}'", flush=True)
                    task = asyncio.create_task(run_one(rank, c))
                    in_flight[task] = rank

        return results_all


    if cfg.constrastive:
        if args.models:
            print(
                "NOTE: --models is ignored when --contrastive is enabled "
                "(everything is generated by --model_generating_concept).",
                flush=True,
            )

        print(
            f"=== CONTRASTIVE MODE: generating RELATED + NEGATIVE with single model: {args.model_generating_concept} ===",
            flush=True,
        )
        await run_phase(mode="both", negative_model_tag=None)
        return

    if not args.models:
        raise ValueError("--models must be provided when --contrastive is NOT set.")

    print(
        f"=== Generating RELATED examples with concept model: {args.model_generating_concept} ===",
        flush=True,
    )
    await run_phase(mode="related", negative_model_tag=None)

    for neg_model in args.models:
        print(
            f"=== Generating UNRELATED examples with negative model: {neg_model} ===",
            flush=True,
        )

        reload_futs = [
            actor_for(rank).reload_model.call_one(neg_model)
            for rank in range(use_gpus)
        ]
        await asyncio.gather(*reload_futs)

        await run_phase(mode="unrelated", negative_model_tag=neg_model)


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument(
        "--model_generating_concept",
        required=True,
        help=(
            "HF model id or local path used to generate on-concept (related) JSONLs "
            "(e.g., google/gemma-3-1b-it). In --contrastive mode, this model also generates negatives."
        ),
    )
    p.add_argument(
        "--models",
        nargs="*",
        default=[],
        help=(
            "HF model ids or local paths used to generate off-concept (unrelated) JSONLs; "
            "space-separated list (e.g., Qwen/Qwen3-8B meta-llama/Llama-3-8B). "
            "Ignored when --contrastive is enabled."
        ),
    )
    p.add_argument("--concepts", nargs="+", required=True, help="List of concepts")
    p.add_argument("--out_dir", default="prompts", help="Directory for JSONL outputs")
    p.add_argument("--seed", type=int, default=0)


    p.add_argument(
        "--dim",
        default="gpu",
        help="Mesh dimension name (use 'gpu' if your env uses that)",
    )
    p.add_argument(
        "--max_gpus",
        type=int,
        default=0,
        help="Limit number of GPUs to use (0=auto)",
    )
    return p.parse_args()


if __name__ == "__main__":
    args = parse_args()
    asyncio.run(main_async(args))
