"""
Experiment Runner

This module provides a flexible experiment runner that can execute a series of
steps to run experiments with different configurations.
"""

import json
import pathlib
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
import yaml
from colorama import Fore, Style

from src.utils.decorator_utils import with_logger

from .config import config_manager

# Import and register all components


class ExperimentStep:
    """
    A step in an experiment pipeline.

    This class represents a single step in an experiment pipeline, with a
    function to execute and parameters to pass to it.
    """

    def __init__(self, name: str, func: Callable[..., Any], **kwargs: Any):
        """
        Initialise an experiment step.

        Args:
            name: The name of the step
            func: The function to execute
            **kwargs: Parameters to pass to the function
        """
        self.name = name
        self.func = func
        self.kwargs = kwargs

    @with_logger
    def execute(self, previous_results: Dict[str, Any] = None) -> Any:
        """
        Execute the step.

        Args:
            previous_results: Results from previous steps

        Returns:
            The result of the step
        """
        import inspect

        # Only log to the experiment-specific logger to avoid duplicates
        # Add colored output for step execution
        logger.info(f"{Fore.CYAN}Executing step: {self.name}{Style.RESET_ALL}")

        # Get the function signature
        sig = inspect.signature(self.func)
        param_names = set(sig.parameters.keys())

        # Merge previous results with kwargs
        kwargs = self.kwargs.copy()

        # Add the experiment logger to kwargs only if the function accepts it
        if "logger" in param_names:
            kwargs["logger"] = logger

        if previous_results:
            logger.debug(f"previous_results keys: {list(previous_results.keys())}")
            logger.debug(f"param_names: {param_names}")
            logger.debug(f"function name: {self.func.__name__}")

            # Map step names to parameter names
            # For example, "setup_llm" -> "llm"
            step_to_param = {}
            for step_name, result in previous_results.items():
                if step_name.startswith("setup_"):
                    param_name = step_name[6:]  # Remove "setup_" prefix
                    step_to_param[step_name] = param_name

            # Special case for evaluate_results function
            if "evaluate_results" in self.name and "run_generation" in previous_results:
                kwargs["generation_results"] = previous_results["run_generation"]
                logger.info("Added generation_results from run_generation step")

            # Check if function accepts **kwargs (VAR_KEYWORD parameter)
            accepts_var_keyword = any(
                param.kind == inspect.Parameter.VAR_KEYWORD
                for param in sig.parameters.values()
            )

            if accepts_var_keyword:
                # If function accepts **kwargs, pass all previous results
                logger.debug("Function accepts **kwargs, adding all previous results")
                for key, value in previous_results.items():
                    if key not in kwargs:
                        kwargs[key] = value
                        logger.debug(
                            f"Added parameter {key} from previous results (**kwargs)"
                        )
            else:
                # Original logic for functions with specific parameters
                # First, add any previous results that match parameter names directly
                for key, value in previous_results.items():
                    if key in param_names and key not in kwargs:
                        kwargs[key] = value
                        logger.debug(f"Added parameter {key} from previous results")

                # Then, try to map step names to parameter names
                for step_name, param_name in step_to_param.items():
                    if param_name in param_names and param_name not in kwargs:
                        kwargs[param_name] = previous_results[step_name]
                        logger.debug(
                            f"Mapped step {step_name} to parameter {param_name}"
                        )

            # If we're still missing required parameters, try to find them by type
            missing_params = [
                name
                for name, param in sig.parameters.items()
                if name not in kwargs
                and param.default == inspect.Parameter.empty
                and param.kind != inspect.Parameter.VAR_POSITIONAL
                and param.kind != inspect.Parameter.VAR_KEYWORD
                and name != "logger"  # Exclude logger parameter
            ]

            if missing_params:
                logger.debug(f"Missing parameters: {missing_params}")
                # Try to match by parameter name (case-insensitive)
                for param_name in missing_params:
                    for key, value in previous_results.items():
                        if key.lower() == param_name.lower() and key not in kwargs:
                            kwargs[param_name] = value
                            logger.debug(
                                f"Matched parameter {param_name} to {key} (case-insensitive)"
                            )
                            break

                # Special case for analyse_results function
                if "analyse_results" in self.name and "run_task" in previous_results:
                    kwargs["task_results"] = previous_results["run_task"]
                    logger.debug("Added task_results from run_task step")

        # Execute the function
        logger.info(f"Executing function: {self.func.__name__}")
        try:
            result = self.func(**kwargs)
            logger.info(f"Step {self.name} completed successfully")
            return result
        except Exception as e:
            error_msg = f"Error executing step {self.name}: {str(e)}"
            logger.error(error_msg, exc_info=True)
            raise


class ExperimentRunner:
    """
    A runner for executing experiments.

    This class provides a flexible way to define and execute experiments as a
    series of steps, with configuration and result tracking.
    """

    @with_logger
    def __init__(
        self,
        name: str,
        config: Optional[Dict[str, Any]] = None,
        config_name: Optional[str] = None,
        output_dir: Optional[Union[str, pathlib.Path]] = None,
    ):
        """
        Initialise an experiment runner.

        Args:
            name: The name of the experiment
            config: Pre-loaded configuration dictionary (preferred)
            config_name: The name of the configuration file to load if config not provided
            output_dir: The directory to save results to
        """
        self.name = name

        # Load configuration
        if config is not None:
            self.config = config
        elif config_name:
            self.config = config_manager.load_config(config_name)
        else:
            logger.info(f"Using default configuration for experiment '{name}'")
            self.config = config_manager.config

        # Set the output directory
        if output_dir is None:
            self.output_dir = pathlib.Path(
                self.config.get("paths", {}).get("output", "output")
            )
        else:
            self.output_dir = pathlib.Path(output_dir)

        # Create the output directory if it doesn't exist
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Initialise the steps list
        self.steps: List[ExperimentStep] = []

        # Initialise the results dictionary
        self.results: Dict[str, Any] = {}

        # Initialise the metrics dictionary
        self.metrics: Dict[str, float] = {}

        # Initialise the start time
        self.start_time: Optional[float] = None

        # Initialise the end time
        self.end_time: Optional[float] = None

    def get_config(self) -> Dict[str, Any]:
        return self.config

    @with_logger
    def add_step(
        self, name: str, func: Callable[..., Any], **kwargs: Any
    ) -> "ExperimentRunner":
        """
        Add a step to the experiment.

        Args:
            name: The name of the step
            func: The function to execute
            **kwargs: Parameters to pass to the function

        Returns:
            The experiment runner instance (for chaining)
        """
        step = ExperimentStep(name, func, **kwargs)
        self.steps.append(step)
        logger.info(f"Added step: {name}")
        return self

    @with_logger
    def run(self) -> Dict[str, Any]:
        """
        Run the experiment.

        Returns:
            The results of the experiment
        """
        logger.info(f"Starting experiment: {self.name}")

        # Record the start time
        self.start_time = time.time()

        # initialise the results dictionary
        self.results = {}

        # Execute each step
        total_steps = len(self.steps)
        for i, step in enumerate(self.steps, 1):
            logger.info(f"Executing step {i}/{total_steps}: {step.name}")
            try:
                # Pass the logger via kwargs so the decorator can use it
                result = step.execute(previous_results=self.results, logger=logger)
                self.results[step.name] = result
            except Exception as e:
                logger.error(f"Error in step {step.name}: {str(e)}", exc_info=True)
                raise

        # Record the end time
        self.end_time = time.time()

        # Calculate the duration
        duration = self.end_time - self.start_time
        self.metrics["duration"] = duration
        logger.info(f"Experiment completed in {duration:.2f} seconds")

        # Save the results
        self._save_results()

        return self.results

    @with_logger
    def add_metric(self, name: str, value: float) -> None:
        """
        Add a metric to the experiment.

        Args:
            name: The name of the metric
            value: The value of the metric
        """
        self.metrics[name] = value
        logger.info(f"Added metric: {name} = {value}")

    @with_logger
    def _save_results(self) -> None:
        """
        Save the experiment results to files.
        """
        # Create a timestamp
        timestamp = time.strftime("%y%m%d-%H%M%S")

        # Create the experiment output directory
        exp_dir = self.output_dir / f"{self.name}_{timestamp}"
        exp_dir.mkdir(parents=True, exist_ok=True)

        logger.info(f"Saving results to: {exp_dir}")

        # Save the configuration
        config_path = exp_dir / "config.yaml"
        with open(config_path, "w") as f:
            yaml.dump(self.config, f, default_flow_style=False)
        logger.info(f"Saved configuration to: {config_path}")

        # Save error information if available from any step results
        error_tracker = None
        for step_name, result in self.results.items():
            # Check if any step returned an error_tracker
            if isinstance(result, tuple) and len(result) >= 3:
                # Check if the third element is an ErrorTracker
                from src.utils.error_tracking import ErrorTracker

                if isinstance(result[2], ErrorTracker):
                    error_tracker = result[2]
                    logger.info(f"Found error tracker from step: {step_name}")
                    break

        if error_tracker and error_tracker.get_error_count() > 0:
            # Save detailed error information
            errors_path = exp_dir / "errors.json"
            with open(errors_path, "w") as f:
                json.dump(error_tracker.to_dict(), f, indent=2, default=str)
            logger.info(f"Saved error details to: {errors_path}")

            # Add error metrics to the main metrics
            error_summary = error_tracker.get_summary()
            self.metrics.update(error_summary)
            logger.info(
                f"Added error statistics to metrics: {list(error_summary.keys())}"
            )

        # Metrics are now included in experiment_summary.json instead of separate metrics.json

        # Save results to configured output paths if specified
        self._save_configured_outputs()

        # Log the results instead of saving to separate files
        logger.info("Experiment step results:")
        saved_count = 0
        skipped_count = 0

        for name, result in self.results.items():
            # Log serializable results
            if isinstance(result, (str, int, float, bool)):
                logger.info(f"  {name}: {result}")
                saved_count += 1
            elif isinstance(result, (list, dict)):
                try:
                    result_str = json.dumps(result, indent=2)
                    logger.info(f"  {name}:\n{result_str}")
                    saved_count += 1
                except (TypeError, ValueError):
                    logger.info(f"Could not serialize result '{name}' to JSON")
                    skipped_count += 1
            elif isinstance(result, pd.DataFrame):
                logger.info(
                    f"  {name} (DataFrame): {result.shape[0]} rows, {result.shape[1]} columns"
                )
                # Log a summary of the DataFrame
                if not result.empty:
                    logger.info(f"    Columns: {list(result.columns)}")
                    if len(result) > 5:
                        logger.info(f"    First 5 rows:\n{result.head(5)}")
                    else:
                        logger.info(f"    All rows:\n{result}")
                saved_count += 1
            else:
                logger.info(f"Skipped non-serializable result: {name}")
                skipped_count += 1

        logger.info(
            f"Logged {saved_count} results, skipped {skipped_count} non-serializable results"
        )

    @with_logger
    def _save_configured_outputs(self) -> None:
        """
        Save results to configured output paths.

        Uses a simple mapping approach - just specify which step name maps to which output.
        """
        output_config = self.config.get("output", {})

        if not output_config:
            logger.info("No configured output paths found")
            return

        logger.debug("Processing configured output paths")

        # Simple approach: for each output config, find ANY step that has the right data
        for output_key, output_path in output_config.items():
            self._save_any_matching_result(output_key, output_path)

        logger.debug("Configured output paths processing completed")

    @with_logger
    def _get_safe_output_path(self, original_path: pathlib.Path) -> pathlib.Path:
        """
        Get a safe output path that won't overwrite existing files.

        If the file exists, returns a timestamped filename.

        Args:
            original_path: The original desired path

        Returns:
            A safe path to use for writing
        """
        if not original_path.exists():
            return original_path

        # Use timestamped name for new file to avoid overwrite
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        timestamped_path = original_path.with_name(
            f"{original_path.stem}_{timestamp}{original_path.suffix}"
        )
        logger.info(
            f"File exists, using timestamped filename to avoid overwrite: {timestamped_path}"
        )
        return timestamped_path

    @with_logger
    def _save_any_matching_result(self, output_key: str, output_path: str) -> None:
        """
        Find and save the first result that contains saveable data.

        Handles different file types based on extension:
        - .png, .jpg, .jpeg: matplotlib figures
        - .csv: DataFrames
        - .json: JSON-serializable data
        """
        original_path = pathlib.Path(output_path)
        original_path.parent.mkdir(parents=True, exist_ok=True)
        path = self._get_safe_output_path(original_path)
        file_extension = path.suffix.lower()

        # Special case for embeddings (just logging)
        if output_key == "embeddings":
            for step_name, result in self.results.items():
                if isinstance(result, dict) and (
                    "success" in result or "error" in result
                ):
                    if result.get("success"):
                        logger.info(
                            f"Embeddings results were saved by step '{step_name}'"
                        )
                    elif "error" in result:
                        logger.error(
                            f"Embeddings saving failed in step '{step_name}': {result['error']}"
                        )
                    return
            return

        # Handle different file types based on extension
        for step_name, result in self.results.items():
            try:
                # Handle image files (PNG, JPG, JPEG)
                if file_extension in [".png", ".jpg", ".jpeg"]:
                    # Look for matplotlib figure in result
                    figure = None
                    if hasattr(result, "savefig"):
                        # Result is directly a matplotlib figure
                        figure = result
                    elif isinstance(result, dict):
                        # Look for figure in dict (common in embeddings results)
                        if "figure" in result and hasattr(result["figure"], "savefig"):
                            figure = result["figure"]
                        elif "fig" in result and hasattr(result["fig"], "savefig"):
                            figure = result["fig"]

                    if figure:
                        figure.savefig(path, bbox_inches="tight", dpi=300)
                        logger.info(
                            f"Saved {output_key} figure from step '{step_name}' to: {path}"
                        )
                        return

                # Handle CSV files
                elif file_extension == ".csv":
                    # Check if result has DataFrame directly
                    if hasattr(result, "to_csv"):
                        result.to_csv(path, index=False)
                        logger.info(
                            f"Saved {output_key} results from step '{step_name}' to: {path}"
                        )
                        return

                    # Check if result is dict with DataFrame inside
                    elif isinstance(result, dict):
                        for key, value in result.items():
                            if hasattr(value, "to_csv"):
                                value.to_csv(path, index=False)
                                logger.info(
                                    f"Saved {output_key} results from step '{step_name}' to: {path}"
                                )
                                return

                # Handle JSON files
                elif file_extension == ".json":
                    import json

                    data_to_save = None

                    # Try to find JSON-serializable data
                    if (
                        isinstance(result, (dict, list, str, int, float, bool))
                        or result is None
                    ):
                        data_to_save = result
                    elif isinstance(result, dict):
                        # Look for serializable data in dict
                        for key, value in result.items():
                            if (
                                isinstance(value, (dict, list, str, int, float, bool))
                                or value is None
                            ):
                                data_to_save = {key: value}
                                break

                    if data_to_save is not None:
                        with open(path, "w") as f:
                            json.dump(data_to_save, f, indent=2)
                        logger.info(
                            f"Saved {output_key} JSON from step '{step_name}' to: {path}"
                        )
                        return

            except Exception as e:
                # Continue to next result if this one fails
                logger.warning(
                    f"Failed to save {output_key} result from step '{step_name}': {str(e)}"
                )
                continue

        # If no suitable data found, warn user
        logger.warning(
            f"No suitable data found for output '{output_key}' with extension '{file_extension}'. Available steps: {list(self.results.keys())}"
        )

    def compare(self, other: "ExperimentRunner") -> Dict[str, Tuple[float, float]]:
        """
        Compare this experiment with another experiment.

        Args:
            other: The other experiment to compare with

        Returns:
            A dictionary of metric comparisons
        """
        comparisons = {}

        # Compare metrics
        for name in set(self.metrics.keys()) | set(other.metrics.keys()):
            self_value = self.metrics.get(name, float("nan"))
            other_value = other.metrics.get(name, float("nan"))
            comparisons[name] = (self_value, other_value)

        return comparisons


@with_logger
def load_experiment_results(
    experiment_dir: Union[str, pathlib.Path],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Load experiment results from a directory.

    Args:
        experiment_dir: The directory containing the experiment results

    Returns:
        A tuple of (config, metrics, results)

    Raises:
        FileNotFoundError: If the experiment directory does not exist
    """
    experiment_dir = pathlib.Path(experiment_dir)

    logger.debug(f"Loading experiment results from: {experiment_dir}")

    # Check if the directory exists
    if not experiment_dir.exists():
        logger.error(f"Experiment directory not found: {experiment_dir}")
        raise FileNotFoundError(f"Experiment directory not found: {experiment_dir}")

    # Load the configuration
    config_path = experiment_dir / "config.yaml"
    if config_path.exists():
        with open(config_path, "r") as f:
            config = yaml.safe_load(f)
        logger.debug(f"Loaded configuration from: {config_path}")
    else:
        logger.warning(f"Configuration file not found: {config_path}")
        config = {}

    # Load error information if available
    errors_path = experiment_dir / "errors.json"
    errors = {}
    if errors_path.exists():
        with open(errors_path, "r") as f:
            errors = json.load(f)
        logger.debug(f"Loaded error details from: {errors_path}")
        logger.debug(f"Found {len(errors)} LLM invocation errors in experiment")

    # Check for experiment log file
    log_path = experiment_dir / "experiment.log"
    results = {}

    if log_path.exists():
        logger.debug(f"Found experiment log file: {log_path}")
        logger.debug("Experiment step results are logged in the experiment log file.")
        logger.debug("To view detailed results, please check the experiment log file.")

        # Add a reference to the log file in the results
        results["_log_file"] = str(log_path)
    else:
        logger.warning(f"Experiment log file not found: {log_path}")

    # Add errors to results if available
    if errors:
        results["_errors"] = errors

    return config, results
