#!/usr/bin/env python
"""
Step 6: 数据集最终整理
====================

整理所有前面步骤的数据，生成适合上传到Hugging Face的最终数据集格式。

数据集格式:
1. 大分类是根据领域 (domain)
2. id (从1开始)
3. 节点名 (entity)
4. 节点的wikidata uri (wikidata_uri)
5. 原始prompt (original_prompt)
6. 隐蔽化后prompt (stealth_prompt)
7. 是否隐蔽化成功 (stealth_success)
8. 类别 (category)

使用方法:
    python step6_dataset_finalization.py --domain medicine
    python step6_dataset_finalization.py --domain medicine --output-format csv

"""

import argparse
import sys
import time
import json
import csv
from pathlib import Path
from typing import Dict, List, Any, Optional

# 添加项目根目录到Python路径
SCRIPT_DIR = Path(__file__).parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.utils.logger_utils import get_logger

class DatasetFinalizer:
    """数据集最终整理器"""
    
    def __init__(self, domain: str = "medicine", output_format: str = "json", verbose: bool = False):
        self.domain = domain
        self.output_format = output_format.lower()
        self.verbose = verbose
        self.logger = get_logger(self.__class__.__name__)
        
        self.output_dir = SCRIPT_DIR / "outputs" / self.domain
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 统计信息
        self.stats = {
            "total_processed": 0,
            "successful_stealth": 0,           # 情况1：隐蔽化成功
            "failed_with_stealth": 0,          # 情况2：隐蔽化失败但有prompt
            "failed_no_stealth": 0,            # 情况3：完全没有隐蔽化prompt
            "domains": {},
            "categories": {}
        }

    def load_step5_data(self) -> List[Dict]:
        """加载step5的输出数据 - 包括成功和失败的隐蔽化数据"""
        self.logger.info("📂 加载Step5隐蔽化数据（包括失败案例）...")
        
        # 优先级列表：首先尝试加载完整数据集，然后是结果文件，最后是成功案例
        file_priority = [
            # 完整数据集文件（包含所有数据）
            f"step5_parallel_dataset_{self.domain}.json",
            f"step5_stealth_dataset_{self.domain}.json", 
            f"step5_attack_results_{self.domain}.json",
            f"step5_parallel_results_{self.domain}.json",
            
            # 如果没有完整数据集，退而求其次使用成功案例
            f"step5_successful_stealth_dataset_{self.domain}.json",
            f"step5_parallel_successful_{self.domain}.json"
        ]
        
        for filename in file_priority:
            file_path = self.output_dir / filename
            if file_path.exists():
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                    
                    # 统计成功和失败的数量
                    success_count = 0
                    fail_with_stealth = 0
                    fail_no_stealth = 0
                    
                    for record in data:
                        stealth_success = (record.get("success", False) or 
                                         record.get("stealth_success", False))
                        stealth_prompt = record.get("stealth_prompt")
                        
                        if stealth_success:
                            success_count += 1
                        elif stealth_prompt and stealth_prompt.strip():
                            fail_with_stealth += 1
                        else:
                            fail_no_stealth += 1
                    
                    self.logger.info(f"  ✅ 从 {filename} 加载了 {len(data)} 条记录")
                    self.logger.info(f"     🎯 情况1-隐蔽化成功: {success_count} 条")
                    self.logger.info(f"     � 情况2-隐蔽化失败但有prompt: {fail_with_stealth} 条")
                    self.logger.info(f"     ❌ 情况3-完全没有隐蔽化prompt: {fail_no_stealth} 条")
                    
                    # 检查是否是完整数据集还是只有成功案例
                    if fail_with_stealth == 0 and fail_no_stealth == 0 and "successful" in filename:
                        self.logger.warning(f"  ⚠️  当前文件只包含成功案例，可能缺少失败数据")
                        self.logger.warning(f"  💡 建议检查是否有包含完整数据的文件")
                    
                    return data
                    
                except Exception as e:
                    self.logger.warning(f"  ⚠️  加载 {filename} 失败: {e}")
                    continue
        
        # 如果没有找到任何step5文件，返回空列表
        self.logger.error("  ❌ 没有找到任何Step5输出文件")
        self.logger.error("  💡 请先运行Step5生成隐蔽化数据")
        self.logger.error("  📋 期望的文件名:")
        for filename in file_priority:
            self.logger.error(f"     - {filename}")
        return []

    def format_dataset_record(self, record: Dict, record_id: int) -> Dict:
        """将原始记录格式化为最终数据集格式"""
        
        # 提取核心信息
        entity = record.get("entity", "unknown")
        wikidata_uri = record.get("wikidata_uri", "unknown")
        original_prompt = record.get("original_prompt", "")
        stealth_prompt = record.get("stealth_prompt", None)
        
        # 判断隐蔽化情况
        stealth_success = (record.get("success", False) or 
                          record.get("stealth_success", False))
        
        # 处理隐蔽化prompt - 区分三种情况
        stealth_prompt_raw = record.get("stealth_prompt")
        
        if stealth_success:
            # 情况1：隐蔽化成功
            stealth_prompt = stealth_prompt_raw if stealth_prompt_raw else ""
            stealth_status = "success"
        elif stealth_prompt_raw and stealth_prompt_raw.strip():
            # 情况2：隐蔽化失败但有prompt（保留隐蔽化prompt但标记失败）
            stealth_prompt = stealth_prompt_raw
            stealth_status = "failed_with_prompt"
            stealth_success = False  # 确保标记为失败
        else:
            # 情况3：完全没有隐蔽化prompt
            stealth_prompt = None
            stealth_status = "failed_no_prompt"
            stealth_success = False  # 确保标记为失败
        
        category = record.get("category", "unknown")
        
        # 从metadata中提取信息（如果主记录中没有的话）
        metadata = record.get("metadata", {})
        if entity == "unknown" and isinstance(metadata, dict):
            entity = metadata.get("entity", "unknown")
        if wikidata_uri == "unknown" and isinstance(metadata, dict):
            wikidata_uri = metadata.get("wikidata_uri", "unknown")
        if category == "unknown" and isinstance(metadata, dict):
            category = metadata.get("category", "unknown")
        if not original_prompt and isinstance(metadata, dict):
            original_prompt = metadata.get("prompt", metadata.get("original_prompt", ""))
        
        # 确保wikidata_uri格式正确
        wikidata_uri = self._format_wikidata_uri(wikidata_uri, record.get('node_id', 'unknown'))
        
        # 构建最终记录
        final_record = {
            "id": record_id,
            "domain": self.domain,
            "entity": entity,
            "wikidata_uri": wikidata_uri,
            "original_prompt": original_prompt,
            "stealth_prompt": stealth_prompt,  # null for failed cases, string for success/partial cases
            "stealth_success": stealth_success,
            "category": category,
            
            # 原始prompt的属性
            "toxicity_score_original": self._extract_toxicity_score(record),
            
            # 隐蔽化prompt的属性（只要有隐蔽化prompt就保留分数，不管成功与否）
            "similarity_score_stealth": self._extract_similarity_score(record, stealth_prompt),
            "ppl_score_stealth": self._extract_ppl_score(record, stealth_prompt),
            
            # 隐蔽化过程的信息
            "winning_path": record.get("winning_path", "unknown"),
            "iterations": self._extract_iterations(record),
            "timestamp": record.get("timestamp", time.strftime("%Y-%m-%d %H:%M:%S"))
        }
        
        return final_record

    def _format_wikidata_uri(self, wikidata_uri: str, node_id: str) -> str:
        """确保wikidata_uri格式正确"""
        if wikidata_uri and wikidata_uri != "unknown":
            # 如果已经是完整的URI，直接返回
            if wikidata_uri.startswith("http://www.wikidata.org/entity/"):
                return wikidata_uri
            # 如果是Q开头的ID，转换为URI
            elif wikidata_uri.startswith("Q"):
                return f"http://www.wikidata.org/entity/{wikidata_uri}"
            # 如果是其他格式，尝试作为Q ID处理
            else:
                return f"http://www.wikidata.org/entity/Q{wikidata_uri}"
        else:
            # 如果没有wikidata_uri，使用node_id生成一个URI
            return f"http://www.wikidata.org/entity/Q{node_id}"

    def _extract_toxicity_score(self, record: Dict) -> float:
        """提取毒性分数"""
        try:
            # 尝试从多个可能的位置提取毒性分数
            
            # 方法1: 直接从record获取
            if "toxicity_evaluation" in record:
                toxicity_eval = record["toxicity_evaluation"]
                if isinstance(toxicity_eval, dict) and "toxicity_score" in toxicity_eval:
                    return float(toxicity_eval["toxicity_score"])
            
            # 方法2: 从metadata中获取
            metadata = record.get("metadata", {})
            if isinstance(metadata, dict) and "toxicity_evaluation" in metadata:
                toxicity_eval = metadata["toxicity_evaluation"]
                if isinstance(toxicity_eval, dict) and "toxicity_score" in toxicity_eval:
                    return float(toxicity_eval["toxicity_score"])
            
            # 方法3: 直接从record的toxicity_score字段获取（如果有的话）
            if "toxicity_score" in record:
                return float(record["toxicity_score"])
            
            # 方法4: 从metadata的toxicity_score字段获取
            if isinstance(metadata, dict) and "toxicity_score" in metadata:
                return float(metadata["toxicity_score"])
            
            return 0.0
        except (ValueError, TypeError):
            return 0.0

    def _extract_iterations(self, record: Dict) -> int:
        """提取迭代次数"""
        try:
            # 尝试从多个可能的位置提取迭代次数
            
            # 方法1: 直接从record获取iterations (step5并行版本)
            if "iterations" in record and record["iterations"] > 0:
                return int(record["iterations"])
            
            # 方法2: 直接从record获取total_iterations (step5原版)
            if "total_iterations" in record and record["total_iterations"] > 0:
                return int(record["total_iterations"])
            
            # 方法3: 从metadata中获取
            metadata = record.get("metadata", {})
            if isinstance(metadata, dict):
                if "iterations" in metadata and metadata["iterations"] > 0:
                    return int(metadata["iterations"])
                if "total_iterations" in metadata and metadata["total_iterations"] > 0:
                    return int(metadata["total_iterations"])
            
            # 方法4: 根据成功状态和路径推断迭代次数
            # 如果隐蔽化成功，至少应该有1次迭代
            stealth_success = (record.get("success", False) or 
                             record.get("stealth_success", False))
            if stealth_success:
                # 成功的记录至少有1次迭代
                return 1
            
            # 方法5: 根据winning_path推断
            winning_path = record.get("winning_path", "unknown")
            if winning_path != "unknown" and winning_path != "":
                # 有获胜路径意味着至少尝试了1次
                return 1
            
            return 0
        except (ValueError, TypeError):
            return 0

    def _extract_similarity_score(self, record: Dict, stealth_prompt: str) -> float:
        """提取相似度分数"""
        try:
            # 只要有隐蔽化prompt就尝试提取分数，不管成功与否
            if not stealth_prompt or stealth_prompt is None:
                return 0.0
            
            # 尝试从多个可能的位置提取相似度分数
            
            # 方法1: 从final_similarity字段获取
            if "final_similarity" in record and record["final_similarity"] is not None:
                return float(record["final_similarity"])
            
            # 方法2: 从similarity字段获取
            if "similarity" in record and record["similarity"] is not None:
                return float(record["similarity"])
            
            # 方法3: 从metadata中获取
            metadata = record.get("metadata", {})
            if isinstance(metadata, dict):
                if "final_similarity" in metadata and metadata["final_similarity"] is not None:
                    return float(metadata["final_similarity"])
                if "similarity" in metadata and metadata["similarity"] is not None:
                    return float(metadata["similarity"])
            
            return 0.0
        except (ValueError, TypeError):
            return 0.0

    def _extract_ppl_score(self, record: Dict, stealth_prompt: str) -> float:
        """提取困惑度分数"""
        try:
            # 只要有隐蔽化prompt就尝试提取分数，不管成功与否
            if not stealth_prompt or stealth_prompt is None:
                return 0.0
            
            # 尝试从多个可能的位置提取困惑度分数
            
            # 方法1: 从final_ppl字段获取
            if "final_ppl" in record and record["final_ppl"] is not None:
                return float(record["final_ppl"])
            
            # 方法2: 从ppl字段获取
            if "ppl" in record and record["ppl"] is not None:
                return float(record["ppl"])
            
            # 方法3: 从metadata中获取
            metadata = record.get("metadata", {})
            if isinstance(metadata, dict):
                if "final_ppl" in metadata and metadata["final_ppl"] is not None:
                    return float(metadata["final_ppl"])
                if "ppl" in metadata and metadata["ppl"] is not None:
                    return float(metadata["ppl"])
            
            return 0.0
        except (ValueError, TypeError):
            return 0.0

    def process_dataset(self) -> List[Dict]:
        """处理数据集，生成最终格式"""
        self.logger.info("🔄 处理数据集，生成最终格式...")
        
        # 加载原始数据
        raw_data = self.load_step5_data()
        if not raw_data:
            return []
        
        final_dataset = []
        record_id = 1
        
        for record in raw_data:
            try:
                # 格式化记录
                formatted_record = self.format_dataset_record(record, record_id)
                final_dataset.append(formatted_record)
                
                # 更新统计信息
                self._update_stats(formatted_record)
                
                record_id += 1
                
                if self.verbose and record_id <= 5:
                    self.logger.info(f"  样例 {record_id-1}: {formatted_record['entity']} -> {formatted_record['category']}")
                
            except Exception as e:
                self.logger.warning(f"  ⚠️  处理记录失败: {e}")
                continue
        
        self.logger.info(f"  ✅ 成功处理 {len(final_dataset)} 条记录")
        return final_dataset

    def _update_stats(self, record: Dict):
        """更新统计信息"""
        self.stats["total_processed"] += 1
        
        # 根据三种情况统计
        stealth_success = record.get("stealth_success", False)
        stealth_prompt = record.get("stealth_prompt")
        
        if stealth_success:
            # 情况1：隐蔽化成功
            self.stats["successful_stealth"] += 1
        elif stealth_prompt and stealth_prompt.strip():
            # 情况2：隐蔽化失败但有prompt
            self.stats["failed_with_stealth"] += 1
        else:
            # 情况3：完全没有隐蔽化prompt
            self.stats["failed_no_stealth"] += 1
        
        # 更新分类统计
        domain = record["domain"]
        category = record["category"]
        
        self.stats["domains"][domain] = self.stats["domains"].get(domain, 0) + 1
        self.stats["categories"][category] = self.stats["categories"].get(category, 0) + 1

    def save_dataset(self, dataset: List[Dict]) -> bool:
        """保存最终数据集"""
        if not dataset:
            self.logger.error("❌ 没有数据可保存")
            return False
        
        try:
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            
            if self.output_format == "json":
                # 保存JSON格式
                filename = f"final_huggingface_dataset_{self.domain}_{timestamp}.json"
                output_path = self.output_dir / filename
                
                with open(output_path, 'w', encoding='utf-8') as f:
                    json.dump(dataset, f, indent=2, ensure_ascii=False)
                
                self.logger.info(f"📄 JSON数据集已保存: {filename}")
                
            elif self.output_format == "csv":
                # 保存CSV格式
                filename = f"final_huggingface_dataset_{self.domain}_{timestamp}.csv"
                output_path = self.output_dir / filename
                
                if dataset:
                    fieldnames = dataset[0].keys()
                    with open(output_path, 'w', newline='', encoding='utf-8') as f:
                        writer = csv.DictWriter(f, fieldnames=fieldnames)
                        writer.writeheader()
                        writer.writerows(dataset)
                
                self.logger.info(f"📊 CSV数据集已保存: {filename}")
                
            else:
                # 保存两种格式
                # JSON
                json_filename = f"final_huggingface_dataset_{self.domain}_{timestamp}.json"
                json_path = self.output_dir / json_filename
                with open(json_path, 'w', encoding='utf-8') as f:
                    json.dump(dataset, f, indent=2, ensure_ascii=False)
                
                # CSV
                csv_filename = f"final_huggingface_dataset_{self.domain}_{timestamp}.csv"
                csv_path = self.output_dir / csv_filename
                if dataset:
                    fieldnames = dataset[0].keys()
                    with open(csv_path, 'w', newline='', encoding='utf-8') as f:
                        writer = csv.DictWriter(f, fieldnames=fieldnames)
                        writer.writeheader()
                        writer.writerows(dataset)
                
                self.logger.info(f"📄 数据集已保存: {json_filename}, {csv_filename}")
            
            # 保存统计报告
            self._save_stats_report(timestamp)
            
            return True
            
        except Exception as e:
            self.logger.error(f"❌ 保存数据集失败: {e}")
            return False

    def _save_stats_report(self, timestamp: str):
        """保存统计报告"""
        try:
            report = {
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "domain": self.domain,
                "statistics": self.stats,
                "summary": {
                    "total_records": self.stats["total_processed"],
                    "stealth_success_rate": self.stats["successful_stealth"] / max(self.stats["total_processed"], 1),
                    "stealth_attempt_rate": (self.stats["successful_stealth"] + self.stats["failed_with_stealth"]) / max(self.stats["total_processed"], 1),
                    "conditional_success_rate": self.stats["successful_stealth"] / max(self.stats["successful_stealth"] + self.stats["failed_with_stealth"], 1),
                    "unique_categories": len(self.stats["categories"]),
                    "situation_breakdown": {
                        "success": self.stats["successful_stealth"],
                        "failed_with_prompt": self.stats["failed_with_stealth"], 
                        "failed_no_prompt": self.stats["failed_no_stealth"]
                    }
                }
            }
            
            report_filename = f"step6_finalization_report_{self.domain}_{timestamp}.json"
            report_path = self.output_dir / report_filename
            
            with open(report_path, 'w', encoding='utf-8') as f:
                json.dump(report, f, indent=2, ensure_ascii=False)
            
            self.logger.info(f"📊 统计报告已保存: {report_filename}")
            
        except Exception as e:
            self.logger.warning(f"⚠️  保存统计报告失败: {e}")

    def print_summary(self):
        """打印处理摘要"""
        self.logger.info("\n" + "="*60)
        self.logger.info(f"📊 Step6 数据集整理完成 - {self.domain.upper()}领域")
        self.logger.info("="*60)
        
        self.logger.info(f"📋 数据集统计（三种情况）:")
        self.logger.info(f"   总记录数: {self.stats['total_processed']}")
        self.logger.info(f"   🎯 情况1-隐蔽化成功: {self.stats['successful_stealth']}")
        self.logger.info(f"   📝 情况2-隐蔽化失败但有prompt: {self.stats['failed_with_stealth']}")
        self.logger.info(f"   ❌ 情况3-完全没有隐蔽化prompt: {self.stats['failed_no_stealth']}")
        
        if self.stats['total_processed'] > 0:
            success_rate = self.stats['successful_stealth'] / self.stats['total_processed']
            attempt_rate = (self.stats['successful_stealth'] + self.stats['failed_with_stealth']) / self.stats['total_processed']
            
            self.logger.info(f"")
            self.logger.info(f"📊 成功率统计:")
            self.logger.info(f"   隐蔽化成功率: {success_rate:.1%}")
            self.logger.info(f"   隐蔽化尝试率: {attempt_rate:.1%} (至少生成了隐蔽化prompt)")
            
            # 在有隐蔽化尝试的记录中的成功率
            attempted_count = self.stats['successful_stealth'] + self.stats['failed_with_stealth']
            if attempted_count > 0:
                conditional_success_rate = self.stats['successful_stealth'] / attempted_count
                self.logger.info(f"   在有隐蔽化尝试的记录中成功率: {conditional_success_rate:.1%}")
        
        self.logger.info(f"\n📈 按类别分布:")
        for category, count in sorted(self.stats['categories'].items()):
            percentage = (count / self.stats['total_processed'] * 100) if self.stats['total_processed'] > 0 else 0
            self.logger.info(f"   {category}: {count} ({percentage:.1f}%)")
        
        # 数据质量提醒
        self.logger.info(f"\n💡 数据集说明:")
        self.logger.info(f"   ✅ 包含所有原始prompts（保证数据完整性）")
        self.logger.info(f"   🎯 情况1: stealth_success=true, stealth_prompt有值 - 隐蔽化成功")
        self.logger.info(f"   📝 情况2: stealth_success=false, stealth_prompt有值 - 隐蔽化失败但有prompt")
        self.logger.info(f"   ❌ 情况3: stealth_success=false, stealth_prompt=null - 完全没有隐蔽化prompt")
        self.logger.info(f"   🔍 所有记录都保留original_prompt用于研究")

    def run_finalization(self) -> bool:
        """运行完整的数据集整理流程"""
        self.logger.info(f"\n🎯 开始Step6数据集最终整理 (领域: {self.domain}, 格式: {self.output_format.upper()})")
        self.logger.info("="*60)
        
        try:
            # 1. 处理数据集
            final_dataset = self.process_dataset()
            if not final_dataset:
                self.logger.error("❌ 没有数据可处理")
                return False
            
            # 2. 保存数据集
            if not self.save_dataset(final_dataset):
                self.logger.error("❌ 保存数据集失败")
                return False
            
            # 3. 打印摘要
            self.print_summary()
            
            self.logger.info("\n🎉 Step6数据集整理完成！")
            self.logger.info("💡 数据集现在可以上传到Hugging Face了")
            
            return True
            
        except Exception as e:
            self.logger.error(f"❌ Step6整理失败: {e}")
            return False

def main():
    parser = argparse.ArgumentParser(description="Step 6: 数据集最终整理")
    parser.add_argument('--domain', type=str, default='medicine', help='领域名称')
    parser.add_argument('--output-format', type=str, default='both', 
                       choices=['json', 'csv', 'both'], help='输出格式')
    parser.add_argument('--verbose', action='store_true', help='详细输出模式')
    
    args = parser.parse_args()
    
    # 创建并运行整理器
    finalizer = DatasetFinalizer(
        domain=args.domain,
        output_format=args.output_format,
        verbose=args.verbose
    )
    
    success = finalizer.run_finalization()
    sys.exit(0 if success else 1)

if __name__ == "__main__":
    main()