import hashlib
import json
import logging
import os
import numpy as np
import pandas as pd
import protenc
import torch
import ast
import time
import hydra
from omegaconf import DictConfig
from sklearn.model_selection import KFold, StratifiedKFold
from torch.utils.data import Dataset
from tqdm import tqdm
from typing import List
import pickle
import h5py
import lmdb

from haipr.utils.data_utils import DMS_file_for_LLM
from haipr.utils import AA_ALPHABET_WITH_EXTRAS
from haipr.utils.resolvers import register_resolvers
from haipr.embedding_manager import EmbeddingManager

# A logger for this file
logger = logging.getLogger(__name__)
# logger.setLevel(logging.DEBUG)


class HAIPRData(Dataset):
    """
    PyTorch Dataset for PPA (Protein Prediction and Analysis) data.

    Supports two main feature types:
    1. 'embedded': Uses an external embedder (e.g., protenc) to pre-compute embeddings.
    2. 'sequence': Provides raw sequences for the model to tokenize/embed.

    :param data: The input data as a pandas DataFrame.
    :param config: Hydra DictConfig
    """

    def __init__(
        self,
        config: DictConfig,  # OmegaConf DictConfig
        data: pd.DataFrame | None = None,
    ):
        """Initialize the HAIPRData dataset.
        TODO: minimal requirements to instantiate the class with and without config
        """
        if data is None:
            self.data: pd.DataFrame = pd.read_csv(config.benchmark.data)
        else:
            self.data = data

        # Process sequence data using DMS_file_for_LLM for consistency
        # This handles both single-string and multi-chain formats
        self.focus = getattr(config.data, "focus", False)

        # Use DMS_file_for_LLM to process sequences
        processed_data, self.focus_chains = DMS_file_for_LLM(
            self.data,
            focus=self.focus,
            return_focus_chains=True,
            sep=config.data.separator_token,
        )
        self.variable_chains = self.focus_chains  # backward compatibility

        # Set the sequence column to the processed mutated_sequence
        self.data["sequence"] = processed_data["mutated_sequence"]

        self._validate_input_data(self.data, config)

        # essential instance variables
        # Keep a lightweight snapshot of initial sequences for cache keys
        self._init_sequences = self.data["sequence"].tolist()
        self.config = config
        self.label_col = config.data.label_column
        self.num_splits = config.data.num_splits
        self.cache_key = None
        self.features_loaded = False

        if hasattr(config.benchmark, "pdb") and config.benchmark.pdb:
            self.pdb = config.benchmark.pdb
            logger.debug(
                f"PDB path is set to {config.benchmark.pdb}, using to embed sequences with protenc"
            )
        else:
            logger.debug("No PDB Supplied Using Sequence Only.")
            self.pdb = None

        # Set up cache directory
        self._setup_cache_dir()

        # Set cache format (default to hdf5 for better performance)
        self.cache_format = getattr(config.data, "cache_format", "hdf5")

        # Add original indices if needed
        if "orig_idx" not in self.data.columns:
            self.data["orig_idx"] = np.arange(len(self.data))
        if "sample_id" not in self.data.columns:
            self.data["sample_id"] = np.arange(len(self.data))
        # Initialize state
        self.features = None
        self.embedder_instance = None
        # Initialize as empty list instead of None
        self.feature_cols: list[str] = []

        # Initialize embedding manager
        self.embedding_manager = EmbeddingManager(config)

        # Track active view of rows (training pool). Start with all rows.
        self.active_idx = np.arange(len(self.data))
        # Track absolute test indices (w.r.t. self.data); None until set
        self.test_idx_abs: np.ndarray | None = None

        # Set up the dataset
        self.setup()
        # Data may have been filtered/reset in setup; refresh active view
        self.active_idx = np.arange(len(self.data))

    def __len__(self):
        """
        Get the length of the dataset.

        :return: The number of samples in the dataset.
        """

        return len(self.active_idx)

    def __getitem__(self, idx):
        """
        Get an item from the dataset.

        Args:
            idx: Integer index or slice/list/array of indices

        Returns:
            If feature_type is 'embedded': (features, labels)
            If feature_type is 'sequence': {"sequence": sequence_data, "labels": label_data, "sample_id": sample_id}
            If feature_type is 'multi_chain_sequence': {"sequence": sequence_data, "labels": label_data, "sample_id": sample_id}
        """

        if self.config.model.feature_type == "embedded":
            # Assumes self.feature_cols is populated with embedding column names
            if not self.feature_cols:
                raise RuntimeError(
                    "Feature columns not set for 'embedded' type. Call data.prepare_features() first."
                )

            # Map dataset indices robustly
            # - Slice: map via active view
            # - Array-like (ndim>=1): assume absolute indices (splits store absolute)
            # - Scalar: map via active view
            if isinstance(idx, slice):
                mapped_idx = self.active_idx[idx]
            else:
                idx_arr = np.asarray(idx)
                if idx_arr.ndim == 0:
                    mapped_idx = self.active_idx[int(idx_arr)]
                else:
                    mapped_idx = idx_arr.astype(int)

            labels_series = self.data[self.label_col].iloc[mapped_idx]
            if np.isscalar(labels_series) or getattr(labels_series, "shape", ()) == ():
                labels = labels_series
            else:
                # pandas Series -> list
                labels = labels_series.values.tolist()

            features = self.data[self.feature_cols].iloc[mapped_idx].to_numpy()

            return torch.tensor(features, dtype=torch.float32), torch.tensor(
                labels,
                dtype=(
                    torch.float32 if self.config.model.num_classes == 0 else torch.long
                ),
            )

        if self.config.model.feature_type == "sequence":
            if isinstance(idx, (slice, list, np.ndarray)):
                raise NotImplementedError(
                    "Slicing/batch indexing directly in __getitem__ for 'sequence' type returning dict is not standard for collate_fn. DataLoader fetches items individually."
                )
            else:  # Single index single sequence
                # For sequence features, idx should be used directly since active_idx is just sequential
                # When Subset is used, it passes the correct indices directly
                sequence_data = self.data["sequence"].iloc[idx]
                label_data = self.data[self.label_col].iloc[idx]
                result = {
                    "sequence": sequence_data,
                    "labels": torch.tensor(
                        label_data,
                        dtype=(
                            torch.float32
                            if self.config.model.num_classes == 0
                            else torch.long
                        ),
                    ),
                }
                # Add sample ID for DDP compatibility
                if "sample_id" in self.data.columns:
                    result["sample_id"] = self.data["sample_id"].iloc[idx]
                return result
        elif self.config.model.feature_type == "multi_chain_sequence":
            if not self.data["mutated_sequence"].iloc[0].startswith("{"):
                raise ValueError(
                    "Mutated sequence is not a multi-chain sequence format"
                )
            else:
                # For multi-chain sequence features, idx should be used directly
                sequence_data = self.data["mutated_sequence"].iloc[idx]
                label_data = self.data[self.label_col].iloc[idx]
                result = {
                    "sequence": sequence_data,
                    "labels": torch.tensor(
                        label_data,
                        dtype=(
                            torch.float32
                            if self.config.model.num_classes == 0
                            else torch.long
                        ),
                    ),
                }
                # Add sample ID for DDP compatibility
                if "sample_id" in self.data.columns:
                    result["sample_id"] = self.data["sample_id"].iloc[idx]
                return result
        else:
            raise ValueError(
                f"Unsupported feature_type in __getitem__: {self.config.model.feature_type}"
            )

    def get_sequences(self):
        """
        Get the sequences for the dataset.
        """
        return self.data["sequence"].to_list()

    def get_labels(self):
        """
        Get the labels for the dataset.
        """

        return self.data[self.label_col].to_numpy()

    def normalize_labels(self):
        """
        Normalize labels to be between 0 and 1.
        """
        logger.debug("Normalizing labels between 0 and 1")
        logger.debug(
            f"Initial range: {self.data[self.label_col].min()} to {self.data[self.label_col].max()}"
        )
        self.data[self.label_col] = (
            self.data[self.label_col] - self.data[self.label_col].min()
        ) / (self.data[self.label_col].max() - self.data[self.label_col].min())
        logger.debug(
            f"Normalized range: {self.data[self.label_col].min()} to {self.data[self.label_col].max()}"
        )

    def setup(self):
        """
        Set up the HAIPRData object.
        """
        logger.info("Setting up HAIPRData")
        logger.debug(f"Data: \n{self.data.head()}")
        logger.debug(f"Label column: {self.label_col}")
        logger.debug(
            f"Sequence column: {self.config.benchmark.sequence_column}")

        if self.config.data.normalize_labels:
            self.normalize_labels()

        if hasattr(self.config.benchmark, "binary") and self.config.benchmark.binary:
            # check if label column is binary
            if self.data[self.label_col].nunique() == 2:
                logger.info("Label column is binary")
                # nothing more to do
            else:
                raise ValueError("Label column is not binary")

        # Process labels first - Check task type and handle discretization
        elif hasattr(self.config, "task"):
            if self.config.task == "classification":
                if (
                    not hasattr(self.config, "num_classes")
                    or self.config.num_classes <= 0
                ):
                    logger.warning(
                        "num_classes not specified for classification task, defaulting to 2"
                    )
                    self.config.num_classes = 2
                logger.info(
                    f"Discretizing labels into {self.config.num_classes} classes for classification"
                )
                self.discretize_labels(self.config.num_classes)
            elif self.config.task == "regression":
                if hasattr(self.config, "num_classes") and self.config.num_classes > 0:
                    logger.warning(
                        "num_classes specified for regression task, ignoring"
                    )
                    self.config.num_classes = 0

        self.remove_nan_labels()

        # Remove constant labels if configured
        if self.config.data.remove_constant_labels:
            self.remove_constant_labels()

        # Log sequence column setup
        logger.info(
            f"Sequence columns initialized: sequence ({len(self.data['sequence'][0])} chars)"
        )

        # cut top percent of data by label for test data
        if self.config.data.cut_top is not None:
            self._analyze_sequence_variants()  # populate self.seq_list
            # likely already cached, would recompute otherwise since sequences hash changes
            self.prepare_features()
            self.cut_top_percent_by_label(self.config.data.cut_top)
        else:
            self._analyze_sequence_variants()
            self._setup_pos_wise_alphabet()

    def get_test_data(self):
        """
        Get the test data.
        """
        if self.test_idx_abs is None:
            raise ValueError("Test data not set. Call set_test_data() first.")
        return self.data.iloc[self.test_idx_abs]

    def get_train_data(self):
        """
        Get the current training pool view of the data.
        """
        return self.data.iloc[self.active_idx]

    def set_test_data(self, split_method: str, split_idx: int):
        """
        Set the test data. This will generate splits if they are not already generated.
        """
        logger.info(
            f"Setting test data for split method: {split_method} and split index: {split_idx}"
        )

        if self.test_idx_abs is not None:
            logger.warning(
                "Test data already set, resetting selection and regenerating splits"
            )
            self.test_idx_abs = None

        # Always regenerate splits on the current active view to avoid stale indices
        self.generate_splits(split_method)
        # splits store absolute indices
        self.test_idx_abs = np.array(self.splits[split_idx][1])
        # Update active_idx to exclude absolute test indices (training pool view)
        self.active_idx = self.active_idx[~np.isin(
            self.active_idx, self.test_idx_abs)]
        logger.info(
            f"Test size: {len(self.test_idx_abs)} | Train size: {len(self.active_idx)}"
        )

    def subsample_data(self, threshold=0):
        """
        Subsample the data.
        """
        if threshold > 0 and len(self.active_idx) > threshold:
            logger.info(
                f"Subsampling active training pool to {threshold} samples")
            rng = np.random.RandomState(self.config.seed)
            selected = rng.choice(
                self.active_idx, size=threshold, replace=False)
            # Keep stable ascending order for deterministic batches
            self.active_idx = np.sort(selected)
            logger.info(f"Remaining training pool: {len(self.active_idx)}")
        else:
            logger.info("No subsampling needed or threshold not exceeded")
        self.is_subsampled = True

    def _setup_cache_dir(self):
        """Set up the cache directory."""
        if (
            hasattr(self.config.data, "cache_dir")
            and self.config.data.cache_dir is not None
        ):
            self.cache_dir = self.config.data.cache_dir
        else:  # default cache dir
            self.cache_dir = os.path.join(
                os.path.expanduser("~"), ".cache", "haipr", "embeddings"
            )
        os.makedirs(self.cache_dir, exist_ok=True)
        logger.info(f"Using cache directory: {self.cache_dir}")

    def _setup_pos_wise_alphabet(self, full_alphabet: bool = False):
        """
        Set up the position-wise alphabet for the sequences.

        :param full_alphabet: If True, use the full amino acid alphabet for all positions.
                              If False, use only the observed amino acids at each position.
        """
        alphabet = [str(aa) for aa in "ACDEFGHIKLMNPQRSTVWY"]

        if full_alphabet:
            self.pos_wise_alphabet = [alphabet for _ in self.seq_df.columns]
        else:
            self.pos_wise_alphabet = [
                # NOTE: if non homogenous seq length,
                # convert NoneType to str for sorted()
                sorted(self.seq_df[col].astype(str).unique())
                for col in self.seq_df.columns
            ]
        logger.debug(f"Pos-wise alphabet: {self.pos_wise_alphabet}")

    def _analyze_sequence_variants(self):
        """Analyze sequence variants to compute key attributes.

        Computes:
        - Sequence dataframe representation
        - Amino acid frequencies per position
        - Mutated positions
        - Consensus sequence
        - Representative sequence
        - Distance metrics
        """
        logger.info("Analyzing sequence variants")
        logger.debug(f"data: \n{self.data.shape}")

        # Get full sequences for analysis (we need all positions for mutation analysis)
        self.seqs_list = self.data["sequence"].to_numpy()

        # Create sequence dataframe from processed sequences
        self.seq_df = pd.DataFrame([list(seq) for seq in self.seqs_list])
        logger.debug(f"Seq df: \n{self.seq_df.shape}")

        # Recalculate frequencies and consensus
        self.sfreqs = self.seq_df.apply(
            pd.Series.value_counts).fillna(0).astype(int)
        self.aa_per_pos = self.sfreqs.apply(
            lambda x: list(zip(x.index[x > 0], x[x > 0])), axis=0
        )
        logger.debug(f"Amino acids per position: \n{self.aa_per_pos}")
        # get positions with more than one amino acid
        self.mut_pos = self.aa_per_pos.loc[self.aa_per_pos.apply(len) > 1]
        logger.debug(
            f"Mutated positions: \n{self.aa_per_pos[self.mut_pos.index]}")
        self.consensus = self.sfreqs.idxmax(axis=0)

        # TODO: add wild_type to config
        if (
            hasattr(self.config.benchmark, "wild_type")
            and self.config.benchmark.wild_type
        ):
            self.wild_type = self.config.benchmark.wild_type
        else:
            self.wild_type = self.consensus

        # raise if no mutations
        if self.mut_pos.shape[0] == 0:
            raise ValueError("No mutations found in the dataset")

        # sequence-based attributes
        self._compute_representative_sequence()
        self._compute_dist_to_consensus()
        self._set_min_max_mutations()
        self.num_mut_sites = self.mut_pos.shape[0]

    def _set_min_max_mutations(self):
        """
        Set the minimum and maximum number of mutations in the dataset.
        """
        self.max_mutations = self.dist_consensus.max()
        self.min_mutations = self.dist_consensus.min()

    def _compute_dist_to_consensus(self):
        """
        Compute the distance of each sequence to the consensus sequence.
        """
        logger.info("Compute distannce to consesus")
        self.dist_consensus = (self.seq_df != self.consensus).sum(axis=1)

    def _compute_dist_to_best(self):
        logger.info("Compute distance to best")
        best_label_idx = np.argmax(self.data[self.label_col])
        seq_best = self.seq_df.iloc[best_label_idx]
        self.dist_best = (self.seq_df != seq_best).sum(axis=1)

    def _compute_representative_sequence(self):
        """
        Compute the representative sequence for the dataset.

        The representative sequence is the one that is closest to the consensus sequence.
        """
        self.rep_id = (self.seq_df == self.consensus).sum(axis=1).idxmax()
        self.representative = self.seqs_list[self.rep_id]

    def _validate_sequences(self):
        """Validate sequences before feature preparation."""
        if not hasattr(self, "data") or "sequence" not in self.data.columns:
            raise ValueError("Sequences not found. Run setup() first.")

        unique_sequences = self.data["sequence"].nunique()
        if unique_sequences == 1:
            raise ValueError(
                "Dataset contains only a single unique sequence variant")

        if unique_sequences < len(self.data):
            logger.warning(
                f"Found {len(self.data) - unique_sequences} duplicate sequences after processing"
            )

    def _validate_input_data(self, data: pd.DataFrame, config: DictConfig) -> None:
        """Validate input data and configuration."""
        if data.empty:
            raise ValueError("Input data cannot be empty")

        if config.benchmark.sequence_column not in data.columns:
            raise ValueError(
                f"Sequence column '{config.benchmark.sequence_column}' not found in data"
            )

        if config.data.label_column not in data.columns:
            raise ValueError(
                f"Label column '{config.data.label_column}' not found in data"
            )

        # Check for duplicate sequences
        if data[config.benchmark.sequence_column].duplicated().any():
            logger.warning(
                f"Found {data[config.benchmark.sequence_column].duplicated().sum()} duplicate sequences"
            )

        # Check for invalid amino acids in sequences
        if not data[config.benchmark.sequence_column][0].startswith("{"):
            # Only validate single string sequences, not multi-chain dict format
            invalid_aa_mask = data[config.benchmark.sequence_column].apply(
                lambda x: [aa for aa in x.strip(
                ) if aa not in AA_ALPHABET_WITH_EXTRAS]
            )
            if invalid_aa_mask.apply(len).any():
                logger.error(
                    f"Sequence column {config.benchmark.sequence_column} contains invalid amino acids: {invalid_aa_mask}"
                )
                raise ValueError(
                    f"Sequence column {config.benchmark.sequence_column} contains invalid amino acids"
                )

    def discretize_labels(self, num_classes):
        """
        Discretize labels into a specified number of classes.

        For classification tasks, converts continuous labels into discrete integer classes
        from 0 to num_classes-1. This ensures compatibility with sklearn classifiers.

        :param num_classes: The number of classes to discretize the labels into.
        :param set_labels: If True, set the discretized labels as the new labels for the dataset.
        :return: The discretized labels as integers.
        """
        logger.info(f"Discretizing labels into {num_classes} classes")
        logger.debug(
            f"Labels before discretization: {len(self.data[self.label_col].unique())}"
        )
        logger.debug(f"Label column: {self.label_col}")

        if isinstance(self.label_col, list) and len(self.label_col) > 1:
            raise ValueError(
                "Discretization is not supported for multiple label columns"
            )

        # Handle NaN values before discretization
        self.remove_nan_labels()

        # Use qcut to create discrete bins, but convert to integers
        # qcut returns categorical labels 0 to num_classes-1
        labels = pd.qcut(
            self.data[self.label_col],
            q=num_classes,
            labels=False,  # This ensures we get 0-based integer labels
            duplicates="drop",
        )  # Handle duplicate bin edges gracefully

        # Convert to integer type explicitly
        labels = labels.astype(int)

        logger.debug(
            f"Labels after discretization: {len(labels.unique())} unique integer classes"
        )
        logger.debug(
            f"Class distribution: {pd.Series(labels).value_counts().sort_index()}"
        )

        # Always update the data with discretized labels for classification tasks
        self.data["label_class"] = labels
        self.label_col = "label_class"

        return labels

    def remove_nan_labels(self):
        """
        Remove rows with NaN labels
        """
        logger.info("Removing rows with NaN labels")

        self.data.dropna(subset=[self.label_col], inplace=True)
        # Reset index of self.data but keep orig_idx intact
        self.data = self.data.reset_index(drop=True)

        logger.debug(f"Remaining Samples {len(self)}")

    def remove_constant_labels(self):
        """
        Remove samples with constant labels from the dataset.
        """
        logger.info("Removing constant labels")

        min_mask = self.data[self.label_col] == self.data[self.label_col].min()
        max_mask = self.data[self.label_col] == self.data[self.label_col].max()

        if min_mask.sum() > 1:
            logger.debug(
                f"Removing {min_mask.sum()} samples with minimum label")
            self.data = self.data[~min_mask]

        if max_mask.sum() > 1:
            logger.debug(
                f"Removing {max_mask.sum()} samples with maximum label")
            self.data = self.data[~max_mask]

        # Reset index to maintain consistency
        self.data = self.data.reset_index(drop=True)
        logger.debug(f"Remaining Samples {len(self)}")

    def fill_missing_labels(self, minimum=True):
        """
        Fill missing labels with either the minimum or maximum value.

        :param minimum: If True, fill missing labels with the minimum value. If False, use the maximum value.
        """
        logger.debug(
            f"Filling missing labels with min/max value {self.data[self.label_col].max()}"
        )

        if minimum:
            self.data[self.label_col] = self.data[self.label_col].fill(
                self.data[self.label_col].min()
            )
        else:
            self.data[self.label_col] = self.data[self.label_col].fill(
                self.data[self.label_col].max()
            )

    def cut_top_percent_by_label(self, percent=0.1, save_file=False):
        """
        Cut the top percent of the data by label, by sorting the data and removing the top percent
        """
        # sort data by label
        sorted_data = self.data.iloc[self.data[self.label_col].argsort()]
        # keep the first (1-percent) of the data
        self.data = sorted_data.iloc[
            : int(len(sorted_data) * (1 - percent))
        ].reset_index(drop=True)
        # write the top <percent> cutted data to file for test data
        if save_file:
            cut_data = sorted_data.iloc[int(len(sorted_data) * (1 - percent)):]
            cut_data.to_csv(
                f"{os.path.join(self.config.benchmark.data.split('/')[:-1])}/{self.config.benchmark.name}_cut_top_{percent}.csv",
                index=False,
            )
        logger.debug(f"Remaining Samples {len(self.data)}")

    def _initialize_one_hot_encoding(self, sparse=False):
        """Initialize one-hot encoding for sequences.

        :param sparse: Whether to use sparse encoding (not implemented)
        """

        if not hasattr(self, "pos_wise_alphabet"):
            self._setup_pos_wise_alphabet()
        self.one_hot_enc = pd.get_dummies(self.seq_df)

    def get_num_splits(self):
        """
        Get the number of splits in the dataset.

        :return: The number of splits.
        :raises ValueError: If splits have not been generated.
        """

        if not hasattr(self, "splits"):
            raise ValueError("Splits not generated")

        return len(self.splits)

    def generate_splits(self, split_method: str | None = None):
        """Generate data splits based on configuration."""
        splits = []
        logger.info("Generating splits")
        split_method = (
            self.config.data.split_method if split_method is None else split_method
        )
        logger.info(f"Generating splits using method: {split_method}")

        # If CV, generate splits over the active training view ONLY
        if split_method == "cv":
            labels_active = self.data[self.label_col].iloc[self.active_idx].to_numpy(
            )
            n_active = len(self.active_idx)

            # Choose splitter by label characteristics on the active view
            if (
                isinstance(labels_active[0], (int, float))
                and len(np.unique(labels_active)) > 10
            ):
                splitter = KFold(
                    n_splits=self.num_splits,
                    shuffle=True,
                    random_state=self.config.seed,
                )
                for train_v, val_v in splitter.split(np.arange(n_active)):
                    splits.append(
                        (self.active_idx[train_v], self.active_idx[val_v]))
            else:
                splitter = StratifiedKFold(
                    n_splits=self.num_splits,
                    shuffle=True,
                    random_state=self.config.seed,
                )
                for train_v, val_v in splitter.split(
                    np.arange(n_active), labels_active
                ):
                    splits.append(
                        (self.active_idx[train_v], self.active_idx[val_v]))

            self.splits = splits
            logger.info(f"Generated {len(splits)} splits")
            return splits

        # For other methods, use existing generators, then restrict to active view below
        if split_method in self.data.columns:
            logger.info("Data already has a split column, using it")
            split_generator = self._group_splits(split_method)
        elif split_method == "lomo":
            split_generator = self._lomo_splits()
        elif split_method == "lomo_hard":
            split_generator = self._lomo_hard_splits()
        elif split_method == "ood":
            split_generator = self._ood_splits()
        elif split_method == "skewed":
            split_generator = self._skewed_splits()
        elif split_method == "contig":
            split_generator = self._contig_splits()
        elif split_method == "modulo":
            split_generator = self._modulo_splits()
        else:
            raise ValueError(f"Invalid split method: {split_method}")

        for split_idx, (train, test) in enumerate(split_generator):
            logger.debug(f"Train: {len(train)}, Test: {len(test)}")
            # if train or test is less than 0.08 * len(self.data)
            total_len = len(train) + len(test)
            if len(train) < 0.08 * total_len:
                logger.warning(
                    f"Skipping split {split_idx} due to small train set: Train: {len(train)}, Test: {len(test)}"
                )

                continue
            elif len(test) < 0.08 * total_len:
                logger.warning(
                    f"Skipping split {split_idx} due to small test set: Train: {len(train)}, Test: {len(test)}"
                )

                continue
            # Convert to absolute indices if needed and intersect with active view
            train_abs = np.array(train)
            test_abs = np.array(test)
            logger.debug(
                f"Split {split_idx}: train_abs={train_abs}, test_abs={test_abs}")
            logger.debug(
                f"active_idx={self.active_idx}, len(active_idx)={len(self.active_idx)}")

            # If splitter produced indices over the active view (0..len(active_idx)-1),
            # map to absolute via active_idx. Otherwise treat as absolute already.
            if np.array_equal(
                np.sort(np.unique(np.concatenate([train_abs, test_abs]))),
                np.arange(len(self.active_idx)),
            ):
                train_abs = self.active_idx[train_abs]
                test_abs = self.active_idx[test_abs]

            # Ensure no overlap and both are within active view
            train_abs = train_abs[np.isin(train_abs, self.active_idx)]
            test_abs = test_abs[np.isin(test_abs, self.active_idx)]

            splits.append((train_abs, test_abs))

        if not splits:
            raise ValueError(
                f"No valid splits generated using method: {self.config.data.split_method}"
            )

        self.splits = splits
        logger.info(f"Generated {len(splits)} splits")
        return splits

    def get_single_split(self, split_method="cv", split_idx=0):
        """
        Get a single split from the generated splits.

        :param split_method: The method used to generate the splits.
        :param split_idx: The index of the split to retrieve.
        :return: The requested split.
        """

        if not hasattr(self, "splits"):
            self.generate_splits()

        return self.splits[split_idx]

    def _ood_splits(self):
        """
        Generate out-of-distribution splits based on label bins.
        """
        # Use active_idx instead of full dataset
        labels_active = self.data[self.label_col].iloc[self.active_idx]
        bins, _ = pd.qcut(
            labels_active,
            q=self.num_splits,
            labels=False,
            retbins=True,
            duplicates="drop",
        )
        valid_indices = self.active_idx

        for b in np.unique(bins):
            bin_mask = bins == b
            # Use boolean indexing with valid_indices
            train = valid_indices[~bin_mask]
            test = valid_indices[bin_mask]
            yield train, test

    def _skewed_splits(self, top_percent=0.4, sample_from_top_percent=0.5):
        """
        Generate skewed splits based on label distribution.
        Creates num_splits different train/test splits where test sets are sampled from the top percentile.

        :param top_percent: The percentage of top labels to consider for skewed sampling.
        :param sample_from_top_percent: The percentage of samples to draw from the top labels.
        :yield: Tuples of (train_indices, test_indices) for each split.
        """

        if not hasattr(self, "label_col"):
            raise ValueError("Labels are not set")

        # Use active_idx instead of full dataset
        labels_active = self.data[self.label_col].iloc[self.active_idx]
        sorted_indices = np.argsort(labels_active).values
        # Map back to absolute indices
        sorted_indices = self.active_idx[sorted_indices]
        n_samples = len(sorted_indices)

        # Calculate split points
        split_point = int(n_samples * (1 - top_percent))

        # Split indices into train and potential test
        train_indices = np.array(sorted_indices[:split_point])
        test_pool_indices = np.array(sorted_indices[split_point:])

        logger.debug(f"Total samples: {n_samples}")
        logger.debug(f"Train size: {len(train_indices)}")
        logger.debug(f"Test pool size: {len(test_pool_indices)}")

        # Generate multiple splits
        rng = np.random.RandomState(self.config.seed)

        for _ in range(self.num_splits):
            # Sample from test pool
            n_test = int(len(test_pool_indices) * sample_from_top_percent)
            test_mask = rng.choice(
                len(test_pool_indices), n_test, replace=False)

            # Get test indices
            test_indices = np.array(test_pool_indices[test_mask])

            # Add remaining test pool indices to train
            remaining_test = test_pool_indices[
                ~np.isin(np.arange(len(test_pool_indices)), test_mask)
            ]
            current_train = np.concatenate([train_indices, remaining_test])

            yield current_train, test_indices

    def _group_splits(self, split_by: str):
        """
        Generate splits based on a grouping column in the data.

        :param split_by: The name of the column to use for grouping.
        :yield: Tuples of (train_indices, test_indices) for each split.
        """
        logger.info(f"Splitting based on column {split_by}")

        for group, indices in self.data.groupby(split_by).groups.items():
            train = self.data.index[~self.data.index.isin(indices)]
            test = self.data.index[indices]
            yield train, test

    def _cv_splits(self):
        """
        Generate cross-validation splits.
        Uses either StratifiedKFold for classification or KFold for regression.
        The number of splits is determined by self.num_splits.

        :yield: Tuples of (train_indices, test_indices) for each split.
        """

        # Initialize splitter based on task type (classification vs regression)
        labels = self.data[self.label_col].to_numpy()

        if isinstance(labels[0], (int, float)) and len(np.unique(labels)) > 10:
            # For regression or continuous labels
            splitter = KFold(
                n_splits=self.num_splits, shuffle=True, random_state=self.config.seed
            )
            logger.info(
                f"Using KFold with {self.num_splits} splits for regression task"
            )
        else:
            # For classification or discrete labels
            splitter = StratifiedKFold(
                n_splits=self.num_splits, shuffle=True, random_state=self.config.seed
            )
            logger.info(
                f"Using StratifiedKFold with {self.num_splits} splits for classification task"
            )

        # Generate and yield splits

        for train, test in splitter.split(self.data.index, labels):
            yield train, test

    def _lomo_splits(self):
        """
        Generate leave-one-mutation-out splits.
        """
        logger.debug(f"Mutated positions: \n{self.mut_pos.index}")
        # Use active_idx instead of full dataset
        valid_indices = self.active_idx

        for pos in self.mut_pos.index:
            if pos not in self.consensus.index or pos not in self.seq_df.columns:
                logger.warning(f"Position {pos} not found in data, skipping")

                continue

            for aa_at_pos, _ in self.aa_per_pos[pos]:
                logger.debug(f"Position {pos} - Amino acid: {aa_at_pos}")
                pos_mask = self.seq_df[pos] == aa_at_pos
                train = valid_indices[pos_mask]
                test = valid_indices[~pos_mask]

                yield train, test

    def _lomo_hard_splits(self):
        """
        Generates classical lomo splits but only selectes the num_splits most difficult splits
        by comparing the train and test distribution by mean
        """
        splits = []
        for train, test in self._lomo_splits():
            splits.append((train, test))

        # sort splits by the difference in mean label distribution between train and test
        splits.sort(
            key=lambda x: np.abs(
                np.mean(self.data.loc[x[0]][self.label_col])
                - np.mean(self.data.loc[x[1]][self.label_col])
            )
        )

        for train, test in splits[: self.num_splits]:
            yield train, test

    def _modulo_splits(self):
        """
        Generate modulo position-based splits.
        """
        logger.info("Generating modulo position-based splits")
        folds = self.num_splits
        train = DMS_file_for_LLM(pd.read_csv(
            self.config.benchmark.data), focus=True)

        split = [[[], []] for _ in range(self.num_splits)]
        train = train.loc[
            train["mutant"].apply(lambda x: len(x.split(":")) < 2)
        ].reset_index(drop=True)
        print(train)
        all_muts = {}
        for i in train.index:
            for m in train.loc[i, "mutant"].split(":"):
                if m == "":
                    pos = 0
                else:
                    pos = int(m[1:-1])
                if pos not in all_muts:
                    all_muts[pos] = [i]
                else:
                    all_muts[pos].append(i)
        print(all_muts.keys())
        used_idxs = set()
        for pos in all_muts:
            at_fold = pos % folds
            idxs = all_muts[pos]
            idxs = set(idxs)
            split[at_fold][1].extend(list(idxs - used_idxs))
            for fold in range(folds):
                if fold != at_fold:
                    split[fold][0].extend(list(idxs - used_idxs))
            used_idxs |= idxs

        for train, test in split:
            yield train, test

    def _contig_splits(self):
        """
        Generate contiguous position-based splits.
        taken from https://github.com/luwei0917/BindingGYM/training/main.py for reproducibility
        DMS_file_for_LLM merges the multi-chain mutant representation into single : separated mutations with
        offsets for the individual chains. E.g If chain 1 has length 10 and there is a mutation on chain 2 at position 5,
        the new mutation for chain2 will be at position 15 (after concatenating the chains).

        :yield: Tuples of (train_indices, test_indices) for each split.
        """
        logger.info(
            "Generating contiguous position-based splits for single-mutations")
        folds = self.num_splits

        cutoff = self.config.data.single_mutant_cutoff

        # TODO: change to self reliant method to compute this for consistency
        train = DMS_file_for_LLM(pd.read_csv(
            self.config.benchmark.data), focus=True)
        split = [[[], []] for _ in range(folds)]
        train = train.loc[
            train["mutant"].apply(lambda x: len(x.split(":")) < 2)
        ].reset_index(drop=True)
        if train.shape[0] < cutoff:
            logger.warning(
                f"Trying to generate contig splits with less than {cutoff} single mutants"
            )
            # TODO: Maybe shoudl raise
            return
        logger.info(
            f"Made it through cutoff with {train.shape[0]} single mutants")
        # print(train)
        all_muts = {}
        for i in train.index:
            for m in train.loc[i, "mutant"].split(":"):
                if m == "":
                    pos = 0
                else:
                    pos = int(m[1:-1])
                if pos not in all_muts:
                    all_muts[pos] = [i]
                else:
                    all_muts[pos].append(i)
        print(all_muts.keys())
        fold_count = [0 for _ in range(folds)]
        used_idxs = set()
        sorted_muts = sorted(all_muts.items(), key=lambda x: x[0])
        at_fold = 0
        used_count = 0
        for i in range(len(sorted_muts)):
            pos, idxs = sorted_muts[i]
            idxs = set(idxs)
            split[at_fold][1].extend(list(idxs - used_idxs))
            fold_count[at_fold] = len(split[at_fold][1])
            for fold in range(folds):
                if fold != at_fold:
                    split[fold][0].extend(list(idxs - used_idxs))
            used_idxs |= idxs
            if (
                fold_count[at_fold] >= (
                    train.shape[0] - used_count) / (folds - at_fold)
                and at_fold < folds - 1
            ):
                used_count += fold_count[at_fold]
                at_fold += 1

        for train, test in split:
            yield train, test

    def get_pdb_mutations(self):
        """
        Get the pdb mutations for the dataset.
        """
        def safe_eval(x):
            """Safely evaluate string representations of Python objects."""
            if isinstance(x, (dict, list)):
                return x
            try:
                return ast.literal_eval(x)
            except (ValueError, SyntaxError):
                # Fallback to eval for cases where literal_eval fails
                # This handles cases where the data was processed with eval() in DMS_file_for_LLM
                try:
                    return eval(x)
                except Exception as e:
                    logger.warning(
                        f"Failed to parse mutant_pdb value: {x}, error: {e}")
                    return {}

        return self.data["mutant_pdb"].apply(safe_eval)

    def get_leq_mutations(self, n):
        """
        Get the indices of the sequences with less or equal to n mutations
        NOTE: biologically the consensus is the wild-type sequence in most papers
        :param n: the number of mutations
        :return: the indices of the sequences with less or equal to n mutations
        """

        return self.dist_consensus[self.dist_consensus <= n].index

    def get_gt_mutations(self, n):
        """
        Get the indices of the sequences with more than n mutations

        :param n: the number of mutations
        :return: the indices of the sequences with more than n mutations
        """

        return self.dist_consensus[self.dist_consensus > n].index

    def get_representative(self):
        return self.representative

    def set_cutoff(self, cutoff):
        """
        :param cutoff: The cutoff value to use. If "auto", set to the median of the target values.
        """

        if cutoff == "auto":
            # set cutoff to the median of target values
            self.cutoff = self.data[self.label_col].median()
        else:
            self.cutoff = cutoff

    def get_mutated_region(self):
        """
        Get the range of mutated positions.

        :return: The range of mutated positions.
        """

        return range(self.mut_pos.index.min(), self.mut_pos.index.max())

    def get_mutated_positions(self):
        """
        Get the list of mutated positions.

        :return: The list of mutated positions.
        """

        return self.mut_pos.index.tolist()

    def _truncate_sequences(self):
        """Truncate sequences according to config."""
        if hasattr(self.config.benchmark, "truncate"):
            raise NotImplementedError(
                "Truncation not implemented for multiple chains")
            logger.debug(
                f"Truncating sequences for {self.config.benchmark.name}")
            # iterate over all keys in truncate
            for key in self.config.benchmark.truncate.keys():
                start = self.config.benchmark.truncate[key].start
                end = self.config.benchmark.truncate[key].end
                region = self.config.benchmark.truncate[key].region

            # Log original sequence lengths
            orig_lengths = self.data["sequence"].str.len()
            logger.debug(
                f"Original sequence lengths: min={orig_lengths.min()}, max={orig_lengths.max()}"
            )

            if hasattr(self.config.benchmark.truncate, "start") and hasattr(
                self.config.benchmark.truncate, "end"
            ):
                start = self.config.benchmark.truncate.start
                end = self.config.benchmark.truncate.end

                # Validate truncation parameters
                if start < 0:
                    raise ValueError(
                        f"Truncation start index {start} cannot be negative"
                    )
                if end <= start:
                    raise ValueError(
                        f"Truncation end index {end} must be greater than start index {start}"
                    )
                if start >= orig_lengths.min():
                    raise ValueError(
                        f"Truncation start index {start} is larger than shortest sequence length {orig_lengths.min()}"
                    )

                # Truncate from start to end
                self.data["sequence"] = self.data["sequence"].str[start:end]

            else:
                # Handle single-ended truncation
                if hasattr(self.config.benchmark.truncate, "start"):
                    start = self.config.benchmark.truncate.start
                    if start < 0:
                        raise ValueError(
                            f"Truncation start index {start} cannot be negative"
                        )
                    if start >= orig_lengths.min():
                        raise ValueError(
                            f"Truncation start index {start} is larger than shortest sequence length {orig_lengths.min()}"
                        )
                    self.data["sequence"] = self.data["sequence"].str[start:]

                elif hasattr(self.config.benchmark.truncate, "end"):
                    end = self.config.benchmark.truncate.end
                    if end <= 0:
                        raise ValueError(
                            f"Truncation end index {end} must be positive")
                    self.data["sequence"] = self.data["sequence"].str[:end]

            # Verify results
            new_lengths = self.data["sequence"].str.len()
            logger.debug(
                f"After truncation sequence lengths: min={new_lengths.min()}, max={new_lengths.max()}"
            )

            if new_lengths.min() == 0:
                raise ValueError(
                    "Truncation resulted in zero-length sequences")

    def cache_features(self, features):
        """
        Cache features.
        """
        self.features = features
        # NOTE:  always use the dataset defaults for cache key.
        # otherwise we recompute when we check if features exist for model/embedder for benchmark.
        self.cache_key = self._get_cache_key()
        self._save_to_cache(features, self.cache_key)

    def _get_cache_key(self, features=None):
        """Generate a minimal cache key based on dataset hash and model configuration."""
        features_to_hash = features if features is not None else self._init_sequences
        features_hash = hashlib.md5(
            str(features_to_hash).encode()).hexdigest()[:8]

        cache_components = {
            "benchmark": self.config.benchmark.name,
            "data_hash": features_hash,
            "model": self.config.model.name,
            "ft": self.config.model.feature_type,
        }

        if self.config.model.feature_type == "embedded":
            if not hasattr(self.config, "embedder") or not hasattr(
                self.config.embedder, "name"
            ):
                raise ValueError(
                    "Embedder configuration missing for feature_type 'embedded'"
                )
            cache_components["embedder_name"] = self.config.embedder.name

            if self.config.embedder.name == "protenc":
                if not hasattr(self.config.embedder, "model"):
                    raise ValueError(
                        "Protenc model configuration missing for protenc embedder."
                    )
                cache_components["embedder_model"] = self.config.embedder.model

        cache_str = "_".join(
            f"{k}-{v}" for k, v in sorted(cache_components.items()))
        logger.debug(f"Generated cache key components: {cache_components}")
        return cache_str

    def _get_cache_path(self, cache_key):
        """
        Get the path to the cache file for the given cache key.

        :param cache_key: The cache key
        :return: Path to the cache file
        """
        if self.cache_format == "hdf5":
            return os.path.join(self.cache_dir, f"{cache_key}.h5")
        elif self.cache_format == "lmdb":
            return os.path.join(self.cache_dir, f"{cache_key}")
        else:
            # Fallback to .pt for backward compatibility
            return os.path.join(self.cache_dir, f"{cache_key}.pt")

    def _save_to_cache(self, tokenized_inputs, cache_key):
        """
        Save tokenized inputs to cache using the configured format.

        :param tokenized_inputs: The tokenized inputs to save
        :param cache_key: The cache key
        """
        cache_path = self._get_cache_path(cache_key)
        logger.info(f"Start Saving to cache: {cache_path}")

        if self.cache_format == "hdf5":
            self._save_to_hdf5(tokenized_inputs, cache_path)
        elif self.cache_format == "lmdb":
            self._save_to_lmdb(tokenized_inputs, cache_path)
        else:
            # Fallback to .pt format
            self._save_to_pt(tokenized_inputs, cache_path)

        logger.info(f"Finished Saving to cache: {cache_path}")

    def _save_to_hdf5(self, tokenized_inputs, cache_path):
        """Save embeddings to HDF5 format."""
        with h5py.File(cache_path, 'w') as f:
            # Save metadata
            metadata_group = f.create_group('metadata')
            metadata_group.attrs['num_sequences'] = len(self.data)
            metadata_group.attrs['created_at'] = str(pd.Timestamp.now())
            metadata_group.attrs['cache_key'] = os.path.basename(cache_path)

            # Save embeddings
            embeddings_group = f.create_group('embeddings')

            if isinstance(tokenized_inputs, dict):
                for key, values in tokenized_inputs.items():
                    if isinstance(values, list) and len(values) > 0:
                        # Convert list to numpy array
                        embeddings_group.create_dataset(
                            key,
                            data=np.array(values),
                            compression='gzip',
                            compression_opts=9
                        )
                    else:
                        embeddings_group.create_dataset(key, data=values)
            else:
                # Handle non-dict inputs
                embeddings_group.create_dataset(
                    'embeddings',
                    data=tokenized_inputs,
                    compression='gzip',
                    compression_opts=9
                )

    def _save_to_lmdb(self, tokenized_inputs, cache_path):
        """Save embeddings to LMDB format."""
        os.makedirs(cache_path, exist_ok=True)

        env = lmdb.open(cache_path, map_size=2**30)  # 1GB map size
        with env.begin(write=True) as txn:
            # Save metadata
            metadata = {
                'num_sequences': len(self.data),
                'created_at': str(pd.Timestamp.now()),
                'cache_key': os.path.basename(cache_path)
            }
            txn.put(b'metadata', pickle.dumps(metadata))

            # Save embeddings
            if isinstance(tokenized_inputs, dict):
                for key, values in tokenized_inputs.items():
                    txn.put(key.encode(), pickle.dumps(values))
            else:
                txn.put(b'embeddings', pickle.dumps(tokenized_inputs))

        env.close()

    def _save_to_pt(self, tokenized_inputs, cache_path):
        """Save embeddings to .pt format (original method)."""
        torch.save(tokenized_inputs, cache_path)

        # Save metadata
        metadata_path = cache_path.replace(".pt", "_metadata.json")
        metadata = {
            "num_sequences": len(self.data),
            "created_at": str(pd.Timestamp.now()),
        }
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)

    def _load_from_cache(self, cache_key):
        """
        Load tokenized inputs from cache if they exist.

        :param cache_key: The cache key
        :return: The tokenized inputs if found in cache, None otherwise
        """
        cache_path = self._get_cache_path(cache_key)

        if not os.path.exists(cache_path) or self.config.data.recompute_features:
            return None

        logger.info(f"Loading from cache: {cache_path}")

        try:
            if self.cache_format == "hdf5":
                return self._load_from_hdf5(cache_path)
            elif self.cache_format == "lmdb":
                return self._load_from_lmdb(cache_path)
            else:
                # Fallback to .pt format
                return self._load_from_pt(cache_path)
        except Exception as e:
            logger.error(f"Error loading cache from {cache_path}: {e}")
            return None

    def _load_from_hdf5(self, cache_path):
        """Load embeddings from HDF5 format."""
        with h5py.File(cache_path, 'r') as f:
            embeddings_group = f['embeddings']

            # Convert HDF5 datasets back to dictionary format
            embeddings_dict = {}
            for key in embeddings_group.keys():
                embeddings_dict[key] = embeddings_group[key][:].tolist()

            return embeddings_dict

    def _load_from_lmdb(self, cache_path):
        """Load embeddings from LMDB format."""
        env = lmdb.open(cache_path, readonly=True)
        with env.begin() as txn:
            embeddings_dict = {}
            cursor = txn.cursor()
            for key, value in cursor:
                if key != b'metadata':  # Skip metadata
                    embeddings_dict[key.decode()] = pickle.loads(value)

        env.close()
        return embeddings_dict

    def _load_from_pt(self, cache_path):
        """Load embeddings from .pt format (original method)."""
        try:
            # First try loading with weights_only=True
            cached_dict = torch.load(cache_path, weights_only=True)
        except (pickle.UnpicklingError, RuntimeError):
            # If that fails, try loading without weights_only
            logger.debug("Loading cache with weights_only=False")
            cached_dict = torch.load(cache_path, weights_only=False)

        return cached_dict

    def _check_features_ready(self):
        """Check if features are ready in cache"""
        self.cache_key = self._get_cache_key(self._init_sequences)
        cache_path = self._get_cache_path(self.cache_key)

        if self.cache_format == "lmdb":
            # For LMDB, check if directory exists and has data
            return os.path.exists(cache_path) and os.path.isdir(cache_path)
        else:
            # For HDF5 and .pt, check if file exists
            return os.path.exists(cache_path)

    def wait_for_features_ready(self):
        """Wait for features to be ready in cache"""
        if self.config.model.feature_type == "sequence":
            logger.info(
                "Feature type is 'sequence', no cache waiting needed. Features are immediately available."
            )
            return

        self.seqs_list = self.data["sequence"].tolist()
        while not self._check_features_ready():
            logger.info("Waiting for features to be ready in cache")

            time.sleep(10)
        logger.info("Features are ready, continuing.")

    def prepare_features(self, sequences: List[str] = None):
        """Prepare features based on configured feature type: 'embedded' or 'sequence'."""
        logger.info(
            f"Preparing features for type: {self.config.model.feature_type}")

        self._validate_sequences()  # Ensure 'sequence' is ready

        if self.config.model.feature_type == "embedded":
            if (
                not hasattr(self.config, "embedder")
                or self.config.embedder.name != "protenc"
            ):
                # Currently only protenc is implemented as an external embedder
                raise NotImplementedError(
                    f"Embedder '{self.config.embedder.name if hasattr(self.config, 'embedder') else 'None'}' not supported for 'embedded' feature_type. Only 'protenc' is implemented."
                )

            sequences_for_embedding = self.data["sequence"].tolist()
            self.cache_key = self._get_cache_key(sequences_for_embedding)

            cached_features = self._load_from_cache(self.cache_key)
            if cached_features is not None:
                try:
                    cached_features_df = pd.DataFrame(cached_features)
                    self.features_loaded = True
                except Exception as e:
                    logger.error(
                        f"Error converting cached features to DataFrame: {e}")
                    logger.error(f"Cached features: {cached_features}")
                    logger.warning("Recompute features")
                    cached_features_df = None
            else:
                cached_features_df = None

            if cached_features_df is not None:
                logger.info(
                    f"Loaded embedded features from cache: {self.cache_key}")
                # Assign all feature columns in one go to avoid fragmentation
                self.data = pd.concat(
                    [
                        self.data.reset_index(drop=True),
                        cached_features_df.reset_index(drop=True),
                    ],
                    axis=1,
                )
                self.feature_cols = cached_features_df.columns.tolist()
            else:
                logger.info(
                    f"Generating and caching embedded features: {self.cache_key}"
                )
                # Assuming _prepare_protenc_embeddings is the main/only embedder method for now
                features_dict = self._prepare_protenc_embeddings()

                features_df = pd.DataFrame(features_dict)
                self.data = pd.concat(
                    [
                        self.data.reset_index(drop=True),
                        features_df.reset_index(drop=True),
                    ],
                    axis=1,
                )
                self.feature_cols = features_df.columns.tolist()
                self._save_to_cache(features_dict, self.cache_key)

            logger.debug(
                f"Feature columns for 'embedded' type: {self.feature_cols}")

        elif self.config.model.feature_type == "sequence":
            # For 'sequence' type, the 'feature' is the raw sequence itself.
            # 'sequence' column is already prepared by setup().
            self.feature_cols = ["sequence"]
            logger.info(
                "Feature type is 'sequence'. Model will handle tokenization/embedding."
            )
            # No extensive feature generation or caching of derived features by HAIPRData here.
        else:
            raise ValueError(
                f"Unsupported feature_type in config: {self.config.model.feature_type}. Must be 'embedded' or 'sequence'."
            )

    def _prepare_protenc_embeddings(self, sequences: List[str] = None):
        """Prepare embeddings using embedding manager. (Used for feature_type 'embedded')"""

        if sequences is None:
            sequences_to_embed = self.data["sequence"].tolist()
        else:
            sequences_to_embed = sequences

        logger.info(
            f"Generating embeddings for {len(sequences_to_embed)} sequences using embedding manager")

        # Use embedding manager to get embeddings
        embeddings_array = self.embedding_manager.get_embeddings(
            sequences_to_embed)

        # Convert to the expected format for backward compatibility
        average_embeddings = (
            self.config.embedder.average_sequence
            if hasattr(self.config.embedder, "average_sequence")
            else self.config.data.average_embeddings
        )

        if average_embeddings:
            # embeddings_array is shape (n_sequences, embedding_dim)
            embedding_dim = embeddings_array.shape[1]
            features_dict = {
                f"embed_{i}": embeddings_array[:, i].tolist()
                for i in range(embedding_dim)
            }
        else:
            # For per-residue embeddings, we need to handle this differently
            # The embedding manager currently returns averaged embeddings
            # We'll need to modify it to support per-residue if needed
            logger.warning(
                "Per-residue embeddings not yet supported with embedding manager, using averaged embeddings")
            embedding_dim = embeddings_array.shape[1]
            features_dict = {
                f"embed_{i}": embeddings_array[:, i].tolist()
                for i in range(embedding_dim)
            }

        logger.debug(f"Created {len(features_dict)} embedding features.")
        return features_dict


def prepare_features(config):
    data = HAIPRData(config)
    if data._check_features_ready():
        return
    else:
        logger.info(f"Generating features: {data.cache_key}")
        logger.info(config.embedder)
        data.prepare_features()
    return data


@hydra.main(config_path="conf", config_name="data", version_base=None)
def main(config: DictConfig = None):
    data = HAIPRData(config)
    if data._check_features_ready():
        return
    else:
        logger.info(f"Generating features: {data.cache_key}")
        logger.info(config.embedder)
    data.prepare_features()


if __name__ == "__main__":
    register_resolvers()
    main()
