#!/usr/bin/env python
"""
Step 4: 数据过滤测试
===================

对Step 3有害程度评估后的数据进行筛选：
1. 根据有害程度阈值进行第一次筛选
2. 评估通过数据的PPL值
3. 根据PPL阈值进行第二次筛选
4. 保存最终筛选结果供Step 5使用

使用方法:
    python step4_dataset_assembly.py --domain medicine

"""

import argparse
import sys
import time
import json
from pathlib import Path
from typing import Dict, List, Any

# 添加项目根目录到Python路径
SCRIPT_DIR = Path(__file__).parent
PROJECT_ROOT = SCRIPT_DIR.parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.utils.logger_utils import get_logger

class DataFilterTester:
    """数据过滤测试器 - 简化版本专注于核心筛选功能"""
    
    def __init__(self, domain: str = "medicine", verbose: bool = False):
        self.domain = domain
        self.verbose = verbose
        self.logger = get_logger(self.__class__.__name__)
        
        self.output_dir = SCRIPT_DIR / "outputs" / self.domain
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # 加载配置
        from src.utils.config_manager import ConfigManager
        self.config_manager = ConfigManager()
        self.domain_config = None
        
        # 加载阈值配置
        try:
            self.domain_config = self.config_manager.load_domain_config(self.domain)
            filtering_thresholds = getattr(self.domain_config, 'filtering_thresholds', {})
            self.toxicity_threshold = filtering_thresholds.get('toxicity_filter', 0.9)
            self.ppl_threshold = filtering_thresholds.get('perplexity', {}).get('initial_filter', 50.0)
            self.logger.info(f"加载配置: 有害程度阈值={self.toxicity_threshold*100:.0f}%, PPL阈值={self.ppl_threshold}")
        except Exception as e:
            self.logger.warning(f"加载配置失败，使用默认值: {e}")
            self.toxicity_threshold = 0.9
            self.ppl_threshold = 50.0
        
        self.filtered_data = []
    
    def load_step3_data(self) -> List[Dict]:
        """加载Step 3评估结果"""
        self.logger.info("📂 加载Step 3评估数据...")
        
        step3_file = self.output_dir / "step3_evaluated_prompts.json"
        if not step3_file.exists():
            self.logger.error(f"Step 3数据文件不存在: {step3_file}")
            return []
        
        try:
            with open(step3_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            self.logger.info(f"✅ 加载了 {len(data)} 条Step 3评估数据")
            return data
        except Exception as e:
            self.logger.error(f"加载Step 3数据失败: {e}")
            return []
    
    def filter_by_toxicity(self, data: List[Dict]) -> List[Dict]:
        """第一步：根据有害程度筛选"""
        self.logger.info(f"🔍 第一步筛选: 有害程度 >= {self.toxicity_threshold*100:.0f}%")
        
        filtered = []
        for item in data:
            toxicity_eval = item.get('toxicity_evaluation', {})
            toxicity_score = toxicity_eval.get('toxicity_score', 0)
            
            if isinstance(toxicity_score, (int, float)) and toxicity_score >= self.toxicity_threshold:
                filtered.append(item)
                if self.verbose:
                    prompt_id = item.get('prompt_id', 'unknown')
                    self.logger.info(f"  ✅ 保留: {prompt_id} (有害程度: {toxicity_score*100:.1f}%)")
            elif self.verbose:
                prompt_id = item.get('prompt_id', 'unknown')
                self.logger.info(f"  ❌ 过滤: {prompt_id} (有害程度: {toxicity_score*100:.1f}%)")
        
        self.logger.info(f"第一步筛选结果: {len(filtered)}/{len(data)} 条数据通过")
        return filtered
    
    def evaluate_and_filter_ppl(self, data: List[Dict]) -> List[Dict]:
        """第二步：评估PPL并根据PPL阈值筛选"""
        self.logger.info(f"🔍 第二步筛选: 评估PPL并筛选 <= {self.ppl_threshold}")
        
        # 初始化PPL评估服务
        try:
            from src.services.evaluation.fluency_evaluator_service import FluencyEvaluatorService
            fluency_evaluator = FluencyEvaluatorService()
        except Exception as e:
            self.logger.error(f"❌ 无法初始化PPL评估服务: {e}")
            # 如果无法初始化PPL评估服务，直接返回所有数据
            return data
        
        filtered = []
        for item in data:
            prompt_text = item.get('prompt', '')
            prompt_id = item.get('prompt_id', 'unknown')
            
            try:
                # 使用真实的PPL评估
                fluency_result = fluency_evaluator.evaluate_fluency(prompt_text)
                ppl_score = fluency_result.perplexity_score
                
                # 添加PPL信息到数据中
                item['ppl_score'] = ppl_score
                
                if ppl_score <= self.ppl_threshold:
                    filtered.append(item)
                    if self.verbose:
                        self.logger.info(f"  ✅ 保留: {prompt_id} (PPL: {ppl_score:.1f})")
                elif self.verbose:
                    self.logger.info(f"  ❌ 过滤: {prompt_id} (PPL: {ppl_score:.1f})")
                    
            except Exception as e:
                self.logger.error(f"  ❌ PPL评估失败: {prompt_id} - {e}")
                # 评估失败直接抛出异常，不继续处理
                raise RuntimeError(f"PPL评估失败，无法继续筛选: {e}")
        
        self.logger.info(f"第二步筛选结果: {len(filtered)}/{len(data)} 条数据通过")
        return filtered
    
    def save_filtered_data(self, data: List[Dict]) -> bool:
        """保存筛选后的数据"""
        self.logger.info("💾 保存筛选后的数据...")
        
        if not data:
            self.logger.warning("没有数据需要保存")
            return False
        
        # 保留完整数据结构，而不是简化
        full_data = []
        for item in data:
            toxicity_eval = item.get('toxicity_evaluation', {})
            # 保留完整的原始数据，并添加筛选信息
            full_item = {
                **item,  # 保留所有原始字段
                "filtering_info": {
                    "toxicity_threshold": self.toxicity_threshold,
                    "ppl_threshold": self.ppl_threshold,
                    "passed_toxicity_filter": True,  # 能到这里说明通过了毒性筛选
                    "passed_ppl_filter": item.get('ppl_score', 0) <= self.ppl_threshold
                },
                "filter_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
            }
            full_data.append(full_item)
        
        # 保存完整的数据
        output_file = self.output_dir / "step4_filtered_prompts.json"
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(full_data, f, indent=2, ensure_ascii=False)
            
            self.logger.info(f"✅ 保存了 {len(full_data)} 条筛选后的数据: {output_file.name}")
            
            # 保存筛选统计
            stats = {
                "domain": self.domain,
                "filtering_config": {
                    "toxicity_threshold": self.toxicity_threshold,
                    "ppl_threshold": self.ppl_threshold
                },
                "results": {
                    "final_count": len(full_data),
                    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
                }
            }
            
            stats_file = self.output_dir / "step4_filtering_report.json"
            with open(stats_file, 'w', encoding='utf-8') as f:
                json.dump(stats, f, indent=2, ensure_ascii=False)
            
            self.logger.info(f"📊 筛选报告已保存: {stats_file.name}")
            return True
            
        except Exception as e:
            self.logger.error(f"保存数据失败: {e}")
            return False
    
    def run_filtering(self) -> bool:
        """运行完整的数据筛选流程"""
        self.logger.info(f"🔄 开始数据筛选 (领域: {self.domain})")
        self.logger.info("="*50)
        
        # 1. 加载Step 3数据
        step3_data = self.load_step3_data()
        if not step3_data:
            self.logger.error("❌ 没有可处理的数据")
            return False
        
        # 2. 第一步筛选：有害程度
        toxicity_filtered = self.filter_by_toxicity(step3_data)
        if not toxicity_filtered:
            self.logger.error("❌ 有害程度筛选后无数据")
            return False
        
        # 3. 第二步筛选：PPL评估和筛选
        final_filtered = self.evaluate_and_filter_ppl(toxicity_filtered)
        if not final_filtered:
            self.logger.error("❌ PPL筛选后无数据")
            return False
        
        # 4. 保存结果
        success = self.save_filtered_data(final_filtered)
        
        if success:
            self.logger.info("="*50)
            self.logger.info(f"✅ 数据筛选完成:")
            self.logger.info(f"   原始数据: {len(step3_data)} 条")
            self.logger.info(f"   有害程度筛选后: {len(toxicity_filtered)} 条")
            self.logger.info(f"   最终筛选结果: {len(final_filtered)} 条")
            self.logger.info(f"   整体保留率: {len(final_filtered)/len(step3_data)*100:.1f}%")
        else:
            self.logger.error("❌ 数据筛选失败")
        
        return success

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="Step 4: 数据过滤测试")
    parser.add_argument("--domain", type=str, default="medicine", 
                      choices=["medicine", "finance", "education", "law"],
                      help="领域")
    parser.add_argument("--verbose", action="store_true", help="详细输出")
    
    args = parser.parse_args()
    
    # 创建测试器
    tester = DataFilterTester(
        domain=args.domain,
        verbose=args.verbose
    )
    
    # 运行数据筛选
    success = tester.run_filtering()
    
    if success:
        print(f"\n✅ Step 4 完成: 数据过滤成功")
        print(f"📊 筛选后数据已保存到 step4_filtered_prompts.json")
    else:
        print(f"\n❌ Step 4 失败: 数据过滤未通过")
    
    # 退出
    sys.exit(0 if success else 1)


if __name__ == "__main__":
    main()
