#!/usr/bin/env python3
# author: Jannis de Riz

import logging
import multiprocessing
import time
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy

import optuna
import hydra
import lightning.pytorch as pl
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
import mlflow
import numpy as np
import torch
from lightning import seed_everything
from omegaconf import DictConfig, OmegaConf
from torch.nn import (
    BCEWithLogitsLoss,
    CrossEntropyLoss,
    HuberLoss,
    L1Loss,
    MSELoss,
    NLLLoss,
    SmoothL1Loss,
)

from haipr.data import HAIPRData
from haipr.utils.results_logger import ResultsLogger
from haipr.utils.resolvers import register_resolvers

# Make sure we use the correct logger name
logger = logging.getLogger("haipr.trainer")
logger.setLevel(logging.WARNING)
# Define supported loss functions
loss_functions = {
    "mse": MSELoss,
    "ce": CrossEntropyLoss,
    "bce": BCEWithLogitsLoss,
    "smooth_l1": SmoothL1Loss,
    "huber": HuberLoss,
    "l1": L1Loss,
    "nll": NLLLoss,
}

# Define valid loss functions for each task type
REGRESSION_LOSSES = ["mse", "smooth_l1", "huber", "l1"]
CLASSIFICATION_LOSSES = ["ce", "nll", "bce"]


class HAIPRTrainer:
    """
    A unified trainer class for HAIPR models.

    This class handles the entire process of loading data, setting up models,
    training, and evaluating for different types of models (ESM, SVR, SVC, MLP).

    :param cfg: A DictConfig object containing configuration parameters from Hydra.
    """

    def __init__(self, cfg):
        """Initialize the HAIPRTrainer with given configuration."""
        self.cfg = cfg
        # Add optimization related attributes
        self._study = None
        self._hyperopt_run_id = None
        self._search_space = {}
        self._n_trials_target = 0
        self._study_name = None
        # num_classes will be resolved in _resolve_num_classes()

        self.setup_basic_config()
        # Initialize results logger with None run - will be updated when run context is available
        self.results_logger = ResultsLogger(
            cfg=self.cfg,
            run=None,  # Will be updated when run context is available
        )
        logger.debug(f"Results logger initialized {self.results_logger}")
        # Set model type first
        self.is_neural = self.cfg.model.is_neural
        self.is_nested = False
        self.parent_run_id = None
        self.setup_job_and_device()
        self.setup_model_specific_config()

        mlflow.set_tracking_uri(self.cfg.mlflow.tracking_uri)
        torch.manual_seed(self.cfg.seed)  # Set seed for reproducibility
        seed_everything(self.cfg.seed)

    def setup_mlflow(self):
        """Setup MLflow tracking."""
        logger.info(f"Job {self.job_num}: Setting up MLflow tracking")
        mlflow.set_tracking_uri(self.cfg.mlflow.tracking_uri)
        experiment_name = self.cfg.mlflow.experiment_name

        # Add retry logic for experiment setup
        max_retries = 3
        retry_delay = 5

        for attempt in range(max_retries):
            try:
                logger.debug(
                    f"Job {self.job_num}: Attempt {attempt + 1} to set up experiment {experiment_name}"
                )
                experiment = mlflow.get_experiment_by_name(experiment_name)

                if experiment is not None:
                    logger.debug(
                        f"Job {self.job_num}: Found existing experiment. Status: {experiment.lifecycle_stage}"
                    )

                    if experiment.lifecycle_stage == "deleted":
                        logger.warning(
                            f"Job {self.job_num}: Experiment {experiment_name} is in deleted state"
                        )

                        if self.job_num == 0:
                            logger.info(
                                f"Job 0: Recreating deleted experiment {experiment_name}"
                            )
                            mlflow.delete_experiment(experiment.experiment_id)
                            experiment_id = mlflow.create_experiment(
                                experiment_name
                            )
                            logger.info(
                                f"Job 0: Successfully recreated experiment with ID {experiment_id}"
                            )
                        else:
                            logger.debug(
                                f"Job {self.job_num}: Waiting for Job 0 to recreate experiment"
                            )
                            time.sleep(retry_delay)

                            continue
                    else:
                        logger.info(
                            f"Job {self.job_num}: Using existing experiment ID: {experiment.experiment_id}"
                        )
                        mlflow.set_experiment(experiment_name)

                        break
                else:
                    if self.job_num == 0:
                        logger.info(
                            f"Job 0: Creating new experiment {experiment_name}"
                        )
                        experiment_id = mlflow.create_experiment(
                            experiment_name
                        )
                        logger.info(
                            f"Job 0: Created new experiment with ID {experiment_id}"
                        )

                        break
                    else:
                        logger.debug(
                            f"Job {self.job_num}: Waiting for Job 0 to create experiment"
                        )
                        time.sleep(retry_delay)

                        continue

            except Exception as e:
                logger.error(
                    f"Job {self.job_num}: Error setting up MLflow experiment: {str(e)}",
                    exc_info=True,
                )

                if attempt < max_retries - 1:
                    time.sleep(retry_delay)
                else:
                    raise

    def setup_job_and_device(self):
        """Set up job handling and device (GPU/CPU) for training."""
        self._setup_job_handling()
        self._setup_device()

    def _setup_job_handling(self):
        """Set up job handling for both CPU and GPU runs."""
        self.is_multirun = False
        self.job_num = 0

        try:
            from hydra.core.hydra_config import HydraConfig

            if HydraConfig.initialized():
                hydra_cfg = HydraConfig.get()
                self.is_multirun = hydra_cfg.mode.name == "MULTIRUN"

                if self.is_multirun:
                    self.job_num = hydra_cfg.job.num
                    # Set job number in logging context
                    logging.getLogger("haipr").info(
                        f"Running job {self.job_num}"
                    )
        except Exception as e:
            logger.warning(
                f"Could not get Hydra config: {e}. Defaulting to single job mode"
            )

    def _setup_device(self):
        """Set up device configuration based on model type and config."""
        # Initialize devices attribute
        self.devices = None

        # Check if DDP is enabled
        self.use_ddp = getattr(self.cfg.data, "use_ddp", False)

        # For neural models, handle device configuration

        if not torch.cuda.is_available():
            logger.warning("No GPU available, using CPU")
            self.device = torch.device("cpu")
            self.devices = None  # Set to None for CPU-only training

            return

        # Check if devices is not specified or None

        if not hasattr(self.cfg, "devices") or self.cfg.devices is None:
            logger.info("No devices specified, defaulting to device 0")
            self.devices = [0]
            self.device = torch.device("cuda:0")

            return

        # Handle the "rotate" case for multirun scenarios

        if isinstance(self.cfg.devices, str) and self.cfg.devices == "rotate":
            if self.is_multirun:
                gpu_id = self.job_num % torch.cuda.device_count()
                self.devices = [gpu_id]
                logger.info(
                    f"Rotating GPU selection: using GPU {gpu_id} for job {self.job_num}"
                )
                self.device = torch.device(f"cuda:{gpu_id}")
            else:
                # If not in multirun, treat "rotate" as use all GPUs
                logger.info("Not in multirun mode, using all GPUs")
                self.cfg.devices = "all"
                self.device = torch.device("cuda")  # rank 0 GPU
                self.devices = "auto"
        # Handle specific device list
        elif isinstance(self.cfg.devices, list):
            if self.is_multirun:
                # In multirun mode, rotate through the specified devices
                device_idx = self.job_num % len(self.cfg.devices)
                gpu_id = self.cfg.devices[device_idx]
                self.devices = [gpu_id]
                logger.info(
                    f"Rotating through specified devices: using GPU {gpu_id} for job {self.job_num}"
                )
                self.device = torch.device(f"cuda:{gpu_id}")
            else:
                # In single run mode, use all specified devices
                self.devices = self.cfg.devices
                logger.info(f"Using specified devices: {self.devices}")
                # Set device to first GPU in the list for non-distributed operations

                if len(self.devices) > 0:
                    self.device = torch.device(f"cuda:{self.devices[0]}")
                else:
                    self.device = torch.device("cpu")
        # Handle "all" devices case
        elif isinstance(self.cfg.devices, str) and self.cfg.devices == "all":
            self.devices = "auto"  # Let Lightning handle all available devices
            # Use first GPU for non-distributed ops
            self.device = torch.device("cuda")
        # Handle integer device specification
        elif isinstance(self.cfg.devices, int):
            self.devices = [self.cfg.devices]
            self.device = torch.device(f"cuda:{self.cfg.devices}")
            logger.info(f"Using GPU device: {self.cfg.devices}")
        # Default case - unrecognized config, use device 0
        else:
            logger.warning(
                f"Unrecognized devices config: {self.cfg.devices}, defaulting to device 0"
            )
            self.devices = [0]
            self.device = torch.device("cuda:0")

        # Handle DDP configuration

        if self.use_ddp and torch.cuda.device_count() > 1:
            logger.info(
                "DDP enabled - will use distributed training across all available GPUs"
            )
            # For DDP, we want to use all available GPUs

            if isinstance(self.devices, list) and len(self.devices) == 1:
                # If only one device specified, use all available GPUs for DDP
                self.devices = "auto"
                logger.info(
                    "DDP enabled: Using all available GPUs for distributed training"
                )
            elif isinstance(self.devices, list) and len(self.devices) > 1:
                # If multiple devices specified, use them for DDP
                logger.info(
                    f"DDP enabled: Using specified devices {self.devices} for distributed training"
                )
        elif self.use_ddp and torch.cuda.device_count() <= 1:
            logger.warning(
                "DDP enabled but only one GPU available - falling back to single GPU training"
            )
            self.use_ddp = False

        logger.info(
            f"Device configuration for Lightning Trainer: {self.devices}"
        )
        logger.info(f"DDP enabled: {self.use_ddp}")

    def setup_basic_config(self):
        """Set up basic configuration parameters."""
        self._best_val_loss = float("inf")

        # Data configuration
        self.split_method = self.cfg.data.split_method
        self.run_single_split = self.cfg.trainer.run_single_split
        self.num_splits = self.cfg.data.num_splits

        # Resolve num_classes parameter before any MLflow logging begins
        # This prevents conflicts when MLflow tries to log the same parameter with different values
        self._resolve_num_classes()

    def _resolve_num_classes(self):
        """Resolve the num_classes parameter based on task and model type.

        This method ensures that num_classes is properly set before any MLflow logging
        begins, preventing conflicts when MLflow tries to log the same parameter with
        different values.
        """
        # Handle task-based num_classes resolution

        if self.cfg.task == "classification" and self.cfg.num_classes == 0:
            logger.warning(
                "num_classes is 0 for classification task, defaulting to 2"
            )
            self.cfg.num_classes = 2
        elif self.cfg.task == "regression" and self.cfg.num_classes > 0:
            logger.warning(
                "num_classes is set for regression task, defaulting to 0"
            )
            self.cfg.num_classes = 0
        elif self.cfg.model.name == "svc" and self.cfg.num_classes == 0:
            logger.warning("num_classes is 0 for SVC, defaulting to 2")
            self.cfg.num_classes = 2
        elif self.cfg.model.name == "svr" and self.cfg.num_classes > 0:
            logger.warning("num_classes is set for SVR, defaulting to 0")
            self.cfg.num_classes = 0

        # Also update the instance variable to ensure consistency
        self.num_classes = self.cfg.num_classes

    def setup_model_specific_config(self):
        """Set up model-specific configuration."""

        if not hasattr(self.cfg, "model"):
            raise ValueError(
                "Model configuration is missing. Please check your config.yaml"
            )

        self.setup_predictor(self.cfg)

    def setup_neural_network_config(self):
        """Set up neural network specific configuration."""
        # Training configuration
        self.num_epochs = self.cfg.trainer.max_epochs
        self.patience = self.cfg.trainer.patience

        # Handle global batch size configuration
        global_batch_size = getattr(self.cfg, "global_batch_size", None)

        if global_batch_size is not None and global_batch_size > 0:
            model_batch_size = self.cfg.trainer.batch_size

            if global_batch_size > model_batch_size:
                # Use model batch size and accumulate gradients to reach global batch size
                self.batch_size = model_batch_size
                self.accumulate_grad_batches = max(
                    1, global_batch_size // model_batch_size
                )
                logger.info(
                    f"Using model batch size: {model_batch_size} with accumulate_grad_batches: {self.accumulate_grad_batches} to achieve effective batch size: {global_batch_size}"
                )
            else:
                # Use global batch size directly, no accumulation needed
                self.batch_size = global_batch_size
                self.accumulate_grad_batches = 1
                logger.info(f"Using global batch size: {global_batch_size}")
            self.cfg.trainer.batch_size = self.batch_size
            self.cfg.model.batch_size = self.batch_size
        else:
            # Use original logic for batch size adjustment based on GPU memory

            if torch.cuda.is_available():
                total_memory = (
                    torch.cuda.get_device_properties(
                        self.device.index
                    ).total_memory
                    / 1e9
                )
                logger.debug(
                    f"Total memory: {total_memory} on device {self.device}"
                )

                if total_memory < 32:  # 32 GB
                    self.batch_size = min(self.cfg.trainer.batch_size, 8)

                    if self.batch_size < 8:
                        logger.warning(
                            f"Reduced batch size to {self.batch_size} due to limited GPU memory"
                        )
                else:
                    self.batch_size = self.cfg.trainer.batch_size
            else:
                self.batch_size = self.cfg.trainer.batch_size

            # Use default accumulate_grad_batches from config
            self.accumulate_grad_batches = getattr(
                self.cfg.trainer, "accumulate_grad_batches", 1
            )

        # Set appropriate default loss based on task if not specified or if null

        if not hasattr(self.cfg.model, "loss") or self.cfg.model.loss is None:
            self.cfg.model.loss = "mse" if self.num_classes == 0 else "ce"
            logger.info(
                f"Setting default loss function for {self.cfg.task} task: {self.cfg.model.loss}"
            )

        # Validate loss function exists and is appropriate for task
        loss_name = self.cfg.model.loss.lower()

        if loss_name not in loss_functions:
            raise ValueError(
                f"Unknown loss function: {loss_name}. "
                f"Supported loss functions: {list(loss_functions.keys())}"
            )

        # Validate loss is appropriate for task

        if self.num_classes == 0 and loss_name not in REGRESSION_LOSSES:
            raise ValueError(
                f"Loss function {loss_name} is not appropriate for regression task. "
                f"Use one of: {REGRESSION_LOSSES}"
            )
        elif self.num_classes > 0 and loss_name not in CLASSIFICATION_LOSSES:
            raise ValueError(
                f"Loss function {loss_name} is not appropriate for classification task. "
                f"Use one of: {CLASSIFICATION_LOSSES}"
            )

        # Set the criterion
        self.criterion = loss_functions[loss_name]()

    def setup_data(self, cfg):
        """
        Set up the data for training and evaluation.

        :param cfg: Configuration object containing parameters.
        """
        # Initialize dataset with full config
        self.data = HAIPRData(
            config=cfg,
        )
        self.results_logger.data = self.data

        # only prepare features in job 0, rest of the jobs will use cached features,
        # and wait in a loop for job 0 to finish creating the features

        if self.job_num == 0 and self.cfg.parallel:
            # generates features
            self.data.prepare_features()
        else:
            self.data.wait_for_features_ready()
            # loads features from cache, does not load embedder model when cache is hit
            self.data.prepare_features()
        # If we have a test split specified, set it aside first

        if (
            hasattr(self.cfg.data, "test_split_idx")
            and self.cfg.data.test_split_idx is not None
        ):
            self.data.set_test_data(
                self.cfg.data.split_method, self.cfg.data.test_split_idx
            )
        else:
            # No test split specified, just generate splits on full data
            self.data.generate_splits()

        # Then subsample from the remaining training data

        if (
            self.cfg.data.subsample_threshold > 0
            and len(self.data.active_idx) > self.cfg.data.subsample_threshold
        ):
            self.data.subsample_data(self.cfg.data.subsample_threshold)
            # Regenerate splits on the subsampled training data
            self.data.generate_splits()

    def tune(self, cfg, parent_run_id=None, run_name=None) -> tuple[dict, list] | None:
        self.setup_mlflow()  # setup mlflow
        mlflow.set_experiment(self.cfg.mlflow.experiment_name)
        """Run training and evaluation."""
        logger.info(f"Starting training for job {self.job_num}")

        run_name = f"{self.cfg.model.name}_"

        if hasattr(self.cfg, "embedder") and self.cfg.embedder is not None:
            run_name += f"{self.cfg.embedder.model}_"
        run_name += (
            f"{'classification' if self.num_classes > 0 else 'regression'}_"
        )
        run_name += f"{self.cfg.benchmark.name}_"
        run_name += f"{self.cfg.data.split_method}"
        log_system_metrics = True

        if self.cfg.mlflow.parent_run_name:
            run_name = self.cfg.mlflow.parent_run_name
        try:
            with mlflow.start_run(
                run_name=run_name,
                nested=self.is_nested,
                log_system_metrics=log_system_metrics,
            ) as parent_run:
                logger.debug(f"Starting run {run_name}")
                mlflow.set_tag("is_parent", True)
                mlflow.set_tag("job_num", self.job_num)

                # Log experiment setup only in parent run
                parent_run_id = mlflow.active_run().info.run_id
                self.parent_run_id = parent_run_id
                logger.debug(f"Parent run using run ID: {parent_run_id}")
                mlflow.log_metrics({"job_num": self.job_num})

                if self.is_multirun:
                    mlflow.set_tag("is_multirun", True)

                # Update results logger with current run
                self.results_logger.run = parent_run
                self.results_logger.set_run_id(parent_run.info.run_id)

                # Log configuration
                self.results_logger.log_input_sample(
                    self.data.data,
                    context="HAIPRData",
                    tags={"benchmark_name": self.cfg.benchmark.name},
                )

                self.results_logger.log_config(self.cfg)
                try:
                    mean_metrics, predictions = self.run_splits()
                    # TODO: check if all metrics are finite and not nan and add tag to run

                    return mean_metrics, predictions

                except Exception as e:
                    logger.error(f"Error in training: {e}", exc_info=True)

                    if torch.cuda.is_available() and self.is_neural:
                        torch.cuda.empty_cache()
                    # Mark the MLflow run as failed
                    mlflow.set_tag("mlflow.runStatus", "FAILED")
                    mlflow.end_run()
                    return None

        except Exception as e:
            # Safely get run ID if available
            active_run = (
                mlflow.active_run()
                or mlflow.last_active_run()
                or "no_active_run"
            )
            logger.error(f"Error in MLflow run: {active_run} {e}", exc_info=True)
            mlflow.end_run()
            return None
        
        finally:
            mlflow.end_run()

    def _get_process_device(self, process_idx):
        """Get the appropriate device for a given process index.

        Args:
            process_idx: Index of the current process

        Returns:
            str: Device specification for PyTorch Lightning
        """

        if not torch.cuda.is_available() or not self.is_neural:
            return "cpu"

        num_gpus = torch.cuda.device_count()

        if num_gpus == 0:
            return "cpu"

        # Assign GPU round-robin style
        gpu_idx = process_idx % num_gpus

        return f"cuda:{gpu_idx}"

    def _run_split_parallel_worker(self, split_info):
        """Run a single split in parallel mode (non-multirun only).

        Args:
            split_info: tuple of (split_num, process_idx)

        Returns:
            tuple: (metrics_dict, predictions) if successful, None if failed
        """
        split_num, process_idx = split_info
        logger.info(
            f"Running Split {split_num + 1} of {self.data.get_num_splits()} on process {process_idx}"
        )

        try:
            # Ensure the parent run is active in this worker, then open nested split run

            if self.parent_run_id is not None:
                with mlflow.start_run(run_id=self.parent_run_id):
                    with mlflow.start_run(
                        run_name=f"split_{split_num + 1}", nested=True
                    ) as split_run:
                        mlflow.set_tag("split", split_num + 1)
                        mlflow.set_tag("parent_run_id", self.parent_run_id)
                        mlflow.set_tag("process_id", process_idx)
                        mlflow.set_tag("benchmark", self.cfg.benchmark.name)
                        mlflow.set_tag("model_name", self.cfg.model.name)
                        # Use a dedicated results logger for this split run to avoid cross-thread mutation
                        local_results_logger = ResultsLogger(
                            cfg=self.cfg, run=split_run
                        )
                        local_results_logger.set_run_id(split_run.info.run_id)
                        local_results_logger.data = self.data
                        logger.debug(
                            f"Split {split_num + 1} using run ID: {split_run.info.run_id}"
                        )
                        # Log config for this split run (log_config handles duplicates)
                        local_results_logger.log_config(self.cfg)

                        current_trainer_instance = None
                        # Set device for this process if using neural network

                        if self.is_neural:
                            device = self._get_process_device(process_idx)
                            logger.info(
                                f"Using device {device} for split {split_num + 1}"
                            )

                            # Update trainer kwargs for this process with the correct run ID
                            trainer_kwargs = self.get_trainer_kwargs(
                                run_id=split_run.info.run_id
                            )
                            trainer_kwargs["accelerator"] = (
                                "gpu" if "cuda" in device else "cpu"
                            )
                            trainer_kwargs["devices"] = (
                                [int(device.split(":")[-1])]

                                if "cuda" in device
                                else 1
                            )
                            trainer_kwargs["num_sanity_val_steps"] = 0
                            current_trainer_instance = pl.Trainer(
                                **trainer_kwargs
                            )

                        # Get train/val indices for this split
                        train_indices, val_indices = self.data.splits[split_num]
                        logger.debug(
                            f"Train size: {len(train_indices)}, Val size: {len(val_indices)}"
                        )

                        # Set current split before running
                        self._current_split = split_num

                        # Deepcopy the predictor model for this specific split/trainer
                        model_instance_for_split = deepcopy(self.model)

                        split_results = model_instance_for_split.fit_model(
                            self.data,
                            train_indices,
                            val_indices,
                            current_trainer_instance,
                            self.cfg,
                        )

                        # Log model for this split if enabled

                        if self.cfg.mlflow.log_models:
                            try:
                                # Set current split in results logger for proper model naming
                                local_results_logger.set_current_split(
                                    split_num
                                )

                                # Log the model
                                local_results_logger.log_model(
                                    model_instance_for_split,
                                    model_name=f"{self.cfg.mlflow.registered_model_name}_split_{split_num + 1}",
                                )
                            except Exception as e:
                                logger.error(
                                    f"Failed to log model for split {split_num + 1}: {e}",
                                    exc_info=True,
                                )

                        if split_results is None:
                            logger.error(
                                f"Failed to get results for split {split_num + 1}"
                            )

                            return None

                        return (
                            split_results["metrics"],
                            split_results["predictions"],
                        )
            else:
                # Fallback: open only a nested run (no explicit parent context)
                with mlflow.start_run(
                    run_name=f"split_{split_num + 1}", nested=True
                ) as split_run:
                    mlflow.set_tag("split", split_num + 1)
                    mlflow.set_tag("process_id", process_idx)
                    mlflow.set_tag("benchmark", self.cfg.benchmark.name)
                    mlflow.set_tag("model_name", self.cfg.model.name)
                    # Use a dedicated results logger for this split run to avoid cross-thread mutation
                    local_results_logger = ResultsLogger(
                        cfg=self.cfg, run=split_run
                    )
                    local_results_logger.set_run_id(split_run.info.run_id)
                    local_results_logger.data = self.data
                    logger.debug(
                        f"Split {split_num + 1} using run ID: {split_run.info.run_id}"
                    )

                    current_trainer_instance = None
                    # Set device for this process if using neural network

                    if self.is_neural:
                        device = self._get_process_device(process_idx)
                        logger.info(
                            f"Using device {device} for split {split_num + 1}"
                        )

                        # Update trainer kwargs for this process with the correct run ID
                        trainer_kwargs = self.get_trainer_kwargs(
                            run_id=split_run.info.run_id
                        )
                        trainer_kwargs["accelerator"] = (
                            "gpu" if "cuda" in device else "cpu"
                        )
                        trainer_kwargs["devices"] = (
                            [int(device.split(":")[-1])]

                            if "cuda" in device
                            else None
                        )
                        trainer_kwargs["num_sanity_val_steps"] = 0
                        current_trainer_instance = pl.Trainer(**trainer_kwargs)

                    # Get train/val indices for this split
                    train_indices, val_indices = self.data.splits[split_num]
                    logger.debug(
                        f"Train size: {len(train_indices)}, Val size: {len(val_indices)}"
                    )

                    # Set current split before running
                    self._current_split = split_num

                    # Deepcopy the predictor model for this specific split/trainer
                    model_instance_for_split = deepcopy(self.model)
                    local_results_logger.log_config(self.cfg)

                    split_results = model_instance_for_split.fit_model(
                        self.data,
                        train_indices,
                        val_indices,
                        current_trainer_instance,
                        self.cfg,
                    )

                    # Log model for this split if enabled

                    if self.cfg.mlflow.log_models:
                        try:
                            # Set current split in results logger for proper model naming
                            local_results_logger.set_current_split(split_num)

                            # Log the model
                            local_results_logger.log_model(
                                model_instance_for_split,
                                model_name=f"{self.cfg.mlflow.registered_model_name}_split_{split_num + 1}",
                            )
                        except Exception as e:
                            logger.error(
                                f"Failed to log model for split {split_num + 1}: {e}",
                                exc_info=True,
                            )

                    if split_results is None:
                        logger.error(
                            f"Failed to get results for split {split_num + 1}"
                        )

                        return None

                    return (
                        split_results["metrics"],
                        split_results["predictions"],
                    )

        except Exception as e:
            logger.error(
                f"Error in split {split_num + 1}: {e}",
                exc_info=True,
                stack_info=True,
            )

            if torch.cuda.is_available() and self.is_neural:
                torch.cuda.empty_cache()

            return None

    def setup_predictor(self, cfg):
        """Setup a predictor based on the configuration."""
        # instantiate predictor
        self.setup_data(cfg)

        if self.is_neural:
            self.setup_neural_network_config()
        self.model = hydra.utils.instantiate(self.cfg.model)
        self.model.setup_model(self.data, self.cfg)

        if self.is_neural:
            if (
                hasattr(self.cfg, "peft")
                and self.cfg.peft.target_modules is not None
                and not self.model.peft_initialized
            ):
                self.model._initialize_peft_adapters(
                    self.cfg.peft
                )  # inject peft adapters if not already initialized

    def run_predictor(
        self, train_indices, val_indices, split_num=None, trainer_instance=None
    ):
        predictor = deepcopy(self.model)
        results = predictor.fit_model(
            self.data,
            train_indices,
            val_indices,
            trainer_instance,  # can be none for non pytorch models
            self.cfg,
        )

        predictions = results["predictions"]
        metrics = results["metrics"]

        self.results_logger.log_plots(
            predictions["true_values"],
            predictions["predictions"],
            (
                predictions["probabilities"]

                if "probabilities" in predictions
                else None
            ),
            sample_indices=val_indices,
        )

        # Log model for this split if enabled

        if self.cfg.mlflow.log_models:
            try:
                # Set current split in results logger for proper model naming
                self.results_logger.set_current_split(split_num)

                # Log the model
                self.results_logger.log_model(
                    predictor,
                    model_name=f"{self.cfg.mlflow.registered_model_name}_split_{split_num + 1}",
                )
            except Exception as e:
                logger.error(
                    f"Failed to log model for split {split_num + 1}: {e}",
                    exc_info=True,
                )
                # Don't fail the entire run if model logging fails

        return {"metrics": metrics, "predictions": predictions}

    def get_trainer_kwargs(self, run_id=None):
        """Get kwargs for PyTorch Lightning Trainer.

        Args:
            run_id: Optional MLflow run ID to use for the logger. If None, uses the current active run.
        """
        callbacks = []
        callbacks.append(
            EarlyStopping(
                monitor="val_loss",
                patience=self.patience,
                mode="min",
                min_delta=self.cfg.trainer.stop_min_delta_val,
                check_finite=True,
                strict=False,  # Don't fail if metric doesn't exist initially
                check_on_train_epoch_end=False,
            )
        )
        callbacks.append(
            EarlyStopping(
                monitor="avg_train_loss",
                patience=2,
                mode="min",
                min_delta=self.cfg.trainer.stop_min_delta_train,
                check_finite=True,
                strict=False,  # Don't fail if metric doesn't exist initially
                check_on_train_epoch_end=True,
            )
        )
        # Add checkpointing if enabled

        if hasattr(self.cfg.trainer, "model_checkpointing"):
            if self.cfg.trainer.model_checkpointing:
                callbacks.append(
                    ModelCheckpoint(
                        monitor="val_loss",
                        mode="min",
                        save_top_k=2,
                        save_last=True,
                    )
                )

        # Add learning rate monitor if scheduler is enabled

        if (
            hasattr(self.cfg.trainer, "lrscheduler")
            and self.cfg.trainer.lrscheduler.enabled
        ):
            callbacks.append(LearningRateMonitor(logging_interval=None))

        # Create MLflow logger with the specified run ID or current active run
        # This ensures each trainer gets a fresh logger for the correct run

        if run_id is None:
            current_run_id = mlflow.active_run().info.run_id
        else:
            current_run_id = run_id

        train_logger = pl.loggers.MLFlowLogger(
            tracking_uri=self.cfg.mlflow.tracking_uri,
            experiment_name=self.cfg.mlflow.experiment_name,
            run_id=current_run_id,
        )

        trainer_kwargs = {
            "max_epochs": self.cfg.trainer.max_epochs,
            "default_root_dir": self.cfg.trainer.default_root_dir,
            "enable_checkpointing": self.cfg.trainer.model_checkpointing,
            "accelerator": "auto",
            # list of device ids [1,2,4], or all [all visible to torch] or int
            "devices": self.devices,
            "strategy": self.cfg.trainer.strategy,  # DDP strategy configuration
            "callbacks": callbacks,
            "enable_progress_bar": self.cfg.trainer.enable_progress_bar,
            "enable_model_summary": self.cfg.trainer.enable_model_summary,
            "gradient_clip_val": self.cfg.trainer.gradient_clip_val,
            "accumulate_grad_batches": getattr(
                self,
                "accumulate_grad_batches",
                self.cfg.trainer.accumulate_grad_batches,
            ),
            "logger": train_logger,
            "log_every_n_steps": self.cfg.trainer.log_every_n_steps,
            "precision": self.cfg.trainer.precision,
            "num_sanity_val_steps": 0,
            "fast_dev_run": self.cfg.trainer.fast_dev_run,
        }

        logger.debug("Trainer kwargs: %s", trainer_kwargs)

        return trainer_kwargs

    def run_split(self, split_num, run_name=None):
        """Run training and evaluation for a single split.

        Args:
            split_num: Index of the split to run

        Returns:
            tuple: (metrics_dict, predictions) if successful, None if failed
        """
        logger.info(
            f"Running Split {split_num + 1} of {self.data.get_num_splits()}"
        )
        try:

            if run_name is None:
                run_name = f"split_{split_num + 1}"
            with mlflow.start_run(run_name=run_name, nested=True):
                mlflow.set_tag("split", split_num + 1)
                mlflow.set_tag("run_type", "split")
                # Tag parent run if known and update logger to child run

                if self.parent_run_id is not None:
                    mlflow.set_tag("parent_run_id", self.parent_run_id)
                current_run_id = mlflow.active_run().info.run_id
                self.results_logger.set_run_id(current_run_id)
                # Log config for this split run (log_config handles duplicates)
                self.results_logger.log_config(self.cfg)

                return self._execute_split(split_num, 0)

        except Exception as e:
            logger.error(f"Error in split {split_num + 1}: {e}", exc_info=True)

            if torch.cuda.is_available() and self.is_neural:
                torch.cuda.empty_cache()

            return None

    def _execute_split(self, split_num, process_idx):
        """Execute the actual split training and evaluation.

        This is separated from run_split to handle MLflow run contexts properly.
        """
        logger.debug(f"Executing split {split_num + 1} on process {process_idx}")
        self._current_split = split_num
        device = self._get_process_device(process_idx)
        logger.info(f"Using device {device} for split {split_num + 1}")
        trainer = None

        if self.is_neural:
            trainer_kwargs = self.get_trainer_kwargs()
            trainer_kwargs["accelerator"] = "gpu" if "cuda" in device else "cpu"

            if self.cfg.parallel:
                trainer_kwargs["devices"] = "auto"
            else:
                trainer_kwargs["devices"] = (
                    [int(device.split(":")[-1])] if "cuda" in device else None
                )
            trainer = pl.Trainer(**trainer_kwargs)

        # Get train/val indices for this split
        train_indices, val_indices = self.data.splits[split_num]
        logger.debug(
            f"Train size: {len(train_indices)}, Val size: {len(val_indices)}"
        )

        # Run appropriate training function
        split_results = self.run_predictor(
            train_indices, val_indices, split_num, trainer
        )

        if split_results is None:
            logger.error(f"Failed to get results for split {split_num + 1}")

            return None

        # Log metrics for this split
        self.results_logger.log_metrics(split_results["metrics"], step=None)

        return split_results["metrics"], split_results["predictions"]

    def evaluate_trial(
        self,
        trial_number: int | None = None,
        study_name: str | None = None,
        storage_config: dict | None = None,
    ) -> dict | None:
        """Public API: Evaluate a specific trial or best trial from a study.

        Args:
            trial_number: Specific trial number to evaluate. If None, evaluates best trial.
            study_name: Name of the Optuna study. If None, tries to determine from config.
            storage_config: Storage configuration dict. If None, tries to determine from config.

        Returns:
            Dictionary containing evaluation results with metrics and predictions, or None if failed.
        """
        try:
            # Import here to avoid circular imports
            from haipr.optimize import HAIPROptimizer

            # Set up study information if not provided

            if study_name is None or storage_config is None:
                # Create a temporary optimizer to get study setup methods
                temp_optimizer = HAIPROptimizer(self.cfg)

                if study_name is None:
                    study_name = temp_optimizer.create_study_name(self.cfg)

                if storage_config is None:
                    storage_config = temp_optimizer.get_storage_config()

            # Load the study
            study = self._load_study_for_evaluation(study_name, storage_config)

            # Get the specific trial
            trial = self._get_trial_from_study(study, trial_number)

            if trial is None:
                logger.error("No trial found to evaluate")

                return None

            logger.info(f"Evaluating trial {trial.number}")
            logger.info(f"Trial value: {trial.value}")
            logger.info(f"Trial parameters: {trial.params}")

            # Create trial configuration
            temp_optimizer = HAIPROptimizer(self.cfg)
            trial_config = temp_optimizer.update_trial_config(
                self.cfg, trial.params
            )

            # Execute the evaluation

            return self._evaluate_with_trial_config(trial_config, trial.number)

        except Exception as e:
            logger.error(f"Failed to evaluate trial: {e}", exc_info=True)

            return None

    def _load_study_for_evaluation(
        self, study_name: str, storage_config: dict
    ) -> optuna.Study:
        """Load an existing Optuna study for evaluation purposes."""
        try:
            # Import here to avoid circular imports
            from haipr.optimize import HAIPROptimizer

            # Create a temporary optimizer to use its storage setup method
            temp_optimizer = HAIPROptimizer(self.cfg)

            # Create a minimal config object for storage setup
            storage_cfg = OmegaConf.create(
                {
                    "optuna": {
                        "storage": storage_config.get("storage"),
                        "grace_period": storage_config.get("grace_period", 600),
                        "max_retry": storage_config.get("max_retry", 3),
                    }
                }
            )

            storage = temp_optimizer.setup_storage(storage_cfg)
            study = temp_optimizer.load_study(study_name, storage)

            logger.info(
                f"Successfully loaded study '{study_name}' for evaluation"
            )

            return study

        except Exception as e:
            logger.error(
                f"Failed to load study for evaluation: {e}", exc_info=True
            )
            raise

    def _get_trial_from_study(
        self, study: optuna.Study, trial_number: int | None
    ) -> optuna.Trial:
        """Get a specific trial or best trial from study."""
        try:
            if trial_number is None:
                # Get best trial

                if not study.trials:
                    logger.error("No trials found in study")

                    return None

                trial = study.best_trial
                logger.info(f"Using best trial: {trial.number}")
            else:
                # Get specific trial by number
                matching_trials = [
                    t for t in study.trials if t.number == trial_number
                ]

                if not matching_trials:
                    logger.error(f"Trial {trial_number} not found in study")

                    return None

                trial = matching_trials[0]
                logger.info(f"Using specified trial: {trial.number}")

            return trial

        except Exception as e:
            logger.error(f"Failed to get trial from study: {e}", exc_info=True)

            return None

    def _evaluate_with_trial_config(
        self, trial_config: DictConfig, trial_number: int
    ) -> dict | None:
        """Execute evaluation with trial-specific configuration."""
        try:
            # Set evaluation-specific configurations
            eval_cfg = OmegaConf.create(
                OmegaConf.to_container(trial_config, resolve=False)
            )
            eval_cfg.trainer.run_single_split = self.cfg.data.test_split_idx
            eval_cfg.mlflow.log_models = True
            eval_cfg.data.test_split_idx = None  # use full data for evaluation
            eval_cfg.data.subsample_threshold = (
                0  # no subsampling for evaluation
            )

            # Create a new trainer with the trial configuration
            eval_trainer = HAIPRTrainer(eval_cfg)
            eval_trainer.is_nested = True
            eval_trainer.parent_run_id = getattr(self, "parent_run_id", None)

            # Set up the results logger with the current MLflow run if available

            if mlflow.active_run() is not None:
                eval_trainer.results_logger.run = mlflow.active_run()
                eval_trainer.results_logger.set_run_id(
                    mlflow.active_run().info.run_id
                )

            eval_trainer.data.generate_splits()

            # Use the existing _execute_split method for consistency
            split_idx = (
                self.cfg.data.test_split_idx

                if self.cfg.data.test_split_idx is not None
                else 0
            )
            mean_metrics, predictions = eval_trainer._execute_split(
                split_idx, 0
            )

            if mean_metrics:
                logger.info("Evaluation completed successfully")
                logger.info(f"Evaluation metrics: {mean_metrics}")

                return {
                    "metrics": mean_metrics,
                    "predictions": predictions,
                    "trial_number": trial_number,
                    "eval_split": split_idx,
                }
            else:
                logger.error("Evaluation failed - no metrics returned")

                return None

        except Exception as e:
            logger.error(
                f"Failed to execute trial evaluation: {e}", exc_info=True
            )

            return None

    def run_splits(self, split_callback=None) -> tuple[dict, list]:
        """Run training and evaluation over data splits."""
        mean_metrics = {}
        all_metrics = []
        all_predictions = []

        try:
            if (self.is_multirun and not self.cfg.parallel) or (
                hasattr(self.cfg, "run_sequential") and self.cfg.run_sequential
            ):
                # Use original sequential behavior for multirun, also for slurm based parallel
                logger.info("Running splits sequentially")

                for split_num in range(self.data.get_num_splits()):
                    self._current_split = split_num
                    result = self.run_split(
                        split_num
                    )  # Use original run_split for multirun

                    if split_callback is not None:
                        split_callback(split_num, result)

                    if result is not None:
                        metrics, predictions = result
                        all_metrics.append(metrics)
                        all_predictions.append(predictions)
            else:
                logger.info("Running splits in parallel (single run mode) or --parallel flag")

                # Determine number of processes based on available resources

                if self.is_neural and torch.cuda.is_available():
                    num_gpus = torch.cuda.device_count()
                    num_processes = min(num_gpus, self.data.get_num_splits())
                    logger.info(
                        f"Using {num_processes} processes with GPU support"
                    )
                else:
                    num_processes = min(
                        multiprocessing.cpu_count(), self.data.get_num_splits()
                    )
                    logger.info(f"Using {num_processes} CPU processes")

                # Create list of (split_num, process_idx) tuples for mapping
                split_process_pairs = [
                    (split_num, split_num % num_processes)

                    for split_num in range(self.data.get_num_splits())
                ]

                if self.run_single_split:
                    split_process_pairs = [(int(self.run_single_split), 0)]
                    num_processes = 1

                # Process splits in parallel using a thread pool
                with ThreadPoolExecutor(max_workers=num_processes) as executor:
                    # Map run_split across all splits with their process indices
                    future_results = list(
                        executor.map(
                            self._run_split_parallel_worker, split_process_pairs
                        )
                    )

                    if split_callback is not None:
                        for split_num, result in enumerate(future_results):
                            if result is not None:
                                split_callback(split_num, result)

                    # Process results

                    for result in future_results:
                        if result is not None:
                            metrics, predictions = result
                            all_metrics.append(metrics)
                            all_predictions.append(predictions)

            # Calculate mean metrics across all splits

            if all_metrics:
                for metric in all_metrics[0].keys():
                    values = [m[metric] for m in all_metrics]
                    mean_metrics[f"{metric}"] = float(np.mean(values))
                    mean_metrics[f"{metric}_std"] = float(np.std(values))
                    logger.debug(f"Mean {metric}: {mean_metrics[f'{metric}']}")

                # Log final metrics and predictions
                self.results_logger.log_metrics(mean_metrics, step=None)
                self.results_logger.log_run_metrics_and_predictions(
                    all_metrics, all_predictions, self.data
                )

            return mean_metrics, all_predictions

        except Exception as e:
            logger.error(f"Error in run_splits: {e}", exc_info=True)

            if torch.cuda.is_available() and self.is_neural:
                torch.cuda.empty_cache()

            return {}, []


def run(cfg):
    """Main entry point for the HAIPR trainer."""
    # logger.debug(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")

    # Initialize trainer
    trainer = HAIPRTrainer(cfg)
    trainer.setup_data(cfg)  # prepares features
    torch.set_float32_matmul_precision("medium")

    if hasattr(cfg, "prepare_features") and cfg.prepare_features:
        logger.info("Features prepared and cached")

        return 0  # we are done here setup_data already prepared features

    # Run training and get mean metrics
    result = trainer.tune(cfg)
    if result is None:
        logger.error("Training failed")
        return float("inf")
    mean_metrics, all_predictions = result

    if mean_metrics is None:
        logger.error("Training failed")

        return float("inf")

    # Extract optimization metric
    optimization_metric = cfg.model.optimization_metric

    if optimization_metric not in mean_metrics:
        logger.error(
            f"Optimization metric '{optimization_metric}' not found in available metrics: {list(mean_metrics.keys())}"
        )
        raise KeyError(
            f"Optimization metric '{optimization_metric}' not found")

    val_loss = mean_metrics[optimization_metric]
    logger.info(
        f"Final optimization value ({optimization_metric}): {val_loss}")
    # legacy when using hydra builtin optuna

    return float(val_loss)


@hydra.main(version_base=None, config_path="conf", config_name="train")
def main(cfg: DictConfig) -> float:
    register_resolvers()

    return run(cfg)


if __name__ == "__main__":
    register_resolvers()
    main()
