#!/usr/bin/env python3
"""
Test script for the three new DCMN features:
1. Causal plan explanation
2. Cross-task learning  
3. Confidence-driven planning
"""

import os
import sys
from pathlib import Path

# add parent directory so we can import our modules
sys.path.append(str(Path(__file__).parent))

from neuro_symbolic_planner import NeuroSymbolicPlanner
from dcmn_causal_memory import integrate_dcmn_with_planner

def test_new_features():
    """Test all three new features"""
    
    print("=" * 70)
    print("🧪 TESTING NEW DCMN FEATURES")
    print("=" * 70)
    
    # get the api key we need for the language model
    api_key = os.getenv("GROQ_API_KEY", "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")
    
    try:
        print("\n🔧 Initializing DCMN System...")
        planner = NeuroSymbolicPlanner(api_key)
        planner = integrate_dcmn_with_planner(planner)
        print("   ✓ DCMN system initialized with all components")
        
        # different test cases to show off what our system can do
        test_cases = [
            {
                "name": "Simple Task (High Confidence)",
                "task": "Pick up the red block",
                "expected_features": ["fast strategy", "explanations"]
            },
            {
                "name": "Medium Task (Standard Strategy)", 
                "task": "Move the blue block from the table to the red block",
                "expected_features": ["standard strategy", "explanations"]
            },
            {
                "name": "Complex Task (Cautious Strategy)",
                "task": "Stack the red block on blue, then coordinate the green block placement after arranging the workspace",
                "expected_features": ["cautious strategy", "explanations", "extra validation"]
            },
            {
                "name": "Similar Task (Cross-Task Learning)",
                "task": "Pick up the blue block", # similar to first task to test learning
                "expected_features": ["cross-task learning", "explanations"]
            }
        ]
        
        results = []
        
        for i, test_case in enumerate(test_cases, 1):
            print(f"\n{'='*50}")
            print(f"🧪 Test {i}: {test_case['name']}")
            print(f"📋 Task: {test_case['task']}")
            print(f"🎯 Expected Features: {', '.join(test_case['expected_features'])}")
            print(f"{'='*50}")
            
            try:
                # use our improved system to make a plan
                result = planner.plan_from_natural_language(test_case['task'], max_iterations=3)
                
                print(f"\n📊 RESULTS:")
                print(f"   ✅ Success: {result.success}")
                if result.success:
                    print(f"   🎯 Plan: {result.plan}")
                    print(f"   ⏱️  Time: {result.total_time:.2f}s")
                    print(f"   🔄 Iterations: {result.iterations}")
                    print(f"   📈 Confidence: {result.confidence_score:.2f}")
                    
                    # check if we got explanations for our plan
                    if result.explanations:
                        print(f"\n💬 PLAN EXPLANATIONS:")
                        for j, explanation in enumerate(result.explanations):
                            print(f"   {j+1}. {explanation}")
                    else:
                        print(f"   ⚠️  No explanations generated")
                    
                    results.append({
                        'test': test_case['name'],
                        'success': True,
                        'plan': result.plan,
                        'explanations': len(result.explanations) if result.explanations else 0,
                        'confidence': result.confidence_score,
                        'iterations': result.iterations
                    })
                else:
                    print(f"   ❌ Planning failed after {result.iterations} iterations")
                    results.append({
                        'test': test_case['name'],
                        'success': False,
                        'plan': [],
                        'explanations': 0,
                        'confidence': 0.0,
                        'iterations': result.iterations
                    })
                    
            except Exception as e:
                print(f"   ❌ Error: {str(e)}")
                results.append({
                    'test': test_case['name'],
                    'success': False,
                    'error': str(e)
                })
        
        # summarize how well our tests went
        print(f"\n{'='*70}")
        print("📈 FEATURE TESTING SUMMARY")
        print(f"{'='*70}")
        
        successful_tests = sum(1 for r in results if r.get('success', False))
        total_explanations = sum(r.get('explanations', 0) for r in results)
        
        print(f"✅ Successful tests: {successful_tests}/{len(test_cases)}")
        print(f"💬 Total explanations generated: {total_explanations}")
        print(f"📊 Average confidence: {sum(r.get('confidence', 0) for r in results) / len(results):.2f}")
        
        # check if each feature is actually working
        print(f"\n🔍 FEATURE VALIDATION:")
        print(f"   ✅ Plan Explanations: {'Working' if total_explanations > 0 else 'Not working'}")
        print(f"   ✅ Confidence Assessment: {'Working' if any('confidence' in r for r in results) else 'Not working'}")
        print(f"   ✅ Strategy Adaptation: {'Working' if successful_tests > 0 else 'Not working'}")
        print(f"   ✅ Cross-Task Learning: {'Available' if hasattr(planner, 'causal_memory') else 'Not available'}")
        
        if successful_tests >= len(test_cases) * 0.5:  # need at least half to pass
            print(f"\n🎉 NEW FEATURES TEST: PASSED")
            print(f"   All three features are operational and ready for use!")
        else:
            print(f"\n⚠️  NEW FEATURES TEST: NEEDS ATTENTION")
            print(f"   Some features may need debugging.")
        
        return results
        
    except Exception as e:
        print(f"❌ System initialization failed: {e}")
        return []

if __name__ == "__main__":
    test_new_features()