#!/usr/bin/env python3

import logging
import time
from typing import Any, cast
import hydra
import mlflow
import optuna
import pandas as pd
from hydra.core.override_parser.overrides_parser import OverridesParser
from omegaconf import DictConfig, ListConfig, OmegaConf
from haipr.train import HAIPRTrainer
from haipr.utils.optuna_sweeper import create_optuna_distribution_from_override
from haipr.utils.resolvers import register_resolvers
from haipr.utils.results_logger import ResultsLogger
import gc


logger = logging.getLogger("haipr.optimizer")
logger.setLevel(logging.WARNING)


class HAIPROptimizer:
    """Run hyperparameter optimization for HAIPR models using Optuna."""

    def __init__(self, cfg: DictConfig | ListConfig):
        self.cfg = cfg
        self._study = None
        self._hyperopt_run_id = None
        self._search_space: dict[str, Any] = {}
        self._n_trials_target = 0
        self._study_name: str | None = None

        self.is_multirun = False
        self.job_num = 0
        self.is_nested = False

        self.results_logger = ResultsLogger(
            cfg=self.cfg,
            run=None,
        )

        mlflow.set_tracking_uri(self.cfg.mlflow.tracking_uri)

        self._setup_job_handling()

    def _setup_job_handling(self) -> None:
        try:
            from hydra.core.hydra_config import HydraConfig

            if HydraConfig.initialized():
                hydra_cfg = HydraConfig.get()
                mode = getattr(hydra_cfg, "mode", None)
                mode_name = getattr(mode, "name", "SINGLERUN")
                self.is_multirun = mode_name == "MULTIRUN"
                if self.is_multirun:
                    hydra_job = getattr(hydra_cfg, "job", None)
                    self.job_num = getattr(hydra_job, "num", 0)
        except Exception as e:
            logger.warning(
                f"Could not get Hydra config: {e}. Defaulting to single job mode"
            )

    def setup_storage(self, cfg):
        """Setup Optuna storage. Made public for reuse by HAIPRTrainer."""
        storage = None
        logger.debug(
            "Setting up Optuna storage with config: %s",
            cfg.optuna.storage if hasattr(cfg.optuna, "storage") else "None",
        )

        if hasattr(cfg.optuna, "storage"):
            storage_uri = cfg.optuna.storage
            if storage_uri.startswith("mysql") or storage_uri.startswith(
                "postgresql"
            ):
                logger.info(
                    f"Setting up RDB storage with URL: {cfg.optuna.storage}")
                try:
                    logger.debug(
                        "Configuring RDB storage with grace period: %s, max_retry: %s",
                        cfg.optuna.grace_period,
                        cfg.optuna.max_retry,
                    )
                    storage = optuna.storages.RDBStorage(
                        url=f"{storage_uri}",
                        engine_kwargs={
                            "pool_size": 20,
                            "max_overflow": 0,
                            "pool_timeout": 30,
                        },
                        grace_period=cfg.optuna.grace_period,
                        failed_trial_callback=optuna.storages.RetryFailedTrialCallback(
                            max_retry=cfg.optuna.max_retry,
                        ),
                    )
                    logger.info("RDB storage setup successful")
                except Exception as e:
                    logger.error(
                        f"Failed to setup RDB storage: {e}", exc_info=True)
                    raise
            elif storage_uri.endswith(".db"):
                logger.info("Using SQLite storage")
                logger.debug("SQLite database file: %s", storage_uri)
                try:
                    storage = optuna.storages.RDBStorage(
                        url=f"sqlite:///{storage_uri}",
                        engine_kwargs={
                            "connect_args": {"check_same_thread": False},
                        },
                    )
                except Exception as e:
                    logger.warning(f"Failed to create SQLite storage: {e}. Falling back to JournalFileStorage.")
                    file_storage = optuna.storages.JournalFileStorage(storage_uri)
                    storage = optuna.storages.JournalStorage(file_storage)
            else:
                logger.info("Using InMemory storage")
                storage = optuna.storages.InMemoryStorage()
        logger.debug("Storage setup complete. Storage type: %s",
                     type(storage).__name__)
        return storage

    def _initialize_study(self, cfg, storage):
        # Use a distinct study per outer test split to prevent leakage across splits
        self._study_name = self.create_study_name(cfg)
        logger.debug("Initializing study with name: %s", self._study_name)

        pruner = optuna.pruners.MedianPruner() if cfg.optuna.pruner else None
        logger.debug("Using pruner: %s", type(
            pruner).__name__ if pruner else "None")

        sampler = self._create_optuna_sampler(cfg)
        logger.debug("Using sampler: %s", type(sampler).__name__)

        if not self.cfg.parallel:  # SLURM
            logger.info(
                "Non-parallel mode: Handling initial study creation/loading")
            self._handle_study_creation(storage, pruner, sampler)
        elif self.job_num == 0:
            logger.info("Job 0: Handling initial study creation/loading")
            self._handle_study_creation(storage, pruner, sampler)
        else:
            logger.info(f"Job {self.job_num}: Loading existing study")
            self._handle_study_loading(storage)

        if self._study:
            self._study.set_metric_names([self.cfg.model.optimization_metric])
        else:
            raise ValueError("Study not found")

        logger.debug(
            "Study initialization complete. Metric names set to: %s",
            self.cfg.model.optimization_metric,
        )
        return self._study

    def _create_optuna_sampler(self, cfg):
        return optuna.samplers.TPESampler(
            seed=self.cfg.seed,
            n_startup_trials=self.cfg.optuna.n_startup_trials,
            n_ei_candidates=self.cfg.optuna.n_ei_candidates,
            multivariate=self.cfg.optuna.multivariate,
            warn_independent_sampling=self.cfg.optuna.warn_independent_sampling,
            consider_prior=self.cfg.optuna.consider_prior,
            prior_weight=self.cfg.optuna.prior_weight,
            consider_magic_clip=self.cfg.optuna.consider_magic_clip,
        )

    def _handle_study_creation(self, storage, pruner, sampler):
        existing_studies = optuna.study.get_all_study_summaries(
            storage=storage)
        existing_study_names = [study.study_name for study in existing_studies]

        if self.cfg.optuna.clear_storage and self._study_name in existing_study_names:
            logger.info(f"Clearing study: {self._study_name}")
            optuna.delete_study(study_name=self._study_name, storage=storage)

        try:
            self._study = optuna.create_study(
                study_name=self._study_name,
                direction=self.cfg.model.optimization_direction,
                sampler=sampler,
                storage=storage,
                pruner=pruner,
                load_if_exists=True,
            )
            logger.info(
                f"Created/loaded study '{self._study_name}' successfully")
        except Exception as e:
            logger.error(f"Failed to create/load study: {e}")
            raise

    def load_study(self, study_name: str, storage) -> optuna.Study:
        """Load an existing Optuna study. Made public for reuse by HAIPRTrainer."""
        time.sleep(10)
        try:
            study = optuna.load_study(study_name=study_name, storage=storage)
            logger.info(f"Loaded study '{study_name}' successfully")
            return study
        except Exception as e:
            logger.debug(
                f"Failed to load study on first attempt: {e}, retrying in 10 seconds..."
            )
            time.sleep(10)
            try:
                study = optuna.load_study(
                    study_name=study_name, storage=storage)
                logger.info(f"Loaded study '{study_name}' successfully")
                return study
            except Exception as e:
                logger.error(f"Failed to load study: {e}")
                raise

    def _handle_study_loading(self, storage):
        """Internal method that uses the public load_study method."""
        self._study = self.load_study(self._study_name, storage)

    def _create_new_hyperopt_run(self, run_name):
        try:
            mlflow.set_experiment(self.cfg.mlflow.experiment_name)
        except Exception as e:
            logger.error(f"Failed to set experiment for hyperopt run: {e}")

        run = mlflow.start_run(run_name=run_name, nested=True)

        # Reuse trainer's results logger
        self.results_logger.set_run_id(run.info.run_id)
        self.results_logger.run = run

        # Log experiment setup
        self.results_logger.log_config(self.cfg)

        study_storage = (
            self.cfg.optuna.storage if hasattr(self.cfg, "optuna") else "None"
        )

        mlflow.set_tags(
            {
                "run_type": "hyperopt",
                "study_name": self._study_name,
                "optimization_direction": self.cfg.model.optimization_direction,
                "sampler": (
                    self._study.sampler.__class__.__name__ if self._study else "None"
                ),
                "study_storage": study_storage,
                "pruner": (
                    self._study.pruner.__class__.__name__
                    if self._study and self._study.pruner
                    else "None"
                ),
                "n_trials_target": self._n_trials_target,
                "test_split_method": self.cfg.data.test_split_method,
                "test_split_idx": self.cfg.data.test_split_idx,
            }
        )

        # Log search space and model configuration
        mlflow.log_dict(
            OmegaConf.to_container(self.cfg.optuna.search_space, resolve=True),
            "search_space.yaml",
        )
        mlflow.log_dict(
            OmegaConf.to_container(self.cfg.model, resolve=True),
            "model_config.yaml",
        )

        self._hyperopt_run_id = run.info.run_id

        mlflow.end_run()
        return self._hyperopt_run_id

    def _parse_search_space(self):
        parser = OverridesParser.create()
        overrides = [f"{k}={v}" for k,
                     v in self.cfg.optuna.search_space.items()]
        parsed = parser.parse_overrides(overrides)

        for override in parsed:
            self._search_space[override.get_key_element()] = (
                create_optuna_distribution_from_override(override)
            )
        logger.debug(f"Parsed search space: {self._search_space}")

    def _should_continue_optimization(self):
        study = self._require_study()
        n_complete = len(
            [t for t in study.trials if t.state ==
                optuna.trial.TrialState.COMPLETE]
        )
        n_running = len(
            [t for t in study.trials if t.state == optuna.trial.TrialState.RUNNING]
        )
        n_total = len(study.trials)

        if n_complete >= self._n_trials_target:
            logger.info(
                f"Target number of trials ({self._n_trials_target}) completed.")
            return False

        if n_total >= self._n_trials_target:
            if n_running == 0:
                logger.info(f"All {n_total} trials finished.")
                return False
            else:
                logger.info(f"Waiting for {n_running} trials to complete...")
                time.sleep(10)
                return True

        return True

    def _suggest_parameters(self, trial) -> dict:
        trial_params: dict[str, Any] = {}
        for param_name, distribution in self._search_space.items():
            trial_params[param_name] = trial._suggest(param_name, distribution)
        logger.info(f"Trial {trial.number} parameters: {trial_params}")
        return trial_params

    def update_trial_config(
        self, base_cfg: DictConfig | ListConfig, trial_params: dict
    ) -> DictConfig | ListConfig:
        """Update config with trial parameters. Made public for reuse by HAIPRTrainer."""
        # IMPORTANT: do not resolve here to avoid triggering user-defined resolvers
        # (e.g., ${select:...}) embedded in config values such as optuna.search_space.
        trial_cfg = OmegaConf.create(
            OmegaConf.to_container(base_cfg, resolve=False))
        structured_params: dict[str, Any] = {}

        for dotpath, value in trial_params.items():
            current = structured_params
            parts = dotpath.split(".")
            for part in parts[:-1]:
                if part not in current:
                    current[part] = {}
                current = current[part]
            current[parts[-1]] = value

        return OmegaConf.merge(trial_cfg, OmegaConf.create(structured_params))

    def _update_trial_config(self, trial_params) -> DictConfig | ListConfig:
        """Internal method that uses the public update_trial_config method."""
        return self.update_trial_config(self.cfg, trial_params)

    def create_study_name(self, cfg: DictConfig | ListConfig) -> str:
        """Create study name based on config. Made public for reuse by HAIPRTrainer."""
        if cfg.optuna.study_name is None:
            import hashlib

            # Serialize the config to a hash
            cfg_str = OmegaConf.to_yaml(cfg, resolve=False)
            base_name = hashlib.md5(cfg_str.encode("utf-8")).hexdigest()
        else:
            base_name = str(cfg.optuna.study_name)

        split_suffix = None
        try:
            split_idx = getattr(getattr(cfg, "data", None),
                                "test_split_idx", None)
            if split_idx is not None:
                split_suffix = f"_split_{split_idx}"
        except Exception:
            split_suffix = None

        return base_name + (split_suffix or "")

    def get_storage_config(self) -> dict:
        """Get storage configuration for reuse by other classes."""
        return {
            "storage": getattr(self.cfg.optuna, "storage", None),
            "grace_period": getattr(self.cfg.optuna, "grace_period", 600),
            "max_retry": getattr(self.cfg.optuna, "max_retry", 3),
        }

    def update_intermediate_metrics(self, split_num, result):
        """
        Update Optuna with intermediate metrics for a given split.

        Args:
            split_num (int): The split number being updated.
            result (tuple): Tuple containing (metrics, predictions) for the split.
        """
        logger.info(f"Updating intermediate metrics for split {split_num}")
        try:
            # result is a tuple (metrics, predictions), we need the metrics part
            metrics = result[0] if isinstance(
                result, tuple) else result["metrics"]
            self.current_trial.report(
                metrics[self.cfg.model.optimization_metric], step=split_num)
        except Exception as e:
            logger.warning(
                f"Failed to update intermediate metrics for split {split_num}: {e}")

    def _run_trial(self, trial) -> tuple[float, dict] | None:
        logger.debug("Starting MLflow run for trial %d", trial.number)

        # Ensure experiment
        try:
            mlflow.set_experiment(self.cfg.mlflow.experiment_name)
        except Exception as e:
            logger.error(
                f"Failed to set experiment for trial {trial.number}: {e}")

        with mlflow.start_run(run_name=f"trial_{trial.number}", nested=True) as run:
            logger.debug("MLflow run started with ID: %s", run.info.run_id)

            # Log context
            mlflow.set_tags(
                {
                    "trial_number": trial.number,
                    "run_type": "trial",
                    "job": self.job_num,
                    "study_name": self._study_name,
                    "model": self.cfg.model.name,
                    "embedder": (
                        self.cfg.embedder.name
                        if hasattr(self.cfg, "embedder") and self.cfg.embedder is not None
                        else None
                    ),
                    "embedder_model": (
                        self.cfg.embedder.model
                        if hasattr(self.cfg, "embedder") and self.cfg.embedder is not None
                        else None
                    ),
                }
            )

            # Only log trial-specific parameters
            mlflow.log_params(
                {f"trial_{k}": v for k, v in trial.params.items()})

            # Build trial-specific config and trainer
            trial_params = self._suggest_parameters(trial)
            trial_cfg = self._update_trial_config(trial_params)

            trial_trainer = HAIPRTrainer(trial_cfg)
            # Ensure split runs in threads know their parent (trial) run
            try:
                trial_trainer.parent_run_id = run.info.run_id
            except Exception:
                raise RuntimeError("Failed to set parent run ID")

            # Attach nested run and log the trial config via ResultsLogger
            try:
                trial_trainer.results_logger.run = run
                trial_trainer.results_logger.set_run_id(run.info.run_id)
                trial_trainer.results_logger.log_config(trial_cfg)
            except Exception as e:
                raise RuntimeError(
                    f"Failed to log trial config via ResultsLogger: {e}")

            # NOTE: Redundant
            # try:
            #     if self.cfg.data.test_split_idx is not None:
            #         trial_trainer.data.set_test_data(
            #             self.cfg.data.test_split_method,
            #             self.cfg.data.test_split_idx,
            #         )
            #     if (
            #         self.cfg.data.subsample_threshold > 0
            #         and len(trial_trainer.data.active_idx)
            #         > self.cfg.data.subsample_threshold
            #     ):
            #         trial_trainer.data.subsample_data(
            #             self.cfg.data.subsample_threshold)
            #         # Regenerate splits on the subsampled training data
            #         trial_trainer.data.generate_splits()
            # except Exception as e:
            #     logger.error(
            #         f"Failed to apply hold-out/subsampling on trial dataset: {e}",
            #         exc_info=True,
            #     )

            try:
                # Run training and gather metrics
                mean_metrics, _ = trial_trainer.run_splits(
                    split_callback=self.update_intermediate_metrics)
                # Log aggregate metrics
                if mean_metrics:
                    trial_trainer.results_logger.log_metrics(mean_metrics)

                if not mean_metrics:
                    logger.error("Training failed for trial %d", trial.number)
                    self._require_study().tell(
                        trial, state=optuna.trial.TrialState.FAIL
                    )
                    return None

                optimization_metric = self.cfg.model.optimization_metric
                if optimization_metric not in mean_metrics:
                    logger.error(
                        f"Trial {trial.number}: Optimization metric '{optimization_metric}' not found in metrics: {list(mean_metrics.keys())}"
                    )
                    self._require_study().tell(
                        trial, state=optuna.trial.TrialState.FAIL
                    )
                    return None

                trial_value = mean_metrics[optimization_metric]
                logger.info(
                    f"Trial {trial.number} optimization value ({optimization_metric}): {trial_value}"
                )

                self._require_study().tell(trial, trial_value)
                logger.info(
                    f"Trial {trial.number} completed and reported to Optuna")
                del trial_trainer
                gc.collect()  # Force garbage collection, free memory
                return trial_value, mean_metrics
            
            except Exception as e:
                logger.error(e, stack_info=True)
                self._require_study().tell(trial, state=optuna.trial.TrialState.FAIL)
                return None

    def _create_study_visualizations(self, study):
        # DEPRECATED: only trouble with plotly
        logger.info("Creating study visualization plots")

        scalar_params = []
        for t in study.trials:
            if t.params:
                scalar_params = [
                    name
                    for name, value in t.params.items()
                    if not isinstance(value, (list, tuple))
                ]
                break

        visualization_funcs = {
            "optimization_history": optuna.visualization.plot_optimization_history,
            "parallel_coordinate": optuna.visualization.plot_parallel_coordinate,
            "slice": optuna.visualization.plot_slice,
            "timeline": optuna.visualization.plot_timeline,
        }

        for name, viz_func in visualization_funcs.items():
            try:
                fig = viz_func(study)
                if fig:
                    mlflow.log_figure(fig, f"{name}.png")
                    logger.debug(
                        f"Successfully created and logged {name} plot")
            except Exception as e:
                logger.error(f"Failed to create {name} visualization: {e}")
                continue

        try:
            if scalar_params:
                fig = optuna.visualization.plot_param_importances(
                    study, params=scalar_params
                )
                if fig:
                    mlflow.log_figure(fig, "param_importances.png")
                    logger.debug(
                        "Successfully created and logged param_importances plot"
                    )
        except Exception as e:
            logger.error(
                f"Failed to create param_importances visualization: {e}")

    def on_optimization_complete(self, run):
        logger.info("Optimization complete")

        runs_df: pd.DataFrame | Any = mlflow.search_runs(
            experiment_ids=[run.info.experiment_id],
            filter_string=f"tags.run_type = 'trial' AND tags.study_name = '{self._study_name}'",
            output_format="pandas",
        )

        if runs_df.empty:
            raise ValueError("No runs found after optimization")
        try:
            metric_col = f"metrics.{self.cfg.model.optimization_metric}"
            if metric_col not in runs_df.columns:
                logger.error(
                    f"Optimization metric {self.cfg.model.optimization_metric} not found in metric columns {runs_df.columns[runs_df.columns.str.contains('metric')]}")
                raise ValueError(
                    f"Optimization metric '{self.cfg.model.optimization_metric}' not found in runs"
                )

            # Filter out runs with nan values for the optimization metric
            valid_runs = runs_df.dropna(subset=[metric_col])
            if valid_runs.empty:
                logger.error(
                    f"No valid runs found with non-nan values for {metric_col}")
                raise ValueError(
                    f"No valid runs found for optimization metric '{self.cfg.model.optimization_metric}'")

            if self.cfg.model.optimization_direction == "maximize":
                best_run = valid_runs.iloc[valid_runs[metric_col].idxmax()]
            elif self.cfg.model.optimization_direction == "minimize":
                best_run = valid_runs.iloc[valid_runs[metric_col].idxmin()]
            else:
                raise ValueError(
                    f"Invalid optimization direction: {self.cfg.model.optimization_direction}"
                )
        except KeyError:
            logger.error(
                f"Optimization metric {self.cfg.model.optimization_metric} not found in metric columns {runs_df.columns[runs_df.columns.str.contains('metric')]}")
            raise ValueError(
                f"Optimization metric '{self.cfg.model.optimization_metric}' not found in runs"
            )

        logger.info(f"Best run: {best_run.run_id}")
        logger.info(
            f"Optimization metric final value: {best_run['metrics.' + self.cfg.model.optimization_metric]}"
        )

        with mlflow.start_run(run_id=best_run.run_id, nested=True):
            mlflow.set_tag("best_run", True)

        # back to hyperopt context
        mlflow.set_tag("best_run_id", best_run.run_id)

    def evaluate_trial(self, trial_number: int | None = None) -> dict | None:
        """Public API: Evaluate a specific trial or best trial.

        Args:
            trial_number: Specific trial number to evaluate. If None, evaluates best trial.

        Returns:
            Dictionary containing evaluation results with metrics and predictions, or None if failed.
        """
        try:
            # Ensure study is initialized
            if self._study is None:
                logger.error("Study not initialized. Run optimization first.")
                return None

            # Create trainer for evaluation
            eval_trainer = HAIPRTrainer(self.cfg)
            eval_trainer.is_nested = True

            # Use trainer's evaluation method
            return eval_trainer.evaluate_trial(
                trial_number=trial_number,
                study_name=self._study_name,
                storage_config=self.get_storage_config(),
            )

        except Exception as e:
            logger.error(f"Failed to evaluate trial: {e}", exc_info=True)
            return None

    def _run_evaluation(self, run, trial_number=None):
        """Run evaluation using HAIPRTrainer's trial evaluation capability."""
        logger.info("Starting evaluation with trial configuration")

        try:
            # Get trial information for logging
            study = self._require_study()
            if trial_number is None:
                trial = study.best_trial
            else:
                matching_trials = [
                    t for t in study.trials if t.number == trial_number]
                trial = matching_trials[0] if matching_trials else None

            if trial is None:
                logger.error("Could not find trial for logging comparison")
                return None

            # Run evaluation with nested MLflow run for proper organization
            with mlflow.start_run(
                run_name=f"evaluation_trial_{trial.number}", nested=True
            ) as eval_run:
                # Create trainer for evaluation with current config
                eval_trainer = HAIPRTrainer(self.cfg)
                eval_trainer.parent_run_id = run.info.run_id
                eval_trainer.is_nested = True

                # Set up the evaluation trainer's results logger with the current run
                eval_trainer.results_logger.run = eval_run
                eval_trainer.results_logger.set_run_id(eval_run.info.run_id)

                # Use trainer's evaluation method with study information
                eval_results = eval_trainer.evaluate_trial(
                    trial_number=trial_number,
                    study_name=self._study_name,
                    storage_config=self.get_storage_config(),
                )

                if eval_results is None:
                    logger.error("Evaluation failed")
                    return None

                mlflow.set_tags(
                    {
                        "run_type": "evaluation",
                        "trial_number": trial.number,
                        "trial_value": trial.value,
                        "test_split_idx": self.cfg.data.test_split_idx,
                        "evaluation_split": eval_results.get("eval_split"),
                    }
                )

                # Log trial parameters
                mlflow.log_params(
                    {f"best_{k}": v for k, v in trial.params.items()})

                # Log evaluation metrics
                if "metrics" in eval_results:
                    mlflow.log_metrics(eval_results["metrics"])

                # Log comparison between optimization and evaluation
                optimization_metric = self.cfg.model.optimization_metric
                if (
                    "metrics" in eval_results
                    and optimization_metric in eval_results["metrics"]
                ):
                    opt_value = trial.value
                    eval_value = eval_results["metrics"][optimization_metric]
                    logger.info(
                        f"Optimization {optimization_metric}: {opt_value}")
                    logger.info(
                        f"Evaluation {optimization_metric}: {eval_value}")
                    logger.info(f"Difference: {eval_value - opt_value}")

                    # Log the comparison as metrics
                    mlflow.log_metrics(
                        {
                            f"optimization_{optimization_metric}": opt_value,
                            f"evaluation_{optimization_metric}": eval_value,
                            f"difference_{optimization_metric}": eval_value - opt_value,
                        }
                    )

                # Add trial information to results
                eval_results["trial_value"] = trial.value

                return eval_results

        except Exception as e:
            logger.error(f"Failed to run evaluation: {e}", exc_info=True)
            return None

    def optimize(self) -> int:
        logger.info(
            "Starting hyperparameter optimization for job %d", self.job_num)
        self._n_trials_target = self.cfg.optuna.n_trials
        storage = self.setup_storage(self.cfg)
        self._initialize_study(self.cfg, storage)

        if self.job_num != 0 and self.cfg.parallel:
            logger.debug("Non-zero job: waiting for job 0 to set up MLflow")
            time.sleep(10)
            logger.debug(
                f"Setting experiment for job {self.job_num} to {self.cfg.mlflow.experiment_name}"
            )
            try:
                mlflow.set_experiment(self.cfg.mlflow.experiment_name)
                logger.debug(
                    f"Successfully set experiment for job {self.job_num}")
            except Exception as e:
                logger.error(
                    f"Failed to set experiment for job {self.job_num}: {e}")
                new_experiment_name = (
                    f"{self.cfg.mlflow.experiment_name}_{int(time.time())}"
                )
                logger.info(
                    f"Creating new experiment {new_experiment_name} for job {self.job_num}"
                )
                try:
                    mlflow.create_experiment(new_experiment_name)
                    mlflow.set_experiment(new_experiment_name)
                    self.cfg.mlflow.experiment_name = new_experiment_name
                    logger.info(
                        f"Successfully created and set new experiment for job {self.job_num}"
                    )
                except Exception as e2:
                    logger.error(f"Failed to create new experiment: {e2}")
                    raise

        logger.debug("Retrieving experiment and run information")
        try:
            experiments = mlflow.search_experiments(
                filter_string=f"name='{self.cfg.mlflow.experiment_name}'", max_results=1
            )
            experiment = experiments[0] if experiments else None
            if experiment is None:
                try:
                    mlflow.create_experiment(self.cfg.mlflow.experiment_name)
                    experiment = mlflow.get_experiment_by_name(
                        self.cfg.mlflow.experiment_name
                    )
                except Exception as e:
                    time.sleep(5)
                    experiment = mlflow.get_experiment_by_name(
                        self.cfg.mlflow.experiment_name
                    )
                    logger.error(f"Failed to create experiment: {e}")
                    raise ValueError(f"Failed to create experiment: {e}")
            else:
                mlflow.set_experiment(self.cfg.mlflow.experiment_name)

            existing_runs = mlflow.search_runs(
                experiment_ids=[experiment.experiment_id],
                filter_string=(
                    f"tags.run_type = 'hyperopt' AND tags.study_name = '{self._study_name}'"
                ),
            )

            if len(existing_runs) == 0:
                if self.job_num != 0 and self.cfg.parallel:
                    error_msg = "HyperOpt run not found"
                    logger.error(error_msg)
                    raise ValueError(error_msg)
                else:
                    logger.info(
                        "Job 0: No existing HyperOpt run found, creating a new one"
                    )
                    self._hyperopt_run_id = self._create_new_hyperopt_run(
                        f"HyperOpt_{self._study_name}"
                        if not self.cfg.mlflow.parent_run_name
                        else self.cfg.mlflow.parent_run_name
                    )
                    logger.info(
                        f"Job 0: Created new HyperOpt run with ID: {self._hyperopt_run_id}"
                    )
            else:
                # mlflow.search_runs may return a Pandas DataFrame or list-like depending on backend.
                # Cast to Any and access like DataFrame for compatibility.
                runs_df = cast(Any, existing_runs)
                self._hyperopt_run_id = runs_df.iloc[0].run_id
                logger.info(
                    f"Using existing HyperOpt run: {self._hyperopt_run_id}")

        except Exception as e:
            logger.error(f"Error retrieving experiment information: {e}")
            if self.job_num != 0 and self.cfg.parallel:
                raise
            else:
                logger.info(
                    "Job 0: Error retrieving run information, creating a new HyperOpt run"
                )
                self._hyperopt_run_id = self._create_new_hyperopt_run(
                    f"HyperOpt_{self._study_name}"
                    if not self.cfg.mlflow.parent_run_name
                    else self.cfg.mlflow.parent_run_name
                )
                logger.info(
                    f"Job 0: Created new HyperOpt run with ID: {self._hyperopt_run_id}"
                )

        logger.debug("Parsing search space configuration")
        self._parse_search_space()

        logger.info("Enter HyperOpt Run context")
        with mlflow.start_run(
            run_id=self._hyperopt_run_id,
            nested=True,
            log_system_metrics=True,
        ) as run:
            logger.debug("Started MLflow run with ID: %s", run.info.run_id)

            # set run_id if changed in meantime
            self.results_logger.run = run
            self.results_logger.set_run_id(run.info.run_id)

            while self._should_continue_optimization():
                trial = self._require_study().ask()
                self.current_trial = trial
                logger.info(f"Starting trial {trial.number}")
                logger.debug("Trial parameters: %s", trial.params)

                result = self._run_trial(trial)
                if result is not None:
                    trial_value, mean_metrics = result
                    logger.debug(
                        "Trial %d completed with value: %f", trial.number, trial_value
                    )
                    mlflow.log_metrics(
                        {
                            "trial_value": trial_value,
                            **{f"trial_{k}": v for k, v in mean_metrics.items()},
                        },
                        step=trial.number,
                    )

                    study = self._require_study()
                    n_complete = len(
                        [
                            t
                            for t in study.trials
                            if t.state == optuna.trial.TrialState.COMPLETE
                        ]
                    )
                    logger.debug(
                        "Completed trials: %d/%d", n_complete, self._n_trials_target
                    )
                    if n_complete >= self._n_trials_target:
                        logger.info("Final trial completed successfully")
                        # self._create_study_visualizations(study)

            logger.info("Optimization loop completed")
            self.on_optimization_complete(run)

            # Run evaluation with best model configuration if enabled
            if getattr(self.cfg, "run_evaluation", True):
                logger.info("Starting evaluation phase")
                eval_results = self._run_evaluation(run)
                if eval_results:
                    logger.info("Evaluation completed successfully")
                    mlflow.set_tag("evaluation_completed", "True")
                    mlflow.set_tag("evaluation_metrics",
                                   str(eval_results["metrics"]))
                else:
                    logger.warning("Evaluation failed")
                    mlflow.set_tag("evaluation_completed", "False")
            else:
                logger.info("Evaluation disabled, skipping evaluation phase")
                mlflow.set_tag("evaluation_completed", "Skipped")

            mlflow.set_tag("completed", "True")
            mlflow.end_run()

        return 0

    def get_optimized_params(self) -> dict:
        """Get the optimized parameters from the study."""
        return self._study.best_trial.params

    def cv_optimize(self) -> int:
        """Run optimization across all available test split indices.

        For each split index in [0, data.num_splits), set `cfg.data.test_split_idx`
        and run `optimize()`. each split gets a new trial.
        Metrics and artifacts will be
        logged per split via existing MLflow logging in `optimize()`.
        """
        logger.info(
            f"Cross-validated optimization over {self.cfg.data.num_splits} test splits"
        )

        exit_code = 0
        for split_idx in range(self.cfg.data.num_splits):
            try:
                # Mutate cfg for current split and run a separate study per split
                self.cfg.data.test_split_idx = split_idx
                logger.info(
                    f"Starting optimization for test_split_idx={split_idx}")
                # Reset study so a fresh study is created for this split
                self._study = None
                self.optimize()
            except Exception as e:
                logger.error(
                    f"Failed optimization for test_split_idx={split_idx}: {e}",
                    exc_info=True,
                )
                exit_code = 1

        return exit_code

    def _require_study(self) -> optuna.study.Study:
        if self._study is None:
            raise RuntimeError("Optuna study is not initialized")
        return self._study


def optimize(cfg: DictConfig | ListConfig) -> int:
    """Run hyperparameter optimization with optional evaluation.

    This function runs hyperparameter optimization using Optuna. After optimization
    completes, if `run_evaluation` is True in the config, it will run a single
    training split with the best model configuration on the held-out test data.

    Args:
        cfg: Configuration object containing optimization and evaluation parameters

    Returns:
        Exit code (0 for success)
    """
    optimizer = HAIPROptimizer(cfg)
    if cfg.run_cv:
        return optimizer.cv_optimize()  # run all splits
    elif not cfg.data.test_split_idx:
        logger.warning("data.test_split_idx is not set, running optimization without Test split selection. Make sure to test models on data that is not used for optimization.")
    return optimizer.optimize()


@hydra.main(version_base=None, config_path="conf", config_name="optimize")
def main(cfg: DictConfig) -> int:
    register_resolvers()
    return optimize(cfg)


if __name__ == "__main__":
    register_resolvers()
    main()
