import os
# ========== ENV SETUP ==========
os.environ["CUDA_VISIBLE_DEVICES"] = "4,6"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# "6,8,9"
# "0,1,2,3,4,5"
import json
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# ========== SYSTEM PROMPTS ==========
SYSTEM_PROMPT_CLASSIFIER = """
You are a hate speech classifier.

<BEGIN TASK>
You will be given:
- A tweet.

You must:
- Read and analyze the tweet in full.
- Classify it as hate speech ONLY IF ALL of the following are true:
• The language expresses hatred toward a targeted protected group.
• The language is intended to be derogatory.
• The language aims to humiliate or insult members of that group.
- Offensive language is not hate speech.

Output rules:
- Output exactly one digit for the tweet
• 1 = hate speech
• 0 = not hate speech (including offensive language)
- Do not output any explanation, text, punctuation, spaces, or newlines.
<END TASK>
"""

# ========== MODEL UTILS ==========
def load_llm(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype="auto", device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def LLM_inference(model, tokenizer, system_prompt, user_prompt):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    model.generation_config.temperature = None
    model.generation_config.top_p = None
    model.generation_config.top_k = None

    output_ids = model.generate(
        **model_inputs,
        max_new_tokens=4,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=False,
    )
    generated_ids = output_ids[:, model_inputs.input_ids.shape[-1]:]
    response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    return response_text

# ========== INFERENCE FUNCTIONS ==========
def LLM_classifier(item):
    prompt = f"Comment: {item['tweet']}"
    return LLM_inference(classifier_model, classifier_tokenizer, SYSTEM_PROMPT_CLASSIFIER, prompt)

# ========== LOAD MODELS ==========
# "Qwen/Qwen2.5-0.5B-Instruct"
# "Qwen/Qwen2.5-7B-Instruct" "mistralai/Mistral-7B-Instruct-v0.3" "meta-llama/Llama-3.1-8B-Instruct"
# "Qwen/Qwen2.5-32B-Instruct" "mistralai/Mixtral-8x7B-Instruct-v0.1" "meta-llama/Llama-3.3-70B-Instruct"
classifier_model_name = "meta-llama/Llama-3.2-1B-Instruct"
file_path = "./data/HSO/"
file_name = "l1.jsonl"
output_file = file_path + file_name
classifier_model, classifier_tokenizer = load_llm(classifier_model_name)

last_idx = 0
if os.path.exists(output_file):
    with open(output_file, "rb") as f:
        try:
            f.seek(-2, os.SEEK_END)
            while f.read(1) != b"\n":
                f.seek(-2, os.SEEK_CUR)
        except OSError:
            f.seek(0)
        last_line = f.readline().decode("utf-8")
        try:
            last_idx = json.loads(last_line)["idx"]
        except:
            last_idx = 0

# ========== LOAD DATASET ==========
d = load_dataset("tdavidson/hate_speech_offensive")['train'].shuffle(seed=42)

def parse_vote(s: str):
    s = s.strip()
    for ch in s:
        if ch in ("0", "1"):
            return int(ch)
    return None

# ========== MAIN LOOP ==========
for idx, item in tqdm(enumerate(d), total=len(d)):
    if last_idx and idx <= last_idx:
        continue

    label = item["class"]
    gt = 1 if label == 0 else 0
    response_text = LLM_classifier(item)
    pred = parse_vote(response_text[0])
    if pred not in [0, 1]:
        pred = None
    
    result_data = {
            "idx": idx,
            "gt": gt,
            "pred": pred,
        }
    with open(output_file, "a", encoding="utf-8") as f:
        json_line = json.dumps(result_data, ensure_ascii=False)
        f.write(json_line + "\n")