import hydra
import lightning.pytorch as pl
from torch.nn import (
    MSELoss,
    L1Loss,
    SmoothL1Loss,
    HuberLoss,
    CrossEntropyLoss,
    BCEWithLogitsLoss,
    NLLLoss,
)
from peft import LoraConfig, TaskType, get_peft_model
import torch
import torch.distributed as dist
import logging
from haipr.utils import (
    compute_classification_metrics,
    compute_regression_metrics,
)
from typing import Dict, Any, List, Tuple
from omegaconf import DictConfig
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader, Subset, Dataset
from haipr.predictor import BasePredictor
from omegaconf import OmegaConf

logger = logging.getLogger(__name__)


class HAIPRModule(pl.LightningModule, BasePredictor):
    """
    Lightning Module for HAIPR models that are based on torch.
    Handles training, validation, and prediction steps.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        criterion,
        num_classes=0,
        learning_rate=1e-4,
        weight_decay=0.01,
        val_chunk_size=1,
        val_chunk_log_interval=200,
        **kwargs,
    ):
        """
        Initialize the Lightning Module.

        Args:
            model: The PyTorch model to train (e.g., ESM3 instance with prediction head)
            criterion: The loss function to use
            num_classes: Number of classes (0 for regression)
            learning_rate: Learning rate for optimization
            weight_decay: Weight decay for optimization
            val_chunk_size: Size of the validation chunk for frequent evaluation
            val_chunk_log_interval: Number of training steps between validation chunk evaluations
            lora_config: Optional PEFT LoRA configuration (as DictConfig from Hydra)
        """
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.val_chunk_size = val_chunk_size
        self.val_chunk_log_interval = val_chunk_log_interval
        self.peft_initialized = False
        # Validation chunk tracking
        self.val_chunk_data = None
        self.training_step_count = 0
        self.train_num_workers = kwargs.get("train_num_workers", 0)
        self.val_num_workers = kwargs.get("val_num_workers", 0)
        self.batch_size = kwargs.get("batch_size", 1)

        if kwargs.get("peft", None) is not None:
            peft_config = hydra.compose(config_name="train").peft
            peft_config["target_modules"] = kwargs.get("peft")[
                "target_modules"]
            self._initialize_peft_adapters(peft_config)
            self.peft_initialized = True

        self.save_hyperparameters(
            ignore=["model", "criterion", "prediction_head"]
        )
        # Initialize lists to store validation outputs
        self.validation_step_outputs = []

        # Validate criterion is appropriate for task

        if self.num_classes == 0 and not isinstance(
            criterion, (MSELoss, L1Loss, SmoothL1Loss, HuberLoss)
        ):
            raise ValueError(
                "Regression task requires a regression loss function"
            )
        elif self.num_classes > 0 and not isinstance(
            criterion, (CrossEntropyLoss, BCEWithLogitsLoss, NLLLoss)
        ):
            raise ValueError(
                "Classification task requires a classification loss function"
            )

        self.last_val_predictions = None
        self.last_val_metrics = None
        # Initialize empty lists to store predictions and labels
        self.val_preds = []
        self.val_labels = []
        self.val_probs = []
        self.best_val_loss = float("inf")
        self.best_val_metrics = {}
        self.best_val_predictions = None

    def _is_ddp_enabled(self):
        """Check if DDP is enabled and available."""

        return (
            self.trainer is not None
            and hasattr(self.trainer, "is_global_zero")
            and dist.is_available()
            and dist.is_initialized()
        )

    def _gather_predictions(self, tensor_list):
        """Gather predictions from all GPUs in DDP mode."""

        if not self._is_ddp_enabled():
            return tensor_list

        # Gather all tensors from all processes
        gathered_tensors = []

        for tensor in tensor_list:
            # Ensure tensor is on the correct device

            if tensor.device != self.device:
                tensor = tensor.to(self.device)

            # Gather tensor from all processes
            gathered = [
                torch.zeros_like(tensor) for _ in range(dist.get_world_size())
            ]
            dist.all_gather(gathered, tensor)
            gathered_tensors.extend(gathered)

        return gathered_tensors

    def _gather_single_tensor(self, tensor):
        """Gather a single tensor from all GPUs in DDP mode."""

        if not self._is_ddp_enabled():
            return tensor

        # Ensure tensor is on the correct device

        if tensor.device != self.device:
            tensor = tensor.to(self.device)

        # Gather tensor from all processes
        gathered = [
            torch.zeros_like(tensor) for _ in range(dist.get_world_size())
        ]
        dist.all_gather(gathered, tensor)

        # Concatenate all gathered tensors

        return torch.cat(gathered, dim=0)

    def _initialize_peft_adapters(self, lora_config_dict: DictConfig):
        """Injects PEFT adapters (LoRA/DoRA) into self.model directly."""
        logger.info(
            f"Initializing PEFT adapters with config: {lora_config_dict}"
        )

        target_modules_conf = lora_config_dict.get("target_modules", None)

        if isinstance(target_modules_conf, str):
            target_modules_list = target_modules_conf.split(",")
        elif isinstance(
            target_modules_conf, (list, DictConfig)
        ):  # OmegaConf list is ListConfig
            # Ensure it's a Python list of strings
            target_modules_list = [str(item) for item in target_modules_conf]
        elif target_modules_conf is None:
            target_modules_list = (
                None  # Let LoraConfig handle default based on model type
            )
        else:
            raise ValueError(
                f"Invalid target_modules type: {type(target_modules_conf)}. Expected str or list/ListConfig."
            )
        logger.info(f"PEFT target modules: {target_modules_list}")

        peft_config_instance = LoraConfig(
            use_dora=(
                lora_config_dict.get("peft_type", "lora").lower() == "dora"
            ),
            task_type=TaskType.FEATURE_EXTRACTION,
            r=lora_config_dict.get("rank", 8),
            lora_alpha=lora_config_dict.get("alpha", 16),
            lora_dropout=lora_config_dict.get("dropout", 0.1),
            target_modules=target_modules_list,
            bias=lora_config_dict.get("bias", "none"),
            # modules_to_save = lora_config_dict.get("modules_to_save", None) # if you need to train other modules
        )

        self.model = get_peft_model(self.model, peft_config_instance)

        logger.info(
            f"PEFT adapters ({lora_config_dict.get('peft_type', 'lora')}) injected and activated."
        )

        trainable_params = 0
        all_param = 0

        for _, param in self.model.named_parameters():
            all_param += param.numel()

            if param.requires_grad:
                trainable_params += param.numel()

        if all_param > 0:
            trainable_percentage = 100 * trainable_params / all_param
            logger.info(
                f"Trainable params: {trainable_params:,} || All params: {all_param:,} || Trainable %: {trainable_percentage:.2f}%"
            )
        else:
            logger.info("Model has no parameters.")

    def forward(self, batch_for_predictor: Dict[str, Any]):
        return self.model(batch_for_predictor)

    def training_step(self, batch, batch_idx):
        """Training step."""
        labels = batch["labels"]
        # The 'batch' itself is passed to HAIPRModule.forward,
        # which then passes it to the predictor's forward.
        outputs = self(batch)

        if self.num_classes == 0:  # regression
            labels = labels.float()
            # Ensure consistent shapes for regression

            if len(outputs.shape) == 1:
                outputs = outputs.view(-1, 1)

            if len(labels.shape) == 1:
                labels = labels.view(-1, 1)
        else:  # classification
            labels = labels.long()
            # No need to modify outputs for classification losses

        # Ensure outputs have the same shape as labels to avoid broadcasting issues

        if outputs.shape != labels.shape:
            outputs = outputs.view_as(labels)
        loss = self.criterion(outputs, labels)
        batch_size = labels.size(0)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, batch_size=batch_size
        )
        self.avg_train_loss += loss

        # # Increment training step counter and evaluate validation chunk if needed
        # self.training_step_count += 1
        # if (self.training_step_count % self.val_chunk_log_interval == 0 and
        #         self.val_chunk_data is not None):
        #     self._evaluate_validation_chunk()

        return loss

    def validation_step(self, batch, batch_idx):
        """Validation step."""
        labels = batch["labels"]
        # The 'batch' itself is passed to HAIPRModule.forward,
        # which then passes it to the predictor's forward.
        outputs = self(batch)

        if outputs.shape != labels.shape:
            outputs = outputs.view_as(labels)

        if self.num_classes == 0:  # regression
            labels = labels.float()
            loss = self.criterion(outputs, labels)
            # Ensure both predictions and labels are 2D tensors
            preds = outputs.view(-1, 1)  # Make predictions 2D
            labels_for_metric = labels.view(-1, 1)  # Make labels 2D
            probs = None
        else:  # classification
            labels = labels.long()
            loss = self.criterion(outputs, labels)
            preds = torch.argmax(
                outputs, dim=1
            )  # argmax of the logits [0.2, 0.3, 0.5] -> 2
            labels_for_metric = labels
            probs = torch.softmax(
                outputs, dim=1
            )  # probabilities for each class [2, 3, 5] -> [0.2, 0.3, 0.5]
            logger.debug(f"probs: {probs.shape}")
            logger.debug(f"preds: {preds.shape}")

        # Log validation metrics
        batch_size = labels.size(0)
        self.log(
            "val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
            batch_size=batch_size,
        )

        # Store predictions, labels, and sample_ids for gathering at epoch end
        self.val_preds.append(preds.detach())
        self.val_labels.append(labels_for_metric.detach())

        if probs is not None:
            self.val_probs.append(probs.detach())
        # DDP-safe: gather sample_id if present

        if "sample_id" in batch:
            self.val_sample_ids.append(batch["sample_id"].detach().cpu())

        return {"val_loss": loss}

    def on_validation_epoch_start(self):
        """Reset validation lists at start of validation epoch"""
        self.val_preds = []
        self.val_labels = []
        self.val_probs = []
        self.val_sample_ids = []
        logger.debug("Validation epoch started - reset prediction lists")

    def on_validation_epoch_end(self):
        """Called at the end of validation to compute and log metrics."""

        if (
            not self.val_preds or not self.val_labels
        ):  # Check if lists are empty
            logger.warning(
                "Validation epoch end: No predictions or labels to compute metrics."
            )
            # Log dummy metrics or skip if no data
            metrics_to_log = {"val_loss": float("nan")}

            if self.num_classes == 0:  # Regression
                for metric_name in [
                    "mse",
                    "mae",
                    "r2",
                    "pearson_r",
                    "spearman_r",
                ]:
                    metrics_to_log[f"val_{metric_name}"] = float("nan")
            else:  # Classification
                for metric_name in [
                    "accuracy",
                    "precision",
                    "recall",
                    "f1",
                    "roc_auc",
                    "mcc",
                ]:
                    metrics_to_log[f"val_{metric_name}"] = float("nan")
            self.log_dict(
                metrics_to_log, on_epoch=True, prog_bar=False, sync_dist=True, batch_size=1
            )

            return

        # Gather predictions from all GPUs if DDP is enabled

        if self._is_ddp_enabled():
            gathered_preds = self._gather_predictions(self.val_preds)
            gathered_labels = self._gather_predictions(self.val_labels)

            if self.val_probs:
                gathered_probs = self._gather_predictions(self.val_probs)
            else:
                gathered_probs = []
            # Gather sample_ids if present

            if self.val_sample_ids:
                gathered_sample_ids = self._gather_predictions(
                    self.val_sample_ids
                )
            else:
                gathered_sample_ids = []
        else:
            gathered_preds = self.val_preds
            gathered_labels = self.val_labels
            gathered_probs = self.val_probs
            gathered_sample_ids = (
                self.val_sample_ids if hasattr(self, "val_sample_ids") else []
            )

        preds_all = torch.cat(gathered_preds)
        labels_all = torch.cat(gathered_labels)
        sample_ids_all = (
            torch.cat(gathered_sample_ids) if gathered_sample_ids else None
        )

        logger.debug(
            f"Validation epoch end: collected {len(preds_all)} predictions, {len(labels_all)} labels"
        )

        # Fetch val_loss from trainer as it's aggregated across batches/devices
        val_loss_tensor = self.trainer.callback_metrics.get("val_loss")
        val_loss = (
            val_loss_tensor.item()

            if val_loss_tensor is not None
            else float("nan")
        )

        # convert to float32 if needed.

        if preds_all.dtype == torch.bfloat16:
            preds_all = preds_all.float()

        if labels_all.dtype == torch.bfloat16:
            labels_all = labels_all.float()

        preds_all = preds_all.cpu()
        labels_all = labels_all.cpu()

        if sample_ids_all is not None:
            sample_ids_all = sample_ids_all.cpu().numpy().flatten()

        preds_np = preds_all.numpy()
        labels_np = labels_all.numpy()

        # If sample_ids are present, sort predictions/labels by sample_id

        if sample_ids_all is not None:
            sort_idx = sample_ids_all.argsort()
            preds_np = preds_np[sort_idx]
            labels_np = labels_np[sort_idx]
            sample_ids_all = sample_ids_all[sort_idx]

        self.last_val_predictions = {
            "preds": preds_np,
            "labels": labels_np,
        }

        if sample_ids_all is not None:
            self.last_val_predictions["sample_ids"] = sample_ids_all

        probs_np = None

        if gathered_probs:
            try:
                probs_all = torch.cat(gathered_probs).cpu()
                probs_np = probs_all.numpy()
                self.last_val_predictions["probs"] = probs_np
            except RuntimeError as e:
                logger.warning(
                    f"Could not concatenate validation probabilities: {e}"
                )

        if self._is_ddp_enabled() and not self.trainer.is_global_zero:
            return

        # NOTE: RANK 0 ONLY
        # log learning rate on end of validation epoch
        self.log(
            "learning_rate",
            self.trainer.optimizers[0].param_groups[0]["lr"],
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            sync_dist=True,
        )

        if self.num_classes == 0:  # Regression
            metrics = compute_regression_metrics(labels_np, preds_np)
        else:  # Classification
            metrics = compute_classification_metrics(
                labels_np,
                preds_np,
                probs_np,  # Pass probs_np, which could be None
            )

        # Ensure val_loss is part of metrics logged
        metrics["val_loss"] = val_loss

        metrics_to_log = {
            f"val_{k}" if not k.startswith("val_") else k: v

            for k, v in metrics.items()
        }
        # Calculate total validation samples for batch_size
        total_val_samples = len(preds_np)
        self.log_dict(
            metrics_to_log, on_epoch=True, prog_bar=False, sync_dist=True, batch_size=total_val_samples
        )
        self.last_val_metrics = metrics_to_log

        if self.best_val_loss > val_loss:
            self.best_val_loss = val_loss
            self.best_val_metrics = metrics
            self.best_val_predictions = self.last_val_predictions

    def on_sanity_check_start(self):
        print("on_sanity_check_start")

    def on_sanity_check_end(self):
        """Called at the end of sanity check to reset state"""
        print("on_sanity_check_end")
        self.best_val_loss = float("inf")
        self.best_val_metrics = {}
        self.best_val_predictions = None
        self.last_val_predictions = None

    def configure_optimizers(self):
        """
        Configure default optimizers and learning rate schedulers with warmup.
        Can be overridden in the child (predictor) class.

        Uses AdamW optimizer, a linear warmup scheduler, and CosineAnnealingLR for learning rate scheduling.
        """
        # TODO: make this configurable through hydra
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

        # Set T_max to a reasonable default (e.g., 50 epochs), can be overridden in child class
        cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=50, eta_min=0
        )

        # Add a warmup scheduler (LinearLR) before CosineAnnealingLR
        # Warmup for 5 epochs by default; can be overridden in child class
        # warmup_epochs = 5
        # warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
        #     optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
        # )

        # Use SequentialLR to combine warmup and cosine annealing
        # scheduler = torch.optim.lr_scheduler.SequentialLR(
        #     optimizer,
        #     schedulers=[warmup_scheduler, cosine_scheduler],
        #     milestones=[warmup_epochs]
        # )

        scheduler = cosine_scheduler

        return [optimizer], [scheduler]

    def _create_validation_chunk(self):
        """Create a fixed random chunk from the validation dataloader for frequent evaluation."""

        if self.trainer is None or self.trainer.val_dataloaders is None:
            logger.warning(
                "No validation dataloader available for chunk creation"
            )

            return

        val_dataloader = self.trainer.val_dataloaders

        if isinstance(val_dataloader, list):
            # Use first dataloader if multiple
            val_dataloader = val_dataloader[0]

        # Set random seed for reproducible chunk selection
        torch.manual_seed(42)

        chunk_data = []
        total_samples = 0

        try:
            for batch in val_dataloader:
                chunk_data.append(batch)
                # batch size from first tensor in batch

                if isinstance(batch, dict) and len(batch) > 0:
                    first_key = next(iter(batch.keys()))

                    if torch.is_tensor(batch[first_key]):
                        batch_size = batch[first_key].size(0)
                    else:
                        batch_size = 1
                else:
                    batch_size = 1

                total_samples += batch_size

                estimated_total_samples = len(val_dataloader) * batch_size
                target_chunk_size = int(
                    self.val_chunk_size * estimated_total_samples
                )

                if total_samples >= target_chunk_size:
                    break

        except Exception as e:
            logger.warning(f"Error creating validation chunk: {e}")

            return

        if chunk_data:
            self.val_chunk_data = chunk_data
            logger.info(
                f"Created validation chunk with {total_samples} samples from {len(chunk_data)} batches"
            )
        else:
            logger.warning(
                "Failed to create validation chunk - no data collected"
            )

    def _evaluate_validation_chunk(self):
        """Evaluate the validation chunk and log the loss."""

        if self.val_chunk_data is None:
            return

        self.model.eval()
        total_loss = 0.0
        total_samples = 0

        with torch.no_grad():
            for batch in self.val_chunk_data:
                # Move batch to device

                if isinstance(batch, dict):
                    batch = {
                        k: v.to(self.device) if torch.is_tensor(v) else v

                        for k, v in batch.items()
                    }
                else:
                    batch = batch.to(self.device)

                try:
                    labels = batch["labels"]
                    outputs = self(batch)

                    if self.num_classes == 0:  # regression
                        labels = labels.float()
                        # Ensure consistent shapes for regression

                        if len(outputs.shape) == 1:
                            outputs = outputs.view(-1, 1)

                        if len(labels.shape) == 1:
                            labels = labels.view(-1, 1)
                    else:  # classification
                        labels = labels.long()

                    # Ensure outputs have the same shape as labels

                    if outputs.shape != labels.shape:
                        outputs = outputs.view_as(labels)

                    batch_loss = self.criterion(outputs, labels)

                    # Get batch size for weighted average
                    batch_size = labels.size(0)
                    total_loss += batch_loss.item() * batch_size
                    total_samples += batch_size

                except Exception as e:
                    logger.warning(
                        f"Error evaluating validation chunk batch: {e}"
                    )

                    continue

        self.model.train()

        if total_samples > 0:
            avg_loss = total_loss / total_samples
            self.log(
                "val_chunk_loss",
                avg_loss,
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                batch_size=total_samples,
            )

    def on_train_start(self):
        """Called at the start of training to initialize validation chunk."""
        super().on_train_start() if hasattr(super(), "on_train_start") else None
        # self._create_validation_chunk()

    def on_train_epoch_start(self):
        """Called at the start of each training epoch."""
        (
            super().on_train_epoch_start()

            if hasattr(super(), "on_train_epoch_start")
            else None
        )
        self.avg_train_loss = 0
        # Reset training step counter for consistent interval behavior across epochs
        self.training_step_count = 0

    def on_train_epoch_end(self):
        self.log("avg_train_loss", self.avg_train_loss)

    def predict_sequences(
        self, sequences: List[str], params: Dict[str, Any] | None = None
    ) -> Dict[str, Any]:
        """
        Generic predict method that prepares features once and then processes in batches.

        Args:
            sequences: List of protein sequence strings to predict on

        Returns:
            Dictionary containing:
                - predictions: numpy array of predictions
                - probabilities: numpy array of probabilities (for classification tasks)
        """
        self.model.eval()  # Set model to evaluation mode
        # set batch size from params if provided overwriting self.batch_size

        if params is not None and "batch_size" in params:
            self.batch_size = params["batch_size"]
            logger.debug(f"Setting batch size from params: {self.batch_size}")

        # Prepare all features upfront in one go
        sequences_as_dicts = [{"sequence": seq} for seq in sequences]
        prepared_features = self.prepare_features(sequences_as_dicts)

        # Handle different output formats from prepare_features
        # TODO: refactor, models are responsible for being consistent with their own format

        if isinstance(prepared_features, dict):
            # Check if it has "inputs" key (ESMC, ESM3 format)

            if "inputs" in prepared_features:
                inputs = prepared_features["inputs"]
            else:
                # Direct format (ESM2 format: {"input_ids": tensor, "attention_mask": tensor, "labels": tensor})
                inputs = {
                    k: v for k, v in prepared_features.items() if k != "labels"
                }
        else:
            inputs = prepared_features
            labels = None

        # Convert to tensors if needed

        if isinstance(inputs, dict):
            for k, v in inputs.items():
                if isinstance(v, np.ndarray):
                    inputs[k] = torch.tensor(v)
                elif isinstance(v, list):
                    inputs[k] = torch.tensor(v)

        logger.info("Using Lightning Trainer for inference")

        return self._predict(inputs)

    def _predict(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """Use Lightning Trainer for prediction."""
        from torch.utils.data import DataLoader, TensorDataset

        # Create dataset from prepared features

        if isinstance(inputs, dict):
            # Handle dict inputs (multiple tensors)
            input_tensors = list(inputs.values())
            dataset = TensorDataset(*input_tensors)
        else:
            # Handle single tensor input
            dataset = TensorDataset(inputs)

        # Create dataloader
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=False,
        )

        # Use trainer.predict() - Lightning handles device placement automatically
        predictions_list = self._inference_trainer.predict(
            self, dataloaders=dataloader
        )

        logger.debug(
            f"Trainer predictions_list type: {type(predictions_list)}")
        logger.debug(
            f"Trainer predictions_list length: {len(predictions_list)}"
        )

        if len(predictions_list) > 0:
            logger.debug(f"First batch type: {type(predictions_list[0])}")
            logger.debug(
                f"First batch shape: {predictions_list[0].shape if hasattr(predictions_list[0], 'shape') else 'no shape'}"
            )

        # Flatten the list of batches
        predictions_tensor = torch.cat(predictions_list, dim=0)
        predictions_np = predictions_tensor.cpu().numpy()

        logger.debug(f"Final predictions shape: {predictions_np.shape}")
        logger.debug(
            f"Final predictions sample: {predictions_np[:5] if len(predictions_np) > 0 else 'empty'}"
        )

        results = {"predictions": predictions_np}

        if self.num_classes > 0:
            probs = torch.softmax(predictions_tensor, dim=-1)
            results["probabilities"] = probs.cpu().numpy()

        return results

    def _create_tensor_dataset(
        self,
        features_dict: Dict[str, torch.Tensor],
        labels: torch.Tensor,
    ) -> Dataset:
        """
        Create a TensorDataset that returns dict batches compatible with forward().

        Args:
            features_dict: Dictionary of feature tensors (e.g., {"input_ids": tensor, "attention_mask": tensor})
            labels: Labels tensor

        Returns:
            Dataset that returns {"inputs": features_dict, "labels": label} for each item
        """
        from haipr.models.utils import FeatureTensorDataset

        return FeatureTensorDataset(features_dict, labels)

    def _create_inference_dataset(
        self,
        features_dict: Dict[str, torch.Tensor],
    ) -> Dataset:
        """Create a TensorDataset that returns dict batches compatible with forward()."""
        from haipr.models.utils import FeatureTensorDataset

        return FeatureTensorDataset(features_dict, None)

    def _create_dataloaders(
        self,
        features_dict: Dict[str, torch.Tensor],
        labels: torch.Tensor,
        train_indices: np.ndarray,
        val_indices: np.ndarray,
        batch_size: int | None = None,
        shuffle_train: bool = True,
    ) -> Tuple[DataLoader, DataLoader]:
        """
        Create train and validation DataLoaders from pre-computed features.

        Args:
            features_dict: Dictionary of feature tensors
            labels: Labels tensor
            train_indices: Training set indices
            val_indices: Validation set indices
            batch_size: Batch size (uses self.batch_size if None)
            shuffle_train: Whether to shuffle training data

        Returns:
            Tuple of (train_loader, val_loader)
        """

        if batch_size is None:
            batch_size = self.batch_size
        # Try to get from trainer config if available
        try:
            # Check if trainer is available by trying to access it
            trainer = self.trainer

            if trainer is not None:
                # Get from trainer's config
                trainer_cfg = getattr(trainer, "config", None)

                if trainer_cfg is not None:
                    if hasattr(trainer_cfg, "trainer"):
                        num_workers = getattr(
                            trainer_cfg.trainer, "num_workers", 0
                        )
                        pin_memory = getattr(
                            trainer_cfg.trainer, "pin_memory", True
                        )
                    elif hasattr(trainer_cfg, "num_workers"):
                        num_workers = trainer_cfg.num_workers

                    if hasattr(trainer_cfg, "pin_memory"):
                        pin_memory = trainer_cfg.pin_memory
        except RuntimeError:
            # Model is not attached to a trainer yet, use defaults
            pass

        # Create dataset from pre-computed features
        dataset = self._create_tensor_dataset(features_dict, labels)

        # Create DataLoaders
        train_loader = DataLoader(
            Subset(dataset, train_indices),
            batch_size=batch_size,
            shuffle=shuffle_train,
            num_workers=self.train_num_workers,
            pin_memory=True,
        )

        val_loader = DataLoader(
            Subset(dataset, val_indices),
            batch_size=batch_size,
            shuffle=False,
            num_workers=self.val_num_workers,
            pin_memory=True,
        )

        logger.info(
            f"Created DataLoaders: train={len(train_indices)} samples, val={len(val_indices)} samples, train_num_workers={self.train_num_workers}, val_num_workers={self.val_num_workers}, pin_memory={True}"
        )

        return train_loader, val_loader

    def predict_step(self, batch, batch_idx):
        """Prediction step for Lightning's predict loop."""
        # Handle different batch formats from TensorDataset

        if isinstance(batch, (list, tuple)):
            # TensorDataset returns tuple of tensors, need to reconstruct dict

            if len(batch) == 2:  # (input_ids, attention_mask) for ESM2
                batch_dict = {
                    "inputs": {
                        "input_ids": batch[0],
                        "attention_mask": batch[1],
                    }
                }
            elif len(batch) == 1:  # (sequence_tokens,) for ESMC
                batch_dict = {"inputs": {"sequence_tokens": batch[0]}}
            else:
                # Fallback - assume first tensor is the main input
                batch_dict = {"inputs": batch[0]}
        elif isinstance(batch, dict):
            batch_dict = batch
        else:
            # Single tensor case
            batch_dict = {"inputs": batch}

        # Forward pass
        with torch.no_grad():
            outputs = self.forward(batch_dict)

        # Convert BFloat16 to Float32 if needed

        if outputs.dtype == torch.bfloat16:
            outputs = outputs.float()

        # Ensure output shape is consistent

        if self.num_classes == 0 and len(outputs.shape) == 1:
            outputs = outputs.unsqueeze(1)

        return outputs

    def load_context(self, context):
        """Load artifacts when model is loaded from MLflow."""
        logger.info(f"Loading context: {context}")

        # Load config if available

        if "config" in context.artifacts:
            self.cfg = OmegaConf.load(context.artifacts["config"])
            logger.info("Loaded config from artifacts")

        try:
            # Load checkpoint to CPU first
            checkpoint = torch.load(
                context.artifacts["model"], map_location="cpu"
            )
            self.model.load_state_dict(checkpoint)
            self.model.eval()
            logger.info("Successfully loaded PyTorch model artifacts")
        except Exception as e:
            logger.error(f"Failed to load model artifacts: {e}")
            raise

        # Create Trainer for inference
        try:
            trainer = pl.Trainer(
                accelerator="auto",
                devices=1,
                logger=False,
                enable_progress_bar=True,
                enable_model_summary=False,
                enable_checkpointing=False,
            )
            self._inference_trainer = trainer
            logger.info("Created Trainer for inference")
        except Exception as e:
            logger.error(f"Failed to create Trainer: {e}")
            raise

    def save_model(self, save_dir: str) -> str:
        """Save the model (for checkpoint restoration during training).
        Returns the path to the saved model.
        """
        raise NotImplementedError(
            "save_model needs to be implemented in the subclass"
        )
