"""
Decentralized Causal Memory Networks (DCMN) Implementation
Extending the Neuro-Symbolic Planner with causal knowledge graphs and multi-agent validation

Key Components:
1. CausalKnowledgeAsset - Cryptographically-verified causal claims
2. CausalGraphMemory - Graph-based memory with causal relationships
3. MultiAgentValidator - Consensus-based validation system
4. DomainParanet - Self-organizing domain expert networks
5. GraphNeuralRetriever - GNN-based causal knowledge retrieval
"""

import hashlib
import json
import time
from typing import List, Dict, Tuple, Optional, Set, Any
from dataclasses import dataclass, field
from collections import defaultdict, deque
from enum import Enum
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.data import Data, Batch
import rdflib
from rdflib import Graph, Literal, RDF, URIRef, Namespace
import pickle
import logging

# set up logging to see what's happening
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# define namespaces for our rdf knowledge graph
DCMN = Namespace("http://dcmn.ai/ontology/")
PROV = Namespace("http://www.w3.org/ns/prov#")

class CausalRelationType(Enum):
    """Types of causal relationships in planning"""
    ENABLES = "enables"  # when action a makes action b possible
    REQUIRES = "requires"  # when action a needs condition c to work
    PRODUCES = "produces"  # when action a creates effect e
    PREVENTS = "prevents"  # when action a stops action b from happening
    MODIFIES = "modifies"  # when action a changes state s

@dataclass
class CausalTriple:
    """Represents a causal relationship (subject, predicate, object)"""
    subject: str  # e.g., "move(block_a, table)"
    predicate: CausalRelationType
    object: str  # e.g., "clear(block_a)"
    confidence: float = 1.0
    evidence_count: int = 1
    
    def to_rdf(self) -> Tuple[URIRef, URIRef, URIRef]:
        """Convert to RDF triple"""
        return (
            URIRef(DCMN + self.subject.replace(" ", "_")),
            URIRef(DCMN + self.predicate.value),
            URIRef(DCMN + self.object.replace(" ", "_"))
        )
    
    def __hash__(self):
        return hash((self.subject, self.predicate.value, self.object))

@dataclass
class CausalKnowledgeAsset:
    """
    Cryptographically-verified causal claim that can be owned, shared, and validated
    """
    asset_id: str
    causal_triples: List[CausalTriple]
    plan_trace: List[str]  # the sequence of actions that taught us this
    domain: str
    task_description: str
    success_rate: float
    creation_time: float
    creator_id: str
    validators: Set[str] = field(default_factory=set)
    validation_scores: Dict[str, float] = field(default_factory=dict)
    version: int = 1
    parent_asset_id: Optional[str] = None
    
    def __post_init__(self):
        if not self.asset_id:
            self.asset_id = self.generate_hash()
    
    def generate_hash(self) -> str:
        """Generate cryptographic hash for the asset"""
        content = {
            'triples': [(t.subject, t.predicate.value, t.object) for t in self.causal_triples],
            'plan_trace': self.plan_trace,
            'domain': self.domain,
            'task': self.task_description,
            'creator': self.creator_id,
            'time': self.creation_time
        }
        content_str = json.dumps(content, sort_keys=True)
        return hashlib.sha256(content_str.encode()).hexdigest()
    
    def add_validation(self, validator_id: str, score: float):
        """Add a validation score from an agent"""
        self.validators.add(validator_id)
        self.validation_scores[validator_id] = score
    
    def get_consensus_score(self) -> float:
        """Calculate weighted consensus score"""
        if not self.validation_scores:
            return self.success_rate
        return np.mean(list(self.validation_scores.values()))
    
    def to_rdf_graph(self) -> Graph:
        """Convert asset to RDF graph"""
        g = Graph()
        asset_uri = URIRef(DCMN + f"asset/{self.asset_id}")
        
        # basic info about this knowledge asset
        g.add((asset_uri, RDF.type, DCMN.CausalKnowledgeAsset))
        g.add((asset_uri, DCMN.domain, Literal(self.domain)))
        g.add((asset_uri, DCMN.task, Literal(self.task_description)))
        g.add((asset_uri, DCMN.successRate, Literal(self.success_rate)))
        g.add((asset_uri, PROV.wasGeneratedBy, URIRef(DCMN + f"agent/{self.creator_id}")))
        g.add((asset_uri, PROV.generatedAtTime, Literal(self.creation_time)))
        
        # add the cause-effect relationships we learned
        for triple in self.causal_triples:
            subj, pred, obj = triple.to_rdf()
            g.add((subj, pred, obj))
            # connect this relationship to our knowledge asset
            triple_node = URIRef(DCMN + f"triple/{hash(triple)}")
            g.add((asset_uri, DCMN.contains, triple_node))
            g.add((triple_node, DCMN.confidence, Literal(triple.confidence)))
        
        return g

class CausalGraphMemory(nn.Module):
    """
    Graph-based memory system using Graph Neural Networks for causal knowledge storage and retrieval
    """
    
    def __init__(self, embedding_dim=256, hidden_dim=128, num_heads=4):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.assets: Dict[str, CausalKnowledgeAsset] = {}
        self.causal_graph = Graph()  # our knowledge graph
        self.triple_index: Dict[CausalTriple, List[str]] = defaultdict(list)  # maps relationships to assets
        
        # neural network parts for understanding the graph
        self.node_encoder = nn.Linear(embedding_dim, hidden_dim)
        self.gat_layers = nn.ModuleList([
            GATConv(hidden_dim if i == 0 else hidden_dim * num_heads, 
                    hidden_dim, 
                    heads=num_heads, 
                    dropout=0.2,
                    concat=True if i < 2 else False)
            for i in range(3)
        ])
        self.edge_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # network that figures out what kind of relationship this is
        self.relation_classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, len(CausalRelationType)),
            nn.Softmax(dim=-1)
        )
    
    def store_asset(self, asset: CausalKnowledgeAsset):
        """Store a causal knowledge asset in the graph"""
        self.assets[asset.asset_id] = asset
        
        # put this relationship into our knowledge graph
        asset_graph = asset.to_rdf_graph()
        self.causal_graph += asset_graph
        
        # update our index so we can find this later
        for triple in asset.causal_triples:
            self.triple_index[triple].append(asset.asset_id)
        
        logger.info(f"Stored causal asset {asset.asset_id} with {len(asset.causal_triples)} triples")
    
    def retrieve_by_task(self, task_description: str, task_embedding: torch.Tensor, k: int = 5) -> List[CausalKnowledgeAsset]:
        """Retrieve relevant causal knowledge for a task using GNN"""
        if not self.assets:
            return []
        
        # turn our knowledge into a format the neural network can use
        node_features = []
        edge_index = []
        node_to_idx = {}
        
        # make nodes for each unique thing in our graph
        entities = set()
        for asset in self.assets.values():
            for triple in asset.causal_triples:
                entities.add(triple.subject)
                entities.add(triple.object)
        
        for i, entity in enumerate(entities):
            node_to_idx[entity] = i
            # create vector representations with the right size
            # In practice, use a text encoder like BERT
            entity_embedding = torch.zeros(self.embedding_dim)
            # Simple hash-based initialization
            entity_hash = hash(entity) % 1000
            entity_embedding[entity_hash % self.embedding_dim] = 1.0
            entity_embedding += torch.randn(self.embedding_dim) * 0.1
            node_features.append(entity_embedding)
        
        # Create edges from causal relationships
        for asset in self.assets.values():
            for triple in asset.causal_triples:
                src_idx = node_to_idx[triple.subject]
                dst_idx = node_to_idx[triple.object]
                edge_index.append([src_idx, dst_idx])
        
        if not edge_index:
            return []
        
        # Convert to PyTorch Geometric format
        x = torch.stack(node_features)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        
        # Apply GNN layers
        h = self.node_encoder(x)
        for i, gat_layer in enumerate(self.gat_layers):
            h = gat_layer(h, edge_index)
            if i < len(self.gat_layers) - 1:  # Don't apply ReLU to last layer
                h = F.relu(h)
        
        # Score assets based on graph embeddings and task similarity
        asset_scores = []
        for asset_id, asset in self.assets.items():
            # Calculate relevance score using graph embeddings
            relevance = self._calculate_relevance(asset, h, node_to_idx, task_embedding)
            asset_scores.append((relevance, asset))
        
        # Return top-k assets
        asset_scores.sort(key=lambda x: x[0], reverse=True)
        return [asset for _, asset in asset_scores[:k]]
    
    def _calculate_relevance(self, asset: CausalKnowledgeAsset, 
                           node_embeddings: torch.Tensor, 
                           node_to_idx: Dict[str, int],
                           task_embedding: torch.Tensor) -> float:
        """Calculate relevance of an asset to a task"""
        # Get embeddings for entities in this asset
        asset_embeddings = []
        for triple in asset.causal_triples:
            if triple.subject in node_to_idx:
                asset_embeddings.append(node_embeddings[node_to_idx[triple.subject]])
            if triple.object in node_to_idx:
                asset_embeddings.append(node_embeddings[node_to_idx[triple.object]])
        
        if not asset_embeddings:
            return 0.0
        
        # Average pooling of asset entity embeddings
        asset_embedding = torch.stack(asset_embeddings).mean(dim=0)
        
        # Ensure both embeddings have the same dimension
        if asset_embedding.size(0) != task_embedding.size(0):
            # Project asset embedding to match task embedding dimension
            if asset_embedding.size(0) < task_embedding.size(0):
                # Pad with zeros
                padding = torch.zeros(task_embedding.size(0) - asset_embedding.size(0))
                asset_embedding = torch.cat([asset_embedding, padding])
            else:
                # Truncate
                asset_embedding = asset_embedding[:task_embedding.size(0)]
        
        # Cosine similarity with task
        similarity = F.cosine_similarity(asset_embedding.unsqueeze(0), 
                                       task_embedding.unsqueeze(0)).item()
        
        # Weight by consensus score
        consensus = asset.get_consensus_score()
        
        return similarity * consensus
    
    def find_similar_tasks(self, current_task: str, similarity_threshold: float = 0.7) -> Optional[CausalKnowledgeAsset]:
        """
        Find similar tasks that were solved before for cross-task learning
        """
        if not self.assets:
            return None
        
        # Simple text similarity for task descriptions
        current_words = set(current_task.lower().split())
        
        best_asset = None
        best_similarity = 0.0
        
        for asset in self.assets.values():
            # Calculate word overlap similarity
            asset_words = set(asset.task_description.lower().split())
            
            if not asset_words or not current_words:
                continue
                
            # Jaccard similarity
            intersection = len(current_words.intersection(asset_words))
            union = len(current_words.union(asset_words))
            similarity = intersection / union if union > 0 else 0
            
            # Weight by success rate and consensus
            weighted_similarity = similarity * asset.success_rate * asset.get_consensus_score()
            
            if weighted_similarity > best_similarity and weighted_similarity > similarity_threshold:
                best_similarity = weighted_similarity
                best_asset = asset
        
        if best_asset:
            logger.info(f"Found similar task: '{best_asset.task_description}' (similarity: {best_similarity:.2f})")
            return best_asset
        
        return None
    
    def get_successful_patterns(self, domain: str = None) -> List[CausalKnowledgeAsset]:
        """
        Get successful planning patterns, optionally filtered by domain
        """
        successful_assets = []
        
        for asset in self.assets.values():
            # Filter by domain if specified
            if domain and asset.domain != domain:
                continue
                
            # Only include assets with high success rate and good consensus
            if asset.success_rate > 0.8 and asset.get_consensus_score() > 0.7:
                successful_assets.append(asset)
        
        # Sort by combined success rate and consensus score
        successful_assets.sort(
            key=lambda a: a.success_rate * a.get_consensus_score(), 
            reverse=True
        )
        
        return successful_assets
    
    def discover_causal_links(self, action1: str, action2: str) -> Tuple[CausalRelationType, float]:
        """Use GNN to predict causal relationship between actions"""
        # In practice, encode actions properly
        embedding1 = torch.randn(self.hidden_dim)
        embedding2 = torch.randn(self.hidden_dim)
        
        # Concatenate embeddings
        combined = torch.cat([embedding1, embedding2])
        
        # Predict relation type
        relation_probs = self.relation_classifier(combined)
        confidence = relation_probs.max().item()
        relation_idx = relation_probs.argmax().item()
        relation_type = list(CausalRelationType)[relation_idx]
        
        return relation_type, confidence

class MultiAgentValidator:
    """
    Consensus-based validation system for causal knowledge
    """
    
    def __init__(self, min_validators: int = 3, consensus_threshold: float = 0.7):
        self.min_validators = min_validators
        self.consensus_threshold = consensus_threshold
        self.agent_reputation: Dict[str, float] = defaultdict(lambda: 1.0)
        self.validation_history: List[Dict] = []
    
    def validate_asset(self, asset: CausalKnowledgeAsset, 
                      validators: List[Tuple[str, float]]) -> Tuple[bool, float]:
        """
        Validate a causal asset through multi-agent consensus
        
        Args:
            asset: The causal knowledge asset to validate
            validators: List of (agent_id, validation_score) tuples
            
        Returns:
            (is_valid, consensus_score)
        """
        if len(validators) < self.min_validators:
            return False, 0.0
        
        # Weight scores by agent reputation
        weighted_scores = []
        total_weight = 0
        
        for agent_id, score in validators:
            reputation = self.agent_reputation[agent_id]
            weighted_scores.append(score * reputation)
            total_weight += reputation
        
        # Calculate weighted consensus
        consensus_score = sum(weighted_scores) / total_weight if total_weight > 0 else 0
        
        # Update asset with validations
        for agent_id, score in validators:
            asset.add_validation(agent_id, score)
        
        # Record validation event
        self.validation_history.append({
            'asset_id': asset.asset_id,
            'validators': validators,
            'consensus_score': consensus_score,
            'timestamp': time.time()
        })
        
        # Update reputations based on consensus
        self._update_reputations(validators, consensus_score)
        
        return consensus_score >= self.consensus_threshold, consensus_score
    
    def _update_reputations(self, validators: List[Tuple[str, float]], consensus: float):
        """Update agent reputations based on alignment with consensus"""
        for agent_id, score in validators:
            # Agents closer to consensus get reputation boost
            alignment = 1 - abs(score - consensus)
            reputation_delta = 0.1 * alignment - 0.05
            
            # Update with decay
            self.agent_reputation[agent_id] = max(0.1, min(2.0, 
                self.agent_reputation[agent_id] + reputation_delta))

class DomainParanet:
    """
    Self-organizing network of agents specializing in specific planning domains
    """
    
    def __init__(self, domain: str, min_expertise: float = 0.7):
        self.domain = domain
        self.min_expertise = min_expertise
        self.members: Dict[str, float] = {}  # agent_id -> expertise_score
        self.shared_knowledge: List[CausalKnowledgeAsset] = []
        self.performance_history: Dict[str, List[float]] = defaultdict(list)
    
    def add_member(self, agent_id: str, initial_expertise: float = 0.5):
        """Add an agent to the paranet"""
        if initial_expertise >= self.min_expertise:
            self.members[agent_id] = initial_expertise
            logger.info(f"Agent {agent_id} joined {self.domain} paranet")
    
    def share_knowledge(self, asset: CausalKnowledgeAsset, contributor_id: str):
        """Share causal knowledge within the paranet"""
        if contributor_id in self.members and asset.domain == self.domain:
            self.shared_knowledge.append(asset)
            # Boost contributor expertise
            self.update_expertise(contributor_id, 1.0)
    
    def update_expertise(self, agent_id: str, performance: float):
        """Update agent expertise based on performance"""
        if agent_id in self.members:
            self.performance_history[agent_id].append(performance)
            # Calculate moving average of recent performance
            recent_performance = self.performance_history[agent_id][-10:]
            new_expertise = np.mean(recent_performance)
            self.members[agent_id] = new_expertise
            
            # Remove if expertise drops below threshold
            if new_expertise < self.min_expertise:
                del self.members[agent_id]
                logger.info(f"Agent {agent_id} removed from {self.domain} paranet")
    
    def get_experts(self, k: int = 3) -> List[str]:
        """Get top-k experts in this domain"""
        sorted_members = sorted(self.members.items(), key=lambda x: x[1], reverse=True)
        return [agent_id for agent_id, _ in sorted_members[:k]]
    
    def collective_validation(self, asset: CausalKnowledgeAsset) -> List[Tuple[str, float]]:
        """Get validation scores from domain experts"""
        validations = []
        for expert_id in self.get_experts(5):
            # Simulate validation - in practice, query actual agents
            score = np.random.beta(8, 2)  # Biased towards high scores for experts
            validations.append((expert_id, score))
        return validations

class GraphNeuralRetriever:
    """
    GNN-based retrieval system for causal knowledge graphs
    """
    
    def __init__(self, memory: CausalGraphMemory):
        self.memory = memory
        self.query_encoder = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )
    
    def retrieve(self, query: str, k: int = 5) -> List[CausalKnowledgeAsset]:
        """Retrieve relevant causal knowledge using graph neural networks"""
        # Encode query (simplified - use proper text encoder in practice)
        query_embedding = torch.randn(256)
        query_encoded = self.query_encoder(query_embedding)
        
        # Use memory's GNN-based retrieval
        return self.memory.retrieve_by_task(query, query_encoded, k)
    
    def find_causal_paths(self, start_state: str, goal_state: str) -> List[List[CausalTriple]]:
        """Find causal paths between states using graph traversal"""
        paths = []
        
        # Build adjacency list from causal triples
        graph = defaultdict(list)
        triple_map = {}
        
        for asset in self.memory.assets.values():
            for triple in asset.causal_triples:
                graph[triple.subject].append(triple.object)
                triple_map[(triple.subject, triple.object)] = triple
        
        # BFS to find paths
        queue = deque([(start_state, [start_state])])
        visited = {start_state}
        
        while queue and len(paths) < 5:
            current, path = queue.popleft()
            
            if current == goal_state:
                # Reconstruct causal path
                causal_path = []
                for i in range(len(path) - 1):
                    if (path[i], path[i+1]) in triple_map:
                        causal_path.append(triple_map[(path[i], path[i+1])])
                if causal_path:
                    paths.append(causal_path)
                continue
            
            for neighbor in graph[current]:
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, path + [neighbor]))
        
        return paths

class DecentralizedRAG:
    """
    Decentralized Retrieval-Augmented Generation for causal knowledge
    """
    
    def __init__(self, paranets: Dict[str, DomainParanet]):
        self.paranets = paranets
        self.retriever = None
        self.federation_peers: List[str] = []
    
    def set_retriever(self, retriever: GraphNeuralRetriever):
        """Set the graph neural retriever"""
        self.retriever = retriever
    
    def federated_query(self, query: str, domain: str = None) -> List[CausalKnowledgeAsset]:
        """Query across federated paranets"""
        all_results = []
        
        # Query local memory
        if self.retriever:
            local_results = self.retriever.retrieve(query)
            all_results.extend(local_results)
        
        # Query specific domain paranet or all paranets
        target_paranets = [self.paranets[domain]] if domain and domain in self.paranets else self.paranets.values()
        
        for paranet in target_paranets:
            # Get shared knowledge from paranet
            domain_results = [
                asset for asset in paranet.shared_knowledge
                if query.lower() in asset.task_description.lower()
            ]
            all_results.extend(domain_results)
        
        # Remove duplicates and sort by consensus score
        unique_results = {asset.asset_id: asset for asset in all_results}.values()
        return sorted(unique_results, key=lambda x: x.get_consensus_score(), reverse=True)
    
    def augment_plan_with_causal_knowledge(self, plan: List[str], domain: str) -> List[str]:
        """Augment a plan with causal knowledge from the network"""
        augmented_plan = []
        
        for i, action in enumerate(plan):
            augmented_plan.append(action)
            
            # Find relevant causal knowledge
            results = self.federated_query(action, domain)
            
            if results and results[0].get_consensus_score() > 0.8:
                # Add causal insights as comments
                top_asset = results[0]
                for triple in top_asset.causal_triples[:2]:  # Limit to top 2 insights
                    insight = f"# {triple.subject} {triple.predicate.value} {triple.object}"
                    augmented_plan.append(insight)
        
        return augmented_plan

# Integration function to add DCMN to existing planner
def integrate_dcmn_with_planner(planner):
    """
    Integrate DCMN components with existing NeuroSymbolicPlanner
    
    Args:
        planner: Instance of NeuroSymbolicPlanner
    
    Returns:
        Extended planner with DCMN capabilities
    """
    # Replace simple neural memory with causal graph memory
    planner.causal_memory = CausalGraphMemory()
    
    # Add multi-agent validator (rename to avoid conflict)
    planner.causal_validator = MultiAgentValidator()
    
    # Initialize domain paranets
    planner.paranets = {
        'logistics': DomainParanet('logistics'),
        'blocks': DomainParanet('blocks'),
        'cooking': DomainParanet('cooking'),
        'household': DomainParanet('household'),
        'medical': DomainParanet('medical'),
        'business': DomainParanet('business')
    }
    
    # Add retriever and dRAG
    planner.retriever = GraphNeuralRetriever(planner.causal_memory)
    planner.drag = DecentralizedRAG(planner.paranets)
    planner.drag.set_retriever(planner.retriever)
    
    # Add method to store planning experiences as causal assets
    def store_causal_experience(self, task_description, plan, outcome, execution_trace):
        """Store planning experience as causal knowledge asset"""
        # Extract causal relationships from execution trace and plan
        causal_triples = []
        
        # Generate causal triples from execution trace (action-to-state relationships)
        for i, trace_step in enumerate(execution_trace):
            current_action = trace_step.get('action', '')
            post_state = trace_step.get('post_state', {})
            
            if not current_action:
                continue
                
            # Create causal triples for each state change produced by this action
            for state_key, state_value in post_state.items():
                if state_value:  # Only record positive state changes
                    triple = CausalTriple(
                        subject=current_action,
                        predicate=CausalRelationType.PRODUCES,
                        object=state_key,
                        confidence=trace_step.get('confidence', 0.8)
                    )
                    causal_triples.append(triple)
        
        # Generate additional causal triples from plan structure (action-to-action relationships)
        for i in range(len(plan) - 1):
            current_action = plan[i]
            next_action = plan[i + 1]
            
            # Create causal triple for action enabling next action
            triple = CausalTriple(
                subject=current_action,
                predicate=CausalRelationType.ENABLES,
                object=next_action,
                confidence=0.9
            )
            causal_triples.append(triple)
        
        # For single-action plans, create causal relationship with task goal
        if len(plan) == 1 and not causal_triples:
            action = plan[0]
            goal_state = self._extract_goal_from_task(task_description)
            
            triple = CausalTriple(
                subject=action,
                predicate=CausalRelationType.PRODUCES,
                object=goal_state,
                confidence=0.9
            )
            causal_triples.append(triple)
        
        # Learn causal relationships from execution trace (if available)
        if execution_trace:
            learned_triples = self._learn_causal_relationships_from_execution(task_description, plan, execution_trace)
            causal_triples.extend(learned_triples)
        else:
            # Fallback: create basic sequential causality from plan structure only
            logger.info("No execution trace available - using basic plan structure analysis")
        
        # Determine domain
        domain = self._infer_domain(task_description)
        
        # Create causal asset
        asset = CausalKnowledgeAsset(
            asset_id="",  # Will be auto-generated
            causal_triples=causal_triples,
            plan_trace=plan,
            domain=domain,
            task_description=task_description,
            success_rate=1.0 if outcome == 'success' else 0.0,
            creation_time=time.time(),
            creator_id="planner_agent_1"
        )
        
        # Store in causal memory
        self.causal_memory.store_asset(asset)
        
        # Get validation from domain experts
        if domain in self.paranets:
            validations = self.paranets[domain].collective_validation(asset)
            is_valid, consensus = self.causal_validator.validate_asset(asset, validations)
            
            if is_valid:
                self.paranets[domain].share_knowledge(asset, "planner_agent_1")
    
    # Add method to planner
    planner.store_causal_experience = store_causal_experience.__get__(planner, planner.__class__)
    
    # Add domain inference method
    def _infer_domain(self, task_description):
        """Infer domain from task description"""
        task_lower = task_description.lower()
        if any(word in task_lower for word in ['block', 'stack', 'tower']):
            return 'blocks'
        elif any(word in task_lower for word in ['deliver', 'truck', 'package']):
            return 'logistics'
        elif any(word in task_lower for word in ['cook', 'heat', 'food']):
            return 'cooking'
        elif any(word in task_lower for word in ['clean', 'vacuum', 'room']):
            return 'household'
        elif any(word in task_lower for word in ['surgery', 'medical', 'patient']):
            return 'medical'
        elif any(word in task_lower for word in ['business', 'company', 'startup']):
            return 'business'
        else:
            return 'general'
    
    planner._infer_domain = _infer_domain.__get__(planner, planner.__class__)
    
    # Add helper methods for causal reasoning
    def _extract_goal_from_task(self, task_description):
        """Extract goal state from task description"""
        task_lower = task_description.lower()
        
        if 'pick up' in task_lower:
            return 'object-picked-up'
        elif 'turn on' in task_lower:
            return 'light-activated'
        elif 'stack' in task_lower:
            return 'blocks-stacked'
        elif 'deliver' in task_lower:
            return 'package-delivered'
        elif 'prepare' in task_lower or 'make' in task_lower:
            return 'food-prepared'
        elif 'wedding' in task_lower or 'reception' in task_lower:
            return 'event-completed'
        elif 'evacuation' in task_lower or 'emergency' in task_lower:
            return 'emergency-resolved'
        elif 'build' in task_lower:
            return 'structure-built'
        elif 'coordinate' in task_lower:
            return 'coordination-achieved'
        else:
            return 'goal-achieved'
    
    def _learn_causal_relationships_from_execution(self, task_description, plan, execution_trace):
        """Learn causal relationships from actual execution traces - no hardcoded rules"""
        causal_triples = []
        
        if not execution_trace:
            logger.warning("No execution trace provided - cannot learn causal relationships")
            return causal_triples
        
        try:
            # Analyze execution trace to discover causal relationships
            for i, trace_step in enumerate(execution_trace):
                action = trace_step.get('action', '')
                pre_state = trace_step.get('pre_state', {})
                post_state = trace_step.get('post_state', {})
                success = trace_step.get('success', True)
                
                if not action:
                    continue
                    
                # Learn what conditions this action requires (preconditions)
                for condition, value in pre_state.items():
                    if value:  # If condition was true before action
                        causal_triples.append(CausalTriple(
                            subject=action,
                            predicate=CausalRelationType.REQUIRES,
                            object=condition,
                            confidence=0.8 if success else 0.3
                        ))
                
                # Learn what effects this action produces
                for condition, pre_value in pre_state.items():
                    post_value = post_state.get(condition, pre_value)
                    if pre_value != post_value:  # State changed
                        causal_triples.append(CausalTriple(
                            subject=action,
                            predicate=CausalRelationType.MODIFIES if pre_value else CausalRelationType.PRODUCES,
                            object=condition,
                            confidence=0.9 if success else 0.4
                        ))
                
                # Learn action-to-action dependencies (sequential causality)
                if i < len(execution_trace) - 1:
                    next_action = execution_trace[i + 1].get('action', '')
                    if next_action:
                        causal_triples.append(CausalTriple(
                            subject=action,
                            predicate=CausalRelationType.ENABLES,
                            object=next_action,
                            confidence=0.7 if success else 0.2
                        ))
            
            logger.info(f"Learned {len(causal_triples)} causal relationships from execution trace")
            
        except Exception as e:
            logger.warning(f"Causal learning failed: {e}")
            # Fallback to basic plan structure analysis
            
        return causal_triples
    
    # Add methods to planner
    planner._extract_goal_from_task = _extract_goal_from_task.__get__(planner, planner.__class__)
    planner._learn_causal_relationships_from_execution = _learn_causal_relationships_from_execution.__get__(planner, planner.__class__)
    
    logger.info("DCMN integration complete - planner now has causal memory capabilities")
    
    return planner