import os
import json
import logging
from typing import Optional

from omegaconf import DictConfig

# from vllm import LLM, SamplingParams # Keep LLM import for type hinting if needed elsewhere, but not used directly here

# Local imports
from ACD.generate_acd_tasks import load_task_family

# Import the sandbox utility
from dns.docker_sandbox import run_task_in_sandbox


logger = logging.getLogger(__name__)


class ACDTask:
    """
    Represents a task generated by the Automated Capability Discovery (ACD) process.
    Provides methods to get prompts and evaluate LLM responses using a Docker sandbox
    for the *first* example within the task.
    """

    def __init__(self, task_dir: str, cfg: DictConfig):
        """
        Initializes the ACDTask adapter.

        Args:
            task_dir: Path to the directory containing the ACD task ('task.py', 'task.json').
            cfg: The main Hydra configuration object, used for sandbox settings.
        """
        self.task_dir = task_dir
        self.cfg = cfg
        self.logger = logging.getLogger(
            f"{self.__class__.__name__}[{os.path.basename(task_dir)}]"
        )

        # Load ACD task implementation and metadata
        try:
            self.task_impl = load_task_family(
                task_dir
            )()  # Instantiate the task family
            task_json_path = os.path.join(task_dir, "task.json")
            with open(task_json_path, "r") as f:
                self.task_json = json.load(f)
        except Exception as e:
            self.logger.exception(
                f"Failed to load ACD task from {task_dir}: {e}"
            )
            raise ValueError(f"Could not load ACD task from {task_dir}") from e

        # Extract basic info from task JSON
        task_name = self.task_json.get(
            "name_of_task", os.path.basename(task_dir)
        )
        self.description = self.task_json.get("description_of_task", "N/A")

        # Assign a unique ID to this task instance (e.g., directory name)
        self.task_id = os.path.basename(task_dir)
        self.task_name = task_name  # Store for logging

        # Load all examples but store only the first one for evaluation
        self._instructions: Optional[str] = None  # Cache instructions
        try:
            all_tasks_data = self.task_impl.get_tasks()
            if not all_tasks_data:
                self.logger.warning(
                    f"ACD task {task_name} reported zero examples."
                )
                self.first_example_id = None
                self.first_example_data = None
            else:
                # Get the first key-value pair
                self.first_example_id = next(iter(all_tasks_data))
                self.first_example_data = all_tasks_data[self.first_example_id]
                # Pre-calculate and cache instructions
                self._instructions = self.task_impl.get_instructions(
                    self.first_example_data
                )
                assert isinstance(self._instructions, str), f"Instruction is not a string, Instr: {self._instructions}"
                self.logger.info(
                    f"Initialized ACDTask '{task_name}' (ID: {self.task_id}). Will evaluate example ID: {self.first_example_id}"
                )

        except Exception as e:
            self.logger.exception(
                f"Failed to get/process task examples from ACD task {task_name}: {e}"
            )
            self.first_example_id = None
            self.first_example_data = None
            self._instructions = None
            raise ValueError(
                f"Could not process examples for ACD task {task_name}"
            ) from e

    def get_instructions(self) -> Optional[str]:
        """Returns the cached instructions for the first example."""
        return self._instructions

    def get_evaluation_prompt(self) -> Optional[str]:
        """
        Generates the formatted prompt string for the LLM to evaluate the first example.

        Returns:
            The formatted prompt string, or None if instructions are not available.
        """
        instructions = self.get_instructions()
        if instructions is None:
            self.logger.error(
                f"Cannot generate prompt for task {self.task_name}: Instructions not available."
            )
            return None
        # Simple User/Assistant format
        # return f"User:\n{instructions}\n\nAssistant:\n"
        return instructions

    def evaluate_response_sandboxed(self, raw_output: str) -> float:
        """
        Evaluates a given raw LLM output string using the Docker sandbox for the first example.

        Args:
            raw_output: The raw text output generated by the LLM for the task prompt.

        Returns:
            The score (0.0 or 1.0) after thresholding. Returns 0.0 on any error.
        """
        score: float = 0.0

        if self.first_example_id is None or self.first_example_data is None:
            self.logger.error(
                f"Task {self.task_name} has no valid first example to evaluate response against."
            )
            return score  # Return 0.0

        if (
            raw_output is None or raw_output == "<GENERATION FAILED>"
        ):  # Check for explicit failure marker
            self.logger.warning(
                f"Skipping sandbox evaluation for task {self.task_name} due to missing or failed generation output."
            )
            return score  # Return 0.0

        # Score the response using the Docker sandbox
        task_script = os.path.join(self.task_dir, "task.py")
        input_data = {
            "task_data": self.first_example_data,
            "answer": raw_output,
        }

        try:
            sandbox_score = run_task_in_sandbox(
                task_script_path=task_script,
                function_name="score",
                input_data=input_data,
                cfg=self.cfg,
            )
            # Ensure score is numeric and apply threshold
            try:
                numeric_score = float(sandbox_score)
                score = 1.0 if numeric_score >= 0.5 else 0.0
                self.logger.debug(
                    f"Task {self.task_name} sandbox score: {sandbox_score} -> {score}"
                )
            except (ValueError, TypeError):
                self.logger.error(
                    f"Sandbox for task {self.task_name} returned non-numeric score: {sandbox_score}. Assigning 0.0."
                )
                score = 0.0

        except Exception as sandbox_err:
            self.logger.exception(
                f"Sandbox execution failed for task {self.task_name}: {sandbox_err}"
            )
            score = 0.0  # Treat sandbox failure as task failure

        return score
