import gymnasium as gym
from gymnasium import spaces
import numpy as np
from typing import Dict, Any, Tuple, List
import os
import sys
from pathlib import Path
from openai import OpenAI, AsyncOpenAI
import openai
import asyncio
import multiprocessing

# Add to sys.path for sweet_rl import
_sweet_rl_path = Path("")
if str(_sweet_rl_path) not in sys.path:
    sys.path.insert(0, str(_sweet_rl_path))

from ..config import ColBenchGymConfig, get_code_config


def _run_check_correctness_worker(ground_truth, agent_code, test_cases, result_queue):
    """Run check_correctness in a separate process.
    
    This is a module-level function to allow pickling for multiprocessing.
    """
    try:
        from sweet_rl.utils.code_utils import check_correctness
        correctness = check_correctness(ground_truth, agent_code, test_cases)
        result_queue.put(('success', correctness))
    except Exception as e:
        result_queue.put(('error', str(e)))


class ColBenchCodeEnv(gym.Env):
    """
    Gymnasium environment for ColBench Backend Programming tasks.

    This environment simulates collaborative code development where an agent
    interacts with a human collaborator (simulated via VLLM server) to understand
    requirements and produce Python code solutions.
    """

    metadata = {"render_modes": ["human"]}

    def __init__(
        self,
        config: ColBenchGymConfig = None,
        task: Dict[str, Any] = None,
        category: str = None,
        id: str = None
    ):
        """
        Initialize the ColBench Code Environment.

        Args:
            config: ColBenchGymConfig instance with all configuration settings
            task: Task description string or dict with problem_description
            category: Task category (optional)
            id: Task ID (optional)
        """
        super().__init__()

        # Use provided config or default
        self.config = config or get_code_config()
        self.config.validate()

        # Load prompts
        prompt_dir = Path(__file__).parent.parent / "prompts"
        with open(prompt_dir / "llm_agent_code_prompt.txt", "r") as f:
            self.agent_prompt = f.read()
        with open(prompt_dir / "human_simulator_code_prompt.txt", "r") as f:
            self.human_prompt = f.read()

        # Initialize OpenAI clients for environment simulator
        # Use custom API key if provided, otherwise use "EMPTY" for VLLM servers
        api_key = self.config.env_api_key if self.config.env_api_key else "EMPTY"
        # Sync client for backward compatibility (used by sync step method)
        self.client = OpenAI(
            base_url=self.config.env_base_url,
            api_key=api_key
        )
        # Async client for async operations (used by async step_async method)
        self.async_client = AsyncOpenAI(
            base_url=self.config.env_base_url,
            api_key=api_key
        )

        # Task information
        self.initial_task = task
        self.category = category
        self.id = id

        # Environment state
        self.problem_description = ""
        self.ground_truth = ""
        self.test_cases = {}  # Test cases for code evaluation
        self.step_count = 0
        self.episode_complete = False
        self.dialogue_history = []  # List of {"role": "user"/"assistant", "content": str}
        self.agent_answer = "No answer"

        # Action space: text actions
        self.action_space = spaces.Text(max_length=8192)

        # Observation space: dictionary containing dialogue state
        self.observation_space = spaces.Dict({
            "dialogue_history": spaces.Text(max_length=32768),
            "step_count": spaces.Box(low=0, high=self.config.max_steps, shape=(), dtype=np.int32),
            "episode_complete": spaces.Box(low=0, high=1, shape=(), dtype=np.bool_),
            "last_message": spaces.Text(max_length=8192)
        })

    def _format_dialogue_history(self) -> str:
        """Format dialogue history as string for observation."""
        result = ""
        for msg in self.dialogue_history:
            result += f"{msg['role']}: {msg['content']}\n\n"
        return result.strip()

    def _get_dialogue_messages(self) -> List[Dict[str, str]]:
        """Get dialogue history as list of messages."""
        return [
            {"role": msg["role"], "content": msg["content"]}
            for msg in self.dialogue_history
        ]

    def reset(
        self,
        seed=None,
        options: Dict[str, Any] = None
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """Reset the environment to start a new collaborative coding session."""
        super().reset(seed=seed)

        # Extract task information
        if options and "task" in options:
            task_data = options["task"]
        elif self.initial_task:
            task_data = self.initial_task
        else:
            raise ValueError("No task provided. Pass task in options or during init.")

        # Handle different task formats
        if isinstance(task_data, dict):
            self.problem_description = task_data.get("problem_description", "")
            self.ground_truth = task_data.get("ground_truth", "")
            self.test_cases = task_data.get("test_cases", {})
        else:
            self.problem_description = str(task_data)
            self.ground_truth = ""
            self.test_cases = {}

        # Reset episode state
        self.step_count = 0
        self.episode_complete = False
        self.agent_answer = "No answer"
        self.dialogue_history = []

        # Initialize dialogue with problem description
        self.dialogue_history.append({
            "role": "user",
            "content": self.problem_description
        })

        # Create initial observation
        observation = {
            "dialogue_history": self._format_dialogue_history(),
            "step_count": self.step_count,
            "episode_complete": self.episode_complete,
            "last_message": self.problem_description,
            "feedback": self.problem_description  # Add feedback field for interact_tool compatibility
        }

        info = {
            "problem_description": self.problem_description,
            "ground_truth": self.ground_truth,
            "task_id": self.id or "",
            "category": self.category or "",
            "dialogue_messages": self._get_dialogue_messages()
        }

        if self.config.verbose:
            print(f"🎯 New collaborative coding session")
            print(f"Problem: {self.problem_description}")

        return observation, info

    def _invoke_human_simulator(self) -> str:
        """Call the VLLM server to simulate human collaborator response (synchronous version)."""
        import time
        
        max_retries = 3
        retry_delay = 1.0  # seconds
        
        for attempt in range(max_retries):
            try:
                # Format the prompt for human simulator
                dialogue_str = self._format_dialogue_history()

                messages = [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {
                        "role": "user",
                        "content": self.human_prompt.format(
                            problem_description=self.problem_description,
                            hidden_information=self.ground_truth,
                            dialogue_history=dialogue_str
                        )
                    }
                ]

                completion = self.client.chat.completions.create(
                    model=self.config.env_model_name,
                    messages=messages,
                    max_tokens=4096,
                    temperature=0,
                    timeout=30.0  # 30 second timeout
                )

                return completion.choices[0].message.content

            except (openai.BadRequestError, openai.APIError) as e:
                if self.config.verbose:
                    print(f"API error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(retry_delay * (attempt + 1))  # Exponential backoff
                else:
                    return "No response."
            except (openai.APIConnectionError, openai.APITimeoutError, ConnectionError) as e:
                if self.config.verbose:
                    print(f"Connection error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(retry_delay * (attempt + 1))  # Exponential backoff
                else:
                    if self.config.verbose:
                        print(f"Failed to connect to {self.config.env_base_url} after {max_retries} attempts")
                    return "No response."
            except Exception as e:
                if self.config.verbose:
                    print(f"Unexpected error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(retry_delay * (attempt + 1))
                else:
                    return "No response."

        return "No response."

    async def _invoke_human_simulator_async(self) -> str:
        """Call the VLLM server to simulate human collaborator response (async version).
        
        This async version can be properly interrupted by asyncio.wait_for, preventing
        NCCL timeout issues when the API call hangs.
        """
        max_retries = 3
        retry_delay = 1.0  # seconds
        
        for attempt in range(max_retries):
            try:
                # Format the prompt for human simulator
                dialogue_str = self._format_dialogue_history()

                messages = [
                    {"role": "system", "content": "You are a helpful assistant."},
                    {
                        "role": "user",
                        "content": self.human_prompt.format(
                            problem_description=self.problem_description,
                            hidden_information=self.ground_truth,
                            dialogue_history=dialogue_str
                        )
                    }
                ]

                # Use async client with proper timeout - this can be interrupted by asyncio.wait_for
                completion = await self.async_client.chat.completions.create(
                    model=self.config.env_model_name,
                    messages=messages,
                    max_tokens=4096,
                    temperature=0,
                    timeout=30.0  # 30 second timeout per attempt
                )

                return completion.choices[0].message.content

            except (openai.BadRequestError, openai.APIError) as e:
                if self.config.verbose:
                    print(f"API error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    await asyncio.sleep(retry_delay * (attempt + 1))  # Async sleep for exponential backoff
                else:
                    return "No response."
            except (openai.APIConnectionError, openai.APITimeoutError, ConnectionError) as e:
                if self.config.verbose:
                    print(f"Connection error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    await asyncio.sleep(retry_delay * (attempt + 1))  # Async sleep for exponential backoff
                else:
                    if self.config.verbose:
                        print(f"Failed to connect to {self.config.env_base_url} after {max_retries} attempts")
                    return "No response."
            except Exception as e:
                if self.config.verbose:
                    print(f"Unexpected error (attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    await asyncio.sleep(retry_delay * (attempt + 1))  # Async sleep for exponential backoff
                else:
                    return "No response."

        return "No response."

    def step(self, action: str) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
        """Execute one step in the environment."""
        if self.episode_complete:
            raise ValueError("Episode is complete. Call reset() to start a new episode.")

        self.step_count += 1
        agent_response = str(action).strip()

        # Process agent response
        raw_response = agent_response

        # Extract OUTPUT if present
        if "OUTPUT:" in agent_response:
            agent_response = agent_response.split("OUTPUT:")[1]
            raw_response = "OUTPUT:".join(raw_response.split("OUTPUT:")[:2])

        # Check for answer or max steps
        if "I WANT TO ANSWER:" in agent_response or self.step_count >= self.config.max_steps:
            self.episode_complete = True
            if "I WANT TO ANSWER:" in agent_response:
                self.agent_answer = agent_response.split("I WANT TO ANSWER:")[1]
            else:
                self.agent_answer = agent_response

        # Add agent response to dialogue history
        self.dialogue_history.append({
            "role": "assistant",
            "content": agent_response
        })

        # Get human response if not done
        human_response = ""
        if not self.episode_complete:
            human_response = self._invoke_human_simulator()
            # Truncate human response to character limit
            human_response = human_response[:self.config.human_response_char_limit]
            self.dialogue_history.append({
                "role": "user",
                "content": human_response
            })

        # Calculate reward (0 for now, will be computed externally)
        reward = 0.0

        # Apply step penalty
        if not self.episode_complete:
            reward -= self.config.step_penalty

        # Normalize reward if configured
        if self.config.normalize_rewards and reward != 0:
            reward = max(-1.0, min(1.0, reward))

        # Create observation
        # Use last_message as feedback for consistency with other environments
        feedback_message = human_response if not self.episode_complete else agent_response
        observation = {
            "dialogue_history": self._format_dialogue_history(),
            "step_count": self.step_count,
            "episode_complete": self.episode_complete,
            "last_message": feedback_message,
            "feedback": feedback_message  # Add feedback field for interact_tool compatibility
        }

        info = {
            "problem_description": self.problem_description,
            "ground_truth": self.ground_truth,
            "agent_answer": self.agent_answer,
            "task_id": self.id or "",
            "category": self.category or "",
            "dialogue_messages": self._get_dialogue_messages(),
            "raw_agent_response": raw_response
        }

        terminated = self.episode_complete
        truncated = self.step_count >= self.config.max_steps and not ("I WANT TO ANSWER:" in agent_response)

        if self.config.verbose:
            print(f"\n--- Step {self.step_count} ---")
            print(f"Agent: {agent_response[:200]}...")
            if not self.episode_complete:
                print(f"Human: {human_response[:200]}...")
            print(f"Done: {terminated or truncated}")

        return observation, reward, terminated, truncated, info

    async def step_async(self, action: str) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]:
        """Execute one step in the environment (async version).
        
        This async version uses AsyncOpenAI client, which can be properly interrupted
        by asyncio.wait_for to prevent NCCL timeout issues.
        
        Args:
            action: Action string
            
        Returns:
            Tuple of (observation, reward, terminated, truncated, info)
        """
        if self.episode_complete:
            raise ValueError("Episode is complete. Call reset() to start a new episode.")

        self.step_count += 1
        agent_response = str(action).strip()

        # Process agent response
        raw_response = agent_response

        # Extract OUTPUT if present
        if "OUTPUT:" in agent_response:
            agent_response = agent_response.split("OUTPUT:")[1]
            raw_response = "OUTPUT:".join(agent_response.split("OUTPUT:")[:2])

        # Check for answer or max steps
        if "I WANT TO ANSWER:" in agent_response or self.step_count >= self.config.max_steps:
            self.episode_complete = True
            if "I WANT TO ANSWER:" in agent_response:
                self.agent_answer = agent_response.split("I WANT TO ANSWER:")[1]
            else:
                self.agent_answer = agent_response

        # Add agent response to dialogue history
        self.dialogue_history.append({
            "role": "assistant",
            "content": agent_response
        })

        # Get human response if not done (using async version)
        human_response = ""
        if not self.episode_complete:
            human_response = await self._invoke_human_simulator_async()
            # Truncate human response to character limit
            human_response = human_response[:self.config.human_response_char_limit]
            self.dialogue_history.append({
                "role": "user",
                "content": human_response
            })

        # Calculate reward (0 for now, will be computed externally)
        reward = 0.0

        # Apply step penalty
        if not self.episode_complete:
            reward -= self.config.step_penalty

        # Normalize reward if configured
        if self.config.normalize_rewards and reward != 0:
            reward = max(-1.0, min(1.0, reward))

        # Create observation
        # Use last_message as feedback for consistency with other environments
        feedback_message = human_response if not self.episode_complete else agent_response
        observation = {
            "dialogue_history": self._format_dialogue_history(),
            "step_count": self.step_count,
            "episode_complete": self.episode_complete,
            "last_message": feedback_message,
            "feedback": feedback_message  # Add feedback field for interact_tool compatibility
        }

        info = {
            "problem_description": self.problem_description,
            "ground_truth": self.ground_truth,
            "agent_answer": self.agent_answer,
            "task_id": self.id or "",
            "category": self.category or "",
            "dialogue_messages": self._get_dialogue_messages(),
            "raw_agent_response": raw_response
        }

        terminated = self.episode_complete
        truncated = self.step_count >= self.config.max_steps and not ("I WANT TO ANSWER:" in agent_response)

        if self.config.verbose:
            print(f"\n--- Step {self.step_count} ---")
            print(f"Agent: {agent_response[:200]}...")
            if not self.episode_complete:
                print(f"Human: {human_response[:200]}...")
            print(f"Done: {terminated or truncated}")

        return observation, reward, terminated, truncated, info

    def render(self, mode="human"):
        """Render the environment state."""
        if mode == "human":
            print(f"\n=== ColBench Code Environment Step {self.step_count} ===")
            print(f"Problem: {self.problem_description}")
            print(f"Dialogue turns: {len(self.dialogue_history)}")
            if self.dialogue_history:
                last_msg = self.dialogue_history[-1]
                print(f"Last message ({last_msg['role']}): {last_msg['content'][:200]}...")
            print(f"Episode complete: {self.episode_complete}")

    def calculate_reward(self) -> float:
        """Calculate final reward based on code correctness.

        This method evaluates the agent's final answer against the ground truth
        using test cases. The evaluation logic is identical to sweet_rl's
        check_correctness function to ensure consistency.

        Returns:
            float: Reward value (0.0 to 1.0), representing the fraction of
                   test cases that pass.
        """
        # If episode is not complete or agent_answer is still "No answer",
        # try to extract answer from the last agent response in dialogue history
        if not self.episode_complete or self.agent_answer == "No answer":
            # Check if step_count has reached max_steps (rollout may have ended early)
            # If so, we should mark episode as complete and use the last response
            if self.step_count >= self.config.max_steps:
                print(f"[DEBUG] ColBench calculate_reward: step_count ({self.step_count}) >= max_steps ({self.config.max_steps}), forcing episode completion")
                self.episode_complete = True
            
            # Check if conversation has ended (last message is from assistant, no more user messages)
            # This indicates rollout has ended early, so we should mark as complete and use the last response
            conversation_ended = False
            if self.dialogue_history:
                last_msg = self.dialogue_history[-1]
                if last_msg.get("role") == "assistant":
                    conversation_ended = True
                    print(f"[DEBUG] ColBench calculate_reward: Conversation appears to have ended (last message from assistant, step_count={self.step_count}), marking as complete")
                    self.episode_complete = True
            
            # Check if we have dialogue history and can extract an answer
            if self.dialogue_history:
                # Find the last agent response
                for msg in reversed(self.dialogue_history):
                    if msg.get("role") == "assistant":
                        last_response = msg.get("content", "").strip()
                        if last_response:
                            # Try to extract answer from "I WANT TO ANSWER:" format
                            if "I WANT TO ANSWER:" in last_response:
                                self.agent_answer = last_response.split("I WANT TO ANSWER:")[1].strip()
                                print(f"[DEBUG] ColBench calculate_reward: Extracted answer from incomplete episode, agent_answer_length={len(self.agent_answer)}")
                                # Mark episode as complete since we found an answer
                                self.episode_complete = True
                                break
                            # If no "I WANT TO ANSWER:" but we have a response that looks like code, use it as answer
                            # This handles cases where rollout ended before agent could format answer properly
                            # Check if response contains code-like patterns (def, import, class, etc.)
                            elif len(last_response) > 20 and any(keyword in last_response for keyword in ["def ", "import ", "class ", "return ", "="]):
                                self.agent_answer = last_response
                                print(f"[DEBUG] ColBench calculate_reward: Using last agent response as answer (contains code patterns, no 'I WANT TO ANSWER:' found), agent_answer_length={len(self.agent_answer)}")
                                self.episode_complete = True
                                break
                            # If episode is marked as complete (e.g., reached max_steps or conversation ended) 
                            # but no code pattern found, still use the last response as answer
                            # This handles cases where rollout ended early and agent didn't provide final answer
                            elif self.episode_complete and len(last_response) > 10:
                                self.agent_answer = last_response
                                print(f"[DEBUG] ColBench calculate_reward: Using last agent response as answer (episode complete but no code pattern, step_count={self.step_count}, conversation_ended={conversation_ended}), agent_answer_length={len(self.agent_answer)}")
                                break
            
            # If we still don't have a valid answer, return 0.0
            if self.agent_answer == "No answer" or not self.episode_complete:
                # Special handling for step_count=0: This means rollout ended before any environment step
                # This is normal (e.g., generation failed, length limit, etc.) and should return 0.0
                # We don't log this case to avoid spam, as it's expected behavior
                if self.step_count == 0:
                    return 0.0
                else:
                    # step_count > 0 but no valid answer found - log this as it may indicate an issue
                    # Only log occasionally to avoid spam (every 10th occurrence)
                    if not hasattr(self, '_no_answer_log_counter'):
                        self._no_answer_log_counter = 0
                    self._no_answer_log_counter += 1
                    if self._no_answer_log_counter % 10 == 1:  # Log first occurrence and every 10th
                        print(f"[DEBUG] ColBench calculate_reward: No valid answer found (episode_complete={self.episode_complete}, step_count={self.step_count}, max_steps={self.config.max_steps}, agent_answer='{self.agent_answer[:50] if len(self.agent_answer) > 50 else self.agent_answer}'), returning 0.0")
                    return 0.0

        # Extract code from agent answer (same as sweet_rl)
        # Note: agent_answer is already extracted from "I WANT TO ANSWER:" in step() method
        agent_code = self.agent_answer.strip()

        if not agent_code:
            print(f"[DEBUG] ColBench calculate_reward: agent_code is empty after strip()")
            return 0.0

        # If test cases are available, use sweet_rl's exact evaluation logic
        if self.test_cases and len(self.test_cases) > 0:
            try:
                # IMPORTANT: Filter out None values from test_cases
                # Parquet serialization may add None values for schema consistency
                valid_test_cases = {k: v for k, v in self.test_cases.items() if v is not None}
                
                print(f"[DEBUG] ColBench calculate_reward: total_test_cases={len(self.test_cases)}, valid_test_cases={len(valid_test_cases)}")

                if not valid_test_cases:
                    print(f"[DEBUG] ColBench calculate_reward: All test cases are None after filtering, returning 0.0")
                    return 0.0

                # Use sweet_rl's code evaluation logic (identical to evaluate_code.py)
                # This ensures 100% consistency with sweet_rl's reward calculation
                from sweet_rl.utils.code_utils import check_correctness
                
                print(f"[DEBUG] ColBench calculate_reward: calling check_correctness with agent_code length={len(agent_code)}, num_test_cases={len(valid_test_cases)}")
                
                # Add timeout protection: sweet_rl's check_correctness uses signal.alarm which may not work in multithreaded environment
                # We wrap it in a separate process with timeout to prevent hanging
                # Use multiprocessing with timeout to prevent hanging
                result_queue = multiprocessing.Queue()
                process = multiprocessing.Process(
                    target=_run_check_correctness_worker,
                    args=(self.ground_truth, agent_code, valid_test_cases, result_queue)
                )
                process.start()
                
                # Wait with timeout (120 seconds should be enough for most cases)
                timeout_seconds = 120
                process.join(timeout=timeout_seconds)
                
                if process.is_alive():
                    # Process is still running - it's hanging!
                    print(f"[ERROR] ColBench calculate_reward: check_correctness TIMED OUT after {timeout_seconds}s! Killing process...")
                    process.terminate()
                    process.join(timeout=5)
                    if process.is_alive():
                        process.kill()
                    print(f"[ERROR] ColBench calculate_reward: check_correctness was killed due to timeout. Returning 0.0")
                    return 0.0
                
                # Process finished - get result
                if not result_queue.empty():
                    status, result = result_queue.get()
                    if status == 'success':
                        correctness = result
                        print(f"[DEBUG] ColBench calculate_reward: check_correctness returned {correctness}")
                        return float(correctness)
                    else:
                        print(f"[ERROR] ColBench calculate_reward: check_correctness raised exception: {result}")
                        return 0.0
                else:
                    print(f"[ERROR] ColBench calculate_reward: check_correctness process finished but no result in queue. Returning 0.0")
                    return 0.0
            except Exception as e:
                print(f"[DEBUG] ColBench calculate_reward: Error evaluating code with test cases: {e}")
                import traceback
                traceback.print_exc()
                # Return 0.0 on error (conservative)
                return 0.0

        # Fallback: If no test cases, return 0.0
        # In sweet_rl, test_cases are always available, so this should rarely happen
        print(f"[DEBUG] ColBench calculate_reward: No test cases available (test_cases={self.test_cases}), returning 0.0")
        return 0.0

    def close(self):
        """Clean up environment resources."""
        pass
