import json
import time
import pandas as pd
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report
import torch
import os
import re

model_path = "xxx" 
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model.eval()

data = pd.read_csv("xxx.csv")  

emotion_mapping = {
    "neutral": 0,
    "joy": 1,
    "sadness": 2,
    "surprise": 3,
    "anger": 4,
    "fear": 5,
    "disgust": 6
}
reverse_emotion_mapping = {v: k for k, v in emotion_mapping.items()}

def build_prompt(text: str) -> str:
    emotion_definition = (
        "0=neutral: no clear emotional cues | "
        "1=joy: features like positive lexicon, uplifting emojis, achievement expressions | "
        "2=sadness: contains loss/grief elements, negative event descriptions | "
        "3=surprise: unexpected events or cognitive dissonance | "
        "4=anger: aggressive language, confrontational rhetoric | "
        "5=fear: threat-related content, anxiety indicators | "
        "6=disgust: expressions of revulsion, descriptions of unpleasant events"
    )
    
    return (
        f"Perform step-by-step reasoning to identify the emotion in the given text. "
        f"After your reasoning, output the final emotion label in the exact format: 'Emotion: <label>'.\n\n"
        f"Emotion Definitions:\n{emotion_definition}\n\n"
        f"Text: \"{text}\"\n\n"
        "Reasoning: Let's think step by step. First, I need to analyze the emotional cues..."
    )

def save_results(results, output_file):
    temp_file = output_file + '.temp'
    with open(temp_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    if os.path.exists(output_file):
        os.remove(output_file)
    os.rename(temp_file, output_file)

texts = data["Utterance"].astype(str).tolist()
true_labels = [emotion_mapping[emotion.lower()] for emotion in data["Emotion"].tolist()]

results = []
output_file = "xxx.json"

if os.path.exists(output_file):
    with open(output_file, 'r', encoding='utf-8') as f:
        results = json.load(f)
    processed_indices = {item['index'] for item in results}
else:
    processed_indices = set()

total_correct = 0
total_tokens = 0

for index, text in enumerate(tqdm(texts)):
    if index in processed_indices:
        continue
        
    prompt = build_prompt(text)
    true_label = true_labels[index]
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=16384, 
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        reply = full_response.replace(prompt, "").strip()
        
        reasoning_part = reply.split("Reasoning:")[-1] if "Reasoning:" in reply else reply
        token_count = len(tokenizer.encode(reasoning_part))  
        
        emotion_match = re.search(r"Emotion:\s*([0-6]|\w+)", reply, re.IGNORECASE)
        if emotion_match:
            pred_str = emotion_match.group(1).lower()
            if pred_str.isdigit():
                pred_label = int(pred_str)
            else:
                pred_label = next((k for k, v in emotion_mapping.items() if v == pred_str.lower()), 0)
        else:
            num_match = re.search(r"\b[0-6]\b", reply)
            pred_label = int(num_match.group(0)) if num_match else 0
       
        is_correct = (pred_label == true_label)
        if is_correct:
            total_correct += 1
        total_tokens += token_count
        
        result_entry = {
            "index": index,
            "text": text,
            "true_label": reverse_emotion_mapping.get(true_label, "unknown"),
            "predicted_label": reverse_emotion_mapping.get(pred_label, "unknown"),
            "is_correct": is_correct,
            "full_response": reply,
            "token_count": token_count,
            "prompt": prompt
        }
        
        results.append(result_entry)
        
        if index % 5 == 0:
            save_results(results, output_file)
            
    except Exception as e:
        results.append({
            "index": index,
            "text": text,
            "true_label": reverse_emotion_mapping.get(true_label, "unknown"),
            "predicted_label": "error",
            "is_correct": False,
            "full_response": str(e),
            "token_count": 0,
            "prompt": prompt
        })
        time.sleep(2)

save_results(results, output_file)
accuracy = total_correct / len(texts) if texts else 0
avg_tokens = total_tokens / len(texts) if texts else 0

pred_labels = [next((k for k, v in emotion_mapping.items() if v == item['predicted_label']), 0) for item in results] 
print(classification_report(true_labels, pred_labels, digits=4))

df = pd.DataFrame({
    "text": texts,
    "true_label": [reverse_emotion_mapping.get(label, "unknown") for label in true_labels],
    "pred_label": [item['predicted_label'] for item in results],
    "is_correct": [item['is_correct'] for item in results],
    "token_count": [item['token_count'] for item in results]
})
df.to_csv("full_predictions_emotion_ppo_model.csv", index=False, encoding="utf-8-sig")