import tools
from transformers import AutoTokenizer
from tqdm import tqdm
import os
import argparse
import requests

SAMPLE_NUM = None
JUDGED_MODEL = None
MODEL_PATH = f"{tools.machine_pather()}/models/models--nvidia--Llama-3.3-Nemotron-70B-Reward-Multilingual"

SG_URL = "http://localhost:30000"


def sglang_generate_batch(batch_token_ids):
    try:
        resp = requests.post(
            f"{SG_URL}/generate",
            json={
                "input_ids": batch_token_ids,
                "sampling_params": {
                    "temperature": 0.0,
                    "max_new_tokens": 1,
                },
                "return_logprob": True,
                "token_ids_logprob": [0]
            }
        )
        resp.raise_for_status()
        data = resp.json()
        rewards = []
        for sample in data:
            assert sample['meta_info']['output_token_ids_logprobs'][0][0][1] == 0
            rewards.append(sample['meta_info']
                           ['output_token_ids_logprobs'][0][0][0])
        return rewards
    except Exception as e:
        print(f"[sglang_generate_batch] FAIL: {e}")
        return [None] * len(batch_token_ids)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sample_num", type=int, default=20)
    parser.add_argument("--traindataset", type=str, required=True)
    parser.add_argument("--judged_model", type=str, required=True)
    parser.add_argument("--subset", type=str, required=True)
    parser.add_argument("--startpoint", type=int, default=0)
    args = parser.parse_args()

    SAMPLE_NUM = args.sample_num
    JUDGED_MODEL = args.judged_model
    subset = args.subset

    assert JUDGED_MODEL in [
        "models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B",
        "models--deepseek-ai--DeepSeek-R1-0528-Qwen3-8B",
        "models--Qwen--Qwen3-8B",
        "models--meta-llama--Llama-3.1-8B-Instruct"
    ]

    INPUT_FILE = f"{tools.machine_pather()}/works/DPO/judge/input/{args.traindataset}/{subset}/{JUDGED_MODEL}/judge_input.jsonl"
    raw_inputs = tools.read_jsonl(INPUT_FILE)

    prompts = []
    print("start to process the inputs")
    for index, each_json in tqdm(enumerate(raw_inputs)):
        assert each_json['rank_id'] == f"{index//SAMPLE_NUM}_{index % SAMPLE_NUM}_0"
        msg = [
            {"role": "user", "content": each_json['assembly_question']},
            {"role": "assistant", "content": each_json["assembly_reasoning"]}
        ]
        prompts.append(msg)

    assert len(prompts) == len(raw_inputs)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

    OUTPUT_DIR = f"{tools.machine_pather()}/works/DPO/judge/output/{args.traindataset}/{subset}/{JUDGED_MODEL}/coherence_score"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print("output to >>>", OUTPUT_DIR)

    if args.startpoint == 0:
        print("startpoint is 0, creating output files")
        tools.check_existence(
            os.path.join(OUTPUT_DIR, f"score.jsonl"))
        open(os.path.join(OUTPUT_DIR,
             f"score.jsonl"), 'w').close()

    part_prompts = prompts[args.startpoint:]
    texts = part_prompts

    batch_size = 20*8
    for i in tqdm(range(0, len(texts), batch_size)):
        rewards = []
        batch_texts = texts[i:i+batch_size]
        batch_token_ids = tokenizer.apply_chat_template(
            batch_texts, tokenize=True, add_generation_prompt=False)
        rewards = sglang_generate_batch(batch_token_ids)

        for msg, reward in zip(batch_texts, rewards):
            tools.write_jsonl(
                {"score_name": 'coherence_score', "score": reward, "input": msg},
                os.path.join(
                    OUTPUT_DIR, f"score.jsonl")
            )

    open(os.path.join(OUTPUT_DIR, f"success.tag"), "w").close()
    open(os.path.join(OUTPUT_DIR, MODEL_PATH.split(
        "/")[-1] + ".tag"), "w").close()
