#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
主控制器 - 统一管理消融实验和可视化
Main Controller - Unified management of ablation studies and visualization
"""

import os
import sys
import json
import logging
import argparse
from datetime import datetime
from typing import List, Dict, Optional
from experiments.ablation.ablation_studies import AblationStudyRunner, AblationResult
from pel_nas.visualization.unified_visualizer import UnifiedAblationVisualizer
from pel_nas.core.config import DATA_CONFIG

def setup_logging(verbose: bool = False, output_dir: str = None):
    """设置日志配置"""
    level = logging.DEBUG if verbose else logging.INFO
    
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        log_file = os.path.join(output_dir, 'ablation_study.log')
    else:
        os.makedirs('outputs', exist_ok=True)
        log_file = 'outputs/ablation_study.log'
    
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(log_file, mode='w')
        ]
    )

class MainController:
    """主控制器类"""
    
    def __init__(self, dataset: str = 'cifar100', devices: List[str] = None, verbose: bool = False, iterations: int = 20):
        """初始化主控制器
        
        Args:
            dataset: 数据集名称
            devices: 要测试的设备列表
            verbose: 是否启用详细日志
            iterations: 每个消融方法的轮次数量
        """
        self.dataset = dataset
        self.devices = devices or ['edgegpu', 'raspi4', 'edgetpu', 'pixel3', 'eyeriss', 'fpga']
        self.verbose = verbose
        self.iterations = iterations
        
        # 移除快速测试模式检查，架构数量由轮次决定
        self.is_quick_test = False
        

        
        # 初始化组件
        self.visualizer = UnifiedAblationVisualizer(dataset)
        
        # 检查必要文件
        self._check_requirements()
        
        logging.info(f"🎛️ 主控制器初始化完成 - {dataset}")
    
    def _check_requirements(self):
        """检查运行要求"""
        logger = logging.getLogger(__name__)
        
        # 检查OpenAI API密钥
        if not os.getenv('OPENAI_API_KEY'):
            logger.error("❌ OPENAI_API_KEY环境变量未设置")
            logger.info("💡 请设置: export OPENAI_API_KEY='your-api-key'")
            sys.exit(1)
        
        # 检查数据文件
        metrics_csv = DATA_CONFIG.get('metrics_csv')
        if not metrics_csv or not os.path.exists(metrics_csv):
            logger.error(f"❌ 未找到主数据文件 nb201_hw_metrics.csv: {metrics_csv}")
            sys.exit(1)

        zc_features = DATA_CONFIG.get('zero_cost_features')
        if zc_features and not os.path.exists(zc_features):
            logger.warning(f"⚠️ 零成本特征文件缺失: {zc_features}")
            logger.warning("   预测器模式将自动退回真实精度")
    
    def run_single_device_ablation(self, device: str) -> List[AblationResult]:
        """为单个设备运行消融实验
        
        Args:
            device: 设备名称
            
        Returns:
            该设备的消融实验结果列表
        """
        logger = logging.getLogger(__name__)
        logger.info(f"🔬 开始设备 {device} 的消融实验...")
        
        # 初始化消融实验运行器，传递轮次参数
        runner = AblationStudyRunner(self.dataset, device, self.iterations)
        
        # 运行所有消融实验
        results = runner.run_all_ablations()
        
        logger.info(f"✅ 设备 {device} 消融实验完成，共 {len(results)} 个方法")
        return results
    
    def run_all_devices_ablation(self) -> Dict[str, List[AblationResult]]:
        """为所有设备运行消融实验
        
        Returns:
            所有设备的消融实验结果
        """
        logger = logging.getLogger(__name__)
        logger.info(f"🚀 开始所有设备的消融实验 - {self.dataset}")
        
        all_results = {}
        
        for i, device in enumerate(self.devices, 1):
            logger.info(f"\n{'='*80}")
            logger.info(f"📱 设备进度: {i}/{len(self.devices)} - {device.upper()}")
            logger.info(f"{'='*80}")
            
            try:
                device_results = self.run_single_device_ablation(device)
                if device_results:
                    all_results[device] = device_results
                    logger.info(f"✅ {device} 完成: {len(device_results)} 个方法成功")
                else:
                    logger.warning(f"⚠️ {device} 没有成功的消融实验结果")
            
            except Exception as e:
                logger.error(f"❌ 设备 {device} 消融实验失败: {e}")
                continue
        
        logger.info(f"\n🎉 所有设备消融实验完成！")
        logger.info(f"📊 成功设备: {len(all_results)}/{len(self.devices)}")
        
        return all_results
    
    def save_ablation_results(self, all_results: Dict[str, List[AblationResult]], output_dir: str):
        """保存消融实验结果
        
        Args:
            all_results: 所有设备的消融实验结果
            output_dir: 输出目录
        """
        logger = logging.getLogger(__name__)
        
        # 创建输出目录
        os.makedirs(output_dir, exist_ok=True)
        
        # 转换为可序列化的格式
        serializable_results = {}
        
        for device, device_results in all_results.items():
            serializable_results[device] = []
            
            for result in device_results:
                # 转换架构信息为字典
                architectures_data = []
                for arch in result.architectures:
                    architectures_data.append({
                        'arch_index': arch.arch_index,
                        'arch_str': arch.arch_str,
                        'flops': arch.flops,
                        'params': arch.params,
                        'accuracy': arch.accuracy,
                        'latency': arch.latency,
                        'energy': arch.energy,
                        'conv_category': arch.conv_category,
                        'conv_3x3_count': arch.conv_3x3_count,
                        'conv_1x1_count': arch.conv_1x1_count,
                        'is_valid': arch.is_valid
                    })
                
                # 转换帕累托前沿
                pareto_data = []
                for arch in result.pareto_front:
                    pareto_data.append({
                        'arch_index': arch.arch_index,
                        'arch_str': arch.arch_str,
                        'flops': arch.flops,
                        'params': arch.params,
                        'accuracy': arch.accuracy,
                        'latency': arch.latency,
                        'energy': arch.energy,
                        'conv_category': arch.conv_category,
                        'conv_3x3_count': arch.conv_3x3_count,
                        'conv_1x1_count': arch.conv_1x1_count,
                        'is_valid': arch.is_valid
                    })
                
                serializable_results[device].append({
                    'method_name': result.method_name,
                    'dataset': result.dataset,
                    'hardware_device': result.hardware_device,
                    'architectures': architectures_data,
                    'pareto_front': pareto_data,
                    'execution_time': result.execution_time,
                    'method_config': result.method_config,
                    'total_architectures': len(result.architectures),
                    'pareto_size': len(result.pareto_front)
                })
        
        # 保存完整结果
        results_file = os.path.join(output_dir, f'ablation_results_{self.dataset}.json')
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump({
                'experiment_type': 'ablation_study',
                'dataset': self.dataset,
                'devices': list(all_results.keys()),
                'timestamp': datetime.now().isoformat(),
                'results': serializable_results
            }, f, indent=2, ensure_ascii=False)
        
        logger.info(f"💾 消融实验结果已保存: {os.path.basename(results_file)}")
        
        # 生成简要统计
        self._generate_summary_stats(all_results, output_dir)
    
    def _generate_summary_stats(self, all_results: Dict[str, List[AblationResult]], output_dir: str):
        """生成统计摘要"""
        logger = logging.getLogger(__name__)
        
        stats_file = os.path.join(output_dir, f'ablation_summary_{self.dataset}.txt')
        
        with open(stats_file, 'w', encoding='utf-8') as f:
            f.write(f"Ablation Study Summary - {self.dataset.upper()}\n")
            f.write("=" * 50 + "\n\n")
            
            f.write("EXPERIMENT OVERVIEW:\n")
            f.write(f"• Dataset: {self.dataset}\n")
            f.write(f"• Devices Tested: {len(all_results)}\n")
            f.write(f"• Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            # 方法统计
            all_methods = set()
            for device_results in all_results.values():
                for result in device_results:
                    all_methods.add(result.method_name)
            
            f.write(f"METHODS COMPARED ({len(all_methods)}):\n")
            for i, method in enumerate(sorted(all_methods), 1):
                f.write(f"{i}. {method}\n")
            f.write("\n")
            
            # 设备统计
            f.write("DEVICE RESULTS:\n")
            f.write("-" * 30 + "\n")
            
            total_architectures = 0
            total_pareto = 0
            
            for device, device_results in all_results.items():
                f.write(f"\n{device.upper()}:\n")
                
                device_archs = 0
                device_pareto = 0
                
                for result in device_results:
                    method_archs = len(result.architectures)
                    method_pareto = len(result.pareto_front)
                    
                    device_archs += method_archs
                    device_pareto += method_pareto
                    
                    f.write(f"  {result.method_name}:\n")
                    f.write(f"    Architectures: {method_archs}\n")
                    f.write(f"    Pareto Front: {method_pareto}\n")
                    f.write(f"    Execution Time: {result.execution_time:.2f}s\n")
                
                f.write(f"  Total: {device_archs} architectures, {device_pareto} Pareto\n")
                total_architectures += device_archs
                total_pareto += device_pareto
            
            f.write(f"\nOVERALL TOTALS:\n")
            f.write(f"• Total Architectures Generated: {total_architectures}\n")
            f.write(f"• Total Pareto Front Points: {total_pareto}\n")
            f.write(f"• Average Architectures per Method: {total_architectures / (len(all_methods) * len(all_results)):.1f}\n")
        
        logger.info(f"📄 统计摘要已保存: {os.path.basename(stats_file)}")
    
    def create_unified_visualization(self, all_results: Dict[str, List[AblationResult]], output_dir: str):
        """创建统一可视化
        
        Args:
            all_results: 所有设备的消融实验结果
            output_dir: 输出目录
        """
        logger = logging.getLogger(__name__)
        logger.info("🎨 创建统一消融实验可视化...")
        
        # 创建统一对比图
        figure_path, metrics_results = self.visualizer.create_unified_comparison_figure(
            all_results, output_dir
        )
        
        logger.info(f"✅ 统一可视化完成")
        logger.info(f"📊 图片路径: {figure_path}")
        
        return figure_path, metrics_results
    
    def run_complete_ablation_study(self, output_dir: Optional[str] = None) -> str:
        """运行完整的消融实验流程
        
        Args:
            output_dir: 输出目录(可选, 默认自动生成)
            
        Returns:
            输出目录路径
        """
        logger = logging.getLogger(__name__)
        
        # 创建输出目录
        if output_dir is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_dir = os.path.join("outputs", f"ablation_study_{self.dataset}_{timestamp}")
        
        os.makedirs(output_dir, exist_ok=True)
        
        logger.info(f"🚀 开始完整消融实验流程")
        logger.info(f"📁 输出目录: {output_dir}")
        logger.info(f"🔧 数据集: {self.dataset}")
        logger.info("=" * 100)
        
        try:
            # 1. 运行所有设备的消融实验
            logger.info("🔬 步骤 1: 运行消融实验...")
            all_results = self.run_all_devices_ablation()
            
            if not all_results:
                logger.error("❌ 没有成功的消融实验结果")
                return output_dir
            
            # 2. 保存结果
            logger.info("\n💾 步骤 2: 保存实验结果...")
            self.save_ablation_results(all_results, output_dir)
            
            # 3. 创建可视化
            logger.info("\n🎨 步骤 3: 创建统一可视化...")
            figure_path, metrics_results = self.create_unified_visualization(all_results, output_dir)
            
            # 4. 最终总结
            logger.info("\n" + "=" * 100)
            logger.info("🎉 完整消融实验流程完成！")
            logger.info(f"📁 所有结果保存在: {output_dir}")
            logger.info(f"📊 主要输出文件:")
            logger.info(f"   • 实验结果: ablation_results_{self.dataset}.json")
            logger.info(f"   • 统计摘要: ablation_summary_{self.dataset}.txt")
            logger.info(f"   • 对比图片: unified_ablation_comparison_{self.dataset}.png")
            logger.info(f"   • 指标报告: ablation_metrics_summary_{self.dataset}.json")
            logger.info("=" * 100)
            
        except Exception as e:
            logger.error(f"❌ 消融实验流程失败: {e}")
            import traceback
            logger.debug("完整错误信息:")
            logger.debug(traceback.format_exc())
        
        return output_dir

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='消融实验主控制器')
    parser.add_argument('--dataset', type=str, default='cifar100',
                       choices=['cifar10', 'cifar100', 'ImageNet16-120'],
                       help='数据集名称')
    parser.add_argument('--iterations', type=int, default=20,
                       help='每个消融方法的轮次数量')
    parser.add_argument('--output-dir', type=str, default=None,
                       help='输出目录（默认自动生成）')
    parser.add_argument('--verbose', action='store_true',
                       help='启用详细日志')
    parser.add_argument('--device', type=str, default=None,
                       choices=['edgegpu', 'raspi4', 'edgetpu', 'pixel3', 'eyeriss', 'fpga'],
                       help='只运行指定设备（默认运行所有设备）')
    
    args = parser.parse_args()
    
    # 设置日志
    setup_logging(args.verbose, args.output_dir)
    
    logger = logging.getLogger(__name__)
    logger.info("🎛️ 消融实验主控制器启动")
    
    try:
        # 初始化控制器
        controller = MainController(args.dataset, None, args.verbose, args.iterations)
        
        if args.device:
            # 单设备模式
            logger.info(f"🎯 单设备模式: {args.device}")
            
            output_dir = args.output_dir or f"outputs/ablation_{args.dataset}_{args.device}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
            os.makedirs(output_dir, exist_ok=True)
            
            device_results = controller.run_single_device_ablation(args.device)
            
            if device_results:
                all_results = {args.device: device_results}
                controller.save_ablation_results(all_results, output_dir)
                controller.create_unified_visualization(all_results, output_dir)
                logger.info(f"✅ 单设备消融实验完成: {output_dir}")
            else:
                logger.error("❌ 单设备消融实验失败")
        else:
            # 全设备模式
            logger.info("🌐 全设备模式")
            output_dir = controller.run_complete_ablation_study(args.output_dir)
            logger.info(f"✅ 完整消融实验完成: {output_dir}")
    
    except KeyboardInterrupt:
        logger.info("\n⚠️ 用户中断")
        sys.exit(0)
    except Exception as e:
        logger.error(f"❌ 程序执行失败: {e}")
        import traceback
        logger.debug(traceback.format_exc())
        sys.exit(1)

if __name__ == "__main__":
    main()
