#!/usr/bin/env python3
"""
Quick test to verify basic DCMN functionality is working
"""

import os
import sys
from pathlib import Path

# Add parent directory to path
sys.path.append(str(Path(__file__).parent))

from neuro_symbolic_planner import NeuroSymbolicPlanner
from dcmn_causal_memory import integrate_dcmn_with_planner

def test_basic_functionality():
    """Test basic functionality with simple tasks"""
    
    print("=== Testing Basic DCMN Functionality ===")
    
    # Get API key
    api_key = os.getenv("GROQ_API_KEY", "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")
    
    try:
        print("\n🔧 Initializing DCMN System...")
        planner = NeuroSymbolicPlanner(api_key)
        planner = integrate_dcmn_with_planner(planner)
        print("   ✓ System initialized")
        
        # Test simple tasks
        simple_tasks = [
            "Pick up the red block",
            "Put down the blue block"
        ]
        
        results = []
        
        for i, task in enumerate(simple_tasks, 1):
            print(f"\n--- Test {i}: {task} ---")
            
            try:
                result = planner.plan_from_natural_language(task, max_iterations=2)
                
                print(f"Success: {result.success}")
                if result.success:
                    print(f"Plan: {result.plan}")
                    print(f"Time: {result.total_time:.2f}s")
                    print(f"Confidence: {result.confidence_score:.2f}")
                    
                results.append({
                    'task': task,
                    'success': result.success,
                    'plan_length': len(result.plan) if result.plan else 0
                })
                
            except Exception as e:
                print(f"Error: {e}")
                results.append({
                    'task': task,
                    'success': False,
                    'error': str(e)
                })
        
        # Summary
        print(f"\n=== SUMMARY ===")
        successful = sum(1 for r in results if r.get('success', False))
        print(f"Successful tests: {successful}/{len(simple_tasks)}")
        
        if successful >= len(simple_tasks) * 0.8:
            print("✅ BASIC FUNCTIONALITY: WORKING")
            return True
        else:
            print("⚠️ BASIC FUNCTIONALITY: NEEDS ATTENTION")
            return False
            
    except Exception as e:
        print(f"❌ System initialization failed: {e}")
        return False

if __name__ == "__main__":
    test_basic_functionality()