import json
import os
import argparse
from datasets import load_dataset
import logging

# --- 配置 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def create_combined_dataset(target_model=None, sql_file_path=None, output_file_path='../datasets/combined_attack_prompts.json'):
    dataset_sources = ["DeepInception", "GCG", "PAIR", "AutoDAN", "SAP30", "Just-Eval"]
    if sql_file_path:
        dataset_sources.append("SQL")

    combined_prompts = []
    current_id = 0

    logging.info("开始合并数据集...")
    if target_model:
        logging.info(f"将根据 target-model='{target_model}' 进行过滤。")
    if sql_file_path:
        logging.info(f"将从 '{sql_file_path}' 加载 SQL 数据。")

    # --- 1. Hugging Face 数据集 ---
    try:
        logging.info("加载 'flydust/SafeDecoding-Attackers' 数据集...")
        hf_attackers = load_dataset('flydust/SafeDecoding-Attackers', split="train", trust_remote_code=True)
        hf_sources = ["GCG", "AutoDAN", "PAIR", "DeepInception"]

        for source_name in hf_sources:
            if source_name in dataset_sources:
                logging.info(f"正在处理: {source_name}")
                filtered_dataset = hf_attackers.filter(lambda x: x['source'] == source_name)
                count = 0
                for item in filtered_dataset:
                    if count >= 20:
                        break
                    if not target_model or item.get('target-model') is None or item.get('target-model') == target_model:
                        unified_item = {
                            "id": current_id,
                            "instruction": item['prompt'],
                            "goal": item['goal'],
                            "source_dataset": source_name,
                            "target_model": item.get('target-model')
                        }
                        combined_prompts.append(unified_item)
                        current_id += 1
                        count += 1
                logging.info(f"添加了 {count} 条来自 {source_name} 的数据。")

    except Exception as e:
        logging.error(f"加载或处理 'flydust/SafeDecoding-Attackers' 时出错: {e}")

    # --- 2. SAP30 数据集 ---
    if "SAP30" in dataset_sources:
        logging.info("正在处理: SAP30")
        # 修改为相对路径（假设 datasets 目录在脚本上级）
        sap30_base_path = '../datasets/SAP30'  # 原绝对路径已删除，替换为相对路径
        logging.info(f"SAP30 数据集尝试从以下路径加载: {sap30_base_path}")
        sap30_count = 0

        if not os.path.exists(sap30_base_path):
            logging.error(f"错误：SAP30 基础路径不存在: {sap30_base_path}。")
        elif not os.path.isdir(sap30_base_path):
            logging.error(f"错误：SAP30 基础路径不是一个目录: {sap30_base_path}。")
        else:
            category_folders = [f for f in os.listdir(sap30_base_path) if os.path.isdir(os.path.join(sap30_base_path, f))]
            for category_folder in category_folders:
                if sap30_count >= 20:
                    break
                category_path = os.path.join(sap30_base_path, category_folder)
                json_file_path = os.path.join(category_path, 'generated_cases.json')

                if not os.path.exists(json_file_path):
                    continue

                try:
                    with open(json_file_path, 'r', encoding='utf-8') as file:
                        category_prompts = json.load(file)
                        if not isinstance(category_prompts, list):
                            continue

                        items_added_from_category = 0
                        for item in category_prompts:
                            if sap30_count >= 20:
                                break
                            if not isinstance(item, dict):
                                continue
                            instruction = item.get("Attack Prompt", None)
                            goal = item.get("Explanation", "")

                            if instruction is None:
                                continue

                            unified_item = {
                                "id": current_id,
                                "instruction": instruction,
                                "goal": goal if goal else "",
                                "source_dataset": "SAP30",
                                "target_model": None
                            }
                            combined_prompts.append(unified_item)
                            current_id += 1
                            sap30_count += 1
                            items_added_from_category += 1
                except Exception as e:
                    logging.error(f"处理 SAP30 文件出错: {e}")
        logging.info(f"最终添加了 {sap30_count} 条来自 SAP30 的数据。")

    # --- 3. SQL 数据集 ---
    if sql_file_path and "SQL" in dataset_sources:
        logging.info(f"正在处理: SQL (来自 {sql_file_path})")
        sql_count = 0
        if os.path.exists(sql_file_path):
            try:
                with open(sql_file_path, 'r', encoding='utf-8') as file:
                    sql_prompts = json.load(file)
                    for item in sql_prompts:
                        if sql_count >= 20:
                            break
                        unified_item = {
                            "id": current_id,
                            "instruction": item.get("prompt", ""),
                            "goal": item.get("goal", ""),
                            "source_dataset": "SQL",
                            "target_model": None
                        }
                        combined_prompts.append(unified_item)
                        current_id += 1
                        sql_count += 1
                logging.info(f"添加了 {sql_count} 条来自 SQL 的数据。")
            except Exception as e:
                logging.error(f"读取或处理 SQL 文件时出错: {e}")
        else:
            logging.warning(f"SQL 数据文件未找到: {sql_file_path}")

    # --- 4. Just-Eval 数据集 ---
    if "Just-Eval" in dataset_sources:
        logging.info("正在处理: Just-Eval")
        try:
            just_eval_dataset = load_dataset('re-align/just-eval-instruct', split="test", trust_remote_code=True)
            added_just_eval_count = 0
            for item in just_eval_dataset:
                if added_just_eval_count >= 120:
                    break
                unified_item = {
                    "id": current_id,
                    "instruction": item.get("instruction", ""),
                    "goal": "",
                    "source_dataset": "Just-Eval",
                    "target_model": None
                }
                combined_prompts.append(unified_item)
                current_id += 1
                added_just_eval_count += 1
            logging.info(f"添加了 {added_just_eval_count} 条来自 Just-Eval 的数据。")
        except Exception as e:
            logging.error(f"加载 Just-Eval 数据集时出错: {e}")

    # --- 保存合并数据 ---
    logging.info(f"总共合并了 {len(combined_prompts)} 条数据。")
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
    try:
        with open(output_file_path, 'w', encoding='utf-8') as f:
            json.dump(combined_prompts, f, indent=4, ensure_ascii=False)
        logging.info(f"合并的数据集已成功保存到: {output_file_path}")
    except Exception as e:
        logging.error(f"保存合并数据集时出错: {e}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="合并多个攻击数据集，并根据目标模型进行过滤。")
    parser.add_argument('--target_model', type=str, default='guanaco', help="要过滤的目标模型名称，例如 'llama2' 或 'falcon'。")
    # 修改为相对路径（假设 datasets 目录在脚本上级）
    parser.add_argument('--sql_file', type=str, default='../datasets/SQL/guanaco-13b-merged_label1_SR_ATTACK_False.json', help="SQL 数据集 JSON 路径。")  # 原绝对路径已删除，替换为相对路径
    parser.add_argument('--output_file', type=str, default='../datasets/combined_attack_prompts.json', help="合并后数据保存路径。")

    args = parser.parse_args()

    create_combined_dataset(
        target_model=args.target_model,
        sql_file_path=args.sql_file,
        output_file_path=args.output_file
    )
