#!/usr/bin/env python3
"""
PDDL Text2World Player - Automated PDDL Environment Interaction

This script provides automated interaction with PDDL environments using LLM agents.
It takes domain.pddl and problem.pddl files as input and outputs interaction results.

Usage:
    python player_text2world.py domain.pddl problem.pddl
"""

import os
import sys
import json
import time
import argparse
import re
import warnings
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

# Suppress gym warnings
warnings.filterwarnings("ignore", message=".*Gym has been unmaintained.*")

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

try:
    import pddlgym
    from pddlgym.core import PDDLEnv
    PDDL_AVAILABLE = True
except ImportError:
    print("Error: pddlgym not installed. Please install with: pip install pddlgym")
    PDDL_AVAILABLE = False

try:
    from utils.llm import call_llm
    LLM_AVAILABLE = True
except ImportError:
    LLM_AVAILABLE = False

@dataclass
class PDDLStep:
    """Represents a single step in PDDL execution"""
    step_number: int
    state_description: str
    action_taken: Optional[str]
    observation: str
    reward: float
    is_terminal: bool
    llm_reasoning: str
    execution_time: float

@dataclass
class PDDLTrajectory:
    """Complete trajectory of PDDL execution"""
    domain_file: str
    problem_file: str
    steps: List[PDDLStep]
    total_reward: float
    success: bool
    total_execution_time: float

class PDDLPlayer:
    """Main PDDL player with LLM integration"""
    
    def __init__(self, model_name: str = "gpt-4o-mini", max_steps: int = 100):
        self.model_name = model_name
        self.max_steps = max_steps

    def play_episode(self, domain_file: str, problem_file: str, verbose: bool = True) -> PDDLTrajectory:
        """Play a complete PDDL episode"""
        if not PDDL_AVAILABLE:
            raise RuntimeError("PDDLGym is not available. Please install pddlgym.")
        
        start_time = time.time()
        
        # Validate files
        if not os.path.exists(domain_file):
            raise FileNotFoundError(f"Domain file not found: {domain_file}")
        if not os.path.exists(problem_file):
            raise FileNotFoundError(f"Problem file not found: {problem_file}")
        
        # Setup environment
        env = PDDLEnv(
            domain_file=domain_file,
            problem_dir=os.path.dirname(problem_file),
            render=False
        )
        
        # Reset environment
        observation = env.reset()
        if isinstance(observation, tuple):
            observation = observation[0]
        
        # Initialize trajectory
        trajectory = PDDLTrajectory(
            domain_file=domain_file,
            problem_file=problem_file,
            steps=[],
            total_reward=0.0,
            success=False,
            total_execution_time=0.0
        )
        
        if verbose:
            print(f"Starting PDDL episode: {Path(domain_file).name} -> {Path(problem_file).name}")
        
        # Execute episode
        step_number = 0
        total_reward = 0.0
        done = False
        
        while not done and step_number < self.max_steps:
            step_start_time = time.time()
            
            # Get available actions
            try:
                available_actions = list(env.action_space.all_ground_literals(observation))
            except Exception as e:
                if verbose:
                    print(f"Error getting available actions: {e}")
                break
            
            if not available_actions:
                if verbose:
                    print("No available actions. Episode terminated.")
                break
            
            # Select action using LLM or heuristic
            action, reasoning = self._select_action(observation, available_actions, env)
            
            if action is None:
                break
            
            # Execute action
            try:
                step_result = env.step(action)
                if len(step_result) == 4:
                    next_observation, reward, done, info = step_result
                else:
                    next_observation, reward, done, truncated, info = step_result
                
                total_reward += reward
                
                # Record step
                step = PDDLStep(
                    step_number=step_number,
                    state_description=self._format_observation(observation),
                    action_taken=str(action),
                    observation=self._format_observation(next_observation),
                    reward=reward,
                    is_terminal=done,
                    llm_reasoning=reasoning,
                    execution_time=time.time() - step_start_time
                )
                trajectory.steps.append(step)
                
                if verbose:
                    print(f"Step {step_number}: {action} -> Reward: {reward}")
                
                observation = next_observation
                step_number += 1
                
            except Exception as e:
                if verbose:
                    print(f"Error executing action: {e}")
                break
        
        # Finalize trajectory
        trajectory.total_reward = total_reward
        trajectory.success = done and total_reward > 0
        trajectory.total_execution_time = time.time() - start_time
        
        if verbose:
            print(f"Episode completed: {step_number} steps, reward: {total_reward:.2f}")
        
        return trajectory

    def _select_action(self, observation: Any, available_actions: List[Any], env: PDDLEnv) -> Tuple[Any, str]:
        """Select action using LLM or heuristic fallback"""
        if LLM_AVAILABLE:
            return self._llm_select_action(observation, available_actions, env)
        else:
            return self._heuristic_select_action(available_actions)

    def _llm_select_action(self, observation: Any, available_actions: List[Any], env: PDDLEnv) -> Tuple[Any, str]:
        """Select action using LLM"""
        try:
            state_desc = self._format_observation(observation)
            actions_desc = self._format_actions(available_actions)
            goal_info = str(env.goal) if hasattr(env, 'goal') else "Unknown goal"
            
            prompt = f"""You are a PDDL planning agent. Select the best action.

CURRENT STATE:
{state_desc}

GOAL: {goal_info}

AVAILABLE ACTIONS:
{actions_desc}

Respond with:
ACTION: <number>
REASONING: <your reasoning>"""
            
            response = call_llm(
                text=prompt,
                system_prompt="You are a PDDL planning expert.",
                model=self.model_name,
                max_tokens=1000,
                temperature=0.1
            )
            
            # Parse response
            action_match = re.search(r'ACTION:\s*(\d+)', response, re.IGNORECASE)
            reasoning_match = re.search(r'REASONING:\s*(.+)', response, re.IGNORECASE | re.DOTALL)
            
            if action_match:
                action_idx = int(action_match.group(1))
                if 0 <= action_idx < len(available_actions):
                    reasoning = reasoning_match.group(1).strip() if reasoning_match else "LLM selection"
                    return available_actions[action_idx], reasoning
            
        except Exception as e:
            print(f"LLM selection failed: {e}")
        
        # Fallback to heuristic
        return self._heuristic_select_action(available_actions)

    def _heuristic_select_action(self, available_actions: List[Any]) -> Tuple[Any, str]:
        """Heuristic action selection"""
        if not available_actions:
            return None, "No actions available"
        
        # Prefer actions with certain keywords
        preferred_keywords = ['pick', 'unlock', 'move', 'open']
        
        for keyword in preferred_keywords:
            for action in available_actions:
                if keyword.lower() in str(action).lower():
                    return action, f"Heuristic: selected action with '{keyword}'"
        
        return available_actions[0], "Heuristic: selected first action"

    def _format_observation(self, observation: Any) -> str:
        """Format observation for display"""
        if observation is None:
            return "No observation"
        
        try:
            if hasattr(observation, 'literals'):
                literals = list(observation.literals)
                return "\n".join([f"  - {lit}" for lit in literals]) if literals else "Empty state"
            elif isinstance(observation, (list, tuple, set)):
                return "\n".join([f"  - {item}" for item in observation]) if observation else "Empty state"
            else:
                return str(observation)
        except Exception:
            return "Error formatting observation"

    def _format_actions(self, actions: List[Any]) -> str:
        """Format available actions"""
        return "\n".join([f"{i}: {action}" for i, action in enumerate(actions)])

def save_trajectory(trajectory: PDDLTrajectory, output_file: str):
    """Save trajectory to JSON file"""
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump({
            'domain_file': trajectory.domain_file,
            'problem_file': trajectory.problem_file,
            'total_reward': trajectory.total_reward,
            'success': trajectory.success,
            'total_steps': len(trajectory.steps),
            'total_execution_time': trajectory.total_execution_time,
            'steps': [
                {
                    'step_number': step.step_number,
                    'action_taken': step.action_taken,
                    'reward': step.reward,
                    'is_terminal': step.is_terminal,
                    'reasoning': step.llm_reasoning,
                    'execution_time': step.execution_time
                }
                for step in trajectory.steps
            ]
        }, f, indent=2, ensure_ascii=False)

def main():
    """Main function"""
    parser = argparse.ArgumentParser(description='PDDL Text2World Player')
    parser.add_argument('domain_file', help='Path to domain.pddl file')
    parser.add_argument('problem_file', help='Path to problem.pddl file')
    parser.add_argument('--model', default='gpt-4o-mini', help='LLM model to use')
    parser.add_argument('--max-steps', type=int, default=100, help='Maximum steps per episode')
    parser.add_argument('--output', help='Output file for trajectory')
    parser.add_argument('--quiet', action='store_true', help='Suppress verbose output')
    
    args = parser.parse_args()
    
    try:
        player = PDDLPlayer(model_name=args.model, max_steps=args.max_steps)
        trajectory = player.play_episode(
            args.domain_file, 
            args.problem_file, 
            verbose=not args.quiet
        )
        
        if args.output:
            save_trajectory(trajectory, args.output)
            print(f"Trajectory saved to: {args.output}")
        
        print(f"\nResults:")
        print(f"Success: {trajectory.success}")
        print(f"Total reward: {trajectory.total_reward}")
        print(f"Steps: {len(trajectory.steps)}")
        print(f"Execution time: {trajectory.total_execution_time:.2f}s")
        
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()