from dataclasses import dataclass
import hashlib
import json
import logging
import os
import numpy as np
import pandas as pd

import torch
import ast
import time
import hydra
from omegaconf import DictConfig
from sklearn.model_selection import KFold, StratifiedKFold
from torch.utils.data import Dataset
from typing import List, Optional
import pickle
import h5py
import lmdb
from biotite.structure.io.pdb import PDBFile
from biotite.structure import chain_iter, get_residues
from biotite.structure.info import one_letter_code
import biotite.structure as struc

from torch_geometric.data import Data

from haipr.utils import AA_ALPHABET_WITH_EXTRAS
from haipr.utils.resolvers import register_resolvers
from haipr.embedding_manager import EmbeddingManager


logger = logging.getLogger(__name__)


def extract_calpha_coords(pdb_file: str, chain_id: str | None = None) -> np.ndarray:
    structure = PDBFile.read(pdb_file)
    atom_array = structure.get_structure()

    all_coords = []

    if hasattr(atom_array, "stack_depth") and atom_array.stack_depth() > 0:

        for model_idx in range(atom_array.stack_depth()):
            model_atoms = atom_array[model_idx]

            model_chain_ids = struc.get_chains(model_atoms)

            for ch_id in model_chain_ids:

                if chain_id is None or ch_id == chain_id:
                    chain_mask = model_atoms.chain_id == ch_id
                    chain_atoms = model_atoms[chain_mask]

                    ca_mask = chain_atoms.atom_name == "CA"
                    ca_atoms = chain_atoms[ca_mask]

                    if len(ca_atoms) > 0:
                        all_coords.append(ca_atoms.coord)
    else:

        model_atoms = atom_array
        model_chain_ids = struc.get_chains(model_atoms)

        for ch_id in model_chain_ids:

            if chain_id is None or ch_id == chain_id:
                chain_mask = model_atoms.chain_id == ch_id
                chain_atoms = model_atoms[chain_mask]

                ca_mask = chain_atoms.atom_name == "CA"
                ca_atoms = chain_atoms[ca_mask]

                if len(ca_atoms) > 0:
                    all_coords.append(ca_atoms.coord)

    if len(all_coords) == 0:
        if chain_id is not None:
            raise ValueError(
                f"No C-alpha atoms found for chain '{chain_id}' in PDB file")
        else:
            raise ValueError("No C-alpha atoms found in PDB file")

    if len(all_coords) == 1:
        coords = all_coords[0]
    else:

        coords = np.vstack(all_coords)

    return coords


def build_graph_from_coords(
    coords: np.ndarray,
    distance_cutoff: float = 10.0,
    node_features: Optional[torch.Tensor] = None
) -> Data:
    n_nodes = len(coords)

    coords_tensor = torch.tensor(coords, dtype=torch.float32)
    dist_matrix = torch.cdist(coords_tensor, coords_tensor)
    edge_mask = dist_matrix <= distance_cutoff
    edge_mask.fill_diagonal_(False)
    edge_index = edge_mask.nonzero(as_tuple=False).t().contiguous()

    if node_features is None:

        node_features = coords_tensor

    graph_data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=torch.tensor(dist_matrix.numpy(), dtype=torch.float32),
        pos=coords_tensor
    )

    return graph_data


class HAIPRData(Dataset):

    def __init__(
        self,
        config: DictConfig,
        data: pd.DataFrame | None = None,
    ):
        if data is None:
            self._data: pd.DataFrame = pd.read_csv(
                os.path.join(config.benchmark_dir, config.benchmark.data)
            )
        else:
            self._data = data

        self.config = config
        self.label_col = config.data.label_column
        self.num_splits = config.data.num_splits
        self.cache_key = None
        self.features_loaded = False

        self.chain_break_token = ""
        if hasattr(self.config.model, "chain_break_token") and self.config.model.chain_break_token:
            self.chain_break_token = self.config.model.chain_break_token
        elif hasattr(self.config.embedder, "chain_break_token") and self.config.embedder.chain_break_token:
            self.chain_break_token = self.config.embedder.chain_break_token
        elif hasattr(self.config.data, "chain_break_token") and self.config.data.chain_break_token:
            self.chain_break_token = self.config.data.chain_break_token

        if hasattr(config.benchmark, "pdb") and config.benchmark.pdb:
            self.pdb = os.path.join(config.benchmark_dir, config.benchmark.pdb)
            logger.debug(
                f"PDB path is set to {self.pdb}, using to embed sequences with protenc"
            )
        else:
            logger.debug("No PDB Supplied Using Sequence Only.")
            self.pdb = None

        if "mutant_pdb" in self._data.columns:
            self.mutant_pdb = self._data["mutant_pdb"].apply(eval)
        else:
            self.mutant_pdb = None
        self.mutants = self._data["mutant"].apply(eval)

        self.focus = getattr(config.data, "focus", True)
        self.focus_chains = self._compute_focus_chains()

        self._use_structure = False


        if hasattr(self.config.model, "use_structure") and self.config.model.use_structure:
            self._use_structure = True

        elif hasattr(self.config, "embedder") and self.config.embedder is not None:
            if hasattr(self.config.embedder, "use_structure") and self.config.embedder.use_structure:
                self._use_structure = True
        print("use_structure", self._use_structure)
        print("focus", self.focus)
        # Setup cache directory before building sequences (needed for filtered PDB creation)
        self.cache_format = getattr(config.data, "cache_format", "lmdb")

        self._setup_cache_dir()
        self._effective_pdb_path: Optional[str] = None
        if self._use_structure:
            self._effective_pdb_path = self._get_effective_pdb_path()
        print("effective_pdb_path", self._effective_pdb_path)
        self._build_sequence_column()
        if "sample_id" not in self._data.columns:
            self._data["sample_id"] = np.arange(len(self._data))

        self.features: np.ndarray | None = None
        self.embedder_instance = None
        self.feature_cols: list[str] = []
        self.embedding_manager = EmbeddingManager(config)
        self._active_idx = np.arange(len(self._data))
        self.test_idx_abs: np.ndarray | None = None

        self.setup()
        self.active_idx = np.arange(len(self.data))

    @property
    def active_idx(self):
        return self._active_idx

    @active_idx.setter
    def active_idx(self, value):
        self._active_idx = value

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, value):
        self._data = value

        self._active_idx = np.arange(len(self._data))
        logger.debug(
            f"Data modified, resetting active_idx to {len(self._active_idx)} indices"
        )

    def __len__(self):

        return len(self.active_idx)

    def __getitem__(self, idx):

        if self.config.model.feature_type == "embedded":

            if self.features is None:
                raise RuntimeError(
                    "Features array not set for 'embedded' type. Call data.prepare_features() first."
                )

            if isinstance(idx, slice):
                mapped_idx = self.active_idx[idx]
            else:
                idx_arr = np.asarray(idx)
                if idx_arr.ndim == 0:
                    mapped_idx = self.active_idx[int(idx_arr)]
                else:
                    mapped_idx = self.active_idx[idx_arr.astype(int)]

            labels_series = self.data[self.label_col].iloc[mapped_idx]
            if np.isscalar(labels_series) or getattr(labels_series, "shape", ()) == ():
                labels = labels_series
            else:

                labels = labels_series.values.tolist()

            features = self.features[mapped_idx]

            return torch.tensor(features, dtype=torch.float32), torch.tensor(
                labels,
                dtype=(
                    torch.float32 if self.config.model.num_classes == 0 else torch.long
                ),
            )

        if self.config.model.feature_type == "sequence":
            if isinstance(idx, (slice, list, np.ndarray)):
                raise NotImplementedError(
                    "Slicing/batch indexing directly in __getitem__ for 'sequence' type returning dict is not standard for collate_fn. DataLoader fetches items individually."
                )
            else:

                sequence_data = self.data["sequence_flat"].iloc[idx]
                label_data = self.data[self.label_col].iloc[idx]
                result = {
                    "sequence": sequence_data,
                    "labels": torch.tensor(
                        label_data,
                        dtype=(
                            torch.float32
                            if self.config.model.num_classes == 0
                            else torch.long
                        ),
                    ),
                }

                if "sample_id" in self.data.columns:
                    result["sample_id"] = self.data["sample_id"].iloc[idx]
                return result
        elif self.config.model.feature_type == "multi_chain_sequence":
            if not self.data["mutated_sequence"].iloc[0].startswith("{"):
                raise ValueError(
                    "Mutated sequence is not a multi-chain sequence format"
                )
            else:

                sequence_data = self.data["mutated_sequence"].iloc[idx]
                label_data = self.data[self.label_col].iloc[idx]
                result = {
                    "sequence": sequence_data,
                    "labels": torch.tensor(
                        label_data,
                        dtype=(
                            torch.float32
                            if self.config.model.num_classes == 0
                            else torch.long
                        ),
                    ),
                }

                if "sample_id" in self.data.columns:
                    result["sample_id"] = self.data["sample_id"].iloc[idx]
                return result
        else:
            raise ValueError(
                f"Unsupported feature_type in __getitem__: {self.config.model.feature_type}"
            )

    def _compute_focus_chains(self) -> list[str]:
        if self.mutants is None:
            return []

        focus_chains: set[str] = set()
        for mutants in self.mutants:
            if isinstance(mutants, dict):
                for cid, spec in mutants.items():
                    if spec not in ("", None):
                        focus_chains.add(cid)
        return list(focus_chains)

    def _build_sequence_column(self) -> None:
        if "mutated_sequence" not in self._data.columns:
            raise ValueError("Mutated sequence column not found")
        # Determine chain order: use PDB order if available, otherwise use dict keys
        chain_order: List[str] = []
        pdb_path_to_use = self._effective_pdb_path if self._effective_pdb_path is not None else self.pdb
        if pdb_path_to_use is not None and self._use_structure:
            try:
                chain_order = self._get_pdb_chain_order(pdb_path_to_use)
            except Exception:
                chain_order = []

        # Determine which chains to include based on focus
        if self.focus and self.focus_chains:
            chains_to_use = self.focus_chains
        else:
            # Use all chains (focus=False or no focus_chains)
            chains_to_use = None

        flat_sequences: list[str] = []
        for _, row in self._data.iterrows():
            mutated = row.get("mutated_sequence", None)

            # Parse string representation of dict if needed
            if isinstance(mutated, str):
                mutated = eval(mutated)

            if isinstance(mutated, dict):
                # Determine order: PDB order if available, otherwise sorted dict keys
                if chain_order:
                    # Use PDB order, but only include chains present in mutated dict
                    ordered_chains = [
                        cid for cid in chain_order if cid in mutated]
                else:
                    # No PDB, use sorted dict keys
                    ordered_chains = sorted(mutated.keys())

                # Filter by focus if needed
                if chains_to_use is not None:
                    ordered_chains = [
                        cid for cid in ordered_chains if cid in chains_to_use]

                seq = self.chain_break_token.join(str(mutated.get(cid, ""))
                                           for cid in ordered_chains)
            else:
                seq = str(mutated) if mutated is not None else ""
            flat_sequences.append(seq)

        self._data["sequence_flat"] = flat_sequences

    def normalize_labels(self):
        logger.debug("Normalizing labels between 0 and 1")
        logger.debug(
            f"Initial range: {self.data[self.label_col].min()} to {self.data[self.label_col].max()}"
        )
        self.data[self.label_col] = (
            self.data[self.label_col] - self.data[self.label_col].min()
        ) / (self.data[self.label_col].max() - self.data[self.label_col].min())
        logger.debug(
            f"Normalized range: {self.data[self.label_col].min()} to {self.data[self.label_col].max()}"
        )

    def setup(self):
        logger.info("Setting up HAIPRData")
        logger.debug(f"Data: \n{self.data.head()}")
        logger.debug(f"Label column: {self.label_col}")
        logger.debug(
            f"Sequence column: {self.config.benchmark.sequence_column}")

        if self.config.data.normalize_labels:
            self.normalize_labels()

        if hasattr(self.config.benchmark, "binary") and self.config.benchmark.binary:

            if self.data[self.label_col].nunique() == 2:
                logger.info("Label column is binary")

            else:
                raise ValueError("Label column is not binary")

        elif hasattr(self.config, "task"):
            if self.config.task == "classification":
                if (
                    not hasattr(self.config, "num_classes")
                    or self.config.num_classes <= 0
                ):
                    logger.warning(
                        "num_classes not specified for classification task, defaulting to 2"
                    )
                    self.config.num_classes = 2
                logger.info(
                    f"Discretizing labels into {self.config.num_classes} classes for classification"
                )
                self.discretize_labels(self.config.num_classes)
            elif self.config.task == "regression":
                if hasattr(self.config, "num_classes") and self.config.num_classes > 0:
                    logger.warning(
                        "num_classes specified for regression task, ignoring"
                    )
                    self.config.num_classes = 0

        self.remove_nan_labels()

        if self.config.data.remove_constant_labels:
            self.remove_constant_labels()

        logger.info(
            f"Sequence columns initialized: sequence ({len(self.data['sequence_flat'][0])} chars)"
        )

        if self.config.data.cut_top is not None:
            self._analyze_sequence_variants()
            self.prepare_features()
            self.cut_top_percent_by_label(self.config.data.cut_top)
        else:
            self._analyze_sequence_variants()
            self._setup_pos_wise_alphabet()

    def set_test_data(self, split_method: str, split_idx: int, num_splits: int = 5):
        logger.info(
            f"Setting test data for split method: {split_method} and split index: {split_idx}"
        )

        if self.test_idx_abs is not None:
            logger.warning(
                "Test data already set, resetting selection and regenerating splits"
            )
            self.test_idx_abs = None

        if self.num_splits != num_splits:
            logger.debug(
                f"Create test_data with {num_splits} splits using {split_method} method")
            previous_num_splits = self.num_splits
            self.num_splits = num_splits
            self.generate_splits(split_method)
            self.num_splits = previous_num_splits
        else:
            self.generate_splits(split_method)

        self.test_idx_abs = np.array(self.splits[split_idx][1])

        self.active_idx = self.active_idx[~np.isin(
            self.active_idx, self.test_idx_abs)]
        logger.info(
            f"Test size: {len(self.test_idx_abs)} | Train size: {len(self.active_idx)}"
        )

    def subsample_data(self, threshold=0):
        if threshold > 0 and len(self.active_idx) > threshold:
            logger.info(
                f"Subsampling active training pool to {threshold} samples")
            rng = np.random.RandomState(self.config.seed)
            selected = rng.choice(
                self.active_idx, size=threshold, replace=False)

            self.active_idx = np.sort(selected)
            logger.info(f"Remaining training pool: {len(self.active_idx)}")
        else:
            logger.info("No subsampling needed or threshold not exceeded")
        self.is_subsampled = True

    def _setup_cache_dir(self):
        if (
            hasattr(self.config.data, "cache_dir")
            and self.config.data.cache_dir is not None
        ):
            self.cache_dir = self.config.data.cache_dir
        else:
            self.cache_dir = os.path.join(
                os.path.expanduser("~"), ".cache", "haipr", "embeddings"
            )
        os.makedirs(self.cache_dir, exist_ok=True)
        logger.info(f"Using cache directory: {self.cache_dir}")

    def _setup_pos_wise_alphabet(self, full_alphabet: bool = False):
        alphabet = [str(aa) for aa in "ACDEFGHIKLMNPQRSTVWY"]

        if full_alphabet:
            self.pos_wise_alphabet = [alphabet for _ in self.seq_df.columns]
        else:
            self.pos_wise_alphabet = [


                sorted(self.seq_df[col].astype(str).unique())
                for col in self.seq_df.columns
            ]
        logger.debug(f"Pos-wise alphabet: {self.pos_wise_alphabet}")

    def _analyze_sequence_variants(self):
        logger.info("Analyzing sequence variants")
        logger.debug(f"data: \n{self.data.shape}")

        self.seqs_list = self.data["sequence_flat"].to_numpy()

        self.seq_df = pd.DataFrame([list(seq) for seq in self.seqs_list])
        logger.debug(f"Seq df: \n{self.seq_df.shape}")

        self.sfreqs = self.seq_df.apply(
            pd.Series.value_counts).fillna(0).astype(int)
        self.aa_per_pos = self.sfreqs.apply(
            lambda x: list(zip(x.index[x > 0], x[x > 0])), axis=0
        )
        logger.debug(f"Amino acids per position: \n{self.aa_per_pos}")

        self.mut_pos = self.aa_per_pos.loc[self.aa_per_pos.apply(len) > 1]
        logger.debug(
            f"Mutated positions: \n{self.aa_per_pos[self.mut_pos.index]}")
        self.consensus = self.sfreqs.idxmax(axis=0)

        if (
            hasattr(self.config.benchmark, "wild_type")
            and self.config.benchmark.wild_type
        ):
            self.wild_type = self.config.benchmark.wild_type
        else:
            self.wild_type = self.consensus

        if self.mut_pos.shape[0] == 0:
            raise ValueError("No mutations found in the dataset")

        self._compute_representative_sequence()
        self._compute_dist_to_consensus()
        self._set_min_max_mutations()
        self.num_mut_sites = self.mut_pos.shape[0]

    def _set_min_max_mutations(self):
        self.max_mutations = self.dist_consensus.max()
        self.min_mutations = self.dist_consensus.min()

    def _compute_dist_to_consensus(self):
        logger.info("Compute distannce to consesus")
        self.dist_consensus = (self.seq_df != self.consensus).sum(axis=1)

    def _compute_dist_to_best(self):
        logger.info("Compute distance to best")
        best_label_idx = np.argmax(self.data[self.label_col])
        seq_best = self.seq_df.iloc[best_label_idx]
        self.dist_best = (self.seq_df != seq_best).sum(axis=1)

    def _compute_representative_sequence(self):
        self.rep_id = (self.seq_df == self.consensus).sum(axis=1).idxmax()
        self.representative = self.seqs_list[self.rep_id]

    def _validate_sequences(self):
        if not hasattr(self, "data") or "sequence_flat" not in self.data.columns:
            raise ValueError("Sequences not found. Run setup() first.")

        unique_sequences = self.data["sequence_flat"].nunique()
        if unique_sequences == 1:
            raise ValueError(
                "Dataset contains only a single unique sequence variant")

        if unique_sequences < len(self.data):
            logger.warning(
                f"Found {len(self.data) - unique_sequences} duplicate sequences after processing"
            )

    def discretize_labels(self, num_classes):
        logger.info(f"Discretizing labels into {num_classes} classes")
        logger.debug(
            f"Labels before discretization: {len(self.data[self.label_col].unique())}"
        )
        logger.debug(f"Label column: {self.label_col}")

        if isinstance(self.label_col, list) and len(self.label_col) > 1:
            raise ValueError(
                "Discretization is not supported for multiple label columns"
            )

        self.remove_nan_labels()

        labels = pd.qcut(
            self.data[self.label_col],
            q=num_classes,
            labels=False,
            duplicates="drop",
        )

        labels = labels.astype(int)

        logger.debug(
            f"Labels after discretization: {len(labels.unique())} unique integer classes"
        )
        logger.debug(
            f"Class distribution: {pd.Series(labels).value_counts().sort_index()}"
        )

        self.data["label_class"] = labels
        self.label_col = "label_class"

        return labels

    def remove_nan_labels(self):
        logger.info("Removing rows with NaN labels")
        non_nan_mask = ~self.data[self.label_col].isna()
        self.active_idx = self.active_idx[non_nan_mask.iloc[self.active_idx]]
        logger.debug(f"Remaining Samples {len(self)}")

    def remove_constant_labels(self):
        logger.info("Removing constant labels")
        active_labels = self.data[self.label_col].iloc[self.active_idx]
        min_mask = active_labels == active_labels.min()
        max_mask = active_labels == active_labels.max()
        keep_mask = ~(min_mask | max_mask)
        if min_mask.sum() > 1:
            logger.debug(
                f"Removing {min_mask.sum()} samples with minimum label")
        if max_mask.sum() > 1:
            logger.debug(
                f"Removing {max_mask.sum()} samples with maximum label")
        self.active_idx = self.active_idx[keep_mask]
        logger.debug(f"Remaining Samples {len(self)}")

    def fill_missing_labels(self, minimum=True):
        logger.debug(
            f"Filling missing labels with min/max value {self.data[self.label_col].max()}"
        )

        if minimum:
            self.data[self.label_col] = self.data[self.label_col].fillna(
                self.data[self.label_col].min()
            )
        else:
            self.data[self.label_col] = self.data[self.label_col].fillna(
                self.data[self.label_col].max()
            )

    def cut_top_percent_by_label(self, percent=0.1, save_file=False):

        active_labels = self.data[self.label_col].iloc[self.active_idx]
        sorted_indices = self.active_idx[np.argsort(active_labels)]

        keep_count = int(len(sorted_indices) * (1 - percent))
        self.active_idx = sorted_indices[:keep_count]

        if save_file:
            cut_indices = sorted_indices[keep_count:]
            cut_data = self.data.iloc[cut_indices]
            cut_data.to_csv(
                f"{os.path.join(self.config.benchmark.data.split('/')[:-1])}/{self.config.benchmark.name}_cut_top_{percent}.csv",
                index=False,
            )
        logger.debug(f"Remaining Samples {len(self)}")

    def _initialize_one_hot_encoding(self):

        if not hasattr(self, "pos_wise_alphabet"):
            self._setup_pos_wise_alphabet()
        self.one_hot_enc = pd.get_dummies(self.seq_df)

    def get_num_splits(self):

        if not hasattr(self, "splits"):
            raise ValueError("Splits not generated")

        return len(self.splits)

    def _map_indices_to_active_view(self, train_indices, test_indices):
        train_abs = np.array(train_indices)
        test_abs = np.array(test_indices)

        if np.array_equal(
            np.sort(np.unique(np.concatenate([train_abs, test_abs]))),
            np.arange(len(self.active_idx)),
        ):
            train_abs = self.active_idx[train_abs]
            test_abs = self.active_idx[test_abs]

        train_abs = train_abs[np.isin(train_abs, self.active_idx)]
        test_abs = test_abs[np.isin(test_abs, self.active_idx)]

        return train_abs, test_abs

    def _subsample_train_indices(self, train_indices):
        subsample_train = getattr(self.config.data, 'subsample_train', 0)
        if subsample_train > 0 and len(train_indices) > subsample_train:
            logger.info(
                f"Subsampling train data to {subsample_train} samples")
            rng = np.random.RandomState(self.config.seed)
            selected = rng.choice(
                train_indices, size=subsample_train, replace=False)
            return selected
        return train_indices

    def _validate_split_size(self, train_indices, test_indices, split_idx):
        total_len = len(train_indices) + len(test_indices)
        if len(train_indices) < 0.08 * total_len:
            logger.warning(
                f"Skipping split {split_idx} due to small train set: Train: {len(train_indices)}, Test: {len(test_indices)}"
            )
            return False
        elif len(test_indices) < 0.08 * total_len:
            logger.warning(
                f"Skipping split {split_idx} due to small test set: Train: {len(train_indices)}, Test: {len(test_indices)}"
            )
            return False
        return True

    def generate_splits(self, split_method: str | None = None):
        splits = []
        logger.info("Generating splits")
        split_method = (
            self.config.data.split_method if split_method is None else split_method
        )
        logger.info(f"Generating splits using method: {split_method}")

        if split_method == "cv":
            split_generator = self._cv_splits()
        elif split_method in self.data.columns:
            logger.info("Data already has a split column, using it")
            split_generator = self._group_splits(split_method)
        elif split_method == "lomo":
            split_generator = self._lomo_splits()
        elif split_method == "lomo_hard":
            split_generator = self._lomo_hard_splits()
        elif split_method == "ood":
            split_generator = self._ood_splits()
        elif split_method == "skewed":
            split_generator = self._skewed_splits()
        elif split_method == "contig":
            split_generator = self._contig_splits()
        elif split_method == "modulo":
            split_generator = self._modulo_splits()
        else:
            raise ValueError(f"Invalid split method: {split_method}")

        for split_idx, (train, test) in enumerate(split_generator):
            logger.debug(f"Train: {len(train)}, Test: {len(test)}")

            if not self._validate_split_size(train, test, split_idx):
                continue

            train_abs, test_abs = self._map_indices_to_active_view(train, test)

            logger.debug(
                f"Split {split_idx}: train_abs={train_abs}, test_abs={test_abs}"
            )
            logger.debug(
                f"active_idx={self.active_idx}, len(active_idx)={len(self.active_idx)}"
            )

            train_abs = self._subsample_train_indices(train_abs)

            splits.append((train_abs, test_abs))

        if not splits:
            raise ValueError(
                f"No valid splits generated using method: {self.config.data.split_method}"
            )

        self.splits = splits
        logger.info(f"Generated {len(splits)} splits")

        header = f"{'Split':<6} {'Train Size':<12} {'Test Size':<10}"
        logger.info(header)
        logger.info('-' * len(header))
        for i, (train, test) in enumerate(splits):
            logger.info(f"{i:<6} {len(train):<12} {len(test):<10}")

        return splits

    def _ood_splits(self):

        labels_active = self.data[self.label_col].iloc[self.active_idx]

        bins, _ = pd.qcut(
            labels_active,
            q=self.num_splits,
            labels=False,
            retbins=True,
            duplicates="drop",
        )
        valid_indices = self.active_idx

        for b in np.unique(bins):
            bin_mask = bins == b

            train = valid_indices[~bin_mask]
            test = valid_indices[bin_mask]
            yield train, test

    def _skewed_splits(self, top_percent=0.4, sample_from_top_percent=0.5):

        if not hasattr(self, "label_col"):
            raise ValueError("Labels are not set")

        labels_active = self.data[self.label_col].iloc[self.active_idx]
        sorted_indices = np.argsort(labels_active).values

        sorted_indices = self.active_idx[sorted_indices]
        n_samples = len(sorted_indices)

        split_point = int(n_samples * (1 - top_percent))

        train_indices = np.array(sorted_indices[:split_point])
        test_pool_indices = np.array(sorted_indices[split_point:])

        logger.debug(f"Total samples: {n_samples}")
        logger.debug(f"Train size: {len(train_indices)}")
        logger.debug(f"Test pool size: {len(test_pool_indices)}")

        rng = np.random.RandomState(self.config.seed)

        for _ in range(self.num_splits):

            n_test = int(len(test_pool_indices) * sample_from_top_percent)
            test_mask = rng.choice(
                len(test_pool_indices), n_test, replace=False)

            test_indices = np.array(test_pool_indices[test_mask])

            remaining_test = test_pool_indices[
                ~np.isin(np.arange(len(test_pool_indices)), test_mask)
            ]
            current_train = np.concatenate([train_indices, remaining_test])

            yield current_train, test_indices

    def _group_splits(self, split_by: str):
        logger.info(f"Splitting based on column {split_by}")

        for group, indices in self.data.groupby(split_by).groups.items():
            train = self.data.index[~self.data.index.isin(indices)]
            test = self.data.index[indices]
            yield train, test

    def _cv_splits(self):

        labels_active = self.data[self.label_col].iloc[self.active_idx].to_numpy(
        )
        n_active = len(self.active_idx)

        is_continuous = np.issubdtype(labels_active.dtype, np.number)

        if is_continuous:
            splitter = KFold(
                n_splits=self.num_splits,
                shuffle=True,
                random_state=self.config.seed,
            )
            logger.info(
                f"Using KFold with {self.num_splits} splits for regression task"
            )
            for train_v, val_v in splitter.split(np.arange(n_active)):
                yield train_v, val_v
        else:
            splitter = StratifiedKFold(
                n_splits=self.num_splits,
                shuffle=True,
                random_state=self.config.seed,
            )
            logger.info(
                f"Using StratifiedKFold with {self.num_splits} splits for classification task"
            )
            for train_v, val_v in splitter.split(
                np.arange(n_active), labels_active
            ):
                yield train_v, val_v

    def _lomo_splits(self):
        logger.debug(f"Mutated positions: \n{self.mut_pos.index}")

        valid_indices = self.active_idx

        for pos in self.mut_pos.index:
            if pos not in self.consensus.index or pos not in self.seq_df.columns:
                logger.warning(f"Position {pos} not found in data, skipping")

                continue

            for aa_at_pos, _ in self.aa_per_pos[pos]:
                logger.debug(f"Position {pos} - Amino acid: {aa_at_pos}")

                pos_mask_full = self.seq_df[pos] == aa_at_pos

                pos_mask_active = pos_mask_full.iloc[valid_indices]
                train = valid_indices[pos_mask_active]
                test = valid_indices[~pos_mask_active]

                yield train, test

    def _lomo_hard_splits(self):
        splits = []
        for train, test in self._lomo_splits():
            splits.append((train, test))

        splits.sort(
            key=lambda x: np.abs(
                np.mean(self.data.loc[x[0]][self.label_col])
                - np.mean(self.data.loc[x[1]][self.label_col])
            )
        )

        for train, test in splits[: self.num_splits]:
            yield train, test

    def _modulo_splits(self):
        logger.info("Generating modulo position-based splits")
        folds = self.num_splits

        if "mutant" not in self._data.columns:
            raise ValueError("No 'mutant' column available for modulo splits")

        df = self._data.reset_index(drop=True)

        def _is_single_mutant(muts: object) -> bool:
            if not isinstance(muts, dict):
                return False
            count = 0
            for spec in muts.values():
                if isinstance(spec, str) and spec:
                    count += len([p for p in spec.split(":") if p])
            return count == 1

        train_df = df.loc[df["mutant"].apply(
            _is_single_mutant)].reset_index(drop=True)

        all_muts: dict[int, list[int]] = {}

        for i in train_df.index:
            row = train_df.loc[i]
            chain_ids = row["chain_id"]
            wt_seq_dic = row["wildtype_sequence"]
            mutants = row["mutant"]

            offsets: dict[str, int] = {}
            offset = 0
            for cid in chain_ids:
                offsets[cid] = offset
                if isinstance(wt_seq_dic, dict) and cid in wt_seq_dic:
                    offset += len(wt_seq_dic[cid])
                else:
                    offset += 0

            for cid, ms in mutants.items():
                if not isinstance(ms, str) or ms == "":
                    continue
                for m in ms.split(":"):
                    if not m:
                        continue
                    pos = int(m[1:-1]) + offsets.get(cid, 0)
                    all_muts.setdefault(pos, []).append(i)

        split: list[list[list[int]]] = [[[], []] for _ in range(folds)]
        used_idxs: set[int] = set()

        for pos, idxs in all_muts.items():
            at_fold = pos % folds
            idxs_set = set(idxs)
            new_idxs = list(idxs_set - used_idxs)
            split[at_fold][1].extend(new_idxs)
            for fold in range(folds):
                if fold != at_fold:
                    split[fold][0].extend(new_idxs)
            used_idxs |= idxs_set

        for train, test in split:
            yield train, test

    def _contig_splits(self):
        logger.info(
            "Generating contiguous position-based splits for single-mutations")
        folds = self.num_splits

        cutoff = self.config.data.single_mutant_cutoff

        if "mutant" not in self._data.columns:
            raise ValueError("No 'mutant' column available for contig splits")

        df = self._data.reset_index(drop=True)

        def _is_single_mutant(muts: object) -> bool:
            if not isinstance(muts, dict):
                return False
            count = 0
            for spec in muts.values():
                if isinstance(spec, str) and spec:
                    count += len([p for p in spec.split(":") if p])
            return count == 1

        train = df.loc[df["mutant"].apply(
            _is_single_mutant)].reset_index(drop=True)

        if train.shape[0] < cutoff:
            logger.warning(
                f"Trying to generate contig splits with less than {cutoff} single mutants"
            )

            return
        logger.info(
            f"Made it through cutoff with {train.shape[0]} single mutants")
        all_muts: dict[int, list[int]] = {}

        for i in train.index:
            row = train.loc[i]
            chain_ids = row["chain_id"]
            wt_seq_dic = row["wildtype_sequence"]
            mutants = row["mutant"]

            offsets: dict[str, int] = {}
            offset = 0
            for cid in chain_ids:
                offsets[cid] = offset
                if isinstance(wt_seq_dic, dict) and cid in wt_seq_dic:
                    offset += len(wt_seq_dic[cid])
                else:
                    offset += 0

            for cid, ms in mutants.items():
                if not isinstance(ms, str) or ms == "":
                    continue
                for m in ms.split(":"):
                    if not m:
                        continue
                    pos = int(m[1:-1]) + offsets.get(cid, 0)
                    all_muts.setdefault(pos, []).append(i)

        fold_count = [0 for _ in range(folds)]
        used_idxs = set()
        sorted_muts = sorted(all_muts.items(), key=lambda x: x[0])
        at_fold = 0
        used_count = 0
        for i in range(len(sorted_muts)):
            pos, idxs = sorted_muts[i]
            idxs = set(idxs)
            split[at_fold][1].extend(list(idxs - used_idxs))
            fold_count[at_fold] = len(split[at_fold][1])
            for fold in range(folds):
                if fold != at_fold:
                    split[fold][0].extend(list(idxs - used_idxs))
            used_idxs |= idxs
            if (
                fold_count[at_fold] >= (
                    train.shape[0] - used_count) / (folds - at_fold)
                and at_fold < folds - 1
            ):
                used_count += fold_count[at_fold]
                at_fold += 1

        for train, test in split:
            yield train, test

    def get_single_split(self, split_method="cv", split_idx=0):
        if not hasattr(self, "splits"):
            self.generate_splits()

        return self.splits[split_idx]

    def _get_pdb_chain_sequences_dict(self, indices: Optional[np.ndarray] = None) -> list[dict[str, str]]:
        """Returns per-chain sequences as dicts for each sample. Used internally by models that need per-chain sequences."""
        if not self._use_structure:
            raise ValueError(
                "_get_pdb_chain_sequences_dict() requires use_structure=True")

        if indices is None:
            row_indices = np.arange(len(self._data))
        else:
            row_indices = np.asarray(indices, dtype=int)

        pdb_chain_template = self._get_pdb_chain_template()
        pdb_muts = self.mutant_pdb
        pdb_path = self._effective_pdb_path
        if pdb_path is None:
            raise ValueError(
                "PDB path is not set. Cannot extract PDB sequences.")
        chain_order = self._get_pdb_chain_order(pdb_path)

        results = []
        for row_idx in row_indices:
            row_muts = pdb_muts.iloc[row_idx]
            if not isinstance(row_muts, dict):
                row_muts = {}

            chain_map = {}
            for cid in chain_order:
                if cid not in pdb_chain_template:
                    continue

                tmpl = pdb_chain_template[cid]
                res_ids = tmpl["res_ids"]
                res1 = list(tmpl["res1letter"])

                muts_spec = row_muts.get(cid, "")
                if isinstance(muts_spec, str) and muts_spec:
                    for mut in muts_spec.split(":"):
                        if not mut:
                            continue
                        res = mut[0]
                        mut_to = mut[-1]
                        pos = int(mut[1:-1])

                        if pos not in res_ids:
                            raise ValueError(
                                f"Residue id {pos} not found for chain {cid} "
                                f"in PDB {self.config.benchmark.name}"
                            )
                        if hasattr(res_ids, "tolist"):
                            ids_list = res_ids.tolist()
                        else:
                            ids_list = res_ids
                        local_idx = ids_list.index(pos)

                        if res != res1[local_idx]:
                            raise ValueError(
                                f"Residue mismatch in {self.config.benchmark.name} "
                                f"for chain {cid} at {pos}: expected {res}, "
                                f"got {res1[local_idx]}"
                            )
                        res1[local_idx] = mut_to

                chain_map[cid] = "".join(res1)

            results.append(chain_map)

        return results

    def _get_pdb_chain_template(self) -> dict[str, dict[str, object]]:
        pdb_path = self._effective_pdb_path
        if pdb_path is None:
            raise ValueError(
                "PDB path is not set. Please set the pdb path in the benchmark config file."
            )

        pdb_file = PDBFile.read(pdb_path)
        structure = pdb_file.get_structure()
        chain_data: dict[str, dict[str, object]] = {}
        chain_order = self._get_pdb_chain_order(pdb_path)

        for cid in chain_order:
            for chain_atoms in chain_iter(structure):
                if chain_atoms.chain_id[0] == cid:
                    ids, res3letter = get_residues(chain_atoms)
                    res1letter = [one_letter_code(res) for res in res3letter]
                    chain_data[cid] = {
                        "res_ids": ids,
                        "res1letter": res1letter,
                    }
                    break

        return chain_data

    def _get_effective_pdb_path(self) -> Optional[str]:
        """Returns filtered PDB path if focus=True and use_structure=True, else original PDB path."""
        if not self._use_structure:
            return None
        if self.focus and self.focus_chains is not None:
            return self._create_filtered_pdb(self.focus_chains)
        return self.pdb

    def _create_filtered_pdb(self, chains_to_keep: List[str]) -> str:
        """Creates filtered PDB file with only specified chains using biotite."""
        if self.pdb is None:
            raise ValueError(
                "PDB path is not set. Cannot create filtered PDB.")

        chain_list = "".join(sorted(chains_to_keep))
        filtered_pdb_path = os.path.join(
            self.cache_dir,
            f"filtered_pdb_{self.config.benchmark.name}_{chain_list}.pdb"
        )
        print("filtered_pdb_path", filtered_pdb_path)

        if os.path.exists(filtered_pdb_path):
            logger.debug(f"Reusing existing filtered PDB: {filtered_pdb_path}")
            return filtered_pdb_path

        logger.info(
            f"Creating filtered PDB with chains {chains_to_keep}: {filtered_pdb_path}")

        pdb_file = PDBFile.read(self.pdb)
        structure = pdb_file.get_structure()

        filtered_atoms_list = []
        for chain_atoms in chain_iter(structure):
            cid = chain_atoms.chain_id[0]
            if cid in chains_to_keep:
                filtered_atoms_list.append(chain_atoms)

        if not filtered_atoms_list:
            raise ValueError(
                f"No chains found in PDB matching {chains_to_keep}")

        if len(filtered_atoms_list) == 1:
            filtered_structure = filtered_atoms_list[0]
        else:
            filtered_structure = struc.stack(filtered_atoms_list)

        filtered_pdb_file = PDBFile()
        filtered_pdb_file.set_structure(filtered_structure)
        filtered_pdb_file.write(filtered_pdb_path)

        logger.info(f"Created filtered PDB: {filtered_pdb_path}")
        return filtered_pdb_path

    def _get_pdb_chain_order(self, pdb_path: str) -> List[str]:
        """Returns chain IDs in order they appear in PDB file (not sorted)."""
        pdb_file = PDBFile.read(pdb_path)
        structure = pdb_file.get_structure()
        chain_order = []
        seen_chains = set()
        for chain_atoms in chain_iter(structure):
            cid = chain_atoms.chain_id[0]
            if cid not in seen_chains:
                chain_order.append(cid)
                seen_chains.add(cid)

        return chain_order

    def _get_pdb_sequences_from_template(self) -> np.ndarray:
        """Extracts PDB sequences using template and mutations. Returns flattened sequences in PDB chain order, respecting focus."""
        if self.pdb is None:
            raise ValueError(
                "PDB path is not set. Cannot extract PDB sequences.")

        pdb_chain_template = self._get_pdb_chain_template()
        pdb_muts = self.mutant_pdb
        pdb_path = self._effective_pdb_path
        if pdb_path is None:
            raise ValueError(
                "PDB path is not set. Cannot extract PDB sequences.")
        chain_order = self._get_pdb_chain_order(pdb_path)

        # Filter chain_order by focus if needed
        if self.focus and self.focus_chains:
            chain_order = [
                cid for cid in chain_order if cid in self.focus_chains]

        flat_seqs = []
        for row_idx in range(len(self._data)):
            row_muts = pdb_muts.iloc[row_idx]
            if not isinstance(row_muts, dict):
                row_muts = {}

            chain_map = {}
            for cid in chain_order:
                if cid not in pdb_chain_template:
                    continue

                tmpl = pdb_chain_template[cid]
                res_ids = tmpl["res_ids"]
                res1 = list(tmpl["res1letter"])

                muts_spec = row_muts.get(cid, "")
                if isinstance(muts_spec, str) and muts_spec:
                    for mut in muts_spec.split(":"):
                        if not mut:
                            continue
                        res = mut[0]
                        mut_to = mut[-1]
                        pos = int(mut[1:-1])

                        if pos not in res_ids:
                            raise ValueError(
                                f"Residue id {pos} not found for chain {cid} "
                                f"in PDB {self.config.benchmark.name}"
                            )
                        if hasattr(res_ids, "tolist"):
                            ids_list = res_ids.tolist()
                        else:
                            ids_list = res_ids
                        try:
                            local_idx = ids_list.index(pos)
                        except ValueError:
                            raise ValueError(
                                f"Residue id {pos} not found in converted ids list "
                                f"for chain {cid}"
                            )

                        if res != res1[local_idx]:
                            raise ValueError(
                                f"Residue mismatch in {self.config.benchmark.name} "
                                f"for chain {cid} at {pos}: expected {res}, "
                                f"got {res1[local_idx]}"
                            )
                        res1[local_idx] = mut_to

                chain_map[cid] = "".join(res1)

            flat_seq = self.chain_break_token.join(str(chain_map[cid])
                                                   for cid in chain_order if cid in chain_map)
            flat_seqs.append(flat_seq)

        return np.asarray(flat_seqs, dtype=object)

    def get_sequences(self) -> List[str]:
        if not self._use_structure:
            return self.data["sequence_flat"].to_list()
        else:
            return self._get_pdb_sequences_from_template().tolist()

    def get_pdb_coords(self) -> dict[str, np.ndarray]:
        """
        Return PDB C-alpha coordinates for the dataset, handling chain selection.

        - If ``self.focus`` is False, coordinates for all chains present in the PDB
          are returned.
        - If ``self.focus`` is True, only chains that appear in ``self.focus_chains``
          (i.e. chains with mutations anywhere in the dataset) are returned.
        """
        pdb_path = self._get_effective_pdb_path()
        if pdb_path is None:
            raise ValueError(
                "PDB path is not set. Please set the pdb path in the benchmark config file."
            )

        pdb_file = PDBFile.read(pdb_path)
        structure = pdb_file.get_structure()

        chain_order = self._get_pdb_chain_order(pdb_path)

        if self.focus and self.focus_chains:
            chains_to_use = [
                cid for cid in chain_order if cid in self.focus_chains]
        else:
            chains_to_use = chain_order

        chain_map: dict[str, np.ndarray] = {}

        for cid in chains_to_use:
            coords = extract_calpha_coords(pdb_path, cid)
            chain_map[cid] = coords

        if not chain_map:
            raise ValueError(
                f"No coordinates could be extracted for chains {chains_to_use} from PDB {pdb_path}")

        return chain_map

    def get_labels(self, indices: Optional[np.ndarray] = None):
        if indices is None:
            return self.data[self.label_col].to_numpy()
        else:
            return self.data[self.label_col].iloc[indices].to_numpy()

    def get_test_data(self):
        if self.test_idx_abs is None:
            raise ValueError("Test data not set. Call set_test_data() first.")
        return self.data.iloc[self.test_idx_abs]

    def get_train_data(self):
        return self.data.iloc[self.active_idx]

    def get_leq_mutations(self, n):

        return self.dist_consensus[self.dist_consensus <= n].index

    def get_gt_mutations(self, n):

        return self.dist_consensus[self.dist_consensus > n].index

    def get_mutations(self):
        return self._data["mutant"].tolist()

    def get_representative(self):
        return self.representative

    def set_cutoff(self, cutoff):

        if cutoff == "auto":

            self.cutoff = self.data[self.label_col].median()
        else:
            self.cutoff = cutoff

    def get_mutated_region(self):

        return range(self.mut_pos.index.min(), self.mut_pos.index.max())

    def get_mutated_positions(self):

        return self.mut_pos.index.tolist()

    def cache_features(self, features):

        if isinstance(features, np.ndarray):
            self.features = features
        else:
            self.features = np.asarray(features)

        self.cache_key = self._get_cache_key()
        self._save_to_cache(self.features, self.cache_key)

    def _get_cache_key(self, features=None):
        features_to_hash = features if features is not None else self._init_sequences
        features_hash = hashlib.md5(
            str(features_to_hash).encode()).hexdigest()[:8]

        cache_components = {
            "benchmark": self.config.benchmark.name,
            "data_hash": features_hash,
            "ft": self.config.model.feature_type,
        }

        if self.config.model.feature_type == "embedded":
            if not hasattr(self.config, "embedder") or not hasattr(
                self.config.embedder, "name"
            ):
                raise ValueError(
                    "Embedder configuration missing for feature_type 'embedded'"
                )
            cache_components["embedder_name"] = self.config.embedder.name

            if self.config.embedder.name == "protenc":
                if not hasattr(self.config.embedder, "model"):
                    raise ValueError(
                        "Protenc model configuration missing for protenc embedder."
                    )
                cache_components["embedder_model"] = self.config.embedder.model
                if not self.config.embedder.average_sequence:
                    cache_components["per"] = "res"
                if self.config.embedder.use_structure:
                    cache_components["use"] = "structure"
        cache_str = "_".join(
            f"{k}-{v}" for k, v in sorted(cache_components.items()))
        logger.debug(f"Generated cache key components: {cache_components}")
        return cache_str

    def _get_cache_path(self, cache_key):
        if self.cache_format == "hdf5":
            return os.path.join(self.cache_dir, f"{cache_key}.h5")
        elif self.cache_format == "lmdb":
            return os.path.join(self.cache_dir, f"{cache_key}")
        else:

            return os.path.join(self.cache_dir, f"{cache_key}.pt")

    def _save_to_cache(self, tokenized_inputs, cache_key):
        cache_path = self._get_cache_path(cache_key)
        logger.info(f"Start Saving to cache: {cache_path}")

        if self.cache_format == "hdf5":
            self._save_to_hdf5(tokenized_inputs, cache_path)
        elif self.cache_format == "lmdb":
            self._save_to_lmdb(tokenized_inputs, cache_path)
        else:

            self._save_to_pt(tokenized_inputs, cache_path)

        logger.info(f"Finished Saving to cache: {cache_path}")

    def _save_to_hdf5(self, tokenized_inputs, cache_path):
        with h5py.File(cache_path, "w") as f:

            metadata_group = f.create_group("metadata")
            metadata_group.attrs["num_sequences"] = len(self.data)
            metadata_group.attrs["created_at"] = str(pd.Timestamp.now())
            metadata_group.attrs["cache_key"] = os.path.basename(cache_path)

            embeddings_group = f.create_group("embeddings")

            if isinstance(tokenized_inputs, np.ndarray):
                embeddings_array = tokenized_inputs
            elif isinstance(tokenized_inputs, dict):
                embeddings_array = np.array(list(tokenized_inputs.values())).T
            else:
                embeddings_array = np.asarray(tokenized_inputs)

            embeddings_group.create_dataset(
                "embeddings",
                data=embeddings_array,
                compression="gzip",
                compression_opts=9,
            )

    def _save_to_lmdb(self, tokenized_inputs, cache_path):
        os.makedirs(cache_path, exist_ok=True)

        with lmdb.open(cache_path, map_size=100 * 1024 ** 3) as env:
            with env.begin(write=True) as txn:
                metadata = {
                    "num_sequences": len(self.data),
                    "created_at": str(pd.Timestamp.now()),
                    "cache_key": os.path.basename(cache_path),
                }
                txn.put(b"metadata", pickle.dumps(metadata))

                if isinstance(tokenized_inputs, np.ndarray):
                    embeddings_array = tokenized_inputs
                elif isinstance(tokenized_inputs, dict):
                    embeddings_array = np.array(
                        list(tokenized_inputs.values())
                    ).T
                else:
                    embeddings_array = np.asarray(tokenized_inputs)

                # Store embeddings per sequence to avoid LMDB value-size limits.
                # Each key is the stringified row index: b"0", b"1", ...
                if embeddings_array.ndim == 1:
                    embeddings_array = embeddings_array.reshape(-1, 1)

                for idx in range(embeddings_array.shape[0]):
                    key = str(idx).encode("ascii")
                    value = pickle.dumps(
                        embeddings_array[idx], protocol=pickle.HIGHEST_PROTOCOL
                    )
                    txn.put(key, value)

    def _save_to_pt(self, tokenized_inputs, cache_path):

        if isinstance(tokenized_inputs, np.ndarray):
            embeddings_array = tokenized_inputs
        elif isinstance(tokenized_inputs, dict):

            embeddings_array = np.array(list(tokenized_inputs.values())).T
        else:
            embeddings_array = np.asarray(tokenized_inputs)

        torch.save(embeddings_array, cache_path)

        metadata_path = cache_path.replace(".pt", "_metadata.json")
        metadata = {
            "num_sequences": len(self.data),
            "created_at": str(pd.Timestamp.now()),
        }
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)

    def _load_from_cache(self, cache_key):
        cache_path = self._get_cache_path(cache_key)

        if not os.path.exists(cache_path) or self.config.data.recompute_features:
            return None

        logger.info(f"Loading from cache: {cache_path}")

        try:
            if self.cache_format == "hdf5":
                features = self._load_from_hdf5(cache_path)
            elif self.cache_format == "lmdb":
                features = self._load_from_lmdb(cache_path)
            else:

                features = self._load_from_pt(cache_path)

            if np.all(features == features[0]) or np.isnan(features).any():
                raise ValueError(
                    f"Features are all the same or contain NaNs, recomputing features")
                return None
            return features
        except Exception as e:
            logger.error(f"Error loading cache from {cache_path}: {e}")
            return None

    def _load_from_hdf5(self, cache_path) -> np.ndarray:
        with h5py.File(cache_path, "r") as f:
            embeddings_group = f["embeddings"]

            if "embeddings" in embeddings_group.keys():
                return np.array(embeddings_group["embeddings"][:])
            else:

                embeddings_dict = {}
                for key in embeddings_group.keys():
                    embeddings_dict[key] = embeddings_group[key][:]

                return np.array(list(embeddings_dict.values())).T

    def _load_from_lmdb(self, cache_path) -> np.ndarray:

        with lmdb.open(cache_path, readonly=True) as env:
            with env.begin() as txn:

                # Backward compatibility: older caches store a single "embeddings" entry.
                legacy_blob = txn.get(b"embeddings")
                if legacy_blob is not None:
                    embeddings_array = pickle.loads(legacy_blob)
                    return np.asarray(embeddings_array)

                # New format: per-sequence entries keyed by row index (b"0", b"1", ...).
                embeddings_dict: dict[str, object] = {}
                cursor = txn.cursor()
                for key, value in cursor:
                    if key == b"metadata":
                        continue
                    embeddings_dict[key.decode()] = pickle.loads(value)

                if not embeddings_dict:
                    raise ValueError(
                        f"No embeddings found in LMDB cache at {cache_path}"
                    )

                keys = list(embeddings_dict.keys())

                # If all keys are numeric, assume per-sequence layout and sort by index.
                if all(k.isdigit() for k in keys):
                    sorted_indices = sorted(int(k) for k in keys)
                    rows = [embeddings_dict[str(i)] for i in sorted_indices]
                    first = rows[0]
                    if isinstance(first, np.ndarray):
                        rows_array = [np.asarray(r) for r in rows]
                        return np.stack(rows_array, axis=0)
                    else:
                        return np.asarray(rows)

                # Fallback for very old dict-of-arrays layout: transpose into (N, D).
                return np.array(list(embeddings_dict.values())).T

    def _load_from_pt(self, cache_path) -> np.ndarray:
        try:
            cached_data = torch.load(cache_path, weights_only=True)
        except (pickle.UnpicklingError, RuntimeError):

            logger.debug("Loading cache with weights_only=False")
            cached_data = torch.load(cache_path, weights_only=False)

        if isinstance(cached_data, np.ndarray):
            return cached_data
        elif isinstance(cached_data, torch.Tensor):
            return cached_data.numpy()
        elif isinstance(cached_data, dict):

            return np.array(list(cached_data.values())).T
        else:
            return np.asarray(cached_data)

    def _check_features_ready(self):

        if self.config.model.feature_type == "sequence":
            return True

        if self.config.model.feature_type == "embedded":
            sequences_for_embedding = self.get_sequences()
            self.cache_key = self._get_cache_key(sequences_for_embedding)
        else:

            self.cache_key = self._get_cache_key(self._init_sequences)

        cache_path = self._get_cache_path(self.cache_key)

        if self.cache_format == "lmdb":
            return os.path.exists(cache_path) and os.path.isdir(cache_path)
        else:

            return os.path.exists(cache_path)

    def wait_for_features_ready(self):
        if self.config.model.feature_type == "sequence":
            logger.info(
                "Feature type is 'sequence', no cache waiting needed. Features are immediately available."
            )
            return

        while not self._check_features_ready():
            logger.info("Waiting for features to be ready in cache")
            time.sleep(10)
        logger.info("Features are ready, continuing.")

    def prepare_features(self, sequences: List[str] | None = None):
        logger.info(
            f"Preparing features for type: {self.config.model.feature_type}")
        logger.info(f"Cache format: {self.cache_format}")

        self._validate_sequences()

        if self.config.model.feature_type == "embedded":
            if (
                not hasattr(self.config, "embedder")
                or self.config.embedder.name != "protenc"
            ):

                raise NotImplementedError(
                    f"Embedder '{self.config.embedder.name if hasattr(self.config, 'embedder') else 'None'}' not supported for 'embedded' feature_type. Only 'protenc' is implemented."
                )

            sequences_for_embedding = self.get_sequences()
            self.cache_key = self._get_cache_key(sequences_for_embedding)

            logger.info(f"Loading features from cache: {self.cache_key}")
            cached_features = self._load_from_cache(self.cache_key)
            if cached_features is not None:
                logger.info(
                    f"loaded{len(cached_features)} features from cache")
                try:

                    if isinstance(cached_features, np.ndarray):
                        cached_features_array = cached_features
                    elif isinstance(cached_features, dict):

                        cached_features_array = np.array(
                            list(cached_features.values())).T
                    else:
                        cached_features_array = np.asarray(cached_features)

                    if len(cached_features_array) != len(self.data):
                        raise ValueError(
                            f"Cached features length {len(cached_features_array)} does not match dataset length {len(self.data)}"
                        )

                    self.features = cached_features_array
                    self.features_loaded = True
                    logger.info(
                        f"Loaded embedded features from cache: {self.cache_key}, shape: {self.features.shape}"
                    )
                except Exception as e:
                    logger.error(
                        f"Error loading cached features: {e}")
                    logger.error(
                        f"Cached features type: {type(cached_features)}")
                    logger.warning("Recomputing features")
                    cached_features_array = None
            else:
                cached_features_array = None

            if cached_features_array is None:
                logger.info(
                    f"Generating and caching embedded features: {self.cache_key}"
                )
                features_array = self._prepare_protenc_embeddings()
                self.features_loaded = True
                if np.all(features_array == features_array[0]):
                    raise ValueError(f"Features are all the same")
                    return None

                if len(features_array) != len(self.data):
                    raise ValueError(
                        f"Generated features length {len(features_array)} does not match dataset length {len(self.data)}"
                    )

                self.features = features_array
                self._save_to_cache(features_array, self.cache_key)
                logger.info(
                    f"Generated embedded features, shape: {self.features.shape}"
                )

            if self.features is not None:
                if self.features.ndim == 2:
                    self.feature_cols = [
                        f"feature_{i}" for i in range(self.features.shape[1])]
                else:
                    self.feature_cols = ["features"]
                logger.debug(
                    f"Feature array shape: {self.features.shape}, feature_cols metadata: {len(self.feature_cols)} columns"
                )

        elif self.config.model.feature_type == "sequence":
            self.feature_cols = ["sequence_flat"]
            logger.info(
                "Feature type is 'sequence'. Model will handle tokenization/embedding."
            )
        else:
            raise ValueError(
                f"Unsupported feature_type in config: {self.config.model.feature_type}. Must be 'embedded' or 'sequence'."
            )

    def _prepare_protenc_embeddings(self, sequences: List[str] | None = None):

        if sequences is None:
            sequences_to_embed = self.get_sequences()
        else:
            sequences_to_embed = sequences

        logger.info(
            f"Generating embeddings for {len(sequences_to_embed)} sequences using embedding manager"
        )

        pdb_path = self._get_effective_pdb_path() if self._use_structure else None
        print("pdb_path", pdb_path)
        embeddings_array = self.embedding_manager.get_embeddings(
            sequences_to_embed, pdb_path
        )
        print("embeddings_array", embeddings_array)
        logger.debug(
            f"Created embeddings array with shape: {embeddings_array.shape}")
        return embeddings_array


def prepare_features(config):
    data = HAIPRData(config)
    if data._check_features_ready():
        return
    else:
        logger.info(f"Generating features: {data.cache_key}")
        logger.info(config.embedder)
        data.prepare_features()
    return data


@hydra.main(config_path="conf", config_name="data", version_base=None)
def main(config: DictConfig = None):
    data = HAIPRData(config)
    if data._check_features_ready():
        return
    else:
        logger.info(f"Generating features: {data.cache_key}")
        logger.info(config.embedder)
    data.prepare_features()


if __name__ == "__main__":
    register_resolvers()
    main()
