import os,json
from typing import Any, Dict, List, Optional, Sequence, Tuple
import random
import torch
import argparse
from data_utils import get_task
from ifgain import logp_y_given_userprompt, resolve_model_id


from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams


_CACHE: Dict[str, Tuple[object, object]] = {}
_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "./model_cache")


MODEL_ID_MAP = {
    "qwen3-4b": "Qwen/Qwen3-4B-Instruct-2507",
    "llama3.1-8b": "meta-llama/Llama-3.1-8B-Instruct",
}

LOCAL_LLM_PATH = {"Qwen/Qwen3-4B-Instruct-2507": os.getenv("QWEN3_4B_PATH", "Qwen/Qwen3-4B-Instruct-2507"),
                    "meta-llama/Llama-3.1-8B-Instruct": os.getenv("LLAMA3_1_8B_PATH", "meta-llama/Llama-3.1-8B-Instruct")}

def sample_seed_set(candidate_indices: List[int], K: int, rng: random.Random) -> List[int]:
    if len(candidate_indices) < K:
        # fallback: sample with replacement if pool too small
        return [rng.choice(candidate_indices) for _ in range(K)]
    return rng.sample(candidate_indices, K)

def make_swap(S0: List[int], candidate_indices: List[int], rng: random.Random) -> Optional[List[int]]:
    S = set(S0)
    outside = [i for i in candidate_indices if i not in S]
    if not outside or not S0:
        return None
    e_minus = rng.choice(S0)
    e_plus = rng.choice(outside)
    S_new = S0.copy()
    j = S_new.index(e_minus)
    S_new[j] = e_plus
    return S_new

def make_drop(S0: List[int], rng: random.Random) -> Optional[List[int]]:
    if len(S0) <= 1:
        return None
    S_new = S0.copy()
    j = rng.randrange(len(S_new))
    S_new.pop(j)
    return S_new


# -------------------------
# Main collection loop
# -------------------------
# python collect_traindata.py --model "llama3.1-8b" --tasks gsm8k gpqa
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, default="llama3.1-8b", help="LLM name [qwen3-4b,llama3.1-8b]")
    ap.add_argument("--tasks", type=str, nargs="+", default=["gsm8k"], help="task names: gsm8k gpqa fp xsum date salient")
    ap.add_argument("--budget", type=int, default=500, help="#LLM evaluations (calls that compute a reward)")
    ap.add_argument("--Ks", type=str, default="1,2,4,8", help="choice of #ICL samples")
    ap.add_argument("--moves", type=str, default="swap,drop", help="comma-separated moves among {swap,drop}")
    ap.add_argument("--seed", type=int, default=42)

    args = ap.parse_args()

    rng = random.Random(args.seed)

    backend_id = resolve_model_id(args.model)

    Ks = [int(x) for x in args.Ks.split(",") if x.strip()]
    moves = [m.strip() for m in args.moves.split(",") if m.strip()]

    for task in args.tasks:
        key = f"{backend_id}|vllm"
        tok, llm = _CACHE.get(key, (None, None))

        if tok is None:
            tok = AutoTokenizer.from_pretrained(
                backend_id,
                trust_remote_code=True,
                cache_dir=_CACHE_DIR,
            )
            if tok.pad_token_id is None:
                tok.pad_token = tok.eos_token

            llm = LLM(
                model=LOCAL_LLM_PATH[backend_id],
                max_model_len=8192,
                tensor_parallel_size=1,
                gpu_memory_utilization=0.85,
            )
            _CACHE[key] = (tok, llm)


        out_dir = f"./data/{task}"
        demoset = get_task(task, 'train')
        devset = get_task(task, 'val')
        ldev=len(devset)
        ldemo=len(demoset)
        instpath = f"./data/{task}/{task}_instrs.json"
        with open(instpath, "r", encoding="utf-8") as f:
            instructions = json.load(f)
        out_rows = []
        calls_used = 0

    # We'll treat each evaluated (q,T,S) as 1 "call" for budget accounting.
    # Each iteration will evaluate multiple S's until budget is consumed.
        while calls_used < args.budget:
            q_id = rng.randrange(ldev)
            q=devset[q_id]
            T_id = rng.randrange(len(instructions))
            T = instructions[T_id]
            K = rng.choice(Ks)

            # candidate pool for demos: random subset of size L
            # (If you have a retriever, replace this with retrieved top-L indices)
            C_idx = list(range(ldemo))

            S0_idx = sample_seed_set(C_idx, K, rng)

            S_candidates: List[Tuple[str, Optional[List[int]]]] = [("seed", S0_idx)]
            if "swap" in moves:
                s_swap = make_swap(S0_idx, C_idx, rng)
                if s_swap is not None:
                    S_candidates.append(("swap", s_swap))
            if "drop" in moves:
                s_drop = make_drop(S0_idx, rng)
                if s_drop is not None:
                    S_candidates.append(("drop", s_drop))

            # Evaluate each S candidate (until budget)
            for tag, S_idx in S_candidates:
                if calls_used >= args.budget:
                    break
                assert S_idx is not None

                score,n_token=logp_y_given_userprompt(tok, llm, task, T, q, examples=[demoset[i] for i in S_idx])
                row=[q_id,T_id,K,S_idx,score,n_token]

                out_rows.append(row)
                calls_used += 1
        json_path = os.path.join(out_dir, f"{args.model}_data.json")
        print(f"Writing {len(out_rows)} rows to {json_path}")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(out_rows, f, ensure_ascii=False)



if __name__ == "__main__":
    main()