#!/usr/bin/env python3
"""
在 Agent-SafetyBench (ASB) 数据集上评估本地 Llama Guard 模型

Llama Guard 模型运行在本地8030端口，只输出 "safe" 或 "unsafe"
"""

import sys
import os
import json
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
from tqdm import tqdm
from openai import OpenAI


def call_llama_guard(user_input, agent_response="", api_url="http://127.0.0.1:8030/v1", model_name="llama-guard", api_key="EMPTY"):
    """
    调用本地 Llama Guard 模型
    
    Args:
        user_input: 用户输入
        agent_response: Agent响应（可选）
        api_url: API端点URL
        model_name: 模型名称
        api_key: API Key（本地服务通常不需要，设为"EMPTY"）
    
    Returns:
        tuple: (is_safe, raw_output)
        - is_safe: bool, True表示safe，False表示unsafe
        - raw_output: str, 模型的原始输出
    """
    try:
        client = OpenAI(base_url=api_url, api_key=api_key)
        
        # 构建prompt（Llama Guard标准格式）
        # Llama Guard期望的格式：User: ... Agent: ...
        if agent_response and agent_response.strip() and agent_response != "No response yet":
            prompt = f"User: {user_input}\nAgent: {agent_response}"
        else:
            # 如果只有用户输入，只检查用户输入
            prompt = f"User: {user_input}\nAgent: No response yet"
        
        # 调用API
        response = client.chat.completions.create(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            temperature=0.0,
            max_tokens=20  # 增加一点token，确保能完整输出safe/unsafe
        )
        
        raw_output = response.choices[0].message.content.strip()
        raw_output_lower = raw_output.lower()
        
        # 解析输出：检查是否包含"unsafe"
        # Llama Guard通常只输出 "safe" 或 "unsafe"
        if "unsafe" in raw_output_lower:
            is_safe = False
        elif "safe" in raw_output_lower:
            is_safe = True
        else:
            # 如果输出格式不符合预期，默认safe（避免误判）
            print(f"⚠️  Llama Guard输出格式异常: {raw_output}")
            is_safe = True
        
        return is_safe, raw_output
        
    except Exception as e:
        print(f"⚠️ 调用Llama Guard失败: {e}")
        import traceback
        traceback.print_exc()
        # 出错时默认返回safe（避免阻塞）
        return True, f"Error: {str(e)}"


def evaluate_llamaguard_on_asb(
    api_url: str = "http://127.0.0.1:8030/v1",
    model_name: str = "llama-guard",
    api_key: str = "EMPTY",
    data_path: str = "Agent-SafetyBench/data/released_data.json",
    output_dir: str = "./llamaguard_results_asb",
    limit: int = None,
    max_workers: int = 10,
):
    """
    在 Agent-SafetyBench 数据集上评估 Llama Guard
    
    Args:
        api_url: Llama Guard API URL
        model_name: 模型名称
        api_key: API Key
        data_path: ASB 数据集路径
        output_dir: 输出目录
        limit: 限制处理的样本数量（用于测试）
        max_workers: 最大线程数
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # 加载数据集
    if not os.path.isabs(data_path):
        script_dir = os.path.dirname(os.path.abspath(__file__))
        project_root = script_dir
        data_path = os.path.join(project_root, data_path)
        if not os.path.isabs(data_path):
            data_path = os.path.abspath(data_path)
    
    print(f"加载 Agent-SafetyBench 数据集: {data_path}")
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"数据文件不存在: {data_path}")
    
    with open(data_path, 'r', encoding='utf-8') as f:
        dataset = json.load(f)
    
    if limit:
        dataset = dataset[:limit]
        print(f"限制处理前 {limit} 个样本")
    
    print(f"✅ 加载了 {len(dataset)} 个样本")
    print(f"📡 Llama Guard API: {api_url}")
    print(f"🤖 模型名称: {model_name}")
    
    # 线程安全的结果列表和锁
    results = []
    results_lock = Lock()
    
    def process_sample(sample_data):
        """处理单个样本的函数（用于多线程）"""
        i, sample = sample_data
        
        # 提取 agent input（从 dialog 或 instruction 中）
        agent_input = ""
        if "dialog" in sample:
            for msg in reversed(sample["dialog"]):
                if msg.get("role") == "user":
                    agent_input = msg.get("content", "")
                    break
        elif "instruction" in sample:
            agent_input = sample["instruction"]
        else:
            agent_input = str(sample.get("prompt", sample.get("query", "")))
        
        # 提取 agent output
        agent_output = ""
        if "dialog" in sample:
            for msg in reversed(sample["dialog"]):
                if msg.get("role") == "assistant":
                    if "tool_calls" in msg:
                        agent_output = f"Tool calls: {[tc.get('function', {}).get('name', '') for tc in msg['tool_calls']]}"
                    elif "content" in msg:
                        agent_output = msg["content"]
                    break
        elif "response" in sample:
            agent_output = sample["response"]
        elif "output" in sample:
            agent_output = sample["output"]
        
        # 调用 Llama Guard
        try:
            is_safe, raw_output = call_llama_guard(
                user_input=agent_input,
                agent_response=agent_output if agent_output else "No response yet",
                api_url=api_url,
                model_name=model_name,
                api_key=api_key
            )
            
            # 构建结果
            result_item = {
                "id": sample.get('id', f"sample_{i}"),
                "index": i,
                "input": agent_input,
                "output": agent_output,
                "action_denied": not is_safe,  # unsafe时action_denied=True
                "is_safe": is_safe,
                "llama_guard_output": raw_output,
                "violation": "" if is_safe else "Unsafe content detected by Llama Guard",
            }
            
            return result_item
            
        except Exception as e:
            print(f"[样本 {i+1}] ❌ 处理时出错: {e}")
            import traceback
            traceback.print_exc()
            return {
                "id": sample.get('id', f"sample_{i}"),
                "index": i,
                "input": agent_input,
                "error": str(e),
            }
    # 准备样本数据（包含索引）
    sample_data_list = [(i, sample) for i, sample in enumerate(dataset)]
    
    # 使用 ThreadPoolExecutor 并行处理
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # 提交所有任务
        futures = {executor.submit(process_sample, sample_data): sample_data for sample_data in sample_data_list}
        
        # 使用 tqdm 显示进度
        for future in tqdm(as_completed(futures), total=len(dataset), desc="处理样本"):
            try:
                result_item = future.result()
                with results_lock:
                    results.append(result_item)
            except Exception as e:
                sample_data = futures[future]
                i, sample = sample_data
                print(f"[样本 {i+1}] ❌ 线程执行出错: {e}")
                with results_lock:
                    results.append({
                        "id": sample.get('id', f"sample_{i}"),
                        "index": i,
                        "input": sample.get("instruction", ""),
                        "error": str(e),
                    })
    
    # 按索引排序结果
    results.sort(key=lambda x: x.get('index', 0))
    
    # 保存结果
    output_file = os.path.join(output_dir, "llamaguard_asb_results.json")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    # 统计信息
    total = len(results)
    denied_count = sum(1 for r in results if r.get('action_denied', False))
    safe_count = sum(1 for r in results if r.get('is_safe', True))
    error_count = sum(1 for r in results if 'error' in r)
    
    return results


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='在 Agent-SafetyBench 数据集上评估本地 Llama Guard')
    parser.add_argument("--api_url", type=str, default="http://127.0.0.1:8030/v1",
                        help="Llama Guard API URL (默认: http://127.0.0.1:8030/v1)")
    parser.add_argument("--model_name", type=str, default="llama-guard",
                        help="模型名称 (默认: llama-guard)")
    parser.add_argument("--api_key", type=str, default="EMPTY",
                        help="API Key (默认: EMPTY)")
    parser.add_argument("--data_path", type=str, 
                        default="Agent-SafetyBench/data/released_data.json",
                        help="ASB 数据集路径")
    parser.add_argument("--output_dir", type=str, default="./llamaguard_results_asb",
                        help="输出目录")
    parser.add_argument("--limit", type=int, default=None,
                        help="限制处理的样本数量（用于测试）")
    parser.add_argument("--max_workers", type=int, default=10,
                        help="最大线程数（默认: 10）")
    
    args = parser.parse_args()
    
    evaluate_llamaguard_on_asb(
        api_url=args.api_url,
        model_name=args.model_name,
        api_key=args.api_key,
        data_path=args.data_path,
        output_dir=args.output_dir,
        limit=args.limit,
        max_workers=args.max_workers,
    )
