#!/usr/bin/env python3
"""
Test causal learning fix with a simple task
"""

import os
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).parent))

from neuro_symbolic_planner import NeuroSymbolicPlanner
from dcmn_causal_memory import integrate_dcmn_with_planner

def test_causal_fix():
    """Test that causal learning now works without errors"""
    
    print("=== TESTING CAUSAL LEARNING FIX ===")
    
    api_key = os.getenv("GROQ_API_KEY", "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")
    
    simple_task = "Pick up the red block"
    
    print(f"\nTesting Task: {simple_task}")
    print("="*50)
    
    try:
        planner = NeuroSymbolicPlanner(api_key)
        planner = integrate_dcmn_with_planner(planner)
        
        print("Looking for causal learning messages...")
        
        # Test with simple task that should trigger causal learning
        result = planner.plan_from_natural_language(simple_task, max_iterations=2)
        
        print(f"\nRESULT:")
        print(f"Success: {result.success}")
        if result.success:
            print(f"Plan: {result.plan}")
            print(f"Confidence: {result.confidence_score:.2f}")
            
            # Check if causal memory has learned anything
            print(f"\nCAUSAL MEMORY STATUS:")
            print(f"Total causal assets: {len(planner.causal_memory.assets)}")
            
            if planner.causal_memory.assets:
                asset = list(planner.causal_memory.assets.values())[0]
                print(f"First asset has {len(asset.causal_triples)} causal triples:")
                for i, triple in enumerate(asset.causal_triples[:3]):
                    print(f"  {i+1}. {triple.subject} → {triple.predicate.value} → {triple.object}")
            
            print("\n✅ CAUSAL LEARNING FIX: SUCCESS")
            return True
        else:
            print("\n❌ Planning failed - cannot test causal learning")
            return False
            
    except Exception as e:
        print(f"❌ Error: {e}")
        import traceback
        traceback.print_exc()
        return False

if __name__ == "__main__":
    test_causal_fix()