"""
Tau2Gym Environment Implementation

This module provides a Gymnasium-compatible environment that wraps tau2-bench's
AgentGymEnv, preserving all original prompts, agent context, and evaluation logic
while integrating with UserRL's training framework.

Key Design Principles:
1. Preserve tau2-bench's original prompts and agent context
2. Maintain identical evaluation logic for fair comparison
3. Support proper train/test splits for RL training
4. Minimize modifications to tau2-bench's core functionality
"""

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import asyncio
import concurrent.futures
from typing import Dict, Any, Tuple, Optional, List

# Import tau2-bench's gym interface directly
try:
    from tau2.gym.gym_agent import AgentGymEnv, register_gym_agent, TAU_BENCH_ENV_ID
    from tau2.data_model.simulation import Task
    from tau2.registry import registry
    TAU2_AVAILABLE = True
except ImportError:
    TAU2_AVAILABLE = False
    print("Warning: tau2-bench not found. Please install tau2-bench first.")

from ..config import Tau2GymConfig, get_default_config


class Tau2Env(gym.Env):
    """
    Tau2Gym Environment - A wrapper around tau2-bench's AgentGymEnv for UserRL.

    This environment directly uses tau2-bench's AgentGymEnv to ensure identical
    behavior, prompts, and evaluation logic. It adapts the tau2-bench interface
    to UserRL's expected format while preserving all tau2-bench functionality.

    The environment supports:
    - Multiple domains (retail, airline, telecom)
    - Train/test splits for proper RL evaluation
    - Both solo and interactive modes
    - Original tau2-bench prompts and evaluation criteria
    """

    metadata = {"render_modes": ["human"]}

    def __init__(self, config: Optional[Tau2GymConfig] = None):
        """
        Initialize the Tau2Gym Environment.

        Args:
            config: Tau2GymConfig instance with all configuration settings
        """
        super().__init__()

        if not TAU2_AVAILABLE:
            raise ImportError(
                "tau2-bench is not installed. Please install it first:\n"
                "cd tau2-bench && pip install -e ."
            )

        # Use provided config or default
        self.config = config or get_default_config()
        self.config.validate()
        
        # Suppress tau2-bench's verbose logging to reduce noise during training
        # This disables INFO-level logs from loguru (tau2-bench's logging library)
        # Only WARNING and ERROR logs will be shown
        try:
            from loguru import logger
            logger.remove()  # Remove default handler
            logger.add(lambda msg: None, level="WARNING")  # Only show WARNING and above
        except Exception:
            pass  # Silently ignore if loguru is not available

        # Register tau2-bench gym environments
        register_gym_agent()

        # Environment state
        self.tau2_env: Optional[AgentGymEnv] = None
        self.current_task_id: Optional[str] = None
        self.current_task: Optional[Task] = None
        self.step_count = 0
        self.episode_complete = False
        self.action_history = []
        self.total_reward = 0.0
        self.task_ids: List[str] = []
        self.current_task_index = 0

        # Thread pool executor for async execution
        # This allows us to run synchronous step() calls without blocking the event loop
        self._executor: Optional[concurrent.futures.ThreadPoolExecutor] = None

        # Set random seed if provided
        if self.config.seed is not None:
            self.seed(self.config.seed)

        # Load available task IDs for the domain and split
        self._load_task_ids()

        # Action space: text actions (tau2-bench uses string actions)
        self.action_space = spaces.Text(max_length=4096)

        # Observation space: text observation from tau2-bench
        # tau2-bench returns conversation history as a string
        self.observation_space = spaces.Text(max_length=16384)

    def _load_task_ids(self):
        """Load available task IDs for the configured domain and split."""
        try:
            # Use tau2-bench registry to get task IDs
            if self.config.task_split in ["train", "test"]:
                # Get task splits loader from registry
                splits_loader = registry.get_task_splits_loader(self.config.domain)
                splits = splits_loader()
                self.task_ids = splits.get(self.config.task_split, [])
            else:  # base split or other - use all tasks
                # Get tasks loader from registry
                tasks_loader = registry.get_tasks_loader(self.config.domain)
                all_tasks = tasks_loader()
                # Tasks are Task objects with an 'id' attribute
                self.task_ids = [task.id for task in all_tasks]

            if self.config.verbose:
                print(f"Loaded {len(self.task_ids)} tasks for {self.config.domain} domain, {self.config.task_split} split")

            # Handle data_source configuration
            if self.config.data_source:
                if isinstance(self.config.data_source, str):
                    self.task_ids = [self.config.data_source]
                elif isinstance(self.config.data_source, list):
                    self.task_ids = self.config.data_source

        except Exception as e:
            print(f"Error loading task IDs: {e}")
            import traceback
            traceback.print_exc()
            self.task_ids = []

    def _select_task_id(self) -> str:
        """Select a task ID based on data_mode configuration."""
        if not self.task_ids:
            raise ValueError("No task IDs available. Check your domain and split configuration.")

        if self.config.data_mode == "random":
            return random.choice(self.task_ids)
        elif self.config.data_mode == "sequential":
            task_id = self.task_ids[self.current_task_index % len(self.task_ids)]
            self.current_task_index += 1
            return task_id
        elif self.config.data_mode == "single":
            return self.task_ids[0]
        else:
            return self.task_ids[0]

    def _create_tau2_env(self, task_id: str) -> AgentGymEnv:
        """
        Create a tau2-bench AgentGymEnv instance.

        This directly uses tau2-bench's gym environment to ensure
        identical prompts, agent context, and evaluation logic.

        Args:
            task_id: Task ID to load

        Returns:
            AgentGymEnv instance
        """
        try:
            env = gym.make(
                TAU_BENCH_ENV_ID,
                domain=self.config.domain,
                task_id=task_id,
                solo_mode=self.config.solo_mode,
                user_llm=self.config.user_llm,
                user_llm_args=self.config.user_llm_args,
            )
            return env
        except Exception as e:
            raise RuntimeError(f"Failed to create tau2-bench environment: {e}")

    def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[str, Dict[str, Any]]:
        """Reset the environment to start a new episode.

        Args:
            seed: Random seed for reproducibility
            options: Additional options (not used currently)

        Returns:
            Tuple of (observation, info) where observation is a string
            containing the conversation history and info contains metadata
        """
        if seed is not None:
            self.seed(seed)

        # Select task ID
        self.current_task_id = self._select_task_id()

        # Create tau2-bench environment
        self.tau2_env = self._create_tau2_env(self.current_task_id)

        # Reset tau2-bench environment
        observation, info = self.tau2_env.reset()

        # Reset episode state
        self.step_count = 0
        self.episode_complete = False
        self.action_history = []
        self.total_reward = 0.0

        # Extract task information from info
        self.current_task = info.get('simulation_run', {})

        if self.config.verbose:
            print(f"="*60)
            print(f"NEW EPISODE - Tau2Gym ({self.config.domain})")
            print(f"="*60)
            print(f"Task ID: {self.current_task_id}")
            print(f"Task Split: {self.config.task_split}")
            print(f"Solo Mode: {self.config.solo_mode}")
            print(f"Initial Observation:\n{observation}")
            print(f"="*60)

        # Create UserRL-compatible info dictionary
        userrl_info = {
            "task_id": self.current_task_id,
            "domain": self.config.domain,
            "task_split": self.config.task_split,
            "solo_mode": self.config.solo_mode,
            "tau2_info": info,  # Preserve original tau2 info
            "policy": info.get('policy', ''),
            "tools": info.get('tools', []),
        }

        return observation, userrl_info

    async def reset_async(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[str, Dict[str, Any]]:
        """Reset the environment to start a new episode (async version).

        This method uses a thread pool executor to run the synchronous reset() method
        without blocking the event loop.

        Args:
            seed: Random seed for reproducibility
            options: Additional options (not used currently)

        Returns:
            Tuple of (observation, info) where observation is a string
            containing the conversation history and info contains metadata
        """
        # Use thread pool executor to run synchronous reset() without blocking event loop
        loop = asyncio.get_event_loop()
        
        # Create executor if it doesn't exist
        if self._executor is None:
            self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        
        # Run the synchronous reset method in the executor
        result = await loop.run_in_executor(self._executor, self.reset, seed, options)
        return result

    def step(self, action: str) -> Tuple[str, float, bool, bool, Dict[str, Any]]:
        """Execute one step in the environment.

        Args:
            action: Action string (either message to user or tool call)

        Returns:
            Tuple of (observation, reward, terminated, truncated, info)
        """
        if self.episode_complete:
            raise ValueError("Episode is complete. Call reset() to start a new episode.")

        self.step_count += 1
        action_str = str(action).strip()

        # Add action to history
        self.action_history.append(action_str)

        if self.config.verbose:
            print(f"\n--- Step {self.step_count} ---")
            print(f"Action: {action_str[:200]}...")

        try:
            # Call tau2-bench's step method directly
            # This ensures we use tau2-bench's original action parsing,
            # prompt generation, and evaluation logic
            observation, reward, terminated, truncated, info = self.tau2_env.step(action_str)

            # Update total reward
            self.total_reward += reward

            # Check truncation based on max_steps
            if self.step_count >= self.config.max_steps:
                truncated = True

            # Apply reward scaling and penalties
            scaled_reward = reward * self.config.reward_scale
            if self.config.step_penalty > 0:
                scaled_reward -= self.config.step_penalty

            if self.config.normalize_rewards and scaled_reward != 0:
                scaled_reward = max(0.0, min(1.0, scaled_reward))

            if terminated or truncated:
                self.episode_complete = True

            if self.config.verbose:
                print(f"Observation: {observation[:200]}...")
                print(f"Reward: {reward:.4f} (scaled: {scaled_reward:.4f})")
                print(f"Terminated: {terminated}, Truncated: {truncated}")

            # Create UserRL-compatible info dictionary
            userrl_info = {
                "raw_action": action_str,
                "task_id": self.current_task_id,
                "domain": self.config.domain,
                "task_split": self.config.task_split,
                "step_count": self.step_count,
                "total_reward": self.total_reward,
                "action_history": self.action_history.copy(),
                "tau2_info": info,  # Preserve original tau2 info
            }

            return observation, scaled_reward, terminated, truncated, userrl_info

        except Exception as e:
            if self.config.verbose:
                print(f"Error in step: {e}")
            import traceback
            traceback.print_exc()

            # Return error state
            error_obs = f"Error executing action: {str(e)}"
            error_info = {
                "error": str(e),
                "task_id": self.current_task_id,
                "step_count": self.step_count,
            }
            return error_obs, 0.0, False, False, error_info

    async def step_async(self, action: str) -> Tuple[str, float, bool, bool, Dict[str, Any]]:
        """Execute one step in the environment (async version).

        Note: tau2-bench's AgentGymEnv doesn't have native async support,
        so this just calls the sync version for interface compatibility.

        Args:
            action: Action string

        Returns:
            Tuple of (observation, reward, terminated, truncated, info)
        """
        # Use thread pool executor to run synchronous step() without blocking event loop
        loop = asyncio.get_event_loop()
        
        # Create executor if it doesn't exist
        if self._executor is None:
            self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
        
        # Run the synchronous step method in the executor
        # This prevents blocking the event loop during LLM API calls
        result = await loop.run_in_executor(self._executor, self.step, action)
        return result

    def render(self, mode="human"):
        """Render the current state of the environment."""
        if not self.tau2_env:
            print("No environment loaded. Call reset() first.")
            return

        print("\n" + "="*60)
        print(f"TAU2 GYM SESSION")
        print("="*60)
        print(f"Domain: {self.config.domain}")
        print(f"Task ID: {self.current_task_id}")
        print(f"Task Split: {self.config.task_split}")
        print(f"Steps taken: {self.step_count}/{self.config.max_steps}")
        print(f"Total reward: {self.total_reward:.4f}")
        print(f"Episode complete: {self.episode_complete}")

        if self.action_history:
            print(f"\n📝 Recent Actions (last 3):")
            for i, action in enumerate(self.action_history[-3:], 1):
                print(f"  {i}. {action[:100]}...")

        print("="*60 + "\n")

    def close(self):
        """Clean up resources."""
        if self.tau2_env:
            self.tau2_env.close()
            self.tau2_env = None

        # Clean up thread pool executor
        if self._executor is not None:
            self._executor.shutdown(wait=True)
            self._executor = None

    def seed(self, seed: Optional[int] = None):
        """Set random seed for reproducibility."""
        if seed is not None:
            np.random.seed(seed)
            random.seed(seed)
        return [seed]
