#!/usr/bin/env python3
"""
Curriculum-Guided Evolution Example Script

This script demonstrates how to use curriculum-guided evolutionary algorithms to optimize multi-agent systems. Curriculum learning helps the system learn complex tasks better by starting with simple tasks and gradually increasing the difficulty.
"""

import os
import json
import argparse
from curriculum_guided_evolution import CurriculumGuidedEvoOptimizer, CurriculumManager, DifficultyClassifier
from tools.unified_tool import ToolRegistry


def main():
    parser = argparse.ArgumentParser(description="Curriculum-Guided Evolution Example")
    parser.add_argument("--task", type=str, default="Solve mathematical word problems step by step", 
                       help="Task description")
    parser.add_argument("--benchmark", type=str, default="gsm-8k", choices=["gsm-8k", "math"],
                       help="Benchmark dataset")
    parser.add_argument("--population_size", type=int, default=3, help="Population size")
    parser.add_argument("--generations", type=int, default=10, help="Evolution generations")
    parser.add_argument("--llm", type=str, default="gpt-4o-mini", help="LLM model to use")
    parser.add_argument("--enable_curriculum", action="store_true", default=True, 
                       help="Enable curriculum learning")
    parser.add_argument("--output_dir", type=str, default="results", help="Output directory")
    
    args = parser.parse_args()
    
    print("🧬 Start curriculum-guided evolution algorithm")
    print("=" * 60)
    print(f"Task: {args.task}")
    print(f"Benchmark dataset: {args.benchmark}")
    print(f"Population size: {args.population_size}")
    print(f"Evolution generations: {args.generations}")
    print(f"Curriculum learning: {'Enabled' if args.enable_curriculum else 'Disabled'}")
    print("=" * 60)
    
    # 创建结果目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 设置工具注册表
    try:
        tool_registry = ToolRegistry()
        tools = tool_registry.get_all_tools() if hasattr(tool_registry, 'get_all_tools') else None
    except:
        tools = None
        print("⚠️  Tool registry initialization failed, tools will not be used")
    
    # 自定义课程配置
    curriculum_config = {
        "gsm8k": {
            "levels": ["easy", "medium", "hard"],
            "thresholds": [0.6, 0.75, 0.85],  # 调整阈值使其更容易达到
            "sample_sizes": [10, 20, 30],      # 减小样本数量以加快训练
            "min_generations": [2, 3, 4]       # 减少最少训练代数
        },
        "math": {
            "levels": ["Level 1", "Level 2", "Level 3", "Level 4", "Level 5"],
            "thresholds": [0.5, 0.6, 0.65, 0.7, 0.75],
            "sample_sizes": [5, 10, 15, 20, 25],
            "min_generations": [2, 2, 3, 3, 4],
            "subjects": ["prealgebra", "algebra", "geometry", "counting_and_probability"]
        }
    }
    
    # 初始化课程引导的进化优化器
    optimizer = CurriculumGuidedEvoOptimizer(
        task_info=args.task,
        tools=tools,
        population_size=args.population_size,
        llm=args.llm,
        curriculum_config=curriculum_config,
        enable_curriculum=args.enable_curriculum
    )
    
    print(f"\n🎯 Start evolution process...")
    
    if args.enable_curriculum:
        # 使用课程引导进化
        best_individual, curriculum_history = optimizer.evolve_with_curriculum(
            generations=args.generations,
            benchmark_type=args.benchmark
        )
        
        # 分析课程进度
        print("\n📊 Curriculum learning analysis:")
        optimizer.analyze_curriculum_progress()
        
        # 保存课程历史
        with open(f"{args.output_dir}/curriculum_results.json", "w", encoding="utf-8") as f:
            json.dump(curriculum_history, f, ensure_ascii=False, indent=2)
            
        print(f"\n💾 Curriculum history saved to: {args.output_dir}/curriculum_results.json")
        
    else:
        # 使用传统进化算法
        best_individual = optimizer.evolve(
            generations=args.generations,
            benchmark_type=args.benchmark
        )
    
    # 显示最佳结果
    print(f"\n🏆 Evolution completed!")
    print(f"Best individual performance: {best_individual.get('fitness', 'N/A'):.4f}")
    print(f"Best individual description: {best_individual.get('plan', 'N/A')}")
    
    # 保存最佳个体
    with open(f"{args.output_dir}/best_individual.json", "w", encoding="utf-8") as f:
        json.dump(best_individual, f, ensure_ascii=False, indent=2)
    
    # 保存最佳个体的代码
    with open(f"{args.output_dir}/best_flow.py", "w", encoding="utf-8") as f:
        f.write(best_individual["code"])
    
    print(f"💾 Best individual saved to: {args.output_dir}/best_individual.json")
    print(f"💾 Best flow code saved to: {args.output_dir}/best_flow.py")
    
    # 如果启用了课程学习，显示一些统计信息
    if args.enable_curriculum and hasattr(optimizer, 'curriculum_history'):
        print(f"\n📈 Curriculum learning statistics:")
        print(f"    Total curriculum stages: {len(optimizer.curriculum_history)}")
        print(f"    Difficulty transition count: {len(optimizer.difficulty_transitions)}")
        
        if optimizer.difficulty_transitions:
            print(f"    Final difficulty level: {optimizer.difficulty_transitions[-1]['new_level']}")


def compare_with_without_curriculum():
    """Compare the effect of using and not using curriculum learning"""
    print("\n🔬 Start comparing experiments: Curriculum learning vs traditional evolution")
    print("=" * 60)
    
    task = "Solve mathematical reasoning problems accurately"
    benchmark = "gsm-8k"
    generations = 8
    population_size = 3
    
    results = {}
    
    # 自定义课程配置（更快的训练）
    curriculum_config = {
        "gsm8k": {
            "levels": ["easy", "medium", "hard"],
            "thresholds": [0.6, 0.75, 0.85],
            "sample_sizes": [5, 10, 15],
            "min_generations": [2, 2, 3]
        }
    }
    
    # 测试课程学习
    print("\n🎯 Test curriculum-guided evolution...")
    curriculum_optimizer = CurriculumGuidedEvoOptimizer(
        task_info=task,
        population_size=population_size,
        curriculum_config=curriculum_config,
        enable_curriculum=True
    )
    
    best_curriculum, curriculum_history = curriculum_optimizer.evolve_with_curriculum(
        generations=generations,
        benchmark_type=benchmark
    )
    
    results["curriculum"] = {
        "best_fitness": best_curriculum.get("fitness", 0),
        "history": curriculum_history
    }
    
    # 测试传统进化
    print("\n🔄 Test traditional evolution algorithm...")
    traditional_optimizer = CurriculumGuidedEvoOptimizer(
        task_info=task,
        population_size=population_size,
        enable_curriculum=False
    )
    
    best_traditional = traditional_optimizer.evolve(
        generations=generations,
        benchmark_type=benchmark
    )
    
    results["traditional"] = {
        "best_fitness": best_traditional.get("fitness", 0)
    }
    
    # 比较结果
    print(f"\n📊 Comparison results:")
    print(f"    Curriculum learning best performance: {results['curriculum']['best_fitness']:.4f}")
    print(f"    Traditional evolution best performance: {results['traditional']['best_fitness']:.4f}")
    
    improvement = results['curriculum']['best_fitness'] - results['traditional']['best_fitness']
    print(f"    Performance improvement: {improvement:.4f}")
    
    if improvement > 0:
        print("✅ Curriculum learning performs better!")
    elif improvement < 0:
        print("❌ Traditional evolution performs better!")
    else:
        print("⚖️  Both methods perform equally well!")
    
    # 保存比较结果
    with open("results/comparison_results.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    print(f"💾 Comparison results saved to: results/comparison_results.json")


def classify_custom_dataset():
    """Example of using LLM-as-a-judge to classify a custom dataset"""
    print("\n🤖 LLM-based Dataset Classification")
    print("=" * 60)
    
    dataset_path = input("Enter dataset path (default: Benchmark/gsm8k-test.jsonl): ").strip()
    if not dataset_path:
        dataset_path = "Benchmark/gsm8k-test.jsonl"
    
    domain = input("Enter domain (math/gsm8k/general, default: gsm8k): ").strip()
    if not domain:
        domain = "gsm8k"
    
    sample_size = input("Enter sample size (default: 10): ").strip()
    try:
        sample_size = int(sample_size) if sample_size else 10
    except ValueError:
        sample_size = 10
    
    # Initialize classifier
    classifier = DifficultyClassifier()
    
    # Extract dataset name
    dataset_name = os.path.splitext(os.path.basename(dataset_path))[0]
    
    try:
        # Classify dataset
        results = classifier.classify_dataset(
            dataset_path=dataset_path,
            domain=domain,
            sample_size=sample_size,
            output_path=f"results/difficulty_classification_{dataset_name}.json"
        )
        
        # Create curriculum configuration
        curriculum_config = classifier.create_curriculum_config_from_classification(
            results, dataset_name
        )
        
        # Test with curriculum manager
        print(f"\n🧪 Testing with Curriculum Manager:")
        manager = CurriculumManager()
        manager.add_llm_classified_dataset(
            dataset_path=dataset_path,
            dataset_name=dataset_name,
            domain=domain,
            sample_size=sample_size,
            classifier=classifier
        )
        
        # Show current difficulty
        difficulty = manager.get_current_difficulty(dataset_name)
        print(f"✅ Successfully added {dataset_name} to curriculum")
        print(f"   Current level: {difficulty['level']}")
        print(f"   Sample size: {difficulty['sample_size']}")
        print(f"   Threshold: {difficulty['threshold']}")
        
        return True
        
    except Exception as e:
        print(f"❌ Error during classification: {e}")
        return False


if __name__ == "__main__":
    # Ensure results directory exists
    os.makedirs("results", exist_ok=True)
    
    print("Choose running mode:")
    print("1. Single curriculum-guided evolution")
    print("2. Compare curriculum learning vs traditional evolution")
    print("3. LLM-based dataset difficulty classification")
    print("4. Full pipeline: Classify dataset + Curriculum evolution")
    
    choice = input("Enter choice (1-4, default: 1): ").strip()
    
    if choice == "2":
        compare_with_without_curriculum()
    elif choice == "3":
        classify_custom_dataset()
    elif choice == "4":
        print("🚀 Full Pipeline: Dataset Classification + Curriculum Evolution")
        print("=" * 70)
        
        # Step 1: Classify dataset
        if classify_custom_dataset():
            print("\n🎯 Proceeding to curriculum-guided evolution...")
            # Step 2: Run curriculum evolution
            main()
        else:
            print("❌ Classification failed, skipping evolution step")
    else:
        main() 