import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,5,6,7,9"
tp_size = 8
# os.environ["CUDA_VISIBLE_DEVICES"] = "4,6"
# tp_size = 2
import multiprocessing as mp
if mp.get_start_method(allow_none=True) != "spawn":
    mp.set_start_method("spawn", force=True)
import json
import time
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

# ========== CONFIG ==========
file_path = "./data/HSO/"
input_file = file_path + "l1.jsonl"
output_file = file_path + "l1_l70.jsonl"
judge_model_name = "meta-llama/Llama-3.3-70B-Instruct"

# "Qwen/Qwen2.5-0.5B-Instruct"
# "Qwen/Qwen2.5-7B-Instruct" "mistralai/Mistral-7B-Instruct-v0.3" "meta-llama/Llama-3.1-8B-Instruct"
# "Qwen/Qwen2.5-32B-Instruct" "mistralai/Mixtral-8x7B-Instruct-v0.1" "meta-llama/Llama-3.3-70B-Instruct"
# 不再用独立采样投票；保留 None 表示用模型原始 logits 的 argmax 决策
TEMP = None

# =============================
# Prompts
# =============================
TASK = """
You are a hate speech judge.

<BEGIN TASK>
You will be given:
- A tweet.
- A prior prediction ("hate speech" or "not hate speech") as non-authoritative, supplementary information that may be wrong.

Your task:
- Carefully read and analyze the tweet.
- Classify it as hate speech ONLY IF ALL of the following are true:
    • The language targets a protected group (race, ethnicity, religion, gender, sexual orientation, disability, nationality).
    • The tone is derogatory or hostile.
    • The intent is to humiliate, insult, or demean members of that group.
- If the tweet is merely offensive without group targeting (e.g., profanity, lewdness, or personal attacks without group targeting), it is NOT hate speech.
- Treat the prior prediction only as a secondary, unreliable hint. If it conflicts with your own analysis of the tweet, ignore it.

Output rules:
- Output exactly one digit:
    • 1 = hate speech
    • 0 = not hate speech
- Output nothing else: no explanation, no text, no punctuation, no newlines.
<END TASK>
"""

SYSTEM_PROMPTS = []
SYSTEM_PROMPTS.append(TASK)

def kill_children():
    try:
        import psutil
        me = psutil.Process(os.getpid())
        for c in me.children(recursive=True):
            c.terminate()
        gone, alive = psutil.wait_procs(me.children(recursive=True), timeout=3)
        for a in alive:
            a.kill()
    except Exception:
        os.system("pkill -15 -f VllmWorkerProcess; sleep 2; pkill -9 -f VllmWorkerProcess")

def build_prompt(item, pred_int):
    pred_txt = "hate speech" if pred_int == 1 else "not hate speech"
    return f'User Comment: "{item["tweet"]}", Existing prediction: {pred_txt}.'

def get_digit_token_ids(tokenizer):
    id0 = tokenizer.encode("0", add_special_tokens=False)
    id1 = tokenizer.encode("1", add_special_tokens=False)
    if not (len(id0) == 1 and len(id1) == 1):
        raise RuntimeError("Digit tokens '0' or '1' are not single-token for this tokenizer.")
    return id0[0], id1[0]

def load_llm_vllm(model_name, tp_size=4):
    from vllm import LLM
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tp_size,
        dtype="half",
        trust_remote_code=True,
        gpu_memory_utilization=0.8,
        enforce_eager=True,
        max_model_len=5096,
        
        kv_cache_dtype="fp8",
        enable_prefix_caching=False,
    )
    tokenizer = llm.get_tokenizer()
    return llm, tokenizer


def batch_vote_with_vllm(llm, tokenizer, system_prompt, user_prompts, id0, id1):
    texts = []
    for up in user_prompts:
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": up},
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        texts.append(text)

    sampling = SamplingParams(
        max_tokens=1,
        temperature=0.0 if (TEMP is None) else float(TEMP),
        top_p=1.0,
    )
    outputs = llm.generate(texts, sampling)

    votes = []
    for out in outputs:
        gen = out.outputs[0]
        tok = gen.text.strip()
        if tok.startswith("1"):
            vote = 1
        elif tok.startswith("0"):
            vote = 0
        else:
            vote = 0
        votes.append(vote)

    return votes

# ========== LOAD ==========
llm, judge_tokenizer = load_llm_vllm(judge_model_name, tp_size)
id0, id1 = get_digit_token_ids(judge_tokenizer)


predictions = []
with open(input_file, 'r', encoding='utf-8') as f_in:
    for line in f_in:
        predictions.append(json.loads(line.strip()))


d = load_dataset("tdavidson/hate_speech_offensive")["train"].shuffle(seed=42)


def read_last_idx(jsonl_path):
    if not os.path.exists(jsonl_path):
        return -1
    try:
        with open(jsonl_path, "rb") as f:
            f.seek(-2, os.SEEK_END)
            while f.read(1) != b"\n":
                f.seek(-2, os.SEEK_CUR)
            last_line = f.readline().decode("utf-8", errors="ignore").strip()
            return json.loads(last_line).get("idx", -1)
    except Exception:
        return -1

last_idx = read_last_idx(output_file)
last_idx = 0 if last_idx == -1 else last_idx

# ========== MAIN ==========
try:
    BATCH_SIZE = 16
    pbar = tqdm(total=len(d), initial=last_idx, desc="Processing")

    for start in range(last_idx, len(d), BATCH_SIZE):
        end = min(start + BATCH_SIZE, len(d))
        batch_items = []
        user_prompts = []
        pending_results = []

        for idx in range(start, end):
            item = d[idx]
            label = item["class"]
            gt = 1 if label == 0 else 0
            pred = predictions[idx]["pred"]

            if pred in (0, 1) and gt in (0, 1):
                up = build_prompt(item, pred)
                batch_items.append((idx, gt, pred))
                user_prompts.append(up)
            else:
                result_data = {"idx": idx, "gt": None, "pred": None}
                for j in range(1, len(SYSTEM_PROMPTS) + 1):
                    result_data[f"judge{j}"] = None
                pending_results.append(result_data)

        all_votes = []
        if batch_items:
            for system_prompt in SYSTEM_PROMPTS:
                votes = batch_vote_with_vllm(
                    llm, judge_tokenizer, system_prompt, user_prompts, id0, id1
                )
                all_votes.append(votes)

            for k, (idx, gt, pred) in enumerate(batch_items):
                votes_k = [all_votes[j][k] for j in range(len(SYSTEM_PROMPTS))]
                result_data = {"idx": idx, "gt": gt, "pred": pred}
                for j, v in enumerate(votes_k, start=1):
                    result_data[f"judge{j}"] = v
                pending_results.append(result_data)

        pending_results.sort(key=lambda x: x["idx"])
        with open(output_file, "a", encoding="utf-8") as fout:
            for r in pending_results:
                fout.write(json.dumps(r, ensure_ascii=False) + "\n")
        pbar.update(end - start)

    pbar.close()
except KeyboardInterrupt:
    print("KeyboardInterrupt: cleaning up…")
finally:
    kill_children()
    time.sleep(1)