import json
import time
import pandas as pd
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report
import torch
import os
import re

model_path = "xxx" 
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model.eval()

data = pd.read_csv("xxx.csv") 

humor_mapping = {
    0: "not humor",
    1: "humor"
}
reverse_humor_mapping = {v: k for k, v in humor_mapping.items()}

def build_prompt(text: str) -> str:
    humor_definition = (
        "0=not humor: no clear humor cues | "
        "1=humor: contains features like wordplay/puns, exaggerated scenarios, unexpected twists, contextual incongruity, absurd juxtapositions, etc."
    )
    
    return (
        f"Perform step-by-step reasoning to identify the humor in the given text. "
        f"After your reasoning, output the final humor label in the exact format: 'Humor: <label>'.\n\n"
        f"Humor Definitions:\n{humor_definition}\n\n"
        f"Text: \"{text}\"\n\n"
        "Reasoning: Let's think step by step. First, I need to analyze the context and linguistic cues..."
    )

def save_results(results, output_file):
    temp_file = output_file + '.temp'
    with open(temp_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    if os.path.exists(output_file):
        os.remove(output_file)
    os.rename(temp_file, output_file)

texts = data["Text"].astype(str).tolist()
true_labels = data["Label"].astype(int).tolist()

results = []
output_file = "xxx.json"

if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        results = json.load(f)
    processed_indices = {item['index'] for item in results}
else:
    processed_indices = set()

total_correct = 0
total_tokens = 0

for index, text in enumerate(tqdm(texts)):
    if index in processed_indices:
        continue
        
    prompt = build_prompt(text)
    true_humor = true_labels[index]
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=16384,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        reply = full_response.replace(prompt, "").strip()

        reasoning_part = reply.split("Reasoning:")[-1]  
        token_count = len(tokenizer.encode(reasoning_part))  
        
        humor_match = re.search(r"Humor:\s*(\w+)", reply, re.IGNORECASE)
        pred_humor = humor_match.group(1).lower() if humor_match else None
        pred_label = 1 if pred_humor and "humor" in pred_humor else 0  
        
        is_correct = (pred_label == true_humor)
        if is_correct:
            total_correct += 1
        total_tokens += token_count
        
        result_entry = {
            "index": index,
            "text": text,
            "true_humor": humor_mapping.get(true_humor, "unknown"),
            "predicted_humor": humor_mapping.get(pred_label, "unknown"),
            "is_correct": is_correct,
            "full_response": reply,
            "token_count": token_count,
            "prompt": prompt
        }
        
        results.append(result_entry)
        
        if index % 5 == 0:
            save_results(results, output_file)
            
    except Exception as e:
        results.append({
            "index": index,
            "text": text,
            "true_humor": humor_mapping.get(true_humor, "unknown"),
            "predicted_humor": "error",
            "is_correct": False,
            "full_response": str(e),
            "token_count": 0,
            "prompt": prompt
        })
        time.sleep(2)

save_results(results, output_file)

accuracy = total_correct / len(texts) if texts else 0
avg_tokens = total_tokens / len(texts) if texts else 0

pred_labels = [1 if "humor" in str(item['predicted_humor']).lower() else 0 for item in results]
print(classification_report(true_labels, pred_labels, digits=4))

df = pd.DataFrame({
    "text": texts,
    "true_label": [humor_mapping.get(label, "unknown") for label in true_labels],
    "pred_label": [item['predicted_humor'] for item in results],
    "is_correct": [item['is_correct'] for item in results],
    "token_count": [item['token_count'] for item in results]
})
df.to_csv("full_predictions_humor_ppo_model.csv", index=False, encoding="utf-8-sig")