"""
Dynamic Experiment Runner

This module provides a dynamic experiment runner that can parse step definitions
from configuration, handle dependencies, and execute steps in the correct order.
It extends the existing ExperimentRunner architecture while adding support for
the new step-based configuration system.
"""

import time
from collections import defaultdict, deque
from typing import Any, Dict, List, Optional, Set, Tuple

from src.utils.decorator_utils import with_logger

from .experiment import ExperimentRunner
from .step_factory import StepFactory, VariableResolver


class DynamicExperimentRunner(ExperimentRunner):
    """
    Dynamic experiment runner that creates steps from configuration.

    This runner extends the base ExperimentRunner to support:
    - Step definitions from configuration
    - Dependency resolution and ordering
    - Variable interpolation
    - Optional step execution
    """

    @with_logger
    def __init__(
        self,
        name: str,
        config: Optional[Dict[str, Any]] = None,
        config_name: Optional[str] = None,
        output_dir: Optional[str] = None,
    ):
        """
        Initialise the dynamic experiment runner.

        Args:
            name: The name of the experiment
            config: Pre-loaded configuration dictionary
            config_name: The name of the configuration file to load
            output_dir: The directory to save results to
        """
        super().__init__(name, config, config_name, output_dir)

        self.step_factory = StepFactory()
        self.variable_resolver = VariableResolver(self.config)
        self.step_dependencies: Dict[str, List[str]] = {}
        self.optional_steps: Set[str] = set()

        # Parse and create steps from configuration
        self._parse_steps_from_config()

    @with_logger
    def _parse_steps_from_config(self) -> None:
        """Parse step definitions from configuration and create ExperimentStep instances."""
        steps_config = self.config.get("steps", [])

        if not steps_config:
            logger.info("No steps defined in configuration, using legacy mode")
            return

        logger.debug(f"Parsing {len(steps_config)} steps from configuration")

        # First pass: create all steps and collect dependencies
        step_configs = {}
        for step_config in steps_config:
            step_name = step_config["name"]
            step_configs[step_name] = step_config

            # Track dependencies
            depends_on = step_config.get("depends_on", [])
            self.step_dependencies[step_name] = depends_on

            # Track optional steps
            if step_config.get("optional", False):
                self.optional_steps.add(step_name)

        # Resolve dependencies and create execution order
        execution_order = self._resolve_dependencies(list(step_configs.keys()))

        # Create steps in dependency order
        for step_name in execution_order:
            step_config = step_configs[step_name]
            try:
                step = self.step_factory.create_step(
                    step_config, self.config, self.variable_resolver.resolve
                )
                self.steps.append(step)
                logger.debug(f"Created step: {step_name}")
            except Exception as e:
                if step_name in self.optional_steps:
                    logger.warning(
                        f"Failed to create optional step '{step_name}'. This step will be skipped: {str(e)}"
                    )
                else:
                    logger.error(
                        f"Failed to create required step '{step_name}': {str(e)}"
                    )
                    raise

    @with_logger
    def _resolve_dependencies(self, step_names: List[str]) -> List[str]:
        """
        Resolve step dependencies and return execution order.

        Args:
            step_names: List of all step names

        Returns:
            List of step names in execution order

        Raises:
            ValueError: If circular dependencies are detected or dependencies are missing
        """
        logger.debug("Resolving step dependencies")

        # Validate that all dependencies exist
        for step_name, dependencies in self.step_dependencies.items():
            for dep in dependencies:
                if dep not in step_names:
                    raise ValueError(
                        f"Step '{step_name}' depends on unknown step '{dep}'"
                    )

        # Topological sort using Kahn's algorithm
        # Calculate in-degrees
        in_degree = defaultdict(int)
        for step_name in step_names:
            in_degree[step_name] = 0

        for step_name, dependencies in self.step_dependencies.items():
            for dep in dependencies:
                in_degree[step_name] += 1

        # Initialise queue with steps that have no dependencies
        queue = deque([step for step in step_names if in_degree[step] == 0])
        execution_order = []

        while queue:
            current_step = queue.popleft()
            execution_order.append(current_step)

            # Update in-degrees for dependent steps
            for step_name, dependencies in self.step_dependencies.items():
                if current_step in dependencies:
                    in_degree[step_name] -= 1
                    if in_degree[step_name] == 0:
                        queue.append(step_name)

        # Check for circular dependencies
        if len(execution_order) != len(step_names):
            remaining_steps = set(step_names) - set(execution_order)
            raise ValueError(
                f"Circular dependency detected among steps: {remaining_steps}"
            )

        logger.debug(f"Resolved execution order: {execution_order}")
        return execution_order

    @with_logger
    def run(self) -> Dict[str, Any]:
        """
        Run the experiment with dependency-aware execution.

        Returns:
            The results of the experiment
        """
        logger.info(f"Starting dynamic experiment: {self.name}")

        # Record the start time
        self.start_time = time.time()

        # Initialize the results dictionary
        self.results = {}
        failed_steps = set()

        # Execute each step
        total_steps = len(self.steps)
        for i, step in enumerate(self.steps, 1):
            step_name = step.name
            logger.info(f"Executing step {i}/{total_steps}: {step_name}")

            # Check if dependencies are satisfied
            dependencies = self.step_dependencies.get(step_name, [])
            missing_deps = [dep for dep in dependencies if dep in failed_steps]

            if missing_deps:
                error_msg = f"Step '{step_name}' cannot execute due to failed dependencies: {missing_deps}"
                if step_name in self.optional_steps:
                    logger.warning(f"Skipping optional step: {error_msg}")
                    failed_steps.add(step_name)
                    continue
                else:
                    logger.error(error_msg)
                    raise RuntimeError(error_msg)

            try:
                # Execute the step
                result = step.execute(previous_results=self.results)
                self.results[step_name] = result

            except Exception as e:
                error_msg = f"Error in step '{step_name}': {str(e)}"

                if step_name in self.optional_steps:
                    logger.warning(f"Optional step failed: {error_msg}")
                    failed_steps.add(step_name)
                    # Store the error in results for debugging
                    self.results[step_name] = {"error": str(e), "failed": True}
                else:
                    logger.error(error_msg, exc_info=True)
                    raise

        # Record the end time
        self.end_time = time.time()

        # Calculate the duration
        duration = self.end_time - self.start_time
        self.metrics["duration"] = duration

        # Add execution summary to metrics
        self.metrics["total_steps"] = total_steps
        self.metrics["failed_optional_steps"] = len(failed_steps)
        self.metrics["successful_steps"] = total_steps - len(failed_steps)

        logger.info(f"Dynamic experiment completed in {duration:.2f} seconds")
        logger.info(
            f"Successful steps: {self.metrics['successful_steps']}/{total_steps}"
        )

        if failed_steps:
            logger.info(f"Failed optional steps: {list(failed_steps)}")

        # Save the results
        self._save_results()

        return self.results

    def add_step_from_config(
        self, step_config: Dict[str, Any]
    ) -> "DynamicExperimentRunner":
        """
        Add a step from configuration.

        Args:
            step_config: Configuration dictionary for the step

        Returns:
            The experiment runner instance (for chaining)
        """
        step = self.step_factory.create_step(
            step_config, self.config, self.variable_resolver.resolve
        )

        # Update dependencies
        step_name = step_config["name"]
        dependencies = step_config.get("depends_on", [])
        self.step_dependencies[step_name] = dependencies

        # Track optional steps
        if step_config.get("optional", False):
            self.optional_steps.add(step_name)

        self.steps.append(step)
        return self

    def get_step_dependencies(self) -> Dict[str, List[str]]:
        """Get the step dependency graph."""
        return self.step_dependencies.copy()

    def get_optional_steps(self) -> Set[str]:
        """Get the set of optional step names."""
        return self.optional_steps.copy()

    @with_logger
    def validate_configuration(self) -> Tuple[bool, List[str]]:
        """
        Validate the steps in the experiment configuration.

        Returns:
            Tuple of (is_valid, list_of_errors)
        """
        errors = []

        # Check if steps are defined
        steps_config = self.config.get("steps", [])
        if not steps_config:
            return True, []  # Legacy mode is valid

        # Validate step configurations
        step_names = set()
        for step_config in steps_config:
            step_name = step_config.get("name")
            if not step_name:
                errors.append("Step missing required 'name' field")
                continue

            if step_name in step_names:
                errors.append(f"Duplicate step name: {step_name}")
            step_names.add(step_name)

            step_type = step_config.get("type")
            if not step_type:
                errors.append(f"Step '{step_name}' missing required 'type' field")
                continue

            # Validate step type specific requirements
            if step_type == "component":
                if not step_config.get("component_type"):
                    errors.append(
                        f"Component step '{step_name}' missing 'component_type'"
                    )
                if not step_config.get("component_name"):
                    errors.append(
                        f"Component step '{step_name}' missing 'component_name'"
                    )
            elif step_type == "function":
                if not step_config.get("function"):
                    errors.append(f"Function step '{step_name}' missing 'function'")
            elif step_type == "custom":
                if not step_config.get("module"):
                    errors.append(f"Custom step '{step_name}' missing 'module'")
                if not step_config.get("class"):
                    errors.append(f"Custom step '{step_name}' missing 'class'")
            else:
                errors.append(f"Step '{step_name}' has unknown type: {step_type}")

        # Validate dependencies
        for step_config in steps_config:
            step_name = step_config.get("name")
            dependencies = step_config.get("depends_on", [])
            for dep in dependencies:
                if dep not in step_names:
                    errors.append(f"Step '{step_name}' depends on unknown step '{dep}'")

        # Check for circular dependencies
        try:
            self._resolve_dependencies(list(step_names))
        except ValueError as e:
            errors.append(str(e))

        is_valid = len(errors) == 0
        if is_valid:
            logger.debug("Validation of configuration steps have passed")
        else:
            logger.error(
                f"Validation of configuration steps failed with {len(errors)} errors"
            )
            for error in errors:
                logger.error(f"  - {error}")

        return is_valid, errors
