#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
配置管理器
统一管理pipeline的所有配置参数
"""

import os
import yaml
from pathlib import Path
from typing import Dict, Any, Optional
from dataclasses import dataclass
import logging

logger = logging.getLogger(__name__)

@dataclass
class DatabaseConfig:
    """数据库配置"""
    neo4j_uri: str
    neo4j_user: str
    neo4j_password: str

@dataclass 
class ModelConfig:
    """模型配置"""
    model_name: str
    server_url: str
    max_tokens: int
    temperature: float
    top_p: float

@dataclass
class DomainConfig:
    """领域配置"""
    name: str
    description: str
    test_params: Dict[str, Any]  # 测试参数配置
    wikidata_seeds: Dict[str, Any]
    retrieval_params: Dict[str, Any]
    harm_categories: Dict[str, Any]  # 修改为Any以支持新的结构
    filtering_thresholds: Dict[str, Any]  # 数据过滤阈值配置
    attack_config: Dict[str, Any]  # 迭代攻击配置
    output_paths: Dict[str, str]

@dataclass
class PipelineConfig:
    """Pipeline完整配置"""
    global_config: Dict[str, Any]
    database: DatabaseConfig
    models: Dict[str, ModelConfig]
    stage1: Dict[str, Any]
    stage2: Dict[str, Any]
    output: Dict[str, Any]
    performance: Dict[str, Any]
    experiment_tracking: Dict[str, Any]

class ConfigManager:
    """配置管理器"""
    
    def __init__(self, config_dir: Optional[str] = None):
        """
        初始化配置管理器
        
        Args:
            config_dir: 配置文件目录，默认为当前包的configs目录
        """
        if config_dir is None:
            config_dir = Path(__file__).parent.parent.parent / "configs"
        
        self.config_dir = Path(config_dir)
        self.main_config_path = self.config_dir / "pipeline_config.yaml"
        self.domains_config_dir = self.config_dir / "domains"
        
        # 确保配置目录存在
        self.config_dir.mkdir(parents=True, exist_ok=True)
        self.domains_config_dir.mkdir(parents=True, exist_ok=True)
        
        # 加载主配置
        self._load_main_config()
        
    def _load_main_config(self) -> None:
        """加载主配置文件"""
        try:
            with open(self.main_config_path, 'r', encoding='utf-8') as f:
                self.main_config = yaml.safe_load(f)
                logger.info(f"成功加载主配置文件: {self.main_config_path}")
        except FileNotFoundError:
            logger.error(f"主配置文件不存在: {self.main_config_path}")
            raise
        except yaml.YAMLError as e:
            logger.error(f"解析主配置文件失败: {e}")
            raise
            
    def load_domain_config(self, domain_name: str) -> DomainConfig:
        """
        加载领域特定配置
        
        Args:
            domain_name: 领域名称（如 medicine, finance, education）
            
        Returns:
            DomainConfig对象
        """
        domain_config_path = self.domains_config_dir / f"{domain_name}.yaml"
        
        try:
            with open(domain_config_path, 'r', encoding='utf-8') as f:
                domain_data = yaml.safe_load(f)
                logger.info(f"成功加载领域配置文件: {domain_config_path}")
                
                return DomainConfig(
                    name=domain_data['domain']['name'],
                    description=domain_data['domain']['description'],
                    test_params=domain_data.get('test_params', {}),
                    wikidata_seeds=domain_data['wikidata_seeds'],
                    retrieval_params=domain_data['retrieval_params'],
                    harm_categories=domain_data['harm_categories'],
                    filtering_thresholds=domain_data.get('filtering_thresholds', {}),
                    attack_config=domain_data.get('attack_config', {}),
                    output_paths=domain_data['output_paths']
                )
        except FileNotFoundError:
            logger.error(f"领域配置文件不存在: {domain_config_path}")
            raise
        except yaml.YAMLError as e:
            logger.error(f"解析领域配置文件失败: {e}")
            raise
            
    def get_database_config(self) -> DatabaseConfig:
        """获取数据库配置"""
        db_config = self.main_config['database']['neo4j']
        return DatabaseConfig(
            neo4j_uri=db_config['uri'],
            neo4j_user=db_config['user'], 
            neo4j_password=db_config['password']
        )
        
    def get_model_config(self, model_type: str) -> ModelConfig:
        """
        获取模型配置
        
        Args:
            model_type: 模型类型（harmful_prompt_generator, toxicity_evaluator, implicit）
        """
        model_config = self.main_config['models'][model_type]
        return ModelConfig(
            model_name=model_config['model_name'],
            server_url=model_config['server_url'],
            max_tokens=model_config['max_tokens'],
            temperature=model_config['temperature'],
            top_p=model_config['top_p']
        )
        
    def get_stage1_config(self) -> Dict[str, Any]:
        """获取阶段1配置"""
        return self.main_config['stage1']
        
    def get_stage2_config(self) -> Dict[str, Any]:
        """获取阶段2配置"""
        return self.main_config['stage2']
        
    def get_output_config(self) -> Dict[str, Any]:
        """获取输出配置"""
        return self.main_config['output']
        
    def get_performance_config(self) -> Dict[str, Any]:
        """获取性能配置"""
        return self.main_config['performance']
        
    def get_complete_config(self, domain_name: str) -> PipelineConfig:
        """
        获取完整的pipeline配置
        
        Args:
            domain_name: 领域名称
            
        Returns:
            完整的PipelineConfig对象
        """
        domain_config = self.load_domain_config(domain_name)
        
        # 合并主配置和领域配置
        merged_config = PipelineConfig(
            global_config=self.main_config['global'],
            database=self.get_database_config(),
            models={
                'harmful_prompt_generator': self.get_model_config('harmful_prompt_generator'),
                'toxicity_evaluator': self.get_model_config('toxicity_evaluator'), 
                'implicit': self.get_model_config('implicit')
            },
            stage1=self.get_stage1_config(),
            stage2=self.get_stage2_config(),
            output=self.get_output_config(),
            performance=self.get_performance_config(),
            experiment_tracking=self.main_config['experiment_tracking']
        )
        
        return merged_config, domain_config
        
    def update_config(self, config_updates: Dict[str, Any]) -> None:
        """
        更新配置文件
        
        Args:
            config_updates: 配置更新字典
        """
        # 递归更新配置
        def deep_update(base_dict: Dict[str, Any], update_dict: Dict[str, Any]) -> None:
            for key, value in update_dict.items():
                if key in base_dict and isinstance(base_dict[key], dict) and isinstance(value, dict):
                    deep_update(base_dict[key], value)
                else:
                    base_dict[key] = value
        
        deep_update(self.main_config, config_updates)
        
        # 保存更新后的配置
        with open(self.main_config_path, 'w', encoding='utf-8') as f:
            yaml.dump(self.main_config, f, default_flow_style=False, allow_unicode=True)
            
        logger.info("配置文件已更新")
        
    def list_available_domains(self) -> list:
        """列出所有可用的领域配置"""
        domain_files = list(self.domains_config_dir.glob("*.yaml"))
        return [f.stem for f in domain_files]
        
    def validate_config(self, domain_name: str) -> bool:
        """
        验证配置文件的完整性
        
        Args:
            domain_name: 领域名称
            
        Returns:
            验证是否通过
        """
        try:
            # 验证主配置
            required_main_keys = ['global', 'database', 'models', 'stage1', 'stage2', 'output']
            for key in required_main_keys:
                if key not in self.main_config:
                    logger.error(f"主配置缺少必需键: {key}")
                    return False
                    
            # 验证领域配置
            domain_config = self.load_domain_config(domain_name)
            if not domain_config.name:
                logger.error("领域配置缺少name字段")
                return False
                
            logger.info("配置验证通过")
            return True
            
        except Exception as e:
            logger.error(f"配置验证失败: {e}")
            return False

def get_config_manager() -> ConfigManager:
    """获取配置管理器实例（单例模式）"""
    if not hasattr(get_config_manager, 'instance'):
        get_config_manager.instance = ConfigManager()
    return get_config_manager.instance

if __name__ == "__main__":
    # 测试配置管理器
    config_manager = ConfigManager()
    
    # 测试加载主配置
    print("测试主配置加载...")
    db_config = config_manager.get_database_config()
    print(f"数据库配置: {db_config}")
    
    # 测试加载领域配置
    print("\n测试领域配置加载...")
    available_domains = config_manager.list_available_domains()
    print(f"可用领域: {available_domains}")
    
    if available_domains:
        domain_name = available_domains[0]
        print(f"\n加载 {domain_name} 领域配置...")
        domain_config = config_manager.load_domain_config(domain_name)
        print(f"领域配置: {domain_config.name} - {domain_config.description}")
        
        # 测试完整配置
        print(f"\n获取 {domain_name} 完整配置...")
        pipeline_config, domain_config = config_manager.get_complete_config(domain_name)
        print(f"Pipeline配置获取成功")
        
        # 验证配置
        print(f"\n验证 {domain_name} 配置...")
        is_valid = config_manager.validate_config(domain_name)
        print(f"配置验证结果: {is_valid}")
