#!/usr/bin/env python3

import logging
import multiprocessing
import time
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy

import hydra
import lightning.pytorch as pl
import mlflow
import numpy as np
import optuna
import torch
from lightning import seed_everything
from lightning.pytorch.callbacks import (EarlyStopping, LearningRateMonitor,
                                         ModelCheckpoint)
from omegaconf import DictConfig, OmegaConf
from sklearn.svm import SVC, SVR  # type: ignore[import]
from torch.nn import (BCEWithLogitsLoss, CrossEntropyLoss, HuberLoss, L1Loss,
                      MSELoss, NLLLoss, SmoothL1Loss)

from haipr.data import HAIPRData
from haipr.utils.resolvers import register_resolvers
from haipr.utils.results_logger import ResultsLogger

# Make sure we use the correct logger name
logger = logging.getLogger("haipr.trainer")
logger.setLevel(logging.WARNING)
# Define supported loss functions
loss_functions = {
    "mse": MSELoss,
    "ce": CrossEntropyLoss,
    "bce": BCEWithLogitsLoss,
    "smooth_l1": SmoothL1Loss,
    "huber": HuberLoss,
    "l1": L1Loss,
    "nll": NLLLoss,
}

# Define valid loss functions for each task type
REGRESSION_LOSSES = ["mse", "smooth_l1", "huber", "l1"]
CLASSIFICATION_LOSSES = ["ce", "nll", "bce"]


class HAIPRTrainer:
    """
    A unified trainer class for HAIPR models.

    This class handles the entire process of loading data, setting up models,
    training, and evaluating for different types of models (ESM, SVR, SVC, MLP).

    :param cfg: A DictConfig object containing configuration parameters from Hydra.
    """

    def __init__(self, cfg):
        """Initialize the HAIPRTrainer with given configuration."""
        self.cfg = cfg
        # Add optimization related attributes
        self._study = None
        self._hyperopt_run_id = None
        self._search_space = {}
        self._n_trials_target = 0
        self._study_name = None
        # num_classes will be resolved in _resolve_num_classes()

        self.setup_basic_config()
        # Initialize results logger with None run - will be updated when run context is available
        self.results_logger = ResultsLogger(
            cfg=self.cfg,
            run=None,  # Will be updated when run context is available
        )
        logger.debug(f"Results logger initialized {self.results_logger}")
        # Set model type first
        self.is_neural = self.cfg.model.is_neural
        self.is_nested = False
        self.parent_run_id = None
        self.setup_job_and_device()
        self.setup_model_specific_config()

        mlflow.set_tracking_uri(self.cfg.mlflow.tracking_uri)
        torch.manual_seed(self.cfg.seed)  # Set seed for reproducibility
        seed_everything(self.cfg.seed)

    def setup_mlflow(self):
        """Setup MLflow tracking."""
        logger.info(f"Job {self.job_num}: Setting up MLflow tracking")
        mlflow.set_tracking_uri(self.cfg.mlflow.tracking_uri)
        experiment_name = self.cfg.mlflow.experiment_name

        # Add retry logic for experiment setup
        max_retries = 3
        retry_delay = 5

        for attempt in range(max_retries):
            try:
                logger.debug(
                    f"Job {self.job_num}: Attempt {attempt + 1} to set up experiment {experiment_name}"
                )
                experiment = mlflow.get_experiment_by_name(experiment_name)

                if experiment is not None:
                    logger.debug(
                        f"Job {self.job_num}: Found existing experiment. Status: {experiment.lifecycle_stage}"
                    )

                    if experiment.lifecycle_stage == "deleted":
                        logger.warning(
                            f"Job {self.job_num}: Experiment {experiment_name} is in deleted state"
                        )

                        if self.job_num == 0:
                            logger.info(
                                f"Job 0: Recreating deleted experiment {experiment_name}"
                            )
                            mlflow.delete_experiment(experiment.experiment_id)
                            experiment_id = mlflow.create_experiment(
                                experiment_name)
                            logger.info(
                                f"Job 0: Successfully recreated experiment with ID {experiment_id}"
                            )
                        else:
                            logger.debug(
                                f"Job {self.job_num}: Waiting for Job 0 to recreate experiment"
                            )
                            time.sleep(retry_delay)

                            continue
                    else:
                        logger.info(
                            f"Job {self.job_num}: Using existing experiment ID: {experiment.experiment_id}"
                        )
                        mlflow.set_experiment(experiment_name)

                        break
                else:
                    if self.job_num == 0:
                        logger.info(
                            f"Job 0: Creating new experiment {experiment_name}")
                        experiment_id = mlflow.create_experiment(
                            experiment_name)
                        logger.info(
                            f"Job 0: Created new experiment with ID {experiment_id}"
                        )

                        break
                    else:
                        logger.debug(
                            f"Job {self.job_num}: Waiting for Job 0 to create experiment"
                        )
                        time.sleep(retry_delay)

                        continue

            except Exception as e:
                logger.error(
                    f"Job {self.job_num}: Error setting up MLflow experiment: {str(e)}",
                    exc_info=True,
                )

                if attempt < max_retries - 1:
                    time.sleep(retry_delay)
                else:
                    raise

    def setup_job_and_device(self):
        """Set up job handling and device (GPU/CPU) for training."""
        self._setup_job_handling()
        self._setup_device()

    def _setup_job_handling(self):
        """Set up job handling for both CPU and GPU runs."""
        self.is_multirun = False
        self.job_num = 0

        try:
            from hydra.core.hydra_config import HydraConfig

            if HydraConfig.initialized():
                hydra_cfg = HydraConfig.get()
                self.is_multirun = hydra_cfg.mode.name == "MULTIRUN"

                if self.is_multirun:
                    self.job_num = hydra_cfg.job.num
                    # Set job number in logging context
                    logging.getLogger("haipr").info(
                        f"Running job {self.job_num}")
        except Exception as e:
            logger.warning(
                f"Could not get Hydra config: {e}. Defaulting to single job mode"
            )

    def _setup_device(self):
        """Set up device configuration based on model type and config."""
        # Initialize devices attribute
        self.devices = None

        # Check if DDP is enabled
        self.use_ddp = getattr(self.cfg.data, "use_ddp", False)

        # For neural models, handle device configuration

        if not torch.cuda.is_available():
            logger.warning("No GPU available, using CPU")
            self.device = torch.device("cpu")
            self.devices = None  # Set to None for CPU-only training

            return

        # Check if devices is not specified or None

        if not hasattr(self.cfg, "devices") or self.cfg.devices is None:
            logger.info("No devices specified, defaulting to device 0")
            self.devices = [0]
            self.device = torch.device("cuda:0")

            return

        # Handle the "rotate" case for multirun scenarios

        if isinstance(self.cfg.devices, str) and self.cfg.devices == "rotate":
            if self.is_multirun:
                gpu_id = self.job_num % torch.cuda.device_count()
                self.devices = [gpu_id]
                logger.info(
                    f"Rotating GPU selection: using GPU {gpu_id} for job {self.job_num}"
                )
                self.device = torch.device(f"cuda:{gpu_id}")
            else:
                # If not in multirun, treat "rotate" as use all GPUs
                logger.info("Not in multirun mode, using all GPUs")
                self.cfg.devices = "all"
                self.device = torch.device("cuda")  # rank 0 GPU
                self.devices = "auto"
        # Handle specific device list
        elif isinstance(self.cfg.devices, list):
            if self.is_multirun:
                # In multirun mode, rotate through the specified devices
                device_idx = self.job_num % len(self.cfg.devices)
                gpu_id = self.cfg.devices[device_idx]
                self.devices = [gpu_id]
                logger.info(
                    f"Rotating through specified devices: using GPU {gpu_id} for job {self.job_num}"
                )
                self.device = torch.device(f"cuda:{gpu_id}")
            else:
                # In single run mode, use all specified devices
                self.devices = self.cfg.devices
                logger.info(f"Using specified devices: {self.devices}")
                # Set device to first GPU in the list for non-distributed operations

                if len(self.devices) > 0:
                    self.device = torch.device(f"cuda:{self.devices[0]}")
                else:
                    self.device = torch.device("cpu")
        # Handle "all" devices case
        elif isinstance(self.cfg.devices, str) and self.cfg.devices == "all":
            self.devices = "auto"  # Let Lightning handle all available devices
            # Use first GPU for non-distributed ops
            self.device = torch.device("cuda")
        # Handle integer device specification
        elif isinstance(self.cfg.devices, int):
            self.devices = [self.cfg.devices]
            self.device = torch.device(f"cuda:{self.cfg.devices}")
            logger.info(f"Using GPU device: {self.cfg.devices}")
        # Default case - unrecognized config, use device 0
        else:
            logger.warning(
                f"Unrecognized devices config: {self.cfg.devices}, defaulting to device 0"
            )
            self.devices = [0]
            self.device = torch.device("cuda:0")

        # Handle DDP configuration

        if self.use_ddp and torch.cuda.device_count() > 1:
            logger.info(
                "DDP enabled - will use distributed training across all available GPUs"
            )
            # For DDP, we want to use all available GPUs

            if isinstance(self.devices, list) and len(self.devices) == 1:
                # If only one device specified, use all available GPUs for DDP
                self.devices = "auto"
                logger.info(
                    "DDP enabled: Using all available GPUs for distributed training"
                )
            elif isinstance(self.devices, list) and len(self.devices) > 1:
                # If multiple devices specified, use them for DDP
                logger.info(
                    f"DDP enabled: Using specified devices {self.devices} for distributed training"
                )
        elif self.use_ddp and torch.cuda.device_count() <= 1:
            logger.warning(
                "DDP enabled but only one GPU available - falling back to single GPU training"
            )
            self.use_ddp = False

        logger.info(
            f"Device configuration for Lightning Trainer: {self.devices}")
        logger.info(f"DDP enabled: {self.use_ddp}")

    def setup_basic_config(self):
        """Set up basic configuration parameters."""
        self._best_val_loss = float("inf")

        # Data configuration
        self.split_method = self.cfg.data.split_method
        self.run_single_split = self.cfg.trainer.run_single_split
        self.num_splits = self.cfg.data.num_splits

        # Resolve num_classes parameter before any MLflow logging begins
        # This prevents conflicts when MLflow tries to log the same parameter with different values
        self._resolve_num_classes()

    def _resolve_num_classes(self):
        """Resolve the num_classes parameter based on task and model type.

        This method ensures that num_classes is properly set before any MLflow logging
        begins, preventing conflicts when MLflow tries to log the same parameter with
        different values.
        """
        # Handle task-based num_classes resolution

        if self.cfg.task == "classification" and self.cfg.num_classes == 0:
            logger.warning(
                "num_classes is 0 for classification task, defaulting to 2")
            self.cfg.num_classes = 2
        elif self.cfg.task == "regression" and self.cfg.num_classes > 0:
            logger.warning(
                "num_classes is set for regression task, defaulting to 0")
            self.cfg.num_classes = 0
        elif self.cfg.model.name == "svc" and self.cfg.num_classes == 0:
            logger.warning("num_classes is 0 for SVC, defaulting to 2")
            self.cfg.num_classes = 2
        elif self.cfg.model.name == "svr" and self.cfg.num_classes > 0:
            logger.warning("num_classes is set for SVR, defaulting to 0")
            self.cfg.num_classes = 0

        # Also update the instance variable to ensure consistency
        self.num_classes = self.cfg.num_classes

    def setup_model_specific_config(self):
        """Set up model-specific configuration."""

        if not hasattr(self.cfg, "model"):
            raise ValueError(
                "Model configuration is missing. Please check your config.yaml"
            )

        if not hasattr(self.cfg.model, "type"):
            raise ValueError("Model type is not specified in the model config")

        # num_classes is already resolved in _resolve_num_classes()
        # Just log the final value for debugging
        logger.info(
            f"Setting up {self.cfg.task} task with {self.num_classes} classes"
        )

        self.setup_predictor(self.cfg)

    def setup_neural_network_config(self):
        """Set up neural network specific configuration."""
        # Training configuration
        self.num_epochs = self.cfg.trainer.max_epochs
        self.patience = self.cfg.trainer.patience

        # Initialize base model with hydra initialize
        # self.base_model = hydra.utils.instantiate(self.cfg.model)

        # Handle global batch size configuration
        global_batch_size = getattr(self.cfg, 'global_batch_size', None)

        if global_batch_size is not None and global_batch_size > 0:
            model_batch_size = self.cfg.trainer.batch_size

            if global_batch_size > model_batch_size:
                # Use model batch size and accumulate gradients to reach global batch size
                self.batch_size = model_batch_size
                self.accumulate_grad_batches = max(
                    1, global_batch_size // model_batch_size)
                logger.info(
                    f"Using model batch size: {model_batch_size} with accumulate_grad_batches: {self.accumulate_grad_batches} to achieve effective batch size: {global_batch_size}")
            else:
                # Use global batch size directly, no accumulation needed
                self.batch_size = global_batch_size
                self.accumulate_grad_batches = 1
                logger.info(f"Using global batch size: {global_batch_size}")
            self.cfg.trainer.batch_size = self.batch_size
            self.cfg.model.batch_size = self.batch_size
        else:
            # Use original logic for batch size adjustment based on GPU memory

            if torch.cuda.is_available():
                total_memory = (
                    torch.cuda.get_device_properties(
                        self.device.index).total_memory / 1e9
                )
                logger.debug(
                    f"Total memory: {total_memory} on device {self.device}")

                if total_memory < 32:  # 32 GB
                    self.batch_size = min(self.cfg.trainer.batch_size, 8)

                    if self.batch_size < 8:
                        logger.warning(
                            f"Reduced batch size to {self.batch_size} due to limited GPU memory"
                        )
                else:
                    self.batch_size = self.cfg.trainer.batch_size
            else:
                self.batch_size = self.cfg.trainer.batch_size

            # Use default accumulate_grad_batches from config
            self.accumulate_grad_batches = getattr(
                self.cfg.trainer, 'accumulate_grad_batches', 1)

        # Set appropriate default loss based on task if not specified or if null

        if not hasattr(self.cfg.model, "loss") or self.cfg.model.loss is None:
            self.cfg.model.loss = "mse" if self.num_classes == 0 else "ce"
            logger.info(
                f"Setting default loss function for {self.cfg.task} task: {self.cfg.model.loss}"
            )

        # Validate loss function exists and is appropriate for task
        loss_name = self.cfg.model.loss.lower()

        if loss_name not in loss_functions:
            raise ValueError(
                f"Unknown loss function: {loss_name}. "
                f"Supported loss functions: {list(loss_functions.keys())}"
            )

        # Validate loss is appropriate for task

        if self.num_classes == 0 and loss_name not in REGRESSION_LOSSES:
            raise ValueError(
                f"Loss function {loss_name} is not appropriate for regression task. "
                f"Use one of: {REGRESSION_LOSSES}"
            )
        elif self.num_classes > 0 and loss_name not in CLASSIFICATION_LOSSES:
            raise ValueError(
                f"Loss function {loss_name} is not appropriate for classification task. "
                f"Use one of: {CLASSIFICATION_LOSSES}"
            )

        # Set the criterion
        self.criterion = loss_functions[loss_name]()

    def setup_data(self, cfg):
        """
        Set up the data for training and evaluation.

        :param cfg: Configuration object containing parameters.
        """
        # Initialize dataset with full config
        self.data = HAIPRData(
            config=cfg,
        )
        self.results_logger.data = self.data

        # only prepare features in job 0, rest of the jobs will use cached features,
        # and wait in a loop for job 0 to finish creating the features

        if self.job_num == 0 and self.cfg.parallel:
            # generates features
            self.data.prepare_features()
        else:
            self.data.wait_for_features_ready()
            # loads features from cache, does not load embedder model when cache is hit
            self.data.prepare_features()
        # If we have a test split specified, set it aside first

        if hasattr(self.cfg.data, 'test_split_idx') and self.cfg.data.test_split_idx is not None:
            self.data.set_test_data(
                self.cfg.data.split_method,
                self.cfg.data.test_split_idx
            )
        else:
            # No test split specified, just generate splits on full data
            self.data.generate_splits()

        # Then subsample from the remaining training data

        if (
            self.cfg.data.subsample_threshold > 0
            and len(self.data.active_idx) > self.cfg.data.subsample_threshold
        ):
            self.data.subsample_data(self.cfg.data.subsample_threshold)
            # Regenerate splits on the subsampled training data
            self.data.generate_splits()

    def tune(self, cfg, parent_run_id=None, run_name=None):
        self.setup_mlflow()  # setup mlflow
        mlflow.set_experiment(self.cfg.mlflow.experiment_name)
        """Run training and evaluation."""
        logger.info(f"Starting training for job {self.job_num}")

        run_name = f"{self.cfg.model.name}_"

        if hasattr(self.cfg, "embedder") and self.cfg.embedder is not None:
            run_name += f"{self.cfg.embedder.model}_"
        run_name += f"{'classification' if self.num_classes > 0 else 'regression'}_"
        run_name += f"{self.cfg.benchmark.name}_"
        run_name += f"{self.cfg.data.split_method}"
        log_system_metrics = True

        if self.cfg.mlflow.parent_run_name:
            run_name = self.cfg.mlflow.parent_run_name
        try:
            with mlflow.start_run(
                run_name=run_name,
                nested=self.is_nested,
                log_system_metrics=log_system_metrics,
            ) as parent_run:
                logger.debug(f"Starting run {run_name}")
                mlflow.set_tag("is_parent", True)
                mlflow.set_tag("job_num", self.job_num)

                # Log experiment setup only in parent run
                parent_run_id = mlflow.active_run().info.run_id
                self.parent_run_id = parent_run_id
                logger.debug(f"Parent run using run ID: {parent_run_id}")
                mlflow.log_metrics({"job_num": self.job_num})

                if self.is_multirun:
                    mlflow.set_tag("is_multirun", True)

                # Update results logger with current run
                self.results_logger.run = parent_run
                self.results_logger.set_run_id(parent_run.info.run_id)

                # Log configuration
                self.results_logger.log_input_sample(
                    self.data.data,
                    context="HAIPRData",
                    tags={"benchmark_name": self.cfg.benchmark.name},
                )

                self.results_logger.log_config(self.cfg)
                try:
                    mean_metrics, predictions = self.run_splits()
                    # TODO: check if all metrics are finite and not nan and add tag to run

                    return mean_metrics, predictions

                except Exception as e:
                    logger.error(f"Error in training: {e}", exc_info=True)

                    if torch.cuda.is_available() and self.is_neural:
                        torch.cuda.empty_cache()

                    return None

        except Exception as e:
            # Safely get run ID if available
            active_run = (
                mlflow.active_run() or mlflow.last_active_run() or "no_active_run"
            )
            logger.error(
                f"Error in MLflow run: {active_run} {e}", exc_info=True)

            return None

    def _get_process_device(self, process_idx):
        """Get the appropriate device for a given process index.

        Args:
            process_idx: Index of the current process

        Returns:
            str: Device specification for PyTorch Lightning
        """

        if not torch.cuda.is_available() or not self.is_neural:
            return "cpu"

        num_gpus = torch.cuda.device_count()

        if num_gpus == 0:
            return "cpu"

        # Assign GPU round-robin style
        gpu_idx = process_idx % num_gpus

        return f"cuda:{gpu_idx}"

    def _run_split_parallel_worker(self, split_info):
        """Run a single split in parallel mode (non-multirun only).

        Args:
            split_info: tuple of (split_num, process_idx)

        Returns:
            tuple: (metrics_dict, predictions) if successful, None if failed
        """
        split_num, process_idx = split_info
        logger.info(
            f"Running Split {split_num + 1} of {self.data.get_num_splits()} on process {process_idx}"
        )

        try:
            # Ensure the parent run is active in this worker, then open nested split run

            if self.parent_run_id is not None:
                with mlflow.start_run(run_id=self.parent_run_id):
                    with mlflow.start_run(
                        run_name=f"split_{split_num + 1}", nested=True
                    ) as split_run:
                        mlflow.set_tag("split", split_num + 1)
                        mlflow.set_tag("parent_run_id", self.parent_run_id)
                        mlflow.set_tag("process_id", process_idx)
                        mlflow.set_tag("benchmark", self.cfg.benchmark.name)
                        mlflow.set_tag("model_type", self.cfg.model.type)
                        mlflow.set_tag("model_name", self.cfg.model.name)
                        # Use a dedicated results logger for this split run to avoid cross-thread mutation
                        local_results_logger = ResultsLogger(
                            cfg=self.cfg, run=split_run
                        )
                        local_results_logger.set_run_id(split_run.info.run_id)
                        local_results_logger.data = self.data
                        logger.debug(
                            f"Split {split_num + 1} using run ID: {split_run.info.run_id}"
                        )
                        # No config logging in split runs - parent run already logged everything

                        current_trainer_instance = None
                        # Set device for this process if using neural network

                        if self.is_neural:
                            device = self._get_process_device(process_idx)
                            logger.info(
                                f"Using device {device} for split {split_num + 1}"
                            )

                            # Update trainer kwargs for this process with the correct run ID
                            trainer_kwargs = self.get_trainer_kwargs(
                                run_id=split_run.info.run_id)
                            trainer_kwargs["accelerator"] = (
                                "gpu" if "cuda" in device else "cpu"
                            )
                            trainer_kwargs["devices"] = (
                                [int(device.split(":")[-1])]

                                if "cuda" in device
                                else 1
                            )
                            trainer_kwargs["num_sanity_val_steps"] = 0
                            current_trainer_instance = pl.Trainer(
                                **trainer_kwargs)

                        # Get train/val indices for this split
                        train_indices, val_indices = self.data.splits[split_num]
                        logger.debug(
                            f"Train size: {len(train_indices)}, Val size: {len(val_indices)}"
                        )

                        # Set current split before running
                        self._current_split = split_num

                        # Deepcopy the predictor model for this specific split/trainer
                        model_instance_for_split = deepcopy(self.model)

                        split_results = model_instance_for_split.fit_model(
                            self.data,
                            train_indices,
                            val_indices,
                            current_trainer_instance,
                            self.cfg,
                        )

                        # Log model for this split if enabled

                        if self.cfg.mlflow.log_models:
                            try:
                                # Set current split in results logger for proper model naming
                                local_results_logger.set_current_split(
                                    split_num)

                                # Log the model
                                local_results_logger.log_model(
                                    model_instance_for_split,
                                    model_name=f"{self.cfg.mlflow.registered_model_name}_split_{split_num + 1}",
                                )
                            except Exception as e:
                                logger.error(
                                    f"Failed to log model for split {split_num + 1}: {e}",
                                    exc_info=True,
                                )

                        if split_results is None:
                            logger.error(
                                f"Failed to get results for split {split_num + 1}"
                            )

                            return None

                        return split_results["metrics"], split_results["predictions"]
            else:
                # Fallback: open only a nested run (no explicit parent context)
                with mlflow.start_run(run_name=f"split_{split_num + 1}", nested=True) as split_run:
                    mlflow.set_tag("split", split_num + 1)
                    mlflow.set_tag("process_id", process_idx)
                    mlflow.set_tag("benchmark", self.cfg.benchmark.name)
                    mlflow.set_tag("model_type", self.cfg.model.type)
                    mlflow.set_tag("model_name", self.cfg.model.name)
                    # Use a dedicated results logger for this split run to avoid cross-thread mutation
                    local_results_logger = ResultsLogger(
                        cfg=self.cfg, run=split_run
                    )
                    local_results_logger.set_run_id(split_run.info.run_id)
                    local_results_logger.data = self.data
                    logger.debug(
                        f"Split {split_num + 1} using run ID: {split_run.info.run_id}"
                    )
                    # No config logging in split runs - parent run already logged everything

                    current_trainer_instance = None
                    # Set device for this process if using neural network

                    if self.is_neural:
                        device = self._get_process_device(process_idx)
                        logger.info(
                            f"Using device {device} for split {split_num + 1}")

                        # Update trainer kwargs for this process with the correct run ID
                        trainer_kwargs = self.get_trainer_kwargs(
                            run_id=split_run.info.run_id)
                        trainer_kwargs["accelerator"] = (
                            "gpu" if "cuda" in device else "cpu"
                        )
                        trainer_kwargs["devices"] = (
                            [int(device.split(":")[-1])
                             ] if "cuda" in device else None
                        )
                        trainer_kwargs["num_sanity_val_steps"] = 0
                        current_trainer_instance = pl.Trainer(**trainer_kwargs)

                    # Get train/val indices for this split
                    train_indices, val_indices = self.data.splits[split_num]
                    logger.debug(
                        f"Train size: {len(train_indices)}, Val size: {len(val_indices)}"
                    )

                    # Set current split before running
                    self._current_split = split_num

                    # Deepcopy the predictor model for this specific split/trainer
                    model_instance_for_split = deepcopy(self.model)
                    local_results_logger.log_config(self.cfg)

                    split_results = model_instance_for_split.fit_model(
                        self.data,
                        train_indices,
                        val_indices,
                        current_trainer_instance,
                        self.cfg,
                    )

                    # Log model for this split if enabled

                    if self.cfg.mlflow.log_models:
                        try:
                            # Set current split in results logger for proper model naming
                            local_results_logger.set_current_split(split_num)

                            # Log the model
                            local_results_logger.log_model(
                                model_instance_for_split,
                                model_name=f"{self.cfg.mlflow.registered_model_name}_split_{split_num + 1}",
                            )
                        except Exception as e:
                            logger.error(
                                f"Failed to log model for split {split_num + 1}: {e}",
                                exc_info=True,
                            )

                    if split_results is None:
                        logger.error(
                            f"Failed to get results for split {split_num + 1}")

                        return None

                    return split_results["metrics"], split_results["predictions"]

        except Exception as e:
            logger.error(
                f"Error in split {split_num + 1}: {e}", exc_info=True, stack_info=True)

            if torch.cuda.is_available() and self.is_neural:
                torch.cuda.empty_cache()

            return None

    def setup_predictor(self, cfg):
        """Setup a predictor based on the configuration."""
        # instantiate predictor
        self.setup_data(cfg)

        if self.is_neural:
            self.setup_neural_network_config()
        self.model = hydra.utils.instantiate(self.cfg.model)
        self.model.setup_model(self.data, self.cfg)

        if self.is_neural:
            if (
                hasattr(self.cfg, "peft")
                and self.cfg.peft.target_modules is not None
                and not self.model.peft_initialized
            ):
                self.model._initialize_peft_adapters(
                    self.cfg.peft
                )  # inject peft adapters if not already initialized

    def run_predictor(
        self, train_indices, val_indices, split_num=None, trainer_instance=None
    ):

        results = self.model.fit_model(
            self.data,
            train_indices,
            val_indices,
            trainer_instance,  # can be none for non pytorch models
            self.cfg,
        )

        predictions = results["predictions"]
        metrics = results["metrics"]

        self.results_logger.log_plots(
            predictions["true_values"],
            predictions["predictions"],
            predictions["probabilities"] if "probabilities" in predictions else None,
            sample_indices=val_indices,
        )

        # Log model for this split if enabled

        if self.cfg.mlflow.log_models:
            try:
                # Set current split in results logger for proper model naming
                self.results_logger.set_current_split(split_num)

                # Log the model
                self.results_logger.log_model(
                    self.model,
                    model_name=f"{self.cfg.mlflow.registered_model_name}_split_{split_num + 1}",
                )
            except Exception as e:
                logger.error(
                    f"Failed to log model for split {split_num + 1}: {e}", exc_info=True
                )
                # Don't fail the entire run if model logging fails

        return {"metrics": metrics, "predictions": predictions}

    def get_trainer_kwargs(self, run_id=None):
        """Get kwargs for PyTorch Lightning Trainer.

        Args:
            run_id: Optional MLflow run ID to use for the logger. If None, uses the current active run.
        """
        callbacks = []
        callbacks.append(
            EarlyStopping(
                monitor="val_loss",
                patience=self.patience,
                mode="min",
                min_delta=0.005,
                check_finite=True,
                strict=False,  # Don't fail if metric doesn't exist initially
                check_on_train_epoch_end=False,
            )
        )
        # Add checkpointing if enabled

        if hasattr(self.cfg.trainer, "model_checkpointing"):
            if self.cfg.trainer.model_checkpointing:
                callbacks.append(
                    ModelCheckpoint(
                        monitor="val_loss",
                        mode="min",
                        save_top_k=2,
                        save_last=True,
                    )
                )

        # Add learning rate monitor if scheduler is enabled

        if (
            hasattr(self.cfg.trainer, "lrscheduler")
            and self.cfg.trainer.lrscheduler.enabled
        ):
            callbacks.append(LearningRateMonitor(logging_interval=None))

        # Create MLflow logger with the specified run ID or current active run
        # This ensures each trainer gets a fresh logger for the correct run

        if run_id is None:
            current_run_id = mlflow.active_run().info.run_id
        else:
            current_run_id = run_id

        train_logger = pl.loggers.MLFlowLogger(
            tracking_uri=self.cfg.tracking.mlflow.uri,
            experiment_name=self.cfg.mlflow.experiment_name,
            run_id=current_run_id,
        )

        trainer_kwargs = {
            "max_epochs": self.cfg.trainer.max_epochs,
            "default_root_dir": self.cfg.trainer.default_root_dir,
            "enable_checkpointing": self.cfg.trainer.model_checkpointing,
            "accelerator": "auto",
            # list of device ids [1,2,4], or all [all visible to torch] or int
            "devices": self.devices,
            "strategy": self.cfg.trainer.strategy,  # DDP strategy configuration
            "callbacks": callbacks,
            "enable_progress_bar": self.cfg.trainer.enable_progress_bar,
            "enable_model_summary": self.cfg.trainer.enable_model_summary,
            "gradient_clip_val": self.cfg.trainer.gradient_clip_val,
            "accumulate_grad_batches": getattr(self, 'accumulate_grad_batches', self.cfg.trainer.accumulate_grad_batches),
            "logger": train_logger,
            "log_every_n_steps": self.cfg.trainer.log_every_n_steps,
            "precision": self.cfg.trainer.precision,
            "num_sanity_val_steps": 0,
            "fast_dev_run": self.cfg.trainer.fast_dev_run,
        }

        logger.debug("Trainer kwargs: %s", trainer_kwargs)

        return trainer_kwargs

    def run_split(self, split_num, run_name=None):
        """Run training and evaluation for a single split.

        Args:
            split_num: Index of the split to run

        Returns:
            tuple: (metrics_dict, predictions) if successful, None if failed
        """
        logger.info(
            f"Running Split {split_num + 1} of {self.data.get_num_splits()}")
        try:

            if run_name is None:
                run_name = f"split_{split_num + 1}"
            with mlflow.start_run(run_name=run_name, nested=True):
                mlflow.set_tag("split", split_num + 1)
                mlflow.set_tag("run_type", "split")
                # Tag parent run if known and update logger to child run

                if self.parent_run_id is not None:
                    mlflow.set_tag("parent_run_id", self.parent_run_id)
                current_run_id = mlflow.active_run().info.run_id
                self.results_logger.set_run_id(current_run_id)

                return self._execute_split(split_num, 0)

        except Exception as e:
            logger.error(f"Error in split {split_num + 1}: {e}", exc_info=True)

            if torch.cuda.is_available() and self.is_neural:
                torch.cuda.empty_cache()

            return None

    def _execute_split(self, split_num, process_idx):
        """Execute the actual split training and evaluation.

        This is separated from run_split to handle MLflow run contexts properly.
        """
        self._current_split = split_num
        device = self._get_process_device(process_idx)
        logger.info(f"Using device {device} for split {split_num + 1}")
        trainer = None

        if self.is_neural:
            # Update trainer kwargs for this process
            trainer_kwargs = self.get_trainer_kwargs()
            trainer_kwargs["accelerator"] = "gpu" if "cuda" in device else "cpu"
            trainer_kwargs["devices"] = (
                [int(device.split(":")[-1])] if "cuda" in device else None
            )
            trainer = pl.Trainer(**trainer_kwargs)

        # Get train/val indices for this split
        train_indices, val_indices = self.data.splits[split_num]
        logger.debug(
            f"Train size: {len(train_indices)}, Val size: {len(val_indices)}")

        # Run appropriate training function
        split_results = self.run_predictor(
            train_indices, val_indices, split_num, trainer
        )

        if split_results is None:
            logger.error(f"Failed to get results for split {split_num + 1}")

            return None

        # Log metrics for this split
        self.results_logger.log_metrics(split_results["metrics"], step=None)

        return split_results["metrics"], split_results["predictions"]

    def evaluate_trial(self, trial_number: int | None = None,
                       study_name: str | None = None,
                       storage_config: dict | None = None) -> dict | None:
        """Public API: Evaluate a specific trial or best trial from a study.

        Args:
            trial_number: Specific trial number to evaluate. If None, evaluates best trial.
            study_name: Name of the Optuna study. If None, tries to determine from config.
            storage_config: Storage configuration dict. If None, tries to determine from config.

        Returns:
            Dictionary containing evaluation results with metrics and predictions, or None if failed.
        """
        try:
            # Import here to avoid circular imports
            from haipr.optimize import HAIPROptimizer

            # Set up study information if not provided

            if study_name is None or storage_config is None:
                # Create a temporary optimizer to get study setup methods
                temp_optimizer = HAIPROptimizer(self.cfg)

                if study_name is None:
                    study_name = temp_optimizer.create_study_name(self.cfg)

                if storage_config is None:
                    storage_config = temp_optimizer.get_storage_config()

            # Load the study
            study = self._load_study_for_evaluation(study_name, storage_config)

            # Get the specific trial
            trial = self._get_trial_from_study(study, trial_number)

            if trial is None:
                logger.error("No trial found to evaluate")

                return None

            logger.info(f"Evaluating trial {trial.number}")
            logger.info(f"Trial value: {trial.value}")
            logger.info(f"Trial parameters: {trial.params}")

            # Create trial configuration
            temp_optimizer = HAIPROptimizer(self.cfg)
            trial_config = temp_optimizer.update_trial_config(
                self.cfg, trial.params)

            # Execute the evaluation

            return self._evaluate_with_trial_config(trial_config, trial.number)

        except Exception as e:
            logger.error(f"Failed to evaluate trial: {e}", exc_info=True)

            return None

    def _load_study_for_evaluation(self, study_name: str, storage_config: dict) -> optuna.Study:
        """Load an existing Optuna study for evaluation purposes."""
        try:
            # Import here to avoid circular imports
            from haipr.optimize import HAIPROptimizer

            # Create a temporary optimizer to use its storage setup method
            temp_optimizer = HAIPROptimizer(self.cfg)

            # Create a minimal config object for storage setup
            storage_cfg = OmegaConf.create({
                "optuna": {
                    "storage": storage_config.get("storage"),
                    "grace_period": storage_config.get("grace_period", 600),
                    "max_retry": storage_config.get("max_retry", 3),
                }
            })

            storage = temp_optimizer.setup_storage(storage_cfg)
            study = temp_optimizer.load_study(study_name, storage)

            logger.info(
                f"Successfully loaded study '{study_name}' for evaluation")

            return study

        except Exception as e:
            logger.error(
                f"Failed to load study for evaluation: {e}", exc_info=True)
            raise

    def _get_trial_from_study(self, study: optuna.Study, trial_number: int | None) -> optuna.Trial:
        """Get a specific trial or best trial from study."""
        try:
            if trial_number is None:
                # Get best trial

                if not study.trials:
                    logger.error("No trials found in study")

                    return None

                trial = study.best_trial
                logger.info(f"Using best trial: {trial.number}")
            else:
                # Get specific trial by number
                matching_trials = [
                    t for t in study.trials if t.number == trial_number]

                if not matching_trials:
                    logger.error(f"Trial {trial_number} not found in study")

                    return None

                trial = matching_trials[0]
                logger.info(f"Using specified trial: {trial.number}")

            return trial

        except Exception as e:
            logger.error(f"Failed to get trial from study: {e}", exc_info=True)

            return None

    def _evaluate_with_trial_config(self, trial_config: DictConfig, trial_number: int) -> dict | None:
        """Execute evaluation with trial-specific configuration."""
        try:
            # Set evaluation-specific configurations
            eval_cfg = OmegaConf.create(
                OmegaConf.to_container(trial_config, resolve=False))
            eval_cfg.trainer.run_single_split = self.cfg.data.test_split_idx
            eval_cfg.mlflow.log_models = True
            eval_cfg.data.test_split_idx = None  # use full data for evaluation
            eval_cfg.data.subsample_threshold = 0  # no subsampling for evaluation

            # Create a new trainer with the trial configuration
            eval_trainer = HAIPRTrainer(eval_cfg)
            eval_trainer.is_nested = True
            eval_trainer.parent_run_id = getattr(self, 'parent_run_id', None)
            eval_trainer.data.generate_splits()

            # Use the existing _execute_split method for consistency
            split_idx = self.cfg.data.test_split_idx if self.cfg.data.test_split_idx is not None else 0
            mean_metrics, predictions = eval_trainer._execute_split(
                split_idx, 0)

            if mean_metrics:
                logger.info("Evaluation completed successfully")
                logger.info(f"Evaluation metrics: {mean_metrics}")

                return {
                    "metrics": mean_metrics,
                    "predictions": predictions,
                    "trial_number": trial_number,
                    "eval_split": split_idx,
                }
            else:
                logger.error("Evaluation failed - no metrics returned")

                return None

        except Exception as e:
            logger.error(
                f"Failed to execute trial evaluation: {e}", exc_info=True)

            return None

    def run_splits(self) -> tuple[dict, list]:
        """Run training and evaluation over data splits."""
        mean_metrics = {}
        all_metrics = []
        all_predictions = []

        try:
            if self.is_multirun or (
                hasattr(self.cfg, "run_sequential") and self.cfg.run_sequential
            ):
                # Use original sequential behavior for multirun, also for slurm based parallel
                logger.info("Running splits sequentially (multirun mode)")

                for split_num in range(self.data.get_num_splits()):
                    self._current_split = split_num
                    result = self.run_split(
                        split_num
                    )  # Use original run_split for multirun

                    if result is not None:
                        metrics, predictions = result
                        all_metrics.append(metrics)
                        all_predictions.append(predictions)
            else:
                # Parallel execution only for single runs
                logger.info("Running splits in parallel (single run mode)")

                # Determine number of processes based on available resources

                if self.is_neural and torch.cuda.is_available():
                    num_gpus = torch.cuda.device_count()
                    num_processes = min(num_gpus, self.data.get_num_splits())
                    logger.info(
                        f"Using {num_processes} processes with GPU support")
                else:
                    num_processes = min(
                        multiprocessing.cpu_count(), self.data.get_num_splits()
                    )
                    logger.info(f"Using {num_processes} CPU processes")

                # Create list of (split_num, process_idx) tuples for mapping
                split_process_pairs = [
                    (split_num, split_num % num_processes)

                    for split_num in range(self.data.get_num_splits())
                ]

                if self.run_single_split:
                    split_process_pairs = [(int(self.run_single_split), 0)]
                    num_processes = 1

                # Process splits in parallel using a thread pool
                with ThreadPoolExecutor(max_workers=num_processes) as executor:
                    # Map run_split across all splits with their process indices
                    future_results = list(
                        executor.map(
                            self._run_split_parallel_worker, split_process_pairs
                        )
                    )

                    # Process results

                    for result in future_results:
                        if result is not None:
                            metrics, predictions = result
                            all_metrics.append(metrics)
                            all_predictions.append(predictions)

            # Calculate mean metrics across all splits

            if all_metrics:
                for metric in all_metrics[0].keys():
                    values = [m[metric] for m in all_metrics]
                    mean_metrics[f"mean_{metric}"] = float(np.mean(values))
                    mean_metrics[f"{metric}_std"] = float(np.std(values))
                    logger.debug(
                        f"Mean {metric}: {mean_metrics[f'mean_{metric}']}")

                # Log final metrics and predictions
                self.results_logger.log_metrics(mean_metrics, step=None)
                self.results_logger.log_run_metrics(
                    all_metrics, all_predictions, self.data
                )

            return mean_metrics, all_predictions

        except Exception as e:
            logger.error(f"Error in run_splits: {e}", exc_info=True)

            if torch.cuda.is_available() and self.is_neural:
                torch.cuda.empty_cache()

            return {}, []


def run(cfg):
    """Main entry point for the HAIPR trainer."""
    # logger.debug(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")

    # Initialize trainer
    trainer = HAIPRTrainer(cfg)
    trainer.setup_data(cfg)  # prepares features
    torch.set_float32_matmul_precision("medium")

    if hasattr(cfg, "prepare_features") and cfg.prepare_features:
        logger.info("Features prepared and cached")

        return 0  # we are done here setup_data already prepared features

    # Run training and get mean metrics
    mean_metrics, all_predictions = trainer.tune(cfg)

    if mean_metrics is None:
        logger.error("Training failed")

        return float("inf")

    # Extract optimization metric
    optimization_metric = cfg.model.optimization_metric

    if optimization_metric not in mean_metrics:
        logger.error(
            f"Optimization metric '{optimization_metric}' not found in available metrics: {list(mean_metrics.keys())}"
        )
        raise KeyError(
            f"Optimization metric '{optimization_metric}' not found")

    val_loss = mean_metrics[optimization_metric]
    logger.info(
        f"Final optimization value ({optimization_metric}): {val_loss}")
    # legacy when using hydra builtin optuna

    return float(val_loss)


@hydra.main(version_base=None, config_path="conf", config_name="train")
def main(cfg: DictConfig) -> float:
    register_resolvers()

    return run(cfg)


if __name__ == "__main__":
    register_resolvers()
    main()
