import logging
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Callable, Union
from functools import lru_cache
import pickle

from smart_open import open
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm


from coarsebind_public.mol_encoder.mol_enc_inferrer import MolEncInferrer
from coarsebind_public.mol_encoder.io_schema import MolEncIOSchema
from coarsebind_public.esm2.esm2_inference import ESM2Infer
from coarsebind_public.esm2.io_schema import ESM2IOSchema

from coarsebind_public.coarsebind.model.models.coarsebind_affinity import (
    CoarseBindAffinity,
)
from coarsebind_public.coarsebind.io_schema import (
    IOSchemaCoarseBind,
    OutputSchemaDistoPotency,
)
from coarsebind_public.coarsebind.utils import (
    pad_to_max,
    global_potency_feats_collate_fn,
    disto_potency_feats_collate_fn_2,
)
from coarsebind_public.coarsebind.io_transforms import (
    to_numpy,
)
from coarsebind_public.coarsebind.coarse_cofold import (
    CoarseCofoldInf,
)
from coarsebind_public.coarsebind.writer import (
    PairformerWriter,
    LitePairformerWriter,
    AffinityPairformerWriter,
    BaseWriter,
)
from coarsebind_public.coarsebind.model.models.coarsebind_pf import (
    CoarseBindPF,
    const,
)
from coarsebind_public.mol_encoder.util.s3.s3_io import cache_read


logger = logging.getLogger(__name__)


def tokenize_from_io(io: IOSchemaCoarseBind) -> Dict[str, torch.Tensor]:
    """
    Convert IOSchemaCoarseBind into tokenized features for model input.

    This function prepares protein-ligand complex features in the format expected
    by the Allegro2 pairformer model. It combines ligand atoms and protein
    residues into a unified token representation with appropriate IDs, types, and
    embeddings.

    Tokenization Strategy:
        - Ligand atoms: entity_id=0, asym_id=0, tokens marked as UNK/NONPOLYMER
        - Protein chains: entity_id assigned by unique sequence, incrementing asym_id per chain
        - Token embeddings: Ligand uses MolEnc embeddings, proteins use ESM2 embeddings

    Args:
        io (IOSchemaCoarseBind): Input schema containing:
            - mol_enc_io: Ligand embeddings, bonds, and atom features
            - sequence: List of protein sequence strings
            - esm2_io: List of ESM2IOSchema objects, one for each protein chain
            - res_num: Residue numbering (optional)

    Returns:
        Dict[str, torch.Tensor]: Tokenized features including:
            - asym_id: Asymmetric unit ID (chain ID)
            - sym_id: Symmetric unit ID (for multiple copies of same sequence)
            - entity_id: Entity ID (grouped by unique sequence)
            - residue_index: Residue/atom index within chain
            - res_type: One-hot encoded residue types (num_tokens x vocab_size)
            - mol_type: Molecule type (PROTEIN, NONPOLYMER, etc.)
            - res_num: Canonical residue numbering
            - res_name: Residue/atom names (numpy array)
            - token_bonds: Adjacency matrix for bonds (num_tokens x num_tokens x 1)
            - input_token_embeds: Concatenated MolEnc + ESM2 embeddings
            - token_index: Sequential token indices
            - token_h_mask: Hydrogen mask (all True)
            - token_pad_mask: Padding mask (all True for valid tokens)
            - pocket_feature: Pocket conditioning features (one-hot, currently UNSPECIFIED)

    Note:
        - Ligand is always entity_id=0, residue_index=0
        - Protein chains start from entity_id=1, residue_index=1
        - Multiple copies of the same sequence get same entity_id, different sym_id/asym_id
    """

    # ligand

    num_lig_atoms = len(io.mol_enc_io.e3nn_embed)

    lig_res_type = torch.tensor([const.token_ids["UNK"]] * num_lig_atoms, dtype=torch.long)

    lig_res_name = np.array(["UNK"] * num_lig_atoms)

    lig_mol_type = torch.tensor(
        [
            const.chain_type_ids["NONPOLYMER"],
        ]
        * num_lig_atoms,
        dtype=torch.long,
    )

    lig_entity_id = torch.tensor(
        [0] * num_lig_atoms,
        dtype=torch.long,
    )

    lig_bonds = torch.from_numpy(io.mol_enc_io.bonds)

    lig_sym_id = torch.zeros(num_lig_atoms, dtype=torch.long)
    lig_asym_id = torch.zeros(num_lig_atoms, dtype=torch.long)
    lig_residue_idx = torch.zeros(num_lig_atoms, dtype=torch.long)

    # sequences

    entities_count = {}
    seq_to_esm2 = {}
    for i, seq in enumerate(io.sequence):
        entities_count[seq] = entities_count.get(seq, 0) + 1
        seq_to_esm2[seq] = torch.from_numpy(io.esm2_io[i].embed[1:-1])

    seq_entity_id = []
    seq_sym_id = []
    seq_asym_id = []
    seq_res_type = []
    seq_res_name = []
    seq_res_mol_type = []
    esm2_embed = []
    seq_residue_index = []

    curr_chain_id = 1  # start from 1 since 0 is for ligand
    curr_residue_index = 1  # start from 1 since 0 is for ligand
    for entity_id, (seq, count) in enumerate(entities_count.items()):
        entity_id += 1  # start from 1 since 0 is for ligand
        for _sym_id in range(count):
            seq_entity_id.extend([entity_id] * len(seq))
            seq_sym_id.extend([_sym_id] * len(seq))
            seq_asym_id.extend([curr_chain_id] * len(seq))
            curr_chain_id += 1

            seq_residue_index.extend(list(range(curr_residue_index, curr_residue_index + len(seq))))
            curr_residue_index += len(seq)

            seq_res_type.extend([const.token_ids[const.prot_letter_to_token[aa]] for aa in seq])
            seq_res_name.extend([const.prot_letter_to_token[aa] for aa in seq])
            seq_res_mol_type.extend([const.chain_type_ids["PROTEIN"]] * len(seq))
            esm2_embed.append(seq_to_esm2[seq])

    seq_entity_id = torch.tensor(seq_entity_id, dtype=torch.long)
    seq_sym_id = torch.tensor(seq_sym_id, dtype=torch.long)
    seq_asym_id = torch.tensor(seq_asym_id, dtype=torch.long)
    seq_res_type = torch.tensor(seq_res_type, dtype=torch.long)
    seq_res_mol_type = torch.tensor(seq_res_mol_type, dtype=torch.long)
    esm2_embed = torch.cat(esm2_embed, dim=0)
    seq_residue_index = torch.tensor(seq_residue_index, dtype=torch.long)
    seq_res_name = np.array(seq_res_name)

    # res_num associates a canonical residue number to each residue in the sequence
    # ligand is 0
    if io.res_num is not None:
        res_num = torch.cat(
            [torch.tensor([0] * num_lig_atoms)]
            + [torch.tensor(_res_num) for _res_num in io.res_num],
            dim=0,
        )
    else:
        res_num = torch.cat(
            [torch.tensor([0] * num_lig_atoms)]
            + [torch.arange(1, len(io.sequence[i]) + 1) for i in range(len(io.sequence))],
            dim=0,
        )

    res_type = torch.cat([lig_res_type, seq_res_type], dim=0)
    res_type = torch.nn.functional.one_hot(res_type, num_classes=const.num_tokens)

    tokenized_io = {
        "asym_id": torch.cat([lig_asym_id, seq_asym_id], dim=0),
        "sym_id": torch.cat([lig_sym_id, seq_sym_id], dim=0),
        "entity_id": torch.cat([lig_entity_id, seq_entity_id], dim=0),
        "residue_index": torch.cat([lig_residue_idx, seq_residue_index], dim=0),
        "res_type": res_type,
        "mol_type": torch.cat([lig_mol_type, seq_res_mol_type], dim=0),
        "res_num": res_num,
        "res_name": np.concatenate([lig_res_name, seq_res_name], axis=0),
    }

    num_tokens = len(tokenized_io["res_type"])

    # bonds
    token_bonds = torch.zeros((num_tokens, num_tokens), dtype=torch.long)
    # assign 1 to bonds from ligand
    token_bonds[lig_bonds[:, 0], lig_bonds[:, 1]] = 1
    token_bonds[lig_bonds[:, 1], lig_bonds[:, 0]] = 1

    tokenized_io["token_bonds"] = token_bonds.unsqueeze(-1)

    # embeds
    input_token_embeds = torch.zeros(num_tokens, esm2_embed.shape[-1])

    lig_embed = torch.from_numpy(io.mol_enc_io.e3nn_embed)
    input_token_embeds[
        :num_lig_atoms,
        : lig_embed.shape[-1],
    ] = lig_embed

    input_token_embeds[num_lig_atoms:] = esm2_embed

    tokenized_io["input_token_embeds"] = input_token_embeds

    tokenized_io["token_index"] = torch.arange(0, num_tokens, dtype=torch.long)

    tokenized_io["token_h_mask"] = torch.ones(num_tokens, dtype=torch.bool)
    tokenized_io["token_pad_mask"] = torch.ones(num_tokens, dtype=torch.bool)

    # TODO: user input pocket conditioning
    pocket_feature = torch.from_numpy(
        np.zeros(num_tokens) + const.pocket_contact_info["UNSPECIFIED"]
    ).long()
    pocket_feature = torch.nn.functional.one_hot(
        pocket_feature, num_classes=len(const.pocket_contact_info)
    )

    tokenized_io["pocket_feature"] = pocket_feature

    tokenized_io["cyclic_period"] = torch.zeros(num_tokens, dtype=torch.long)

    # Apply binding site cropping if binding_site_res_num is provided
    if io.binding_site_res_num is not None:
        # Create a mask for tokens to keep
        keep_mask = torch.zeros(num_tokens, dtype=torch.bool)

        # Keep all ligand atoms (asym_id == 0)
        keep_mask[:num_lig_atoms] = True

        # For protein chains (asym_id >= 1), mask based on binding_site_res_num
        # binding_site_res_num is a list of lists, one per chain
        asym_ids = tokenized_io["asym_id"]
        res_nums = tokenized_io["res_num"]

        for chain_idx, chain_binding_sites in enumerate(io.binding_site_res_num):
            # chain_idx corresponds to the index in io.sequence
            # asym_id for this chain is chain_idx + 1 (since asym_id 0 is ligand)
            target_asym_id = chain_idx + 1

            # Find all tokens with this asym_id and check if their res_num is in the binding site
            chain_mask = asym_ids == target_asym_id
            keep_mask[chain_mask] = torch.isin(
                res_num[chain_mask], torch.tensor(chain_binding_sites)
            )

        # Apply mask to all 1D features
        tokenized_io["asym_id"] = tokenized_io["asym_id"][keep_mask]
        tokenized_io["sym_id"] = tokenized_io["sym_id"][keep_mask]
        tokenized_io["entity_id"] = tokenized_io["entity_id"][keep_mask]
        tokenized_io["residue_index"] = tokenized_io["residue_index"][keep_mask]
        tokenized_io["res_type"] = tokenized_io["res_type"][keep_mask]
        tokenized_io["mol_type"] = tokenized_io["mol_type"][keep_mask]
        tokenized_io["res_num"] = tokenized_io["res_num"][keep_mask]
        tokenized_io["res_name"] = tokenized_io["res_name"][keep_mask]
        tokenized_io["input_token_embeds"] = tokenized_io["input_token_embeds"][keep_mask]
        tokenized_io["token_h_mask"] = tokenized_io["token_h_mask"][keep_mask]
        tokenized_io["token_pad_mask"] = tokenized_io["token_pad_mask"][keep_mask]
        tokenized_io["pocket_feature"] = tokenized_io["pocket_feature"][keep_mask]
        tokenized_io["cyclic_period"] = tokenized_io["cyclic_period"][keep_mask]

        # Apply mask to 2D features (token_bonds needs both dimensions cropped)
        tokenized_io["token_bonds"] = tokenized_io["token_bonds"][keep_mask][:, keep_mask]

        # Update token_index to be sequential starting from 0
        num_tokens_cropped = keep_mask.sum().item()
        tokenized_io["token_index"] = torch.arange(0, num_tokens_cropped, dtype=torch.long)

    return tokenized_io


def collate(data: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
    """Collate the data.

    Parameters
    ----------
    data : list[dict[str, Tensor]]
        The data to collate.

    Returns
    -------
    dict[str, Tensor]
        The collated data.

    """
    # Get the keys
    keys = data[0].keys()

    # Collate the data
    collated = {}
    for key in keys:
        values = [d[key] for d in data]

        if isinstance(values[0], torch.Tensor):
            # Check if all have the same shape
            shape = values[0].shape
            if not all(v.shape == shape for v in values):
                values, _ = pad_to_max(values, 0)
            else:
                values = torch.stack(values, dim=0)

        # Stack the values
        collated[key] = values

    return collated


def batch_to_device(batch: Dict[str, torch.Tensor], device: str) -> Dict[str, torch.Tensor]:
    """
    Moves a batch of tensors to the specified device.

    Args:
        batch (Dict[str, torch.Tensor]): A dictionary of tensors.
        device (str): The target device (e.g., 'cpu', 'cuda:0').

    Returns:
        Dict[str, torch.Tensor]: A new dictionary with tensors moved to the device.
    """
    new_batch = {}
    for k in batch:

        if batch[k] is None:
            continue
        elif not isinstance(batch[k], torch.Tensor):
            new_batch[k] = batch[k]
        else:
            new_batch[k] = batch[k].to(device, non_blocking=True)

    return new_batch


class InferencePipe(object):
    """
    An iterator that creates batches from a list of data and applies a transformation.
    """

    def __init__(
        self,
        data: List[IOSchemaCoarseBind],
        original_idxs: List[int],
        batch_size: int,
        xform_routine: Callable,
    ):
        """
        Initializes the InferencePipe.

        Args:
            data (List[IOSchema]): The input data to be processed.
            batch_size (int): The size of each batch.
            xform_routine (Callable): A function to apply to each batch.
        """
        self.data = data
        self.original_idxs = original_idxs
        self.xform_routine = xform_routine
        self.batch_size = batch_size

        assert len(self.data) == len(self.original_idxs)

    def __iter__(self):
        """
        Yields transformed batches of data.
        """
        batch = []
        batch_idxs = []
        for i, row in enumerate(self.data):

            batch.append(row)
            batch_idxs.append(self.original_idxs[i])

            if len(batch) == int(self.batch_size):
                yield batch_idxs, self.xform_routine(batch)
                batch = []
                batch_idxs = []

        # last batch
        if len(batch) > 0:
            yield batch_idxs, self.xform_routine(batch)


class CoarseBindInferrer(object):

    def __init__(
        self,
        model_doc_potency: str = None,
        model_doc_pairformer: str = None,
        device: str = "cuda",
        pairformer_writer: BaseWriter = None,
        run_coarse_cofold: bool = False,
        batch_size: int = 1,
        max_tokens: int = 512,
        precision=torch.bfloat16,
        use_kernels: bool = False,
        mol_enc_uri: str = None,
        mol_enc_smiles_tokenizer_path: str = None,
        mol_enc_graph_tokenizer_path: str = None,
    ):
        """
        Initialize the DistoInferrer.

        Args:
            model_doc_potency (str, optional): S3 path to potency model artifact. If provided,
                the pairformer path will be automatically loaded from the potency model's
                base_pairformer attribute unless model_doc_pairformer is explicitly set.
            model_doc_pairformer (str, optional): S3 path to pairformer model artifact.
                Can be a .ckpt file (Allegro2) or a directory (legacy Distogram_Pairformer).
                If None and model_doc_potency is provided, it will be auto-loaded.
            device (str): Compute device. Default: 'cuda'
            pairformer_writer (BaseWriter, optional): Custom writer for processing pairformer output.
                If None, defaults to PairformerWriter for standard distogram output.
            batch_size (int): Number of complexes to process per batch. Default: 1
            max_tokens (int): Maximum total tokens (protein residues + ligand atoms).
                Complexes exceeding this will be skipped. Default: 512
            precision: PyTorch dtype for model inference (torch.float32, torch.bfloat16, etc.).
                Default: torch.bfloat16 for efficiency
            use_kernels (bool): Whether to use optimized CUDA kernels in model.
                Default: False (may cause issues with some CUDA versions)
            mol_enc_uri (str): Path or S3 URI to MolEnc model artifacts.
                Required for ligand embedding generation.
            mol_enc_smiles_tokenizer_path (str): Path to MolEnc SMILES tokenizer vocabulary JSON file.
                Required for MolEnc ligand embedding generation.
            mol_enc_graph_tokenizer_path (str): Path to MolEnc graph tokenizer vocabulary PKL file.
                Required for MolEnc ligand embedding generation.

        Note:
            Call init_models() after instantiation to load all model components.
        """

        self.model_doc_potency = model_doc_potency
        self.model_doc_pairformer = model_doc_pairformer
        self.device = device
        self.batch_size = batch_size
        self.max_tokens = max_tokens
        self.precision = precision
        self.use_kernels = use_kernels
        self.run_coarse_cofold = run_coarse_cofold
        self.mol_enc_uri = mol_enc_uri
        self.mol_enc_smiles_tokenizer_path = mol_enc_smiles_tokenizer_path
        self.mol_enc_graph_tokenizer_path = mol_enc_graph_tokenizer_path

        # Use provided writer or default to PairformerWriter
        self.pairformer_writer = pairformer_writer or (
            PairformerWriter() if self.run_coarse_cofold else LitePairformerWriter()
        )
        self.affinity_writer = AffinityPairformerWriter()

        # deployment info (set by from_mlops)
        self.deployment_id: str = None
        self.artifact_id: str = None
        self.model_task: str = None
        self.model_source_env: str = None
        self.model_unique_metadata: Dict[str, Any] = None

        # set in init_model
        self.mol_enc_inferrer: MolEncInferrer = None
        self.model_pairformer: CoarseBindPF = None
        self.esm2_inferrer: ESM2Infer = None
        self.model_potency: Optional[CoarseBindAffinity] = None

    def init_models(self):
        """
        Load and initialize all model components for inference.

        This method performs the following initialization steps:

        1. **Potency Model Loading (if provided)**:
           - Dynamically loads the potency model class from artifact metadata
           - Supports any potency model variant (CoarseBind_5_Disto, etc.)
           - Automatically extracts the base_pairformer path from the potency model

        2. **Pairformer Model Loading**:
           - If .ckpt file: Loads Allegro2 model (composition pattern, no Lightning)
           - If directory: Loads legacy Distogram_Pairformer model
           - Applies specified precision (bfloat16, float32, etc.)

        3. **Feature Extractors**:
           - MolEncInferrer: Unified ligand embeddings (both SMILES and 3D conformer)
           - ESM2Infer: Protein sequence embeddings

        Note:
            - Model artifacts are loaded from S3 using smart_open
            - All models are set to eval() mode
            - Models are moved to the specified device (CPU/GPU)
            - This method must be called before predict()

        Raises:
            FileNotFoundError: If model artifact paths are invalid
            ImportError: If model class cannot be dynamically loaded
        """

        # Load potency model if provided
        if self.model_doc_potency:

            self.model_potency = CoarseBindAffinity.from_artifact(self.model_doc_potency)

            self.model_potency.eval()
            self.model_potency.to(self.device)

            # Use the pairformer path from potency model
            self.model_doc_pairformer = self.model_potency.base_pairformer
            logger.info(f"Using pairformer from potency model: {self.model_doc_pairformer}")

        # Load MolEnc inferrer (generates both smiles_embed and e3nn_embed)
        if self.model_doc_pairformer.endswith(".ckpt"):
            if not self.mol_enc_uri:
                raise ValueError(
                    "mol_enc_uri is required. Please provide path to MolEnc model artifacts."
                )

            if not self.mol_enc_smiles_tokenizer_path or not self.mol_enc_graph_tokenizer_path:
                raise ValueError(
                    "mol_enc_smiles_tokenizer_path and mol_enc_graph_tokenizer_path are required. "
                    "Please provide paths to tokenizer vocabulary files."
                )

            self.mol_enc_inferrer = MolEncInferrer(
                model_uri=self.mol_enc_uri,
                device=self.device,
                smiles_tokenizer_path=self.mol_enc_smiles_tokenizer_path,
                graph_tokenizer_path=self.mol_enc_graph_tokenizer_path,
            )
            self.mol_enc_inferrer._load_model()

            with cache_read(self.model_doc_pairformer, "rb") as f:

                self.model_pairformer = CoarseBindPF.load_from_checkpoint(
                    f,
                    map_location="cpu",
                    use_kernels=self.use_kernels,
                )

            self.model_pairformer = self.model_pairformer.to(self.device)
            self.model_pairformer.eval()

            self.model_pairformer.default_recycling_steps = 3

        else:

            raise ValueError("Unknown model type")

        esm2_model_name = "esm2_t33_650M_UR50D"
        esm2_rep_layer = 33
        esm2_inferrer = ESM2Infer(
            model_name=esm2_model_name,
            rep_layer=esm2_rep_layer,
            device=self.device,
        )
        self.esm2_inferrer = esm2_inferrer

    def initial_checks_and_featurization(self, row: IOSchemaCoarseBind) -> IOSchemaCoarseBind:

        # Set deployment info from class variables
        row.deployment_id = self.deployment_id
        row.artifact_id = self.artifact_id
        row.model_task = self.model_task
        row.model_source_env = self.model_source_env
        row.artifact_s3 = self.model_doc_potency or self.model_doc_pairformer

        row.canon_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(row.smiles))

        if row.sequence is None and not row.complex_id is None:
            raise ValueError("Cannot provide complex_id without sequence")

        elif row.sequence is not None and row.complex_id is not None:
            raise ValueError("Cannot provide both sequence and complex_id")

        elif row.sequence is not None:

            if isinstance(row.sequence, str):
                row.sequence = [row.sequence]

        elif row.sequence is None and row.complex_id is None:
            raise ValueError("Either sequence or complex_id must be provided")

        # if res_num not provided, set to 1-indexed
        if row.res_num is None:
            row.res_num = []
            for seq in row.sequence:
                row.res_num.append(list(range(1, len(seq) + 1)))
        # check if res_num is a list of lists
        if not all(isinstance(r, list) for r in row.res_num):
            row.res_num = [row.res_num]

        # check that res_num lens match seq lengths
        for i, seq in enumerate(row.sequence):
            if len(row.res_num[i]) != len(seq):
                raise ValueError(
                    f"Length mismatch between sequence and res_num for row {i}: "
                    f"{len(seq)} != {len(row.res_num[i])}"
                )

        # if binding_site_res_num, make sure list of lists
        if row.binding_site_res_num is not None:
            if not all(isinstance(r, list) for r in row.binding_site_res_num):
                row.binding_site_res_num = [row.binding_site_res_num]

        # if binding_site_res_num make sure it is at least 10 in the binding site
        if row.binding_site_res_num is not None:
            for binding_sites in row.binding_site_res_num:
                if len(binding_sites) < 10:
                    raise ValueError(
                        f"Binding site residue numbers must include at least 10 residues."
                    )

        return row

    def get_esm2(self, io_schema: List[IOSchemaCoarseBind]) -> Dict[str, ESM2IOSchema]:
        # esm2 embed
        all_seqs = set()
        for row in io_schema:

            if row.error:
                continue

            for seq in row.sequence:
                all_seqs.add(seq)

        if None in all_seqs:
            raise ValueError("sequence is not set for some rows. ")

        esm2_output = self.esm2_inferrer.predict(list(all_seqs))

        seq_to_io_schema: Dict[str, ESM2IOSchema] = {}
        for esm2_io in esm2_output:
            if esm2_io.error:
                raise ValueError(
                    f"ESM2 featurization error for sequence: {esm2_io.sequence}. Stopping to prevent issues."
                )
            else:
                seq_to_io_schema[esm2_io.sequence] = esm2_io

        return seq_to_io_schema

    def pre_xform(
        self,
        io_schema: List[IOSchemaCoarseBind],
        max_tokens: int,
    ) -> List[IOSchemaCoarseBind]:
        """
        Preprocesses a list of IOSchema objects before model inference.

        This method performs several preprocessing steps:
        1. Canonicalizes SMILES strings.
        2. Fetches protein sequences and ESM2 embeddings if not provided.
        3. Computes ligand graph embeddings (mol_enc_graph) if not provided.
        4. Validates inputs, checking for missing data and token length limits.
        Errors encountered during preprocessing are flagged in the IOSchema objects.

        Args:
            input_data (List[IOSchema]): The list of input data objects.
            max_tokens (int): The maximum allowed number of tokens (protein residues +
                ligand atoms).

        Returns:
            List[IOSchema]: The list of preprocessed IOSchema objects.
        """

        for i, row in enumerate(io_schema):
            try:
                row = self.initial_checks_and_featurization(row)
            except Exception as e:
                row.error = True
                row.error_msg = f"Error in initial checks: {e}"
                print(f"Error processing row {i}: {e}")

            io_schema[i] = row

        # check if all failed
        if all([row.error for row in io_schema]):
            raise ValueError("All rows failed initial checks.")

        # esm2 features
        compute_esm2 = any([row.esm2_io is None for row in io_schema])
        if compute_esm2:

            seq_to_io_schema = self.get_esm2(io_schema)
            for i, row in enumerate(io_schema):

                if row.error:
                    continue

                row.esm2_io = [seq_to_io_schema[seq] for seq in row.sequence]
                io_schema[i] = row

        # MolEnc features (both smiles_embed and e3nn_embed)
        compute_mol_enc = self.mol_enc_inferrer is not None and any(
            [row.mol_enc_io is None for row in io_schema]
        )
        if compute_mol_enc:
            canon_smiles = [row.canon_smiles for row in io_schema]
            mol_enc_results = self.mol_enc_inferrer.predict(
                canon_smiles,
                batch_size=128,
                crop_hydrogens_flag=True,
                prog_bar=False,
            )
            for i, row in enumerate(io_schema):
                mol_enc_result = mol_enc_results[i]
                if mol_enc_result.error:
                    row.error = True
                    row.error_msg = "mol_enc embedding failed"
                else:
                    row.mol_enc_io = mol_enc_result
                io_schema[i] = row

        for row in io_schema:
            if row.error:
                continue

            # check length
            num_lig_atoms = 0
            if row.mol_enc_io is not None:
                num_lig_atoms = len(row.mol_enc_io.e3nn_embed)

            if len(row.sequence) + num_lig_atoms > max_tokens:
                row.error = True
                row.error_msg += f"\n number of tokens > max_tokens={max_tokens}"
                continue

        return io_schema

    def xform_routine(self, ios: List[IOSchemaCoarseBind]) -> Dict[str, torch.Tensor]:

        tokenized_ios = [tokenize_from_io(io) for io in ios]

        disto_features = collate(tokenized_ios)

        global_potency_feats = global_potency_feats_collate_fn(ios)

        features = {
            **disto_features,
            **global_potency_feats,
        }

        return features

    def predict(
        self,
        input_data: List[IOSchemaCoarseBind],
        # TODO: load as defaults ?
        pairformer_model_forward_kwargs: Dict[str, Any] = {
            "recycling_steps": 3,
        },
        potency_model_forward_kwargs: Dict[str, Any] = {
            "epinet_samples": 100,
        },
        coarse_cofold_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List[IOSchemaCoarseBind]:
        """
        Run end-to-end inference pipeline for protein-ligand binding prediction.

        This is the main inference entry point that orchestrates:
        1. Input validation and preprocessing
        2. Feature extraction (ESM2 for proteins, MolEnc for ligands)
        3. Batching and sorting by sequence length
        4. Pairformer prediction (distance bins, representations)
        5. Potency prediction (optional, if potency model loaded)
        6. Coarse co-folding (optional, for 3D structure generation)

        **Pipeline Details**:

        - **Preprocessing**: Validates inputs, generates embeddings, tokenizes sequences
        - **Pairformer**: Predicts inter-atomic distance distributions and pairwise features
        - **Potency**:
            - Crops to affinity-relevant pairs (ligand-ligand, ligand-residue)
            - Predicts IC50 (quantitative) and binary classification
            - Optionally samples from epinet for uncertainty quantification
        - **Postprocessing**: Converts outputs to numpy, computes statistics

        Args:
            input_data (List[IOSchemaCoarseBind]): List of input complexes to predict.
                Each IOSchemaCoarseBind should contain:
                - smiles: Ligand SMILES string
                - sequence: List of protein sequence strings
                - target_name: Target protein identifier (optional)
                - assay_id: Assay identifier for potency prediction (optional)

            pairformer_model_forward_kwargs (Dict[str, Any]): Pairformer model kwargs.
                - recycling_steps (int): Number of iterative refinement steps. Default: 3
                - precomputed_embeds (bool): Use precomputed embeddings. Default: True

            potency_model_forward_kwargs (Dict[str, Any]): Potency model kwargs.
                - epinet_samples (int): Number of epinet samples for uncertainty.
                  If > 0, returns pred_25_pct, pred_75_pct, pred_epinet_samples.
                  If 0, only returns single point prediction. Default: 100

            run_coarse_cofold (bool): Whether to generate 3D structure from distogram.
                Requires additional co-folding module. Default: False

            coarse_cofold_kwargs (Optional[Dict[str, Any]]): Co-folding parameters.
                Only used if run_coarse_cofold=True.

        Returns:
            List[IOSchemaCoarseBind]: Input list updated with predictions:
                - disto_output: OutputSchemaDisto with distance bins, z/s representations
                - disto_potency_output: OutputSchemaDistoPotency (if potency model loaded)
                    - pred_quant: IC50 prediction (mean if epinet_samples > 0)
                    - pred_binary: Binary classification (active/inactive)
                    - pred_25_pct, pred_75_pct: Uncertainty bounds (if epinet_samples > 0)
                    - pred_epinet_samples: All epinet samples (if epinet_samples > 0)
                - pred: Shortcut to pred_quant for convenience
                - error: Boolean indicating if prediction failed
                - error_msg: Error message if error=True

        Example:
            >>> # Basic usage
            >>> results = inf.predict(input_data)
            >>> for r in results:
            ...     print(f"IC50: {r.pred}, Binary: {r.disto_potency_output.pred_binary}")
            >>>
            >>> # With uncertainty quantification
            >>> results = inf.predict(
            ...     input_data,
            ...     potency_model_forward_kwargs={"epinet_samples": 100}
            ... )
            >>> for r in results:
            ...     print(f"IC50: {r.pred} [{r.disto_potency_output.pred_25_pct}, {r.disto_potency_output.pred_75_pct}]")
            >>>
            >>> # Without epinet sampling (faster)
            >>> results = inf.predict(
            ...     input_data,
            ...     potency_model_forward_kwargs={"epinet_samples": 0}
            ... )

        Raises:
            ValueError: If input_data is not a list of IOSchemaCoarseBind objects

        Note:
            - Complexes exceeding max_tokens are automatically skipped with error flag set
            - Batching is done dynamically based on sequence length for efficiency
            - All models are automatically initialized on first call if not already done
        """

        if not isinstance(input_data, list) or not all(
            isinstance(row, IOSchemaCoarseBind) for row in input_data
        ):
            raise ValueError("input_data must be a list of IOSchema")

        if self.model_pairformer is None:
            self.init_models()

        featurized_inputs = self.pre_xform(input_data, self.max_tokens)

        # sort by seq length + mol_enc_graph_embed len
        def sort_fn(idx: int) -> int:
            if not featurized_inputs[idx].error:
                return len(featurized_inputs[idx].sequence)
            else:
                return 0

        old_to_sorted_idx = sorted(
            range(len(featurized_inputs)),
            key=sort_fn,
            reverse=True,
        )
        sorted_inputs = [featurized_inputs[i] for i in old_to_sorted_idx]
        sorted_idx_to_og = {i: old_idx for i, old_idx in enumerate(old_to_sorted_idx)}

        # get indices of rows without errors
        no_error_indices = [i for i, row in enumerate(sorted_inputs) if not row.error]

        filtered_sorted_input_data = [sorted_inputs[i] for i in no_error_indices]
        original_idxs = [sorted_idx_to_og[i] for i in no_error_indices]

        # run inference
        pipe = InferencePipe(
            data=filtered_sorted_input_data,
            original_idxs=original_idxs,
            batch_size=self.batch_size,
            xform_routine=self.xform_routine,
        )

        for batch_original_idx, batch_features in tqdm(
            pipe,
            desc="disto inference",
            total=len(filtered_sorted_input_data) // self.batch_size,
        ):
            with torch.no_grad():

                with torch.autocast(device_type="cuda", dtype=self.precision):

                    batch_features = batch_to_device(batch_features, self.device)
                    _output = self.model_pairformer(
                        batch_features,
                        **pairformer_model_forward_kwargs,
                    )

                    # Run potency prediction if model is available
                    if self.model_potency:
                        # Crop to affinity-relevant regions
                        pairformer_outputs_cropped = []
                        for i in range(len(batch_original_idx)):
                            pairformer_outputs_cropped.append(
                                self.affinity_writer.process(
                                    _output,
                                    batch_features,
                                    i,
                                )
                            )

                        # Collate potency features from pairformer outputs
                        disto_potency_feats = disto_potency_feats_collate_fn_2(
                            pairformer_outputs_cropped
                        )
                        potency_features = {**batch_features, **disto_potency_feats}

                        # Run potency model
                        potency_output: Dict[str, torch.Tensor] = self.model_potency(
                            potency_features,
                            **potency_model_forward_kwargs,
                        )

                        # Assign potency predictions to output schema
                        for i in range(len(batch_original_idx)):
                            if potency_output["pred_quant"].ndim == 2:
                                pred = potency_output["pred_quant"][i].mean().item()
                                pred_epinet_samples = potency_output["pred_quant"][i].cpu().numpy()
                                pred_25_pct = np.percentile(pred_epinet_samples, 25)
                                pred_75_pct = np.percentile(pred_epinet_samples, 75)
                            elif potency_output["pred_quant"].ndim == 1:
                                pred = potency_output["pred_quant"][i].item()
                                pred_epinet_samples = None
                                pred_25_pct = None
                                pred_75_pct = None
                            else:
                                raise ValueError("Invalid tensor shape for pred_quant.")

                            pred_binary = potency_output["pred_binary"][i].item()

                            featurized_inputs[batch_original_idx[i]].disto_potency_output = (
                                OutputSchemaDistoPotency(
                                    pred_quant=pred,
                                    pred_binary=pred_binary,
                                    pred_epinet_samples=pred_epinet_samples,
                                    pred_25_pct=pred_25_pct,
                                    pred_75_pct=pred_75_pct,
                                )
                            )

                            # Simple pred shortcut
                            featurized_inputs[batch_original_idx[i]].pred = pred

                    # Process pairformer outputs for each item in batch
                    for i in range(len(batch_original_idx)):
                        pairformer_output = self.pairformer_writer.process(
                            _output,
                            batch_features,
                            i,
                        )
                        featurized_inputs[batch_original_idx[i]].disto_output = to_numpy(
                            pairformer_output
                        )

        if self.run_coarse_cofold:

            self.coarse_cofold_predict(featurized_inputs, coarse_cofold_kwargs)

        return featurized_inputs

    def coarse_cofold_predict(
        self,
        io_data: List[IOSchemaCoarseBind],
        coarse_cofold_kwargs: Optional[Dict[str, Any]] = {},
    ) -> List[IOSchemaCoarseBind]:
        """
        Runs coarse co-folding on the input data to generate 3D coordinates.

        This method takes a list of IOSchema objects, which should have been
        processed by the main prediction method to contain distance predictions,
        and uses the CoarseCofoldInf class to generate 3D structures.
        The generated poses are added back to the IOSchema objects.

        Args:
            io_data (List[IOSchema]): A list of IOSchema objects containing the
                necessary data for co-folding (e.g., distance predictions).
            coarse_cofold_kwargs (Optional[Dict[str, Any]]): A dictionary of
                keyword arguments to be passed to the CoarseCofoldInf initializer.
                This can be used to customize the co-folding process.

        Returns:
            List[IOSchema]: The list of IOSchema objects, updated with the
                generated coarse poses.
        """

        coarse_cofold_inf = CoarseCofoldInf(
            device=self.device,
            **coarse_cofold_kwargs,
        )

        io_data = coarse_cofold_inf.predict(
            io_schema=io_data,
        )

        return io_data
