"""
Execution Trace System for Real Causal Learning
Captures actual execution traces from plan execution to learn causal relationships
"""

import logging
import time
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
import re

logger = logging.getLogger(__name__)

@dataclass
class ExecutionStep:
    """Represents a single step in plan execution"""
    action: str
    pre_state: Dict[str, bool]
    post_state: Dict[str, bool]
    success: bool
    duration: float
    error_message: Optional[str] = None
    timestamp: float = field(default_factory=time.time)

@dataclass
class ExecutionTrace:
    """Complete execution trace for a plan"""
    task_description: str
    plan: List[str]
    steps: List[ExecutionStep]
    total_success: bool
    total_duration: float
    initial_state: Dict[str, bool]
    final_state: Dict[str, bool]
    
class PDDLExecutionSimulator:
    """Simulates PDDL plan execution to generate realistic execution traces"""
    
    def __init__(self):
        self.state_predicates = set()
        self.action_effects = {}  # action -> (preconditions, effects)
        
    def simulate_plan_execution(self, domain_pddl: str, problem_pddl: str, 
                               plan: List[str], task_description: str) -> ExecutionTrace:
        """Simulate plan execution and generate detailed execution trace"""
        
        logger.info(f"Simulating execution of {len(plan)} actions")
        
        # understand the pddl file so we know what actions do
        initial_state = self._parse_initial_state(problem_pddl)
        action_definitions = self._parse_action_effects(domain_pddl)
        
        # start tracking what happens during execution
        execution_steps = []
        current_state = initial_state.copy()
        total_success = True
        start_time = time.time()
        
        # go through each action one by one
        for i, action in enumerate(plan):
            step_start = time.time()
            
            # remember what the world looked like before this action
            pre_state = current_state.copy()
            
            # pretend to do this action and see what happens
            success, new_state, error = self._simulate_action(
                action, current_state, action_definitions
            )
            
            # write down what we just did
            step = ExecutionStep(
                action=action,
                pre_state=pre_state,
                post_state=new_state,
                success=success,
                duration=time.time() - step_start,
                error_message=error
            )
            
            execution_steps.append(step)
            
            # update our world state and see if we're still on track
            current_state = new_state
            if not success:
                total_success = False
                logger.warning(f"Action failed: {action} - {error}")
                break  # stop trying if something went wrong
                
        # put together a complete record of what happened
        trace = ExecutionTrace(
            task_description=task_description,
            plan=plan,
            steps=execution_steps,
            total_success=total_success,
            total_duration=time.time() - start_time,
            initial_state=initial_state,
            final_state=current_state
        )
        
        logger.info(f"Execution simulation completed: {len(execution_steps)} steps, "
                   f"success={total_success}")
        
        return trace
    
    def _parse_initial_state(self, problem_pddl: str) -> Dict[str, bool]:
        """Parse initial state from PDDL problem"""
        initial_state = {}
        
        # look for the part that describes the starting situation
        init_match = re.search(r':init\s+(.*?)(?=:goal|\Z)', problem_pddl, re.DOTALL)
        if not init_match:
            return initial_state
            
        init_content = init_match.group(1)
        
        # pull out all the facts about the world
        predicates = re.findall(r'\(([^)]+)\)', init_content)
        
        for predicate in predicates:
            predicate = predicate.strip()
            if predicate:
                # turn this into true/false statements
                initial_state[predicate] = True
                
        # add some things that are probably false by default
        common_states = [
            'hand-empty', 'goal-achieved', 'task-complete',
            'robot-at-goal', 'package-delivered', 'food-prepared'
        ]
        
        for state in common_states:
            if state not in initial_state:
                initial_state[state] = False
                
        logger.debug(f"Parsed initial state with {len(initial_state)} predicates")
        return initial_state
    
    def _parse_action_effects(self, domain_pddl: str) -> Dict[str, Dict[str, Any]]:
        """Parse action definitions from PDDL domain"""
        action_definitions = {}
        
        # find all the actions we can do
        action_matches = re.finditer(
            r':action\s+(\w+)\s+:parameters\s*\([^)]*\)\s*'
            r':precondition\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)\s*'
            r':effect\s*\(([^)]*(?:\([^)]*\)[^)]*)*)\)',
            domain_pddl, re.DOTALL
        )
        
        for match in action_matches:
            action_name = match.group(1)
            precondition = match.group(2).strip()
            effect = match.group(3).strip()
            
            # figure out what this action needs and what it does
            preconditions = self._parse_conditions(precondition)
            effects = self._parse_effects(effect)
            
            action_definitions[action_name] = {
                'preconditions': preconditions,
                'effects': effects
            }
            
        logger.debug(f"Parsed {len(action_definitions)} action definitions")
        return action_definitions
    
    def _parse_conditions(self, condition_str: str) -> List[Tuple[str, bool]]:
        """Parse PDDL conditions into (predicate, required_value) pairs"""
        conditions = []
        
        # deal with conditions that have multiple parts
        if condition_str.startswith('and'):
            condition_str = condition_str[3:].strip()
            
        # break down the conditions into individual pieces
        condition_parts = re.findall(r'\(([^)]+)\)', condition_str)
        
        for part in condition_parts:
            part = part.strip()
            if part.startswith('not '):
                # this is something that should not be true
                predicate = part[4:].strip()
                conditions.append((predicate, False))
            else:
                # this is something that should be true
                conditions.append((part, True))
                
        return conditions
    
    def _parse_effects(self, effect_str: str) -> List[Tuple[str, bool]]:
        """Parse PDDL effects into (predicate, new_value) pairs"""
        effects = []
        
        # deal with effects that have multiple parts
        if effect_str.startswith('and'):
            effect_str = effect_str[3:].strip()
            
        # Extract individual effects
        effect_parts = re.findall(r'\(([^)]+)\)', effect_str)
        
        for part in effect_parts:
            part = part.strip()
            if part.startswith('not '):
                # Negative effect (make false)
                predicate = part[4:].strip()
                effects.append((predicate, False))
            else:
                # Positive effect (make true)
                effects.append((part, True))
                
        return effects
    
    def _simulate_action(self, action: str, current_state: Dict[str, bool], 
                        action_definitions: Dict[str, Dict[str, Any]]) -> Tuple[bool, Dict[str, bool], Optional[str]]:
        """Simulate execution of a single action"""
        
        # Extract action name (first word)
        action_name = action.split()[0] if action.split() else action
        
        # Get action definition
        if action_name not in action_definitions:
            # Unknown action - make reasonable assumptions
            return self._simulate_unknown_action(action, current_state)
            
        action_def = action_definitions[action_name]
        preconditions = action_def.get('preconditions', [])
        effects = action_def.get('effects', [])
        
        # Check preconditions
        for predicate, required_value in preconditions:
            current_value = current_state.get(predicate, False)
            if current_value != required_value:
                error = f"Precondition failed: {predicate} should be {required_value}, got {current_value}"
                return False, current_state, error
        
        # Apply effects
        new_state = current_state.copy()
        for predicate, new_value in effects:
            new_state[predicate] = new_value
            
        return True, new_state, None
    
    def _simulate_unknown_action(self, action: str, current_state: Dict[str, bool]) -> Tuple[bool, Dict[str, bool], Optional[str]]:
        """Simulate unknown action with reasonable assumptions"""
        new_state = current_state.copy()
        
        # Make reasonable state changes based on action content
        action_lower = action.lower()
        
        if 'pick' in action_lower:
            new_state['hand-empty'] = False
            new_state[f'holding-{action.split()[-1] if action.split() else "object"}'] = True
            
        elif 'stack' in action_lower or 'place' in action_lower:
            new_state['hand-empty'] = True
            # Update location/stacking states
            if len(action.split()) >= 3:
                obj1, obj2 = action.split()[-2], action.split()[-1]
                new_state[f'on-{obj1}-{obj2}'] = True
                
        elif 'move' in action_lower:
            # Update location
            if len(action.split()) >= 3:
                obj, loc = action.split()[1], action.split()[-1]
                new_state[f'at-{obj}-{loc}'] = True
                
        elif 'load' in action_lower:
            new_state['package-loaded'] = True
            
        elif 'drive' in action_lower:
            new_state['vehicle-moved'] = True
            
        # Add some variability - actions sometimes fail
        import random
        success = random.random() > 0.1  # 90% success rate
        
        if not success:
            return False, current_state, "Simulated action failure"
            
        return True, new_state, None

class CausalLearningIntegrator:
    """Integrates execution traces with causal learning system"""
    
    def __init__(self, simulator: PDDLExecutionSimulator):
        self.simulator = simulator
        self.execution_history = []
        
    def execute_and_learn(self, domain_pddl: str, problem_pddl: str, 
                         plan: List[str], task_description: str, 
                         causal_memory) -> ExecutionTrace:
        """Execute plan, generate trace, and learn causal relationships"""
        
        # Simulate plan execution
        trace = self.simulator.simulate_plan_execution(
            domain_pddl, problem_pddl, plan, task_description
        )
        
        # Store execution history
        self.execution_history.append(trace)
        
        # Learn causal relationships from execution
        if hasattr(causal_memory, 'store_causal_experience'):
            # Convert trace to format expected by causal memory
            execution_trace_format = []
            for step in trace.steps:
                execution_trace_format.append({
                    'action': step.action,
                    'pre_state': step.pre_state,
                    'post_state': step.post_state,
                    'success': step.success,
                    'confidence': 0.9 if step.success else 0.3
                })
                
            # Store causal experience with execution trace
            causal_memory.store_causal_experience(
                task_description=task_description,
                plan=plan,
                outcome='success' if trace.total_success else 'failure',
                execution_trace=execution_trace_format
            )
            
            logger.info(f"Learned causal relationships from execution of {len(plan)} actions")
        
        return trace
    
    def get_execution_statistics(self) -> Dict[str, Any]:
        """Get statistics about executed plans"""
        if not self.execution_history:
            return {}
            
        total_executions = len(self.execution_history)
        successful_executions = sum(1 for trace in self.execution_history if trace.total_success)
        
        avg_plan_length = sum(len(trace.plan) for trace in self.execution_history) / total_executions
        avg_duration = sum(trace.total_duration for trace in self.execution_history) / total_executions
        
        return {
            'total_executions': total_executions,
            'success_rate': successful_executions / total_executions,
            'avg_plan_length': avg_plan_length,
            'avg_execution_time': avg_duration
        }