"""
Benchmark execution functions using Optuna for hyperparameter optimization.

This module provides the core execution engine for running hyperparameter
optimization experiments with comprehensive error handling and progress tracking.
"""

from collections import namedtuple
from pathlib import Path
import numpy as np
import optuna

from sklearn.model_selection import cross_val_score
from joblib import parallel_backend
from sklearn.metrics import mean_squared_error
import dill
import os
import datetime
import time
from sklearn.base import BaseEstimator, RegressorMixin
from typing import Dict, Any, Union, List, Tuple


# --- Error Handling Wrapper ---


class SafeEstimatorWrapper(BaseEstimator, RegressorMixin):
    """
    A wrapper for scikit-learn estimators to catch exceptions during `fit`
    and save the context (data and hyperparameters) for debugging.
    """

    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y=None, **fit_params):
        """
        Fits the wrapped estimator. If an exception occurs, it calls the
        error handler and re-raises the exception.
        """
        try:
            return self.estimator.fit(X, y, **fit_params)
        except BaseException as e:
            self.handle_error(e, X, y)
            raise

    def predict(self, X):
        """Predicts using the fitted wrapped estimator."""
        return self.estimator.predict(X)

    def get_params(self, deep=True):
        """Get parameters for the wrapped estimator."""
        return self.estimator.get_params(deep=deep)

    def set_params(self, **params):
        """Set parameters for the wrapped estimator."""
        self.estimator.set_params(**params)
        return self

    def __sklearn_clone__(self):
        """Custom clone method to handle the wrapper properly."""
        from sklearn.base import clone

        # Clone the wrapped estimator and wrap it in a new SafeEstimatorWrapper
        cloned_estimator = clone(self.estimator)
        return SafeEstimatorWrapper(cloned_estimator)

    def handle_error(self, e, X, y):
        """
        Saves the hyperparameters, data split, and exception to a file
        for later analysis.
        """
        print("--- ERROR during cross-validation fit ---")
        params = self.estimator.get_params()
        print(f"Estimator: {self.estimator.__class__.__name__}")
        print(f"Parameters: {params}")

        error_dir = "cv_error_logs"
        os.makedirs(error_dir, exist_ok=True)

        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        estimator_name = self.estimator.__class__.__name__
        filename = os.path.join(error_dir, f"error__{estimator_name}__{timestamp}.pkl")

        error_data = {"params": params, "X": X, "y": y, "exception": e}

        try:
            with open(filename, "wb") as f:
                dill.dump(error_data, f)
            print(f"Saved error data (hyperparameters and data split) to {filename}")
        except Exception as dump_e:
            print(f"CRITICAL: Failed to dump error data to {filename}: {dump_e}")
        print("--- End of error report ---")

    def __getattr__(self, name):
        """
        Delegates attribute access to the wrapped estimator to ensure
        compatibility with sklearn APIs (e.g., accessing `feature_importances_`).
        """
        if name == "estimator":
            raise AttributeError()
        return getattr(self.estimator, name)


# --- Parameter Sampling Utilities ---


def _sample_single_parameter(
    trial: optuna.Trial,
    name: str,
    distribution: Union[Tuple, List],
) -> Any:
    """
    Sample a single parameter from its distribution using Optuna trial.

    Args:
        trial: Optuna trial object
        name: Parameter name
        distribution: Parameter distribution specification (tuple with (dist_type, *args) or list for categorical)

    Returns:
        Sampled parameter value
    """
    # Handle tuples or lists that look like distribution specs
    # Lists from JSON like ["uniform", 0.1, 0.9] should be treated as distributions
    distribution_types = {"randint", "uniform", "loguniform", "1-loguniform"}

    # Check if it's a tuple or list with distribution type as first element
    if (
        isinstance(distribution, (tuple, list))
        and len(distribution) >= 2
        and distribution[0] in distribution_types
    ):
        dist_type = distribution[0]
        args = distribution[1:]

        if dist_type == "randint":
            if len(args) != 2:
                raise ValueError(
                    f"randint requires 2 arguments (low, high), got {len(args)}"
                )
            low, high = args
            return trial.suggest_int(name, low, high - 1)
        elif dist_type == "uniform":
            if len(args) != 2:
                raise ValueError(
                    f"uniform requires 2 arguments (low, high), got {len(args)}"
                )
            low, high = args
            return trial.suggest_float(name, low, high)
        elif dist_type == "loguniform":
            if len(args) != 2:
                raise ValueError(
                    f"loguniform requires 2 arguments (low, high), got {len(args)}"
                )
            low, high = args
            return trial.suggest_float(name, low, high, log=True)
        elif dist_type == "1-loguniform":
            if len(args) != 2:
                raise ValueError(
                    f"1-loguniform requires 2 arguments (a, b), got {len(args)}"
                )
            a, b = args
            return 1 - trial.suggest_float(name, a, b, log=True)
        else:
            raise ValueError(f"Unsupported distribution type for {name}: {dist_type}")
    elif isinstance(distribution, list):
        # List that doesn't start with a distribution type - treat as categorical
        return trial.suggest_categorical(name, distribution)
    else:
        raise ValueError(f"Unsupported distribution for {name}: {distribution}")


def _sample_independent_parameters(
    trial: optuna.Trial,
    param_distributions: Dict[str, Any],
    dependent_params: List[str] = None,
) -> Dict[str, Any]:
    """
    Sample all independent parameters (excluding dependent ones).

    Args:
        trial: Optuna trial object
        param_distributions: Dictionary of parameter distributions
        dependent_params: List of parameter names that should be skipped

    Returns:
        Dictionary of sampled independent parameters
    """
    if dependent_params is None:
        dependent_params = []

    params = {}
    for name, dist in param_distributions.items():
        if name not in dependent_params:
            params[name] = _sample_single_parameter(trial, name, dist)

    return params


def _sample_conditional_parameters(
    trial: optuna.Trial,
    param_distributions: Dict[str, Any],
    sampled_params: Dict[str, Any],
    fixed_params: Dict[str, Any] = None,
) -> Dict[str, Any]:
    """
    Sample conditional parameters based on already sampled parameters.

    Args:
        trial: Optuna trial object
        param_distributions: Dictionary of parameter distributions
        sampled_params: Already sampled parameters that conditions depend on
        fixed_params: Fixed parameters (not sampled) that conditions may depend on

    Returns:
        Dictionary of sampled conditional parameters
    """
    conditional_params = {}
    if fixed_params is None:
        fixed_params = {}

    # Handle conditional dependencies driven by split_strategy
    # - RandomSplit: requires split_try, colsample_bytree; not top_k/must_fill_all_k
    # - BestSplit: requires none of split_try/colsample_bytree/top_k/must_fill_all_k
    # - TopKSplits: requires top_k, must_fill_all_k; not split_try/colsample_bytree
    # Check both sampled_params and fixed_params for split_strategy
    split_strategy = sampled_params.get("split_strategy") or fixed_params.get(
        "split_strategy"
    )

    if split_strategy is None:
        return conditional_params

    # Normalize to string if categorical choices return other types
    if isinstance(split_strategy, str):
        strategy_key = split_strategy
    else:
        strategy_key = str(split_strategy)

    if strategy_key == "random":
        if "split_try" in param_distributions:
            conditional_params["split_try"] = _sample_single_parameter(
                trial, "split_try", param_distributions["split_try"]
            )
        if "colsample_bytree" in param_distributions:
            conditional_params["colsample_bytree"] = _sample_single_parameter(
                trial, "colsample_bytree", param_distributions["colsample_bytree"]
            )
        # Do not sample top_k / must_fill_all_k here
    elif strategy_key == "top_k":
        if "top_k" in param_distributions:
            conditional_params["top_k"] = _sample_single_parameter(
                trial, "top_k", param_distributions["top_k"]
            )
        if "must_fill_all_k" in param_distributions:
            conditional_params["must_fill_all_k"] = _sample_single_parameter(
                trial, "must_fill_all_k", param_distributions["must_fill_all_k"]
            )
        # Do not sample split_try / colsample_bytree here
    else:
        # Covers "best_split" or any other future strategies that don't need the above
        # Intentionally do not sample split_try/colsample_bytree/top_k/must_fill_all_k
        pass

    # Removed combination_strategy logic since we're not using that parameter

    return conditional_params


def _sample_all_parameters(
    trial: optuna.Trial,
    param_distributions: Dict[str, Any],
    fixed_params: Dict[str, Any] = None,
) -> Dict[str, Any]:
    """
    Sample all parameters with conditional dependencies handled properly.

    Args:
        trial: Optuna trial object
        param_distributions: Dictionary of parameter distributions
        fixed_params: Fixed parameters (not sampled) that conditions may depend on

    Returns:
        Dictionary of all sampled parameters
    """
    if fixed_params is None:
        fixed_params = {}

    # Define which parameters are dependent on split_strategy
    # Only include parameters that actually exist in the current model configurations
    dependent_params = [
        "split_try",
        "colsample_bytree",
        "top_k",
        "must_fill_all_k",
    ]

    # Filter out parameters that don't exist in the current param_distributions
    dependent_params = [p for p in dependent_params if p in param_distributions]

    # First pass: sample independent parameters
    params = _sample_independent_parameters(
        trial, param_distributions, dependent_params
    )

    # Removed similarity_threshold logic since that parameter is not in current MPF config

    # Second pass: sample conditional parameters (pass fixed_params so it can check split_strategy)
    conditional_params = _sample_conditional_parameters(
        trial, param_distributions, params, fixed_params
    )

    # Combine all parameters
    params.update(conditional_params)

    return params


def _generate_grid_search_space(
    param_distributions: Dict[str, Any],
    fixed_params: Dict[str, Any] = None,
) -> Dict[str, List[Any]]:
    """
    Generate grid search space from parameter distributions for GridSampler.

    Args:
        param_distributions: Dictionary of parameter distributions
        fixed_params: Fixed parameters (not included in grid)

    Returns:
        Dictionary in format {param_name: [value1, value2, ...]} for GridSampler

    Note:
        For grid search, we include all parameters in param_distributions. If split_strategy
        is in param_distributions (not fixed), GridSampler will create all combinations.
        Some combinations may be invalid (e.g., split_strategy="random" with top_k), but
        the model will handle this the same way as with other samplers.
    """
    if fixed_params is None:
        fixed_params = {}

    grid_space = {}

    # Define which parameters are dependent on split_strategy
    dependent_params = [
        "split_try",
        "colsample_bytree",
        "top_k",
        "must_fill_all_k",
    ]

    # Get split_strategy from fixed_params (if it's fixed, we can filter dependent params)
    split_strategy = fixed_params.get("split_strategy")

    # Check if split_strategy is in param_distributions (will be sampled)
    split_strategy_in_dist = "split_strategy" in param_distributions

    for param_name, distribution in param_distributions.items():
        # If split_strategy is fixed, filter dependent params based on its value
        if not split_strategy_in_dist and param_name in dependent_params:
            if split_strategy == "random" and param_name in [
                "top_k",
                "must_fill_all_k",
            ]:
                continue  # Skip top_k/must_fill_all_k for random split_strategy
            elif split_strategy == "top_k" and param_name in [
                "split_try",
                "colsample_bytree",
            ]:
                continue  # Skip split_try/colsample_bytree for top_k split_strategy
            elif split_strategy == "best_split":
                continue  # Skip all dependent params for best_split
            # Otherwise include it (e.g., split_try/colsample_bytree for random, top_k/must_fill_all_k for top_k)

        # If split_strategy is in param_distributions, include all params
        # GridSampler will create all combinations, some may be invalid but that's handled by the model

        if isinstance(distribution, tuple) and len(distribution) >= 2:
            dist_type = distribution[0]
            args = distribution[1:]

            if dist_type == "randint":
                if len(args) != 2:
                    raise ValueError(
                        f"randint requires 2 arguments (low, high), got {len(args)}"
                    )
                low, high = args
                # Generate integer grid: use all values if range is small, otherwise sample
                if high - low <= 20:
                    grid_space[param_name] = list(range(low, high))
                else:
                    # Generate 8-10 points for larger ranges
                    num_points = min(10, max(8, (high - low) // 2))
                    step = (high - low) / (num_points - 1) if num_points > 1 else 1
                    grid_space[param_name] = [
                        int(low + i * step) for i in range(num_points)
                    ]
                    # Ensure high-1 is included
                    if grid_space[param_name][-1] != high - 1:
                        grid_space[param_name][-1] = high - 1
                    grid_space[param_name] = sorted(list(set(grid_space[param_name])))

            elif dist_type == "uniform":
                if len(args) != 2:
                    raise ValueError(
                        f"uniform requires 2 arguments (low, high), got {len(args)}"
                    )
                low, high = args
                # Generate 5-8 points for uniform distributions
                num_points = min(8, max(5, int((high - low) * 10)))
                step = (high - low) / (num_points - 1) if num_points > 1 else 0
                grid_space[param_name] = [low + i * step for i in range(num_points)]
                # Ensure high is included
                if abs(grid_space[param_name][-1] - high) > 1e-10:
                    grid_space[param_name][-1] = high
                grid_space[param_name] = sorted(list(set(grid_space[param_name])))

            elif dist_type == "loguniform":
                if len(args) != 2:
                    raise ValueError(
                        f"loguniform requires 2 arguments (low, high), got {len(args)}"
                    )
                low, high = args
                # Generate 5-7 points with logarithmic spacing
                num_points = min(7, max(5, int(np.log10(high / low) * 2) + 1))
                log_low = np.log10(low)
                log_high = np.log10(high)
                log_step = (
                    (log_high - log_low) / (num_points - 1) if num_points > 1 else 0
                )
                grid_space[param_name] = [
                    10 ** (log_low + i * log_step) for i in range(num_points)
                ]
                # Ensure high is included
                if abs(grid_space[param_name][-1] - high) / high > 1e-6:
                    grid_space[param_name][-1] = high
                grid_space[param_name] = sorted(list(set(grid_space[param_name])))

            else:
                # Unsupported distribution type - skip
                continue

        elif isinstance(distribution, list):
            # Categorical: use all provided options
            grid_space[param_name] = distribution

    return grid_space


Task = namedtuple(
    "Task",
    [
        "estimator",
        "param_distributions",
        "fixed_params",  # Fixed parameters to initialize model with
        "X_train",
        "y_train",
        "X_test",
        "y_test",
        "n_iter",
        "cv",
        "random_state",
        "n_jobs",
        "study_name",
        "storage_name",
        "experiment_id",  # Add experiment_id for progress tracking
        "optimization_method",  # "optuna", "random", or "grid"
    ],
)


def run_optuna_benchmark(
    task: Task,
    callbacks: list = None,
    direction="minimize",
):
    """
    Runs Optuna TPE-based optimization for a given estimator and returns benchmarking results.

    Args:
        task: Task object containing all optimization parameters
        callbacks: List of Optuna callbacks to use during optimization
        direction: Optimization direction ("minimize" or "maximize")

    Returns:
        Dictionary containing optimization results
    """
    estimator = task.estimator
    param_distributions = task.param_distributions
    fixed_params = task.fixed_params or {}
    X_train = task.X_train
    y_train = task.y_train
    X_test = task.X_test
    y_test = task.y_test
    n_iter = task.n_iter
    cv = task.cv
    random_state = task.random_state
    n_jobs = task.n_jobs
    study_name = task.study_name
    storage_name = task.storage_name

    if not hasattr(estimator, "_name"):
        estimator._name = estimator.__class__.__name__
    print(f"Estimator: {estimator._name}", flush=True)
    print(f"Storage: {storage_name}", flush=True)
    print(f"Study Name: {study_name}", flush=True)
    print(f"Train set size: X{X_train.shape}, y{y_train.shape}", flush=True)
    print(f"Test set size: X{X_test.shape}, y{y_test.shape}", flush=True)
    print("Parameters:", flush=True)
    print(f"n_iter: {n_iter}", flush=True)
    print(f"cv: {cv}", flush=True)
    print(f"random_state: {random_state}", flush=True)
    print(f"n_jobs: {n_jobs}", flush=True)

    # Get optimization method from task (default to "optuna" for backward compatibility)
    optimization_method = getattr(task, "optimization_method", "optuna")
    print(f"optimization_method: {optimization_method}", flush=True)

    # Create Optuna study with appropriate sampler based on optimization method
    if optimization_method == "random":
        sampler = optuna.samplers.RandomSampler(seed=random_state)
        print("Using RandomSampler", flush=True)
    elif optimization_method == "grid":
        # Generate grid search space from param_distributions
        grid_search_space = _generate_grid_search_space(
            param_distributions, fixed_params
        )
        total_combinations = 1
        for values in grid_search_space.values():
            total_combinations *= len(values)
        print(
            f"Generated grid search space with {total_combinations} total combinations",
            flush=True,
        )
        if total_combinations > n_iter:
            print(
                f"Warning: Grid has {total_combinations} combinations but only {n_iter} trials will run. "
                f"Only a subset of combinations will be evaluated.",
                flush=True,
            )
        sampler = optuna.samplers.GridSampler(grid_search_space, seed=random_state)
        print("Using GridSampler", flush=True)
    else:  # "optuna" or default
        sampler = optuna.samplers.TPESampler(
            seed=random_state, multivariate=True, group=True
        )
        print("Using TPESampler", flush=True)
    storage = None
    if study_name and storage_name:
        # Use shared log file for all studies in the experiment
        print(f"Using shared log file storage: {storage_name}", flush=True)
        storage = optuna.storages.JournalStorage(
            optuna.storages.journal.JournalFileBackend(storage_name)
        )

    print("Creating or loading Optuna study...", flush=True)

    # Check if storage file exists and has incompatible trials
    # If so, warn user and suggest deleting the file
    if storage and storage_name:
        storage_path = Path(storage_name)
        if storage_path.exists():
            try:
                # Try to load the study to check for compatibility
                existing_study = optuna.load_study(
                    study_name=study_name, storage=storage
                )
                if len(existing_study.trials) > 0:
                    # Check if we can sample a trial with current distributions
                    # This is a heuristic - if distributions changed, we'll catch it during optimization
                    print(
                        f"Found existing study with {len(existing_study.trials)} trials. "
                        f"If you get distribution compatibility errors, delete the log file: {storage_name}",
                        flush=True,
                    )
            except Exception:
                # If we can't load, that's fine - we'll create a new one
                pass

    # Try to load existing study first, create if it doesn't exist
    try:
        study = optuna.load_study(study_name=study_name, storage=storage)
        print(f"Loaded existing study: {study_name}", flush=True)
    except (KeyError, optuna.exceptions.OptunaError) as e:
        # KeyError is raised when study doesn't exist in journal storage
        # OptunaError covers other study-related errors
        if "Record does not exist" in str(e) or isinstance(e, KeyError):
            print(f"Creating new study: {study_name}", flush=True)
            study = optuna.create_study(
                direction=direction,
                sampler=sampler,
                study_name=study_name,
                storage=storage,
            )
        else:
            # Re-raise if it's a different OptunaError
            raise

    print(f"Study has {len(study.trials)} existing trials.", flush=True)
    if len(study.trials) >= n_iter:
        print(
            f"Study already has {len(study.trials)} trials, which is >= n_iter={n_iter}. Skipping optimization.",
            flush=True,
        )
        start = end = 0
    else:

        def objective(trial: optuna.Trial):
            # Sample all parameters with conditional dependencies handled
            # Pass fixed_params so conditional logic can check split_strategy if it's fixed
            sampled_params = _sample_all_parameters(
                trial, param_distributions, fixed_params
            )

            # Initialize model with fixed params first, then apply sampled range params
            # Get the underlying estimator from SafeEstimatorWrapper
            from sklearn.base import clone

            base_estimator = (
                estimator.estimator if hasattr(estimator, "estimator") else estimator
            )
            model = clone(base_estimator)

            # Set fixed params first (these are not tuned)
            if fixed_params:
                model = model.set_params(**fixed_params)

            # Then set sampled range params (these are tuned by Optuna)
            if sampled_params:
                model = model.set_params(**sampled_params)

            safe_model = SafeEstimatorWrapper(model)
            with parallel_backend("threading"):
                scores = cross_val_score(
                    safe_model,
                    X_train,
                    y_train,
                    scoring="neg_mean_squared_error",
                    cv=cv,
                    error_score="raise",
                    n_jobs=n_jobs,
                )

            return -np.mean(scores)

        # Set up callbacks
        if callbacks is None:
            callbacks = []

        # Calculate how many more trials are needed to reach n_iter total
        existing_trials = len(study.trials)
        trials_needed = n_iter - existing_trials

        if trials_needed > 0:
            print(
                f"Continuing study '{study_name}', running {trials_needed} more trials to reach {n_iter} total (currently {existing_trials}).",
                flush=True,
            )
            start = time.time()
            study.optimize(objective, n_trials=trials_needed, callbacks=callbacks)
            end = time.time()
        else:
            print(
                f"Study already has {existing_trials} trials, which meets or exceeds n_iter={n_iter}. Skipping optimization.",
                flush=True,
            )
            start = end = 0

    try:
        best_params = study.best_params
        best_cv_score = -study.best_value
    except ValueError:
        print(
            "Study has no completed trials. Cannot determine best parameters.",
            flush=True,
        )
        return {}, None

    # Initialize best estimator with fixed params, then apply best sampled params
    from sklearn.base import clone

    base_estimator = (
        estimator.estimator if hasattr(estimator, "estimator") else estimator
    )
    best_estimator = clone(base_estimator)
    if fixed_params:
        best_estimator = best_estimator.set_params(**fixed_params)
    if best_params:
        best_estimator = best_estimator.set_params(**best_params)
    best_estimator.fit(X_train, y_train)

    preds = best_estimator.predict(X_test)
    test_mse = mean_squared_error(y_test, preds)
    test_rmse = np.sqrt(test_mse)

    results = {
        "model": estimator._name,
        "best_params": best_params,
        "fixed_params": fixed_params,  # Include fixed parameters in results
        "best_cv_score": best_cv_score,
        "test_mse": test_mse,
        "test_rmse": test_rmse,
        "fit_time": end - start,
    }

    return results
