"""
Step Factory for Dynamic Experiment Steps

This module provides a factory pattern for creating different types of experiment steps
based on configuration. It supports component steps (using registry), function steps
(built-in functions), and custom steps (user-defined classes).
"""

import importlib
import re
from typing import Any, Callable, Dict, Optional

from src.utils.decorator_utils import with_logger
from src.utils.logging_utils import log_prompt_usage

from .experiment import ExperimentStep
from .registry import (
    embeddings_registry,
    evaluator_registry,
    llm_registry,
    prompt_optimiser_registry,
    task_registry,
)


class StepFactory:
    """
    Factory for creating experiment steps from configuration.

    This factory can create three types of steps:
    1. Component steps: Use the registry system to instantiate components
    2. Function steps: Execute built-in experiment functions
    3. Custom steps: Load and execute user-defined classes
    """

    def __init__(self):
        """Initialise the step factory."""
        self.registries = {
            "llm": llm_registry,
            "task": task_registry,
            "prompt_optimiser": prompt_optimiser_registry,
            "evaluator": evaluator_registry,
            "embeddings": embeddings_registry,
        }

        # Built-in experiment functions
        self.builtin_functions = {
            "optimise_prompt": self._optimise_prompt_function,
            "run_generation": self._run_generation_function,
            "evaluate_results": self._evaluate_results_function,
            "analyse_results": self._analyse_results_function,
        }

    @with_logger
    def create_step(
        self,
        step_config: Dict[str, Any],
        experiment_config: Dict[str, Any],
        variable_resolver: Optional[Callable[[str], str]] = None,
    ) -> ExperimentStep:
        """
        Create an experiment step from configuration.

        Args:
            step_config: Configuration for the step
            experiment_config: Full experiment configuration for context
            variable_resolver: Function to resolve variable interpolations

        Returns:
            An ExperimentStep instance

        Raises:
            ValueError: If step configuration is invalid
            ImportError: If custom module/class cannot be imported
        """
        step_name = step_config["name"]
        step_type = step_config["type"]

        logger.debug(f"Creating step '{step_name}' of type '{step_type}'")

        if step_type == "component":
            return self._create_component_step(
                step_config, experiment_config, variable_resolver
            )
        elif step_type == "function":
            return self._create_function_step(step_config, experiment_config)
        elif step_type == "custom":
            return self._create_custom_step(step_config, experiment_config)
        else:
            raise ValueError(f"Unknown step type: {step_type}")

    @with_logger
    def _create_component_step(
        self,
        step_config: Dict[str, Any],
        experiment_config: Dict[str, Any],
        variable_resolver: Optional[Callable[[str], str]] = None,
    ) -> ExperimentStep:
        """Create a component step that uses the registry system."""
        step_name = step_config["name"]
        component_type = step_config["component_type"]
        component_name = step_config["component_name"]

        # Resolve variable interpolation in component_name
        if variable_resolver and "${" in component_name:
            component_name = variable_resolver(component_name)

        logger.debug(f"Creating component step: {component_type}.{component_name}")

        # Get the appropriate registry
        if component_type not in self.registries:
            raise ValueError(f"Unknown component type: {component_type}")

        registry = self.registries[component_type]

        # Get component configuration
        component_config = self._get_component_config(
            component_type, component_name, experiment_config
        )

        @with_logger
        def setup_component(**kwargs):
            """Setup function for component instantiation."""
            logger.debug(f"Setting up {component_type} component: {component_name}")
            try:
                instance = registry.create(component_name, **component_config)
                logger.debug(f"Successfully created {component_type} instance")
                return instance
            except Exception as e:
                logger.error(f"Failed to create {component_type} instance: {str(e)}")
                raise

        return ExperimentStep(step_name, setup_component)

    @with_logger
    def _create_function_step(
        self,
        step_config: Dict[str, Any],
        experiment_config: Dict[str, Any],
    ) -> ExperimentStep:
        """Create a function step that executes built-in functions."""
        step_name = step_config["name"]
        function_name = step_config["function"]

        logger.debug(f"Creating function step: {function_name}")

        if function_name not in self.builtin_functions:
            raise ValueError(f"Unknown built-in function: {function_name}")

        function = self.builtin_functions[function_name]

        # Get additional configuration for the function
        func_config = step_config.get("config", {})

        @with_logger
        def execute_function(**kwargs):
            """Wrapper function for built-in function execution."""
            logger.debug(f"Executing built-in function: {function_name}")
            try:
                # Call the bound method with the kwargs
                return function(**kwargs)
            except Exception as e:
                logger.error(f"Failed to execute function {function_name}: {str(e)}")
                raise

        return ExperimentStep(step_name, execute_function, **func_config)

    @with_logger
    def _create_custom_step(
        self,
        step_config: Dict[str, Any],
        experiment_config: Dict[str, Any],
    ) -> ExperimentStep:
        """Create a custom step that loads user-defined classes."""
        step_name = step_config["name"]
        module_path = step_config["module"]
        class_name = step_config["class"]

        logger.info(f"Creating custom step: {module_path}.{class_name}")

        try:
            # Import the module
            module = importlib.import_module(module_path)

            # Get the class
            step_class = getattr(module, class_name)

            # Get configuration for the custom step
            custom_config = step_config.get("config", {})

            @with_logger
            def execute_custom_step(**kwargs):
                """Execute function for custom step."""
                logger.info(f"Executing custom step: {class_name}")
                try:
                    instance = step_class(**custom_config)
                    if hasattr(instance, "execute"):
                        return instance.execute(**kwargs)
                    elif callable(instance):
                        return instance(**kwargs)
                    else:
                        raise ValueError(
                            f"Custom step class {class_name} must have execute method or be callable"
                        )
                except Exception as e:
                    logger.error(f"Failed to execute custom step: {str(e)}")
                    raise

            return ExperimentStep(step_name, execute_custom_step)

        except ImportError as e:
            logger.error(f"Failed to import custom step module {module_path}: {str(e)}")
            raise
        except AttributeError as e:
            logger.error(
                f"Class {class_name} not found in module {module_path}: {str(e)}"
            )
            raise

    def _get_component_config(
        self,
        component_type: str,
        component_name: str,
        experiment_config: Dict[str, Any],
    ) -> Dict[str, Any]:
        """Get configuration for a component from experiment config."""
        # Try new components structure first
        components = experiment_config.get("components", {})
        if component_type in components:
            type_config = components[component_type]
            if component_name in type_config:
                return type_config[component_name]

        # Fall back to legacy structure for backward compatibility
        if component_type in experiment_config:
            type_config = experiment_config[component_type]
            if component_name in type_config:
                return type_config[component_name]

        # Return empty config if not found
        return {}

    # Built-in function implementations
    @with_logger
    def _optimise_prompt_function(self, **kwargs):
        """Built-in function for prompt optimisation."""

        logger.debug(f"kwargs: {kwargs}")

        # Extract required components from kwargs using the actual step names
        # The step names from the config are: init_llm, init_task, init_prompt_optimiser
        llm = kwargs.get("init_llm")  # From init_llm step
        task = kwargs.get("init_task")  # From init_task step
        prompt_optimiser = kwargs.get(
            "init_prompt_optimiser"
        )  # From init_prompt_optimiser step

        if not prompt_optimiser:
            logger.info("No prompt optimiser provided, skipping optimisation")
            return None

        if not llm or not task:
            raise ValueError("LLM and task are required for prompt optimisation")

        logger.debug("Starting prompt optimisation process")

        # Get the base prompt from the task
        msg_template = task.get_prompt_msg_template()

        # Log base prompt usage

        log_prompt_usage(logger, str(msg_template), "base")

        # Set the message template on the optimiser
        prompt_optimiser.set_message_template(msg_template)

        # Perform the optimisation
        prompt_optimiser.optimise(task, llm)

        logger.debug("Prompt optimisation complete")

        return prompt_optimiser

    @with_logger
    def _run_generation_function(self, **kwargs):
        """Built-in function for running generation with optimised or base prompt."""

        logger.debug(f"_run_generation_function kwargs: {list(kwargs.keys())}")

        # Extract required components from kwargs using the actual step names
        llm = kwargs.get("init_llm")  # From init_llm step
        task = kwargs.get("init_task")  # From init_task step
        optimiser = kwargs.get("optimise_prompt")  # From optimise_prompt step

        if not llm or not task:
            raise ValueError("LLM and task are required for generation")

        # Check if optimisation was performed
        if optimiser is None:
            logger.debug("Running task with base prompt (no optimisation performed)")
            base_prompt = task.get_prompt_msg_template()
            log_prompt_usage(logger, str(base_prompt), "final task")
            results_df, score, error_tracker = task.run(llm)
        else:
            logger.debug("Running task with optimised prompt")
            optimised_msg_template = optimiser.apply()

            log_prompt_usage(logger, str(optimised_msg_template), "optimised task")
            task.update_prompt_msg_template(optimised_msg_template)
            results_df, score, error_tracker = task.run(llm)

        from src.utils.logging_utils import log_evaluation_score

        # Get the dataset name from the task - check if it has a name or infer from class
        dataset_name = getattr(
            task, "name", None
        ) or task.__class__.__name__.lower().replace("task", "").replace("handler", "")
        log_evaluation_score(logger, score, dataset_name)

        logger.info(f"Results dataframe shape: {results_df.shape}")

        return results_df, score, error_tracker

    @with_logger
    def _evaluate_results_function(self, **kwargs):
        """
        Built-in function for evaluating results.

        Compares LLM-generated answers against ground truth answers and provides
        detailed evaluation metrics including accuracy, error counts, and success rates.

        Args:
            **kwargs: Contains results from previous steps, specifically:
                - run_generation: Tuple of (results_df, score, error_tracker)

        Returns:
            Dict containing evaluation results with metrics and statistics
        """
        logger.debug(f"_evaluate_results_function kwargs: {list(kwargs.keys())}")

        # Extract results from run_generation step
        run_generation_results = kwargs.get("run_generation")
        if run_generation_results is None:
            raise ValueError("run_generation results are required for evaluation")

        # Unpack the results tuple
        if (
            isinstance(run_generation_results, tuple)
            and len(run_generation_results) >= 3
        ):
            results_df, calculated_score, error_tracker = run_generation_results[:3]
        else:
            raise ValueError(
                "Invalid run_generation results format. Expected (results_df, score, error_tracker)"
            )

        logger.info(f"Evaluating results for {len(results_df)} questions")

        # Basic validation
        required_columns = ["question", "answer", "llm_answer", "score"]
        missing_columns = [
            col for col in required_columns if col not in results_df.columns
        ]
        if missing_columns:
            raise ValueError(
                f"Missing required columns in results DataFrame: {missing_columns}"
            )

        # Calculate detailed metrics
        total_questions = len(results_df)

        # Count correct answers (score column should contain 1 for correct, 0 for incorrect)
        correct_answers = results_df["score"].sum()
        incorrect_answers = total_questions - correct_answers

        # Calculate accuracy
        accuracy = correct_answers / total_questions if total_questions > 0 else 0.0

        # Count failed/missing responses (where llm_answer is None/NaN)
        failed_responses = results_df["llm_answer"].isna().sum()
        successful_responses = total_questions - failed_responses

        # Success rate for LLM invocations
        llm_success_rate = (
            successful_responses / total_questions if total_questions > 0 else 0.0
        )

        # Error statistics from error tracker
        total_errors = (
            error_tracker.get_error_count()
            if hasattr(error_tracker, "get_error_count")
            else 0
        )
        error_success_rate = (
            error_tracker.get_success_rate(total_questions)
            if hasattr(error_tracker, "get_success_rate")
            else 1.0
        )

        # Prepare evaluation results
        evaluation_results = {
            "accuracy": accuracy,
            "total_questions": total_questions,
            "correct_answers": int(correct_answers),
            "incorrect_answers": int(incorrect_answers),
            "failed_responses": int(failed_responses),
            "results_dataframe": results_df,
            "successful_responses": int(successful_responses),
            "llm_success_rate": llm_success_rate,
            "calculated_score": calculated_score,  # Score from the task's eval_handler
            "error_statistics": {
                "total_errors": total_errors,
                "error_success_rate": error_success_rate,
            },
            "status": "completed",
            "message": f"Evaluation completed: {correct_answers}/{total_questions} correct ({accuracy:.2%})",
        }

        # Log detailed evaluation results
        logger.info("Evaluation Results:")
        logger.info(f"  Total Questions: {total_questions}")
        logger.info(
            f"  Correct Answers: {correct_answers}/{total_questions} ({accuracy:.2%})"
        )
        logger.info(f"  Failed Responses: {failed_responses}")
        logger.info(f"  LLM Success Rate: {llm_success_rate:.2%}")
        logger.info(f"  Calculated Score: {calculated_score:.4f}")

        if total_errors > 0:
            logger.info(f"  LLM Invocation Errors: {total_errors}")
            logger.info(f"  Error Success Rate: {error_success_rate:.2%}")

        return evaluation_results

    @with_logger
    def _analyse_results_function(self, **kwargs):
        """Built-in function for analysing results."""
        return {"status": "analysed", "message": "Analysis completed"}


class VariableResolver:
    """
    Resolver for variable interpolation in configuration values.

    Supports syntax like ${components.llm.default} to reference configuration values.
    """

    def __init__(self, config: Dict[str, Any]):
        """
        Initialize the variable resolver.

        Args:
            config: The full configuration dictionary
        """
        self.config = config

    @with_logger
    def resolve(self, value: str) -> str:
        """
        Resolve variable interpolations in a string value.

        Args:
            value: String that may contain variable references like ${path.to.value}

        Returns:
            String with variables resolved

        Raises:
            ValueError: If a referenced variable is not found
        """
        if not isinstance(value, str) or "${" not in value:
            return value

        logger.debug(f"Resolving variables in: {value}")

        # Find all variable references
        pattern = r"\$\{([^}]+)\}"
        matches = re.findall(pattern, value)

        resolved_value = value
        for match in matches:
            var_path = match.strip()

            try:
                resolved = self._get_nested_value(var_path)
                placeholder = "${" + match + "}"
                resolved_value = resolved_value.replace(placeholder, str(resolved))
            except KeyError:
                raise ValueError(f"Variable not found: {var_path}")

        logger.debug(f"Final resolved value: {resolved_value}")
        return resolved_value

    def _get_nested_value(self, path: str) -> Any:
        """
        Get a nested value from the configuration using dot notation.

        Args:
            path: Dot-separated path like "components.llm.default"

        Returns:
            The value at the specified path

        Raises:
            KeyError: If the path is not found
        """
        parts = path.split(".")
        current = self.config

        for part in parts:
            if isinstance(current, dict) and part in current:
                current = current[part]
            else:
                raise KeyError(f"Path not found: {path}")

        return current


# Create a singleton instance
step_factory = StepFactory()
