#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Main entry point for Multi-Category Convolution-based NAS Search
支持单设备运行和多设备消融实验对比
"""

import os
import sys
import json
import logging
import argparse
from datetime import datetime
from pel_nas.search.nas_searcher import MultiCategoryNASSearcher
from pel_nas.core.config import DATA_CONFIG
from pel_nas.core.main_controller import MainController

def setup_logging(verbose: bool = False, output_dir: str = None):
    """Setup logging configuration"""
    level = logging.DEBUG if verbose else logging.INFO
    
    # Create outputs directory if not exists
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        log_file = os.path.join(output_dir, 'search.log')
    else:
        os.makedirs('outputs', exist_ok=True)
        log_file = 'outputs/search.log'
    
    # Configure logging
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(log_file, mode='w')  # Use 'w' to overwrite each run
        ]
    )

def check_requirements():
    """Check if all requirements are met"""
    logger = logging.getLogger(__name__)
    
    # Check OpenAI API key
    if not os.getenv('OPENAI_API_KEY'):
        logger.error("❌ OPENAI_API_KEY environment variable not set")
        logger.info("💡 Please set it with: export OPENAI_API_KEY='your-api-key'")
        return False
    
    # Check data files
    nas_bench_file = 'NAS-Bench-201-v1_1-096897.pth'
    hw_nas_bench_file = 'HW-NAS-Bench/HW-NAS-Bench-v1_0.pickle'
    
    if not os.path.exists(nas_bench_file):
        logger.warning(f"⚠️  NAS-Bench-201 file not found: {nas_bench_file}")
        logger.info("   The search will continue with available data")
    
    if not os.path.exists(hw_nas_bench_file):
        logger.warning(f"⚠️  HW-NAS-Bench file not found: {hw_nas_bench_file}")
        logger.info("   The search will continue without hardware metrics")
    
    # Check if prompts directory exists
    if not os.path.exists('pel_nas/llm/prompts'):
        logger.warning("⚠️  Prompts directory not found")
        logger.info("   Creating prompts directory...")
        os.makedirs('pel_nas/llm/prompts', exist_ok=True)

    metrics_csv = DATA_CONFIG.get('metrics_csv')
    if not metrics_csv or not os.path.exists(metrics_csv):
        logger.error(f"❌ Metrics CSV missing: {metrics_csv}")
        return False

    zc_features = DATA_CONFIG.get('zero_cost_features')
    if zc_features and not os.path.exists(zc_features):
        logger.warning(f"⚠️ Zero-cost feature file missing: {zc_features}")
        logger.warning("   Predictor mode will fall back to ground-truth accuracies")

    return True

def main():
    """Main function"""
    parser = argparse.ArgumentParser(description='Multi-Category Convolution-based Neural Architecture Search')
    
    # 添加运行模式选择
    parser.add_argument('--mode', type=str, default='multi_device', 
                       choices=['single', 'multi_device', 'ablation'],
                       help='运行模式: single=单设备, multi_device=单数据集全设备, ablation=多设备消融')
    
    parser.add_argument('--iterations', type=int, default=10, 
                       help='Maximum number of search iterations (for all modes)')
    parser.add_argument('--verbose', action='store_true',
                       help='Enable verbose logging')
    parser.add_argument('--dataset', type=str, default='cifar10',
                       choices=['cifar10', 'cifar100', 'ImageNet16-120'],
                       help='Dataset to use')
    parser.add_argument('--hardware-device', type=str, default='fpga',
                       choices=['edgegpu', 'raspi4', 'edgetpu', 'pixel3', 'eyeriss', 'fpga'],
                       help='Hardware device to use (single mode only)')
    parser.add_argument('--output-dir', type=str, default=None,
                       help='Output directory (default: auto-generated timestamp)')
    parser.add_argument('--use-predictor', action='store_true',
                       help='Enable zero-cost predictor for accuracy estimation')
    parser.add_argument('--devices', type=str, default=None,
                       help='Comma-separated device list for multi_device mode (default: all devices)')
    
    
    args = parser.parse_args()
    
    # Setup logging with output directory
    setup_logging(args.verbose, args.output_dir)
    logger = logging.getLogger(__name__)
    
    # Check requirements
    if not check_requirements():
        logger.error("❌ Requirements check failed")
        sys.exit(1)
    
    try:
        if args.mode == 'single':
            # 单设备模式 - 原有逻辑
            run_single_device_mode(args, logger)
        elif args.mode == 'multi_device':
            run_multi_device_mode(args, logger)
        elif args.mode == 'ablation':
            # 消融实验模式 - 新逻辑
            run_ablation_mode(args, logger)
        else:
            logger.error(f"❌ 未知运行模式: {args.mode}")
            sys.exit(1)
            
    except KeyboardInterrupt:
        logger.info("\n⚠️ 用户中断")
        sys.exit(0)
        
    except Exception as e:
        logger.error(f"❌ 程序执行失败: {e}")
        import traceback
        logger.debug("完整错误信息:")
        logger.debug(traceback.format_exc())
        sys.exit(1)

def run_single_device_mode(args, logger):
    """运行单设备模式"""
    use_predictor = args.use_predictor
    
    # Create timestamped output directory if not provided
    if args.output_dir:
        output_dir = args.output_dir
    else:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join("outputs", f"single_device_{args.dataset}_{args.hardware_device}_{timestamp}")
    
    logger.info("🚀 Starting Single Device Multi-Category NAS Search")
    logger.info(f"📁 Output directory: {output_dir}")
    logger.info(f"🔧 Configuration:")
    logger.info(f"   Mode: Single Device")
    logger.info(f"   Dataset: {args.dataset}")
    logger.info(f"   Hardware Device: {args.hardware_device}")
    logger.info(f"   Max Iterations: {args.iterations}")
    logger.info("=" * 80)
    
    # Initialize and run searcher
    logger.info("🔧 Initializing multi-category searcher...")
    searcher = MultiCategoryNASSearcher(
        dataset=args.dataset,
        hardware_device=args.hardware_device.replace('-', '_'),  # Convert kebab-case to snake_case
        use_predictor=use_predictor
    )
    
    logger.info(f"🎯 Starting search with {args.iterations} iterations...")
    results = searcher.run_search(max_iterations=args.iterations, output_dir=output_dir)
    
    # Print summary
    logger.info("\n🎉 Single device search completed successfully!")
    logger.info("📊 Summary:")
    
    search_config = results['search_config']
    logger.info(f"   Total architectures generated: {search_config['total_architectures']}")
    logger.info(f"   Search iterations: {search_config['max_iterations']}")
    
    logger.info("\n📈 Final Category Results:")
    total_pareto_count = 0
    for category, stats in results['category_statistics'].items():
        pareto_count = stats['pareto_count']
        best_acc = stats['best_accuracy']
        avg_acc = stats['avg_accuracy']
        total_pareto_count += pareto_count
        
        logger.info(f"   {category:15s}: {pareto_count:2d} Pareto architectures, "
                  f"Best: {best_acc:.2f}%, Avg: {avg_acc:.2f}%")
    
    logger.info(f"\n🏆 Total Pareto Front Size: {total_pareto_count} architectures")
    
    # LLM statistics
    llm_stats = results['llm_stats']
    logger.info(f"\n🤖 LLM Client Statistics:")
    logger.info(f"   Total requests: {llm_stats['request_count']}")
    logger.info(f"   Success rate: {llm_stats['success_rate']:.1f}%")
    logger.info(f"   Categories processed: {llm_stats['categories_loaded']}")
    
    logger.info(f"\n💾 All results saved to {output_dir}")
    logger.info("✅ Single device search completed successfully!")


def run_multi_device_mode(args, logger):
    """运行单数据集多设备模式"""

    default_devices = ['edgegpu', 'raspi4', 'edgetpu', 'pixel3', 'eyeriss', 'fpga']
    if args.devices:
        devices = [item.strip() for item in args.devices.split(',') if item.strip()]
    else:
        devices = default_devices

    if not devices:
        logger.error("❌ No devices specified for multi_device mode")
        sys.exit(1)

    # Validate devices
    invalid_devices = [d for d in devices if d not in default_devices]
    if invalid_devices:
        logger.error(f"❌ Invalid device(s): {', '.join(invalid_devices)}")
        logger.info(f"   Supported devices: {', '.join(default_devices)}")
        sys.exit(1)

    # Create timestamped output directory if not provided
    if args.output_dir:
        output_dir = args.output_dir
    else:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join("outputs", f"multi_device_{args.dataset}_{timestamp}")

    os.makedirs(output_dir, exist_ok=True)

    logger.info("🚀 Starting Multi-Device NAS Search")
    logger.info(f"📁 Output directory: {output_dir}")
    logger.info(f"🔧 Configuration:")
    logger.info(f"   Mode: Multi Device")
    logger.info(f"   Dataset: {args.dataset}")
    logger.info(f"   Devices: {', '.join(devices)}")
    logger.info(f"   Max Iterations: {args.iterations}")
    logger.info(f"   Use Predictor: {args.use_predictor}")
    logger.info("=" * 80)

    aggregated_results = {
        'dataset': args.dataset,
        'iterations': args.iterations,
        'use_predictor': args.use_predictor,
        'devices': devices,
        'runs': [],
    }

    for idx, device in enumerate(devices, 1):
        logger.info(f"\n{'='*80}")
        logger.info(f"📱 Device progress: {idx}/{len(devices)} - {device.upper()}")
        logger.info(f"{'='*80}")

        device_output_dir = os.path.join(output_dir, device)
        os.makedirs(device_output_dir, exist_ok=True)

        try:
            searcher = MultiCategoryNASSearcher(
                dataset=args.dataset,
                hardware_device=device,
                use_predictor=args.use_predictor,
            )

            device_results = searcher.run_search(
                max_iterations=args.iterations,
                output_dir=device_output_dir,
            )

            search_config = device_results.get('search_config', {})
            category_stats = device_results.get('category_statistics', {})
            llm_stats = device_results.get('llm_stats', {})

            total_pareto = sum(stat.get('pareto_count', 0) for stat in category_stats.values())

            aggregated_results['runs'].append({
                'device': device,
                'output_dir': device_output_dir,
                'search_config': search_config,
                'total_pareto': total_pareto,
                'llm_stats': llm_stats,
            })

            logger.info(f"✅ Device {device} completed: {search_config.get('total_architectures', 0)} architectures")
            logger.info(f"   Pareto front size: {total_pareto}")
            if llm_stats:
                logger.info(f"   LLM success rate: {llm_stats.get('success_rate', 0):.1f}%")

        except Exception as exc:
            logger.error(f"❌ Device {device} search failed: {exc}")
            continue

    summary_path = os.path.join(output_dir, 'multi_device_summary.json')
    with open(summary_path, 'w') as handle:
        json.dump(aggregated_results, handle, indent=2)

    logger.info("\n🎉 Multi-device search completed!")
    logger.info(f"📊 Successful devices: {len(aggregated_results['runs'])}/{len(devices)}")
    logger.info(f"💾 Summary saved to {summary_path}")

def run_ablation_mode(args, logger):
    """运行消融实验模式"""
    # Create timestamped output directory if not provided
    if args.output_dir:
        output_dir = args.output_dir
    else:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join("outputs", f"ablation_study_{args.dataset}_{timestamp}")
    
    logger.info("🧪 Starting Multi-Device Ablation Study")
    logger.info(f"📁 Output directory: {output_dir}")
    logger.info(f"🔧 Configuration:")
    logger.info(f"   Mode: Ablation Study")
    logger.info(f"   Dataset: {args.dataset}")
    logger.info(f"   Iterations per method: {args.iterations}")
    logger.info(f"   Devices: All 6 devices (edgegpu, raspi4, edgetpu, pixel3, eyeriss, fpga)")
    logger.info("=" * 80)
    
    # Initialize and run controller
    logger.info("🔧 Initializing ablation study controller...")
    controller = MainController(
        dataset=args.dataset,
        devices=None,               # run all default devices
        verbose=args.verbose,
        iterations=args.iterations,
    )
    
    logger.info("🎯 Starting complete ablation study...")
    final_output_dir = controller.run_complete_ablation_study(output_dir)
    
    logger.info(f"\n🎉 Ablation study completed successfully!")
    logger.info(f"💾 All results saved to {final_output_dir}")
    logger.info("✅ Multi-device ablation study completed successfully!")

if __name__ == "__main__":
    main() 
