"""
Code Execution Environment for Curriculum-Based Training.

This module provides safe code execution with subprocess isolation,
supporting curriculum-based code generation training.
"""

import os
import subprocess
import tempfile
import time
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional

logger = logging.getLogger(__name__)


@dataclass
class CodeEnvConfig:
    """Configuration for the code execution environment."""
    timeout_seconds: float = 10.0
    max_code_length: int = 2048
    max_output_length: int = 4096
    sandbox_enabled: bool = True


@dataclass
class ExecutionResult:
    """Result of executing code in the environment."""
    reward: float
    info: Dict[str, Any] = field(default_factory=dict)
    execution_time: float = 0.0


@dataclass
class CodeTask:
    """A code task with test cases."""
    task_id: str
    prompt: str
    canonical_solution: str
    test_cases: str
    entry_point: str


class CurriculumCodeEnv:
    """
    Code execution environment supporting curriculum-based training.

    Executes generated code safely with subprocess isolation and
    evaluates against test cases for reward computation.
    """

    def __init__(
        self,
        tasks: List[CodeTask],
        config: Optional[CodeEnvConfig] = None,
    ):
        """
        Initialize the code execution environment.

        Args:
            tasks: List of CodeTask objects
            config: Environment configuration
        """
        self.tasks = tasks
        self.config = config or CodeEnvConfig()
        self.current_task = None
        self.current_task_idx = 0

    def reset(
        self,
        task_idx: Optional[int] = None,
        curriculum_task: Optional[Any] = None,
    ) -> Dict[str, Any]:
        """
        Reset the environment with a new task.

        Args:
            task_idx: Index of task to use (optional)
            curriculum_task: CurriculumTask object (optional, overrides task_idx)

        Returns:
            Observation dictionary with task information
        """
        if curriculum_task is not None:
            # Use curriculum task directly
            self.current_task = curriculum_task
            self.current_task_idx = -1
        elif task_idx is not None:
            self.current_task_idx = task_idx
            self.current_task = self.tasks[task_idx]
        else:
            # Random task selection
            import random
            self.current_task_idx = random.randrange(len(self.tasks))
            self.current_task = self.tasks[self.current_task_idx]

        return {
            "task_id": getattr(self.current_task, "task_id", f"task_{self.current_task_idx}"),
            "prompt": getattr(self.current_task, "curriculum_prompt",
                            getattr(self.current_task, "prompt", "")),
        }

    def step(self, code: str) -> ExecutionResult:
        """
        Execute the provided code and compute reward.

        Args:
            code: The code to execute (should include prompt + generated completion)

        Returns:
            ExecutionResult with reward and execution info
        """
        if self.current_task is None:
            return ExecutionResult(
                reward=0.0,
                info={"error": "No task set. Call reset() first.", "passed": False},
                execution_time=0.0,
            )

        # Get test cases
        test_cases = getattr(self.current_task, "test_cases", "")
        task_id = getattr(self.current_task, "task_id", "unknown")

        # Validate test_cases
        if not test_cases or len(test_cases.strip()) < 10:
            logger.warning(f"Task {task_id}: Empty or too short test_cases (len={len(test_cases.strip()) if test_cases else 0})")
            return ExecutionResult(
                reward=0.0,
                info={"error": "empty_test_cases", "passed": False, "task_id": task_id},
                execution_time=0.0,
            )

        # Check for common test patterns
        if "assert" not in test_cases and "check(" not in test_cases:
            logger.warning(f"Task {task_id}: test_cases has no 'assert' or 'check()' - may not actually test anything")

        # Execute code safely
        result = self._execute_code_safely(code, test_cases)
        return result

    def _execute_code_safely(self, code: str, test_cases: str) -> ExecutionResult:
        """
        Execute code with subprocess isolation and timeout.

        Args:
            code: The code to execute
            test_cases: Test assertions to verify correctness

        Returns:
            ExecutionResult with reward based on test passing
        """
        # Validate code length
        if len(code) > self.config.max_code_length:
            return ExecutionResult(
                reward=0.0,
                info={"passed": False, "error": "Code exceeds maximum length"},
                execution_time=0.0,
            )

        # Log execution details for debugging
        logger.debug(f"Executing code ({len(code)} chars) with tests ({len(test_cases)} chars)")

        # Log first execution attempt in detail for debugging
        if not hasattr(self, '_logged_first_execution'):
            self._logged_first_execution = True
            logger.info(f"First execution attempt - Code preview:\n{code[:500]}...")
            logger.info(f"Test cases preview:\n{test_cases[:300]}...")

        # Create temporary file with code + test cases
        full_code = f"{code}\n\n{test_cases}"

        temp_path = None
        start_time = time.time()

        try:
            # Write to temporary file
            with tempfile.NamedTemporaryFile(
                mode='w',
                suffix='.py',
                delete=False,
                encoding='utf-8',
            ) as f:
                f.write(full_code)
                temp_path = f.name

            # Execute with subprocess
            result = subprocess.run(
                ['python', temp_path],
                capture_output=True,
                text=True,
                timeout=self.config.timeout_seconds,
                env={**os.environ, 'PYTHONDONTWRITEBYTECODE': '1'},
            )

            execution_time = time.time() - start_time

            # Check if tests passed
            passed = result.returncode == 0

            # Truncate output if needed
            stdout = result.stdout[:self.config.max_output_length] if result.stdout else ""
            stderr = result.stderr[:self.config.max_output_length] if result.stderr else ""

            # Log failures for debugging
            if not passed:
                logger.debug(f"Execution failed (returncode={result.returncode}): {stderr[:200] if stderr else 'no stderr'}")

            return ExecutionResult(
                reward=1.0 if passed else 0.0,
                info={
                    "passed": passed,
                    "returncode": result.returncode,
                    "stdout": stdout,
                    "stderr": stderr,
                },
                execution_time=execution_time,
            )

        except subprocess.TimeoutExpired:
            execution_time = time.time() - start_time
            return ExecutionResult(
                reward=0.0,
                info={
                    "passed": False,
                    "error": "timeout",
                    "timeout_seconds": self.config.timeout_seconds,
                },
                execution_time=execution_time,
            )

        except Exception as e:
            execution_time = time.time() - start_time
            return ExecutionResult(
                reward=0.0,
                info={
                    "passed": False,
                    "error": str(e),
                },
                execution_time=execution_time,
            )

        finally:
            # Clean up temporary file
            if temp_path and os.path.exists(temp_path):
                try:
                    os.unlink(temp_path)
                except Exception:
                    pass


def create_synthetic_tasks(num_tasks: int = 20) -> List[CodeTask]:
    """
    Create synthetic code tasks for testing and development.

    Args:
        num_tasks: Number of tasks to generate

    Returns:
        List of CodeTask objects
    """
    templates = [
        {
            "name": "add",
            "prompt": 'def add(a, b):\n    """Add two numbers and return the result."""\n    ',
            "solution": "return a + b",
            "tests": "assert add(1, 2) == 3\nassert add(-1, 1) == 0\nassert add(0, 0) == 0",
        },
        {
            "name": "multiply",
            "prompt": 'def multiply(a, b):\n    """Multiply two numbers and return the result."""\n    ',
            "solution": "return a * b",
            "tests": "assert multiply(2, 3) == 6\nassert multiply(-2, 3) == -6\nassert multiply(0, 5) == 0",
        },
        {
            "name": "max_of_two",
            "prompt": 'def max_of_two(a, b):\n    """Return the maximum of two numbers."""\n    ',
            "solution": "return a if a > b else b",
            "tests": "assert max_of_two(1, 2) == 2\nassert max_of_two(5, 3) == 5\nassert max_of_two(4, 4) == 4",
        },
        {
            "name": "is_even",
            "prompt": 'def is_even(n):\n    """Return True if n is even, False otherwise."""\n    ',
            "solution": "return n % 2 == 0",
            "tests": "assert is_even(2) == True\nassert is_even(3) == False\nassert is_even(0) == True",
        },
        {
            "name": "abs_value",
            "prompt": 'def abs_value(n):\n    """Return the absolute value of n."""\n    ',
            "solution": "return n if n >= 0 else -n",
            "tests": "assert abs_value(5) == 5\nassert abs_value(-5) == 5\nassert abs_value(0) == 0",
        },
        {
            "name": "sum_list",
            "prompt": 'def sum_list(lst):\n    """Return the sum of all elements in the list."""\n    ',
            "solution": "total = 0\n    for x in lst:\n        total += x\n    return total",
            "tests": "assert sum_list([1, 2, 3]) == 6\nassert sum_list([]) == 0\nassert sum_list([5]) == 5",
        },
        {
            "name": "list_length",
            "prompt": 'def list_length(lst):\n    """Return the number of elements in the list."""\n    ',
            "solution": "count = 0\n    for _ in lst:\n        count += 1\n    return count",
            "tests": "assert list_length([1, 2, 3]) == 3\nassert list_length([]) == 0\nassert list_length([5]) == 1",
        },
        {
            "name": "first_element",
            "prompt": 'def first_element(lst):\n    """Return the first element of the list, or None if empty."""\n    ',
            "solution": "if len(lst) == 0:\n        return None\n    return lst[0]",
            "tests": "assert first_element([1, 2, 3]) == 1\nassert first_element([]) == None\nassert first_element([5]) == 5",
        },
        {
            "name": "last_element",
            "prompt": 'def last_element(lst):\n    """Return the last element of the list, or None if empty."""\n    ',
            "solution": "if len(lst) == 0:\n        return None\n    return lst[-1]",
            "tests": "assert last_element([1, 2, 3]) == 3\nassert last_element([]) == None\nassert last_element([5]) == 5",
        },
        {
            "name": "contains",
            "prompt": 'def contains(lst, x):\n    """Return True if x is in the list, False otherwise."""\n    ',
            "solution": "for item in lst:\n        if item == x:\n            return True\n    return False",
            "tests": "assert contains([1, 2, 3], 2) == True\nassert contains([1, 2, 3], 4) == False\nassert contains([], 1) == False",
        },
        {
            "name": "double_list",
            "prompt": 'def double_list(lst):\n    """Return a new list with each element doubled."""\n    ',
            "solution": "result = []\n    for x in lst:\n        result.append(x * 2)\n    return result",
            "tests": "assert double_list([1, 2, 3]) == [2, 4, 6]\nassert double_list([]) == []\nassert double_list([5]) == [10]",
        },
        {
            "name": "reverse_string",
            "prompt": 'def reverse_string(s):\n    """Return the reversed string."""\n    ',
            "solution": "return s[::-1]",
            "tests": "assert reverse_string('hello') == 'olleh'\nassert reverse_string('') == ''\nassert reverse_string('a') == 'a'",
        },
        {
            "name": "count_char",
            "prompt": 'def count_char(s, c):\n    """Count occurrences of character c in string s."""\n    ',
            "solution": "count = 0\n    for ch in s:\n        if ch == c:\n            count += 1\n    return count",
            "tests": "assert count_char('hello', 'l') == 2\nassert count_char('hello', 'x') == 0\nassert count_char('', 'a') == 0",
        },
        {
            "name": "is_palindrome",
            "prompt": 'def is_palindrome(s):\n    """Return True if string s is a palindrome."""\n    ',
            "solution": "return s == s[::-1]",
            "tests": "assert is_palindrome('racecar') == True\nassert is_palindrome('hello') == False\nassert is_palindrome('') == True",
        },
        {
            "name": "factorial",
            "prompt": 'def factorial(n):\n    """Return the factorial of n (n >= 0)."""\n    ',
            "solution": "if n <= 1:\n        return 1\n    return n * factorial(n - 1)",
            "tests": "assert factorial(5) == 120\nassert factorial(0) == 1\nassert factorial(1) == 1",
        },
        {
            "name": "fibonacci",
            "prompt": 'def fibonacci(n):\n    """Return the nth Fibonacci number (0-indexed)."""\n    ',
            "solution": "if n <= 0:\n        return 0\n    if n == 1:\n        return 1\n    return fibonacci(n - 1) + fibonacci(n - 2)",
            "tests": "assert fibonacci(0) == 0\nassert fibonacci(1) == 1\nassert fibonacci(6) == 8",
        },
        {
            "name": "power",
            "prompt": 'def power(base, exp):\n    """Return base raised to the power exp (exp >= 0)."""\n    ',
            "solution": "result = 1\n    for _ in range(exp):\n        result *= base\n    return result",
            "tests": "assert power(2, 3) == 8\nassert power(5, 0) == 1\nassert power(3, 2) == 9",
        },
        {
            "name": "gcd",
            "prompt": 'def gcd(a, b):\n    """Return the greatest common divisor of a and b."""\n    ',
            "solution": "while b:\n        a, b = b, a % b\n    return a",
            "tests": "assert gcd(12, 8) == 4\nassert gcd(17, 13) == 1\nassert gcd(100, 25) == 25",
        },
        {
            "name": "is_prime",
            "prompt": 'def is_prime(n):\n    """Return True if n is a prime number, False otherwise."""\n    ',
            "solution": "if n < 2:\n        return False\n    for i in range(2, int(n ** 0.5) + 1):\n        if n % i == 0:\n            return False\n    return True",
            "tests": "assert is_prime(2) == True\nassert is_prime(17) == True\nassert is_prime(4) == False\nassert is_prime(1) == False",
        },
        {
            "name": "list_max",
            "prompt": 'def list_max(lst):\n    """Return the maximum element in the list, or None if empty."""\n    ',
            "solution": "if not lst:\n        return None\n    max_val = lst[0]\n    for x in lst[1:]:\n        if x > max_val:\n            max_val = x\n    return max_val",
            "tests": "assert list_max([1, 3, 2]) == 3\nassert list_max([5]) == 5\nassert list_max([]) == None",
        },
    ]

    tasks = []
    for i in range(num_tasks):
        template = templates[i % len(templates)]
        task = CodeTask(
            task_id=f"synthetic_{i}_{template['name']}",
            prompt=template["prompt"],
            canonical_solution=template["solution"],
            test_cases=template["tests"],
            entry_point=template["name"],
        )
        tasks.append(task)

    return tasks


def load_humaneval_tasks() -> List[CodeTask]:
    """
    Load HumanEval benchmark tasks.

    Returns:
        List of CodeTask objects from HumanEval
    """
    try:
        from datasets import load_dataset

        dataset = load_dataset("openai/openai_humaneval", split="test", token=os.environ.get("HF_TOKEN"))

        tasks = []
        for item in dataset:
            # Extract test cases from the test field
            test_code = item.get("test", "")

            # Create a runnable test by calling the check function
            entry_point = item.get("entry_point", "")
            if entry_point and f"check({entry_point})" not in test_code:
                test_code = f"{test_code}\ncheck({entry_point})"

            task = CodeTask(
                task_id=item["task_id"],
                prompt=item["prompt"],
                canonical_solution=item["canonical_solution"],
                test_cases=test_code,
                entry_point=entry_point,
            )
            tasks.append(task)

        logger.info(f"Loaded {len(tasks)} HumanEval tasks")
        return tasks

    except ImportError:
        logger.warning("datasets library not available, cannot load HumanEval")
        return []
    except Exception as e:
        logger.warning(f"Failed to load HumanEval: {e}")
        return []


def load_mbpp_tasks() -> List[CodeTask]:
    """
    Load MBPP (Mostly Basic Python Problems) benchmark tasks.

    Returns:
        List of CodeTask objects from MBPP
    """
    try:
        from datasets import load_dataset

        dataset = load_dataset("google-research-datasets/mbpp", split="test")

        tasks = []
        for idx, item in enumerate(dataset):
            # MBPP provides test_list with assertion strings
            test_list = item.get("test_list", [])
            test_cases = "\n".join(test_list)

            # Extract function name from code
            code = item.get("code", "")
            entry_point = ""
            for line in code.split("\n"):
                if line.strip().startswith("def "):
                    entry_point = line.split("def ")[1].split("(")[0].strip()
                    break

            task = CodeTask(
                task_id=f"mbpp_{item.get('task_id', idx)}",
                prompt=f'"""{item.get("text", "")}\n"""\n',
                canonical_solution=code,
                test_cases=test_cases,
                entry_point=entry_point,
            )
            tasks.append(task)

        logger.info(f"Loaded {len(tasks)} MBPP tasks")
        return tasks

    except ImportError:
        logger.warning("datasets library not available, cannot load MBPP")
        return []
    except Exception as e:
        logger.warning(f"Failed to load MBPP: {e}")
        return []


if __name__ == "__main__":
    # Test the code execution environment
    print("Testing CurriculumCodeEnv...")

    # Create synthetic tasks
    tasks = create_synthetic_tasks(5)
    print(f"Created {len(tasks)} synthetic tasks")

    # Create environment
    env = CurriculumCodeEnv(tasks)

    # Test with correct code
    obs = env.reset(task_idx=0)
    print(f"\nTask: {obs['task_id']}")
    print(f"Prompt:\n{obs['prompt']}")

    # Generate correct completion
    correct_code = obs['prompt'] + tasks[0].canonical_solution
    result = env.step(correct_code)
    print(f"\nCorrect code result:")
    print(f"  Reward: {result.reward}")
    print(f"  Passed: {result.info.get('passed')}")
    print(f"  Execution time: {result.execution_time:.4f}s")

    # Test with incorrect code
    incorrect_code = obs['prompt'] + "return a - b  # Wrong!"
    result = env.step(incorrect_code)
    print(f"\nIncorrect code result:")
    print(f"  Reward: {result.reward}")
    print(f"  Passed: {result.info.get('passed')}")

    # Test with timeout code
    timeout_code = "while True: pass"
    env_timeout = CurriculumCodeEnv(tasks, CodeEnvConfig(timeout_seconds=1.0))
    env_timeout.reset(task_idx=0)
    result = env_timeout.step(timeout_code)
    print(f"\nTimeout code result:")
    print(f"  Reward: {result.reward}")
    print(f"  Error: {result.info.get('error')}")

    print("\nEnvironment created successfully!")
