import gzip
import math
import os
import random
from typing import Dict, List, Literal

import einops
import torch
from jaxtyping import Float
from loguru import logger
from torch.nn import functional as F
from torch_scatter import scatter_mean

from openfold.data import data_transforms
from openfold.np.residue_constants import atom_types
from proteinfoundation.utils.angle_utils import bond_angles, signed_dihedral_angle
from proteinfoundation.utils.fold_utils import extract_cath_code_by_level
from torch.nn.utils.rnn import pad_sequence


def get_index_embedding(indices, edim, max_len=2056):

    K = torch.arange(edim // 2, device=indices.device)

    if len(indices.shape) == 1:
        K = K[None, ...]
    elif len(indices.shape) == 2:
        K = K[None, None, ...]

    pos_embedding_sin = torch.sin(
        indices[..., None] * math.pi / (max_len ** (2 * K / edim))
    ).to(indices.device)

    pos_embedding_cos = torch.cos(
        indices[..., None] * math.pi / (max_len ** (2 * K / edim))
    ).to(indices.device)
    pos_embedding = torch.cat([pos_embedding_sin, pos_embedding_cos], axis=-1)
    return pos_embedding


def get_time_embedding(
    t: Float[torch.Tensor, "b"], edim: int, max_positions: int = 2000
) -> Float[torch.Tensor, "b embdim"]:

    assert len(t.shape) == 1
    t = t * max_positions
    half_dim = edim // 2
    emb = math.log(max_positions) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=t.device) * -emb)
    emb = t.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if edim % 2 == 1:
        emb = F.pad(emb, (0, 1), mode="constant")
    assert emb.shape == (t.shape[0], edim)
    return emb


def bin_pairwise_distances(x, min_dist, max_dist, dim):

    pair_dists_nm = torch.norm(x[:, :, None, :] - x[:, None, :, :], dim=-1)
    bin_limits = torch.linspace(min_dist, max_dist, dim - 1, device=x.device)
    return bin_and_one_hot(pair_dists_nm, bin_limits)


def bin_and_one_hot(tensor, bin_limits):

    bin_indices = torch.bucketize(tensor, bin_limits)
    return torch.nn.functional.one_hot(bin_indices, len(bin_limits) + 1) * 1.0


def indices_force_start_w_one(pdb_idx, mask):

    first_val = pdb_idx[:, 0][:, None]
    pdb_idx = pdb_idx - first_val + 1
    pdb_idx = torch.masked_fill(pdb_idx, ~mask, -1)
    return pdb_idx


class Feature(torch.nn.Module):

    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def get_dim(self):
        return self.dim

    def forward(self, batch: Dict):
        pass

    def extract_bs_and_n(self, batch: Dict):

        if "x_t" in batch:
            if "bb_ca" in batch["x_t"]:
                v = batch["x_t"]["bb_ca"]
        elif "coords" in batch:
            v = batch["coords"]
        elif "z_latent" in batch:
            v = batch["z_latent"]
        else:
            raise IOError("Don't know how to extract batch size and n from batch...")
        bs, n = v.shape[0], v.shape[1]
        return bs, n

    def extract_device(self, batch: Dict):

        if "x_t" in batch:
            if "bb_ca" in batch["x_t"]:
                v = batch["x_t"]["bb_ca"]
        elif "coords" in batch:
            v = batch["coords"]
        elif "z_latent" in batch:
            v = batch["z_latent"]
        else:
            raise IOError("Don't know how to extract device from batch...")
        return v.device

    def assert_defaults_allowed(self, batch: Dict, ftype: str):

        if "strict_feats" in batch:
            if batch["strict_feats"]:
                raise IOError(
                    f"{ftype} feature requested but no appropriate feature provided. "
                    "Make sure to include the relevant transform in the data config."
                )


class ZeroFeat(Feature):

    def __init__(
        self,
        dim_feats_out=128,
        mode: Literal["seq", "pair"] = "seq",
        name=None,
        **kwargs,
    ):
        super().__init__(dim=128)
        self.mode = mode

    def forward(self, batch):
        b, n = self.extract_bs_and_n(batch)
        device = self.extract_device(batch)
        if self.mode == "seq":
            return torch.zeros((b, n, self.dim), device=device)
        elif self.mode == "pair":
            torch.zeros((b, n, n, self.dim_feats_out), device=device)
        else:
            raise IOError(f"Mode {self.mode} wrong for zero feature")


class CroppedFlagSeqFeat(Feature):

    def __init__(self):
        super().__init__(dim=1)

    def forward(self, batch):
        b, n = self.extract_bs_and_n(batch)
        device = self.extract_device(batch)
        if "cropped" in batch:
            ones = torch.ones((b, n, self.dim), device=device)
            cropped = batch["cropped"]
            return ones * cropped[..., None, None]
        else:
            return torch.zeros((b, n, self.dim), device=device)


class FoldEmbeddingSeqFeat(Feature):

    def __init__(
        self,
        fold_emb_dim,
        cath_code_dir,
        multilabel_mode="sample",
        fold_nhead=4,
        fold_nlayer=2,
        **kwargs,
    ):

        super().__init__(dim=fold_emb_dim * 3)
        self.create_mapping(cath_code_dir)
        self.embedding_C = torch.nn.Embedding(self.num_classes_C + 1, fold_emb_dim)
        self.embedding_fA = torch.nn.Embedding(self.num_classes_A + 1, fold_emb_dim)
        self.embedding_T = torch.nn.Embedding(self.num_classes_T + 1, fold_emb_dim)
        self.register_buffer("_device_param", torch.tensor(0), persistent=False)
        assert multilabel_mode in ["sample", "average", "transformer"]
        self.multilabel_mode = multilabel_mode
        if multilabel_mode == "transformer":
            encoder_layer = torch.nn.TransformerEncoderLayer(
                fold_emb_dim * 3,
                nhead=fold_nhead,
                dim_feedforward=fold_emb_dim * 3,
                batch_first=True,
            )
            self.transformer = torch.nn.TransformerEncoder(encoder_layer, fold_nlayer)

    @property
    def device(self):
        return next(self.buffers()).device

    def create_mapping(self, cath_code_dir):

        mapping_file = os.path.join(cath_code_dir, "cath_label_mapping.pt")
        if os.path.exists(mapping_file):
            class_mapping = torch.load(mapping_file)
        else:
            cath_code_file = os.path.join(cath_code_dir, "cath-b-newest-all.gz")
            cath_code_set = {"C": set(), "A": set(), "T": set()}
            with gzip.open(cath_code_file, "rt") as f:
                for line in f:
                    cath_id, cath_version, cath_code, cath_segment_and_chain = (
                        line.strip().split()
                    )
                    cath_code_set["C"].add(extract_cath_code_by_level(cath_code, "C"))
                    cath_code_set["A"].add(extract_cath_code_by_level(cath_code, "A"))
                    cath_code_set["T"].add(extract_cath_code_by_level(cath_code, "T"))
            class_mapping = {
                "C": {k: i for i, k in enumerate(sorted(list(cath_code_set["C"])))},
                "A": {k: i for i, k in enumerate(sorted(list(cath_code_set["A"])))},
                "T": {k: i for i, k in enumerate(sorted(list(cath_code_set["T"])))},
            }
            torch.save(class_mapping, mapping_file)

        self.class_mapping_C = class_mapping["C"]
        self.class_mapping_A = class_mapping["A"]
        self.class_mapping_T = class_mapping["T"]
        self.num_classes_C = len(self.class_mapping_C)
        self.num_classes_A = len(self.class_mapping_A)
        self.num_classes_T = len(self.class_mapping_T)

    def parse_label(self, cath_code_list):

        results = []
        for cath_codes in cath_code_list:
            result = []
            for cath_code in cath_codes:
                result.append(
                    [
                        self.class_mapping_C.get(
                            extract_cath_code_by_level(cath_code, "C"),
                            self.num_classes_C,
                        ),
                        self.class_mapping_A.get(
                            extract_cath_code_by_level(cath_code, "A"),
                            self.num_classes_A,
                        ),
                        self.class_mapping_T.get(
                            extract_cath_code_by_level(cath_code, "T"),
                            self.num_classes_T,
                        ),
                    ]
                )
            if len(cath_codes) == 0:
                result = [[self.num_classes_C, self.num_classes_A, self.num_classes_T]]
            results.append(result)
        return results

    def sample(self, cath_code_list):

        results = []
        for cath_codes in cath_code_list:
            idx = random.randint(0, len(cath_codes) - 1)
            results.append(cath_codes[idx])
        return results

    def flatten(self, cath_code_list):

        results = []
        batch_id = []
        for i, cath_codes in enumerate(cath_code_list):
            results += cath_codes
            batch_id += [i] * len(cath_codes)
        results = torch.as_tensor(results, device=self.device)
        batch_id = torch.as_tensor(batch_id, device=self.device)
        return results, batch_id

    def pad(self, cath_code_list):

        results = []
        max_num_label = 0
        for cath_codes in cath_code_list:
            results.append(cath_codes)
            max_num_label = max(max_num_label, len(cath_codes))
        mask = []
        for i in range(len(results)):
            mask_i = [False] * len(results[i])
            if len(results[i]) < max_num_label:
                mask_i += [True] * (max_num_label - len(results[i]))
                results[i] += [
                    [self.num_classes_C, self.num_classes_A, self.num_classes_T]
                ] * (max_num_label - len(results[i]))
            mask.append(mask_i)
        results = torch.as_tensor(results, device=self.device)
        mask = torch.as_tensor(mask, device=self.device)
        return results, mask

    def forward(self, batch):
        bs, n = self.extract_bs_and_n(batch)
        if "cath_code" not in batch:
            cath_code = [["x.x.x.x"]] * bs
        else:
            cath_code = batch["cath_code"]

        cath_code_list = self.parse_label(cath_code)
        if self.multilabel_mode == "sample":
            cath_code_list = self.sample(cath_code_list)
            cath_code = torch.as_tensor(cath_code_list, device=self.device)
            fold_emb = torch.cat(
                [
                    self.embedding_C(cath_code[:, 0]),
                    self.embedding_A(cath_code[:, 1]),
                    self.embedding_T(cath_code[:, 2]),
                ],
                dim=-1,
            )
        elif self.multilabel_mode == "average":
            cath_code, batch_id = self.flatten(cath_code_list)
            fold_emb = torch.cat(
                [
                    self.embedding_C(cath_code[:, 0]),
                    self.embedding_A(cath_code[:, 1]),
                    self.embedding_T(cath_code[:, 2]),
                ],
                dim=-1,
            )
            fold_emb = scatter_mean(fold_emb, batch_id, dim=0, dim_size=bs)
        elif self.multilabel_mode == "transformer":
            cath_code, mask = self.pad(cath_code_list)
            fold_emb = torch.cat(
                [
                    self.embedding_C(cath_code[:, :, 0]),
                    self.embedding_A(cath_code[:, :, 1]),
                    self.embedding_T(cath_code[:, :, 2]),
                ],
                dim=-1,
            )
            fold_emb = self.transformer(fold_emb, src_key_padding_mask=mask)
            fold_emb = (fold_emb * (~mask[:, :, None]).float()).sum(dim=1) / (
                (~mask[:, :, None]).float().sum(dim=1) + 1e-10
            )
        fold_emb = fold_emb[:, None, :]
        return fold_emb.expand((fold_emb.shape[0], n, fold_emb.shape[2]))


class TimeEmbeddingSeqFeat(Feature):

    def __init__(self, data_mode_use, t_emb_dim, **kwargs):
        super().__init__(dim=t_emb_dim)
        self.data_mode_use = data_mode_use

    def forward(self, batch):
        t = batch["t"][self.data_mode_use]
        _, n = self.extract_bs_and_n(batch)
        t_emb = get_time_embedding(t, edim=self.dim)
        t_emb = t_emb[:, None, :]
        return t_emb.expand((t_emb.shape[0], n, t_emb.shape[2]))


class TimeEmbeddingPairFeat(Feature):

    def __init__(self, data_mode_use, t_emb_dim, **kwargs):
        super().__init__(dim=t_emb_dim)
        self.data_mode_use = data_mode_use

    def forward(self, batch):
        t = batch["t"][self.data_mode_use]
        _, n = self.extract_bs_and_n(batch)
        t_emb = get_time_embedding(t, edim=self.dim)
        t_emb = t_emb[:, None, None, :]
        return t_emb.expand((t_emb.shape[0], n, n, t_emb.shape[3]))


class IdxEmbeddingSeqFeat(Feature):

    def __init__(self, idx_emb_dim, **kwargs):
        super().__init__(dim=idx_emb_dim)

    def forward(self, batch):
        if "residue_pdb_idx" in batch:
            inds = batch["residue_pdb_idx"]
            inds = indices_force_start_w_one(inds, batch["mask"])
        else:
            self.assert_defaults_allowed(batch, "Residue index sequence")
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            inds = torch.Tensor([[i + 1 for i in range(n)] for _ in range(b)]).to(
                device
            )
        return get_index_embedding(inds, edim=self.dim)


class ChainBreakPerResidueSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=1)

    def forward(self, batch):
        if "chain_breaks_per_residue" in batch:
            chain_breaks = batch["chain_breaks_per_residue"] * 1.0
        else:
            self.assert_defaults_allowed(batch, "Chain break sequence")
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            chain_breaks = torch.zeros((b, n), device=device) * 1.0
        return chain_breaks[..., None]


class XscBBCASeqFeat(Feature):

    def __init__(self, mode_key="x_sc", **kwargs):
        super().__init__(dim=3)
        self.mode_key = mode_key
        self._has_logged = False

    def forward(self, batch):
        if self.mode_key in batch:
            data_modes_avail = [k for k in batch[self.mode_key]]
            assert (
                "bb_ca" in data_modes_avail
            ), f"`bb_ca` sc/recycle seq feature requested but key not available in data modes {data_modes_avail}"
            return batch[self.mode_key]["bb_ca"]
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    f"No {self.mode_key} in batch, returning zeros for XscBBCASeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, 3, device=device)


class XscLocalLatentsSeqFeat(Feature):

    def __init__(self, latent_dim, mode_key="x_sc", **kwargs):
        super().__init__(dim=latent_dim)
        self.mode_key = mode_key
        self._has_logged = False

    def forward(self, batch):
        if self.mode_key in batch:
            data_modes_avail = [k for k in batch[self.mode_key]]
            assert (
                "local_latents" in data_modes_avail
            ), f"`local_latents` sc/recycle seq feature requested but key not available in data modes {data_modes_avail}"
            return batch[self.mode_key]["local_latents"]
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    f"No {self.mode_key} in batch, returning zeros for XscLocalLatentsSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class XtBBCASeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=3)

    def forward(self, batch):
        data_modes_avail = [k for k in batch["x_t"]]
        assert (
            "bb_ca" in data_modes_avail
        ), f"`bb_ca` seq feat feature requested but key not available in data modes {data_modes_avail}"
        return batch["x_t"]["bb_ca"]


class XtLocalLatentsSeqFeat(Feature):

    def __init__(self, latent_dim, **kwargs):
        super().__init__(dim=latent_dim)

    def forward(self, batch):
        data_modes_avail = [k for k in batch["x_t"]]
        assert (
            "local_latents" in data_modes_avail
        ), f"`local_latents` seq feat feature requested but key not available in data modes {data_modes_avail}"
        return batch["x_t"]["local_latents"]


class CaCoorsNanometersSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=3)

    def forward(self, batch):
        assert (
            "ca_coors_nm" in batch or "coords_nm" in batch
        ), "`ca_coors_nm` nor `coords_nm` in batch, cannot compute CaCoorsNanometersSeqFeat"
        if "ca_coors_nm" in batch:
            return batch["ca_coors_nm"]
        else:
            return batch["coords_nm"][:, :, 1, :]


class TryCaCoorsNanometersSeqFeat(CaCoorsNanometersSeqFeat):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._has_logged = False

    def forward(self, batch):
        if "ca_coors_nm" in batch or "coords_nm" in batch:
            return super().forward(batch)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No ca_coors_nm or coords_nm in batch, returning zeros for TryCaCoorsNanometersSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class OptionalCaCoorsNanometersSeqFeat(CaCoorsNanometersSeqFeat):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._has_logged = False

    def forward(self, batch):
        if batch.get("use_ca_coors_nm_feature", False):
            return super().forward(batch)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "use_ca_coors_nm_feature disabled or not in batch, returning zeros for OptionalCaCoorsNanometersSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class ResidueTypeSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=20)

    def forward(self, batch):
        assert (
            "residue_type" in batch
        ), "`residue_type` not in batch, cannot compute ResidueTypeSeqFeat"
        rtype = batch["residue_type"]
        rpadmask = batch["mask_dict"]["residue_type"]
        rtype = rtype * rpadmask
        rtype_onehot = F.one_hot(rtype, num_classes=20)
        rtype_onehot = rtype_onehot * rpadmask[..., None]
        return rtype_onehot * 1.0


class OptionalResidueTypeSeqFeat(ResidueTypeSeqFeat):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._has_logged = False

    def forward(self, batch):
        if batch.get("use_residue_type_feature", False):
            return super().forward(batch)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "use_residue_type_feature disabled or not in batch, returning zeros for OptionalResidueTypeSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, 20, device=device)


class Atom37NanometersCoorsSeqFeat(Feature):

    def __init__(self, rel=False, **kwargs):
        super().__init__(dim=int(37 * 4))

        self.rel = rel

    def forward(self, batch):
        assert (
            "coords_nm" in batch
        ), "`coords_nm` not in batch, cannot compute Atom37NanometersCoorsSeqFeat"
        assert (
            "coord_mask" in batch
        ), "`coord_mask` not in batch, cannot compute Atom37NanometersCoorsSeqFeat"
        coors = batch["coords_nm"]
        coors_mask = batch["coord_mask"]
        coors = coors * coors_mask[..., None]

        if self.rel:

            ca_coors = coors[:, :, 1, :]
            coors = coors - ca_coors[:, :, None, :]
            coors = coors * coors_mask[..., None]

        coors_flat = einops.rearrange(coors, "b n a t -> b n (a t)")
        feat = torch.cat([coors_flat, coors_mask], dim=-1)
        return feat


class BackboneTorsionAnglesSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=int(3 * 21))

    def forward(self, batch):
        bb_torsion = self._get_bb_torsion_angles(batch)
        bb_torsion_feats = bin_and_one_hot(
            bb_torsion,
            torch.linspace(-torch.pi, torch.pi, 20, device=bb_torsion.device),
        )
        bb_torsion_feats = einops.rearrange(bb_torsion_feats, "b n t d -> b n (t d)")
        return bb_torsion_feats

    def _get_bb_torsion_angles(self, batch):
        a37 = batch["coords"]
        if "residue_pdb_idx" in batch and batch["residue_pdb_idx"] is not None:
            idx = batch["residue_pdb_idx"]
        else:
            self.assert_defaults_allowed(batch, "Relative sequence separation pair")
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            idx = torch.Tensor([[i + 1 for i in range(n)] for _ in range(b)]).to(device)
        N = a37[:, :, 0, :]
        CA = a37[:, :, 1, :]
        C = a37[:, :, 2, :]

        psi = signed_dihedral_angle(
            N[:, :-1, :], CA[:, :-1, :], C[:, :-1, :], N[:, 1:, :]
        )
        omega = signed_dihedral_angle(
            CA[:, :-1, :], C[:, :-1, :], N[:, 1:, :], CA[:, 1:, :]
        )
        phi = signed_dihedral_angle(
            C[:, :-1, :], N[:, 1:, :], CA[:, 1:, :], C[:, 1:, :]
        )
        bb_angles = torch.stack([psi, omega, phi], dim=-1)

        good_pair = idx[:, 1:] - idx[:, :-1] == 1
        bb_angles = bb_angles * good_pair[..., None]

        zero_pad = torch.zeros((a37.shape[0], 1, 3), device=bb_angles.device)
        bb_angles = torch.cat([bb_angles, zero_pad], dim=1)
        return bb_angles


class BackboneBondAnglesSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=int(3 * 21))

    def forward(self, batch):
        bb_bond_angle = self._get_bb_bond_angles(batch)
        bb_bond_angle_feats = bin_and_one_hot(
            bb_bond_angle,
            torch.linspace(-torch.pi, torch.pi, 20, device=bb_bond_angle.device),
        )
        bb_bond_angle_feats = einops.rearrange(
            bb_bond_angle_feats, "b n t d -> b n (t d)"
        )
        return bb_bond_angle_feats

    def _get_bb_bond_angles(self, batch):
        a37 = batch["coords"]
        mask = batch["mask_dict"]["coords"][..., 0, 0]

        if "residue_pdb_idx" in batch and batch["residue_pdb_idx"] is not None:
            idx = batch["residue_pdb_idx"]
        else:
            self.assert_defaults_allowed(batch, "Relative sequence separation pair")
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            idx = torch.Tensor([[i + 1 for i in range(n)] for _ in range(b)]).to(device)
        b = a37.shape[0]

        N = a37[:, :, 0, :]
        CA = a37[:, :, 1, :]
        C = a37[:, :, 2, :]
        theta_1 = bond_angles(N[:, :, :], CA[:, :, :], C[:, :, :])
        theta_2 = bond_angles(CA[:, :-1, :], C[:, :-1, :], N[:, 1:, :])
        theta_3 = bond_angles(C[:, :-1, :], N[:, 1:, :], CA[:, 1:, :])

        good_pair = idx[:, 1:] - idx[:, :-1] == 1
        theta_2 = theta_2 * good_pair
        theta_3 = theta_3 * good_pair

        zero_pad = torch.zeros((b, 1), device=theta_2.device)
        theta_2 = torch.cat([theta_2, zero_pad], dim=-1)
        theta_3 = torch.cat([theta_3, zero_pad], dim=-1)

        bb_angles = torch.stack([theta_1, theta_2, theta_3], dim=-1)
        return bb_angles


class OpenfoldSideChainAnglesSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=int(4 * 21 + 4))

    def forward(self, batch):
        _, angles, torsion_angles_mask = self._get_sidechain_angles(batch)

        angles_feat = bin_and_one_hot(
            angles, torch.linspace(-torch.pi, torch.pi, 20, device=angles.device)
        )
        angles_feat = angles_feat * torsion_angles_mask[..., None]
        angles_feat = einops.rearrange(angles_feat, "b n s d -> b n (s d)")
        feat = torch.cat([angles_feat, torsion_angles_mask], dim=-1)
        return feat

    def _get_sidechain_angles(self, batch):
        orig_dtype = batch["coords"].dtype
        aatype = batch["residue_type"]
        coords = batch["coords"].double()
        atom_mask = batch["coord_mask"].double()
        p = {
            "aatype": aatype,
            "all_atom_positions": coords,
            "all_atom_mask": atom_mask,
        }

        p = data_transforms.atom37_to_torsion_angles(prefix="")(p)
        torsion_angles_sin_cos = p["torsion_angles_sin_cos"]
        alt_torsion_angles_sin_cos = p["alt_torsion_angles_sin_cos"]

        torsion_angles_sin_cos = torsion_angles_sin_cos / (
            torch.linalg.norm(torsion_angles_sin_cos, dim=-1, keepdim=True) + 1e-10
        )
        alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos / (
            torch.linalg.norm(alt_torsion_angles_sin_cos, dim=-1, keepdim=True) + 1e-10
        )
        torsion_angles_mask = p["torsion_angles_mask"]
        torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[..., None]
        alt_torsion_angles_sin_cos = (
            alt_torsion_angles_sin_cos * torsion_angles_mask[..., None]
        )
        angles = torch.atan2(
            torsion_angles_sin_cos[..., 0], torsion_angles_sin_cos[..., 1]
        )
        angles = angles * torsion_angles_mask

        torsion_angles_sin_cos = torsion_angles_sin_cos[..., -4:, :]
        alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos[..., -4:, :]
        angles = angles[..., -4:]
        torsion_angles_mask = torsion_angles_mask[..., -4:]
        return (
            torsion_angles_sin_cos.to(dtype=orig_dtype),
            angles.to(dtype=orig_dtype),
            torsion_angles_mask.bool(),
        )


class LatentVariableSeqFeat(Feature):

    def __init__(self, latent_z_dim, **kwargs):
        super().__init__(dim=latent_z_dim)

    def forward(self, batch):
        assert (
            "z_latent" in batch
        ), "`z_latent` not in batch, cannot compute LatentVariableSeqFeat"
        return batch["z_latent"]


class MotifAbsoluteCoordsSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=148)
        self._has_logged = False

    def forward(self, batch):
        if "x_motif" in batch and "motif_mask" in batch:
            batch_coors = {
                "coords_nm": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            return Atom37NanometersCoorsSeqFeat(rel=False)(batch_coors)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No x_motif or motif_mask in batch, returning zeros for MotifAbsoluteCoordsSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class MotifRelativeCoordsSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=148)
        self._has_logged = False

    def forward(self, batch):
        if "x_motif" in batch and "motif_mask" in batch and "seq_motif_mask" in batch:
            required_atoms = torch.tensor(
                [atom_types.index("CA")], device=batch["motif_mask"].device
            )
            has_required_atoms = torch.all(
                batch["motif_mask"][:, :, required_atoms], dim=-1
            )
            relevant_has_required_atoms = torch.where(
                batch["seq_motif_mask"],
                has_required_atoms,
                torch.ones_like(has_required_atoms, dtype=torch.bool),
            )
            if not torch.all(relevant_has_required_atoms):
                if not self._has_logged:
                    logger.warning(
                        "Missing required CA atoms in motif region, returning zeros for MotifRelativeCoordsSeqFeat"
                    )
                    self._has_logged = True
                b, n = self.extract_bs_and_n(batch)
                device = self.extract_device(batch)
                return torch.zeros(b, n, self.dim, device=device)
            batch_coors = {
                "coords_nm": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            return Atom37NanometersCoorsSeqFeat(rel=True)(batch_coors)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No x_motif or motif_mask in batch, returning zeros for MotifRelativeCoordsSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class MotifSequenceSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=20)
        self._has_logged = False

    def forward(self, batch):
        if "seq_motif" in batch and "seq_motif_mask" in batch:
            batch_seq = {
                "residue_type": batch["seq_motif"],
                "mask_dict": {
                    "residue_type": batch["seq_motif_mask"],
                },
            }
            return ResidueTypeSeqFeat()(batch_seq)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No seq_motif or seq_motif_mask in batch, returning zeros for MotifSequenceSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class MotifSideChainAnglesSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=88)
        self._has_logged = False

    def forward(self, batch):
        if "x_motif" in batch and "motif_mask" in batch and "seq_motif" in batch:
            batch_sc_angles = {
                "residue_type": batch["seq_motif"],
                "coords": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            return OpenfoldSideChainAnglesSeqFeat()(batch_sc_angles)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "Missing required motif data in batch, returning zeros for MotifSideChainAnglesSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class MotifTorsionAnglesSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=63)
        self._has_logged = False

    def forward(self, batch):
        if "x_motif" in batch and "motif_mask" in batch and "seq_motif_mask" in batch:
            backbone_atoms = torch.tensor(
                [
                    atom_types.index("N"),
                    atom_types.index("CA"),
                    atom_types.index("C"),
                    atom_types.index("O"),
                ],
                device=batch["motif_mask"].device,
            )
            motif_mask_per_residue_backbone = torch.any(
                batch["motif_mask"][:, :, backbone_atoms], dim=-1
            )
            relevant_motif_mask = torch.where(
                batch["seq_motif_mask"],
                motif_mask_per_residue_backbone,
                torch.ones_like(motif_mask_per_residue_backbone, dtype=torch.bool),
            )
            if not torch.all(relevant_motif_mask):
                if not self._has_logged:
                    logger.warning(
                        "Missing backbone atoms in motif region, returning zeros"
                    )
                    self._has_logged = True
                b, n = self.extract_bs_and_n(batch)
                device = self.extract_device(batch)
                return torch.zeros(b, n, self.dim, device=device)

            batch_torsion_angles = {
                "coords": batch["x_motif"],
                "residue_pdb_idx": batch.get("residue_pdb_idx", None),
            }
            return BackboneTorsionAnglesSeqFeat()(batch_torsion_angles)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No x_motif or motif_mask in batch, returning zeros for MotifTorsionAnglesSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class XmotifBulkTipSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=None)
        self.const_coors_abs = Atom37NanometersCoorsSeqFeat(rel=False)
        self.const_seq = ResidueTypeSeqFeat()

        dim = self.const_coors_abs.dim + self.const_seq.dim + 37
        self.dim = dim

    def forward(self, batch):
        if "x_motif" in batch:

            batch_coors = {
                "coords_nm": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            feat_coors_abs = self.const_coors_abs(batch_coors)

            seq_mask = batch["motif_mask"].sum(-1).bool()
            batch_seq = {
                "residue_type": batch["seq_motif"],
                "mask_dict": {
                    "residue_type": seq_mask,
                },
            }
            feat_seq = self.const_seq(batch_seq)

            motif_mask = batch["motif_mask"] * 1.0

            feat = torch.cat([feat_coors_abs, feat_seq, motif_mask], dim=-1)
            feat = feat * seq_mask[..., None]

            return feat

        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            return torch.zeros(b, n, self.dim, device=device)


class MotifMaskSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=37)
        self._has_logged = False

    def forward(self, batch):
        if "motif_mask" in batch:
            return batch["motif_mask"] * 1.0
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No motif_mask in batch, returning zeros for MotifMaskSeqFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, self.dim, device=device)


class XmotifSeqFeatUnindexed(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=None)
        self.const_coors_abs = Atom37NanometersCoorsSeqFeat(rel=False)
        self.const_seq = ResidueTypeSeqFeat()

        dim = self.const_coors_abs.dim + self.const_seq.dim + 37
        self.dim = dim

    def forward(self, batch):
        if "x_motif" in batch:

            batch_coors = {
                "coords_nm": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            feat_coors_abs = self.const_coors_abs(batch_coors)

            seq_mask = batch["motif_mask"].sum(-1).bool()
            batch_seq = {
                "residue_type": batch["seq_motif"],
                "mask_dict": {
                    "residue_type": seq_mask,
                },
            }
            feat_seq = self.const_seq(batch_seq)

            motif_mask = batch["motif_mask"] * 1.0
            motif_mask_residue = motif_mask.sum(-1).bool()

            feat = torch.cat([feat_coors_abs, feat_seq, motif_mask], dim=-1)
            feat = feat * seq_mask[..., None]

            feats_ind = []
            masks_ind = []
            for b in range(feat.shape[0]):
                feat_local = feat[b, ...]
                mask_local = motif_mask_residue[b, ...]
                feat_local = feat_local[mask_local, ...]
                mask_local = mask_local[mask_local]
                feat_local = feat_local * mask_local[..., None]
                feats_ind.append(feat_local)
                masks_ind.append(mask_local)
                assert torch.all(mask_local), "Mask local wrong"

            masks_motif_uidx = pad_sequence(
                masks_ind, batch_first=True, padding_value=False
            )
            feats_motif_uidx = pad_sequence(
                feats_ind, batch_first=True, padding_value=0.0
            )

            return feats_motif_uidx, masks_motif_uidx

        else:
            raise IOError("No x_motif in batch")


class BulkAllAtomXmotifSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=None)
        self.const_coors_abs = Atom37NanometersCoorsSeqFeat(rel=False)
        self.const_coors_rel = Atom37NanometersCoorsSeqFeat(rel=True)
        self.const_seq = ResidueTypeSeqFeat()
        self.const_sc_angles = OpenfoldSideChainAnglesSeqFeat()
        self.const_torsion_angles = BackboneTorsionAnglesSeqFeat()

        dim = (
            self.const_coors_abs.dim
            + self.const_coors_rel.dim
            + self.const_seq.dim
            + self.const_sc_angles.dim
            + self.const_torsion_angles.dim
            + 37
        )
        self.dim = dim

    def forward(self, batch):
        if "x_motif" in batch:

            batch_coors = {
                "coords_nm": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            feat_coors_abs = self.const_coors_abs(batch_coors)
            feat_coors_rel = self.const_coors_rel(batch_coors)

            seq_mask = batch["motif_mask"].sum(-1).bool()
            batch_seq = {
                "residue_type": batch["seq_motif"],
                "mask_dict": {
                    "residue_type": seq_mask,
                },
            }
            feat_seq = self.const_seq(batch_seq)

            batch_sc_angles = {
                "residue_type": batch["seq_motif"],
                "coords": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            feat_sc_angles = self.const_sc_angles(batch_sc_angles)
            if "residue_pdb_idx" in batch:
                idx = batch["residue_pdb_idx"]
            else:
                self.assert_defaults_allowed(batch, "Relative sequence separation pair")
                b, n = self.extract_bs_and_n(batch)
                device = self.extract_device(batch)
                idx = torch.Tensor([[i + 1 for i in range(n)] for _ in range(b)]).to(
                    device
                )

            batch_torsion_angles = {
                "coords": batch["x_motif"],
                "residue_pdb_idx": idx,
            }
            feat_torsion_angles = self.const_torsion_angles(batch_torsion_angles)

            motif_mask = batch["motif_mask"] * 1.0

            feat = torch.cat(
                [
                    feat_coors_abs,
                    feat_coors_rel,
                    feat_seq,
                    feat_sc_angles,
                    feat_torsion_angles,
                    motif_mask,
                ],
                dim=-1,
            )
            feat = feat * seq_mask[..., None]

            return feat

        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            return torch.zeros(b, n, self.dim, device=device)


class ChainIdxSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=1)

    def forward(self, batch):
        if "chains" in batch:
            mask = batch["chains"].unsqueeze(-1)
        else:
            raise ValueError("chains")
        return mask


class SequenceSeparationPairFeat(Feature):

    def __init__(self, seq_sep_dim, **kwargs):
        super().__init__(dim=seq_sep_dim)

    def forward(self, batch):
        if "residue_pdb_idx" in batch:

            inds = batch["residue_pdb_idx"]
        else:
            self.assert_defaults_allowed(batch, "Relative sequence separation pair")
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            inds = torch.Tensor([[i + 1 for i in range(n)] for _ in range(b)]).to(
                device
            )

        seq_sep = inds[:, :, None] - inds[:, None, :]

        assert (
            self.dim % 2 == 1
        ), "Relative seq separation feature dimension must be odd and > 3"

        low = -(self.dim / 2.0 - 1)
        high = self.dim / 2.0 - 1
        bin_limits = torch.linspace(low, high, self.dim - 1, device=inds.device)

        return bin_and_one_hot(seq_sep, bin_limits)


class XtBBCAPairwiseDistancesPairFeat(Feature):

    def __init__(self, xt_pair_dist_dim, xt_pair_dist_min, xt_pair_dist_max, **kwargs):
        super().__init__(dim=xt_pair_dist_dim)
        self.min_dist = xt_pair_dist_min
        self.max_dist = xt_pair_dist_max

    def forward(self, batch):
        data_modes_avail = [k for k in batch["x_t"]]
        assert (
            "bb_ca" in data_modes_avail
        ), f"`bb_ca` pair dist feature requested but key not available in data modes {data_modes_avail}"
        return bin_pairwise_distances(
            x=batch["x_t"]["bb_ca"],
            min_dist=self.min_dist,
            max_dist=self.max_dist,
            dim=self.dim,
        )


class CaCoorsNanometersPairwiseDistancesPairFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=30)
        self.min_dist = 0.1
        self.max_dist = 3.0

    def forward(self, batch):
        assert (
            "ca_coors_nm" in batch or "coords_nm" in batch
        ), f"`ca_coors_nm` pair dist feature requested but key `ca_coors_nm` nor `coords_nm` not available"
        if "ca_coors_nm" in batch:
            ca_coors = batch["ca_coors_nm"]
        else:
            ca_coors = batch["coords_nm"][:, :, 1, :]
        return bin_pairwise_distances(
            x=ca_coors,
            min_dist=self.min_dist,
            max_dist=self.max_dist,
            dim=self.dim,
        )


class OptionalCaCoorsNanometersPairwiseDistancesPairFeat(
    CaCoorsNanometersPairwiseDistancesPairFeat
):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._has_logged = False

    def forward(self, batch):
        if batch.get("use_ca_coors_nm_feature", False):
            return super().forward(batch)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "use_ca_coors_nm_feature disabled or not in batch, returning zeros for OptionalCaCoorsNanometersPairwiseDistancesPairFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, n, self.dim, device=device)


class XscBBCAPairwiseDistancesPairFeat(Feature):

    def __init__(
        self,
        x_sc_pair_dist_dim,
        x_sc_pair_dist_min,
        x_sc_pair_dist_max,
        mode_key="x_sc",
        **kwargs,
    ):
        super().__init__(dim=x_sc_pair_dist_dim)
        self.min_dist = x_sc_pair_dist_min
        self.max_dist = x_sc_pair_dist_max
        self.mode_key = mode_key
        self._has_logged = False

    def forward(self, batch):
        if self.mode_key in batch:
            data_modes_avail = [k for k in batch[self.mode_key]]
            assert (
                "bb_ca" in data_modes_avail
            ), f"`bb_ca` sc/recycle pair dist feature requested but key not available in data modes {data_modes_avail}"
            return bin_pairwise_distances(
                x=batch[self.mode_key]["bb_ca"],
                min_dist=self.min_dist,
                max_dist=self.max_dist,
                dim=self.dim,
            )
        else:

            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    f"No {self.mode_key} in batch, returning zeros for XscBBCAPairwiseDistancesPairFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, n, self.dim, device=device)


class RelativeResidueOrientationPairFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=int(5 * 21))

    def forward(self, batch):
        aatype = batch["residue_type"]
        coords = batch["coords"]
        atom_mask = batch["coord_mask"]
        mask = atom_mask[:, :, 1]
        has_cb = atom_mask[:, :, 3]
        pair_mask = mask[:, :, None] * mask[:, None, :]
        beta_carbon_pair_mask = has_cb[:, :, None] * has_cb[:, :, None]
        pair_mask = pair_mask * beta_carbon_pair_mask

        N = coords[:, :, 0, :]
        CA = coords[:, :, 1, :]
        CB = coords[:, :, 3, :]

        N_p1, CA_p1, CB_p1 = map(lambda v: v[:, :, None, :], (N, CA, CB))
        N_p2, CA_p2, CB_p2 = map(lambda v: v[:, None, :, :], (N, CA, CB))

        theta_12 = signed_dihedral_angle(N_p1, CA_p1, CB_p1, CB_p2)
        theta_21 = signed_dihedral_angle(N_p2, CA_p2, CB_p2, CB_p1)
        phi_12 = bond_angles(CA_p1, CB_p1, CB_p2)
        phi_21 = bond_angles(CA_p2, CB_p2, CB_p1)
        w = signed_dihedral_angle(CA_p1, CB_p1, CB_p2, CA_p2)
        angles = torch.stack([theta_12, theta_21, phi_12, phi_21, w], dim=-1)

        angles_feat = bin_and_one_hot(
            angles, torch.linspace(-torch.pi, torch.pi, 20, device=angles.device)
        )
        angles_feat = einops.rearrange(angles_feat, "b n m f d -> b n m (f d)")
        angles_feat = angles_feat * pair_mask[..., None]
        return angles_feat


class BackbonePairDistancesNanometerPairFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=int(4 * 21))

    def forward(self, batch):
        assert (
            "coords_nm" in batch
        ), "`coords_nm` not in batch, cannot comptue BackbonePairDistancesNanometerPairFeat"
        coords = batch["coords_nm"]
        atom_mask = batch["coord_mask"]
        mask = atom_mask[:, :, 1]
        pair_mask = mask[:, None, :] * mask[:, :, None]
        has_cb = atom_mask[:, :, 3]

        N = coords[:, :, 0, :]
        CA = coords[:, :, 1, :]
        C = coords[:, :, 2, :]
        CB = coords[:, :, 3, :]

        CA_i = CA[:, :, None, :]
        N_j, CA_j, C_j, CB_j = map(lambda v: v[:, None, :, :], (N, CA, C, CB))

        CA_N, CA_CA, CA_C, CA_CB = map(
            lambda v: torch.linalg.norm(v[0] - v[1], dim=-1),
            ((CA_i, N_j), (CA_i, CA_j), (CA_i, C_j), (CA_i, CB_j)),
        )

        CA_CB = CA_CB * has_cb[:, None, :]

        CA_N, CA_CA, CA_C, CA_CB = map(
            lambda v: v * pair_mask,
            (CA_N, CA_CA, CA_C, CA_CB),
        )

        bin_limits = torch.linspace(0.1, 2, 20, device=coords.device)
        CA_N_feat, CA_CA_feat, CA_C_feat, CA_CB_feat = map(
            lambda v: bin_and_one_hot(v, bin_limits=bin_limits),
            (CA_N, CA_CA, CA_C, CA_CB),
        )

        feat = torch.cat([CA_N_feat, CA_CA_feat, CA_C_feat, CA_CB_feat], dim=-1)
        feat = feat * pair_mask[..., None]
        return feat


class XmotifPairwiseDistancesPairFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=None)
        self.const = BackbonePairDistancesNanometerPairFeat()
        self.dim = self.const.dim
        self._has_logged = False

    def forward(self, batch):
        if "x_motif" in batch:
            batch_bbpd = {
                "coords_nm": batch["x_motif"],
                "coord_mask": batch["motif_mask"],
            }
            feat = self.const(batch_bbpd)
            return feat
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No x_motif in batch, returning zeros for XmotifPairwiseDistancesPairFeat"
                )
                self._has_logged = True
            return torch.zeros(b, n, n, self.dim, device=device)


class ChainIdxPairFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=1)

    def forward(self, batch):
        if "chains" in batch:
            seq_mask = batch["chains"]
            mask = torch.einsum("bi,bj->bij", seq_mask, seq_mask).unsqueeze(-1)
        else:
            raise ValueError("chains")
        return mask


class StochasticTranslationSeqFeat(Feature):

    def __init__(self, **kwargs):
        super().__init__(dim=3)
        self._has_logged = False

    def forward(self, batch):
        if "stochastic_translation" in batch:
            b, n = self.extract_bs_and_n(batch)
            translation = batch["stochastic_translation"]

            mask = translation[:, None, :].expand(b, n, -1)
        else:
            b, n = self.extract_bs_and_n(batch)
            device = self.extract_device(batch)
            if not self._has_logged:
                logger.warning(
                    "No stochastic_translation in batch, returning zeros for StochasticTranslationSeqFeat"
                )
                self._has_logged = True
            mask = torch.zeros((b, n, 3), device=device)
        return mask


class FeatureFactory(torch.nn.Module):
    def __init__(
        self,
        feats: List[str],
        dim_feats_out: int,
        use_ln_out: bool,
        mode: Literal["seq", "pair"],
        **kwargs,
    ):

        super().__init__()
        self.mode = mode
        self.ret_zero = True if (feats is None or len(feats) == 0) else False
        if self.ret_zero:
            logger.info("No features requested")
            self.zero_creator = ZeroFeat(dim_feats_out=dim_feats_out, mode=mode)
            return

        self.feat_creators = torch.nn.ModuleList(
            [self.get_creator(f, **kwargs) for f in feats]
        )
        self.ln_out = (
            torch.nn.LayerNorm(dim_feats_out) if use_ln_out else torch.nn.Identity()
        )
        self.linear_out = torch.nn.Linear(
            sum([c.get_dim() for c in self.feat_creators]), dim_feats_out, bias=False
        )

    def get_creator(self, f, **kwargs):

        if self.mode == "seq":

            if f == "time_emb_bb_ca":
                return TimeEmbeddingSeqFeat(data_mode_use="bb_ca", **kwargs)
            elif f == "time_emb_local_latents":
                return TimeEmbeddingSeqFeat(data_mode_use="local_latents", **kwargs)

            elif f == "res_seq_pdb_idx":
                return IdxEmbeddingSeqFeat(**kwargs)
            elif f == "chain_break_per_res":
                return ChainBreakPerResidueSeqFeat(**kwargs)
            elif f == "chain_idx_seq":
                return ChainIdxSeqFeat(**kwargs)
            elif f == "fold_emb":
                return FoldEmbeddingSeqFeat(**kwargs)
            elif f == "cropped_flag_seq":
                return CroppedFlagSeqFeat()

            elif f == "x1_aatype":
                return ResidueTypeSeqFeat(**kwargs)
            elif f == "optional_res_type_seq_feat":
                return OptionalResidueTypeSeqFeat(**kwargs)

            elif f == "ca_coors_nm":
                return CaCoorsNanometersSeqFeat(**kwargs)
            elif f == "ca_coors_nm_try":
                return TryCaCoorsNanometersSeqFeat(**kwargs)
            elif f == "optional_ca_coors_nm_seq_feat":
                return OptionalCaCoorsNanometersSeqFeat(**kwargs)
            elif f == "x1_a37coors_nm":
                return Atom37NanometersCoorsSeqFeat(**kwargs)
            elif f == "x1_a37coors_nm_rel":
                return Atom37NanometersCoorsSeqFeat(rel=True, **kwargs)

            elif f == "xt_bb_ca":
                return XtBBCASeqFeat(**kwargs)
            elif f == "xt_local_latents":
                return XtLocalLatentsSeqFeat(**kwargs)
            elif f == "x_sc_bb_ca":
                return XscBBCASeqFeat(**kwargs)
            elif f == "x_recycle_bb_ca":
                return XscBBCASeqFeat(mode_key="x_recycle", **kwargs)
            elif f == "x_sc_local_latents":
                return XscLocalLatentsSeqFeat(**kwargs)
            elif f == "x_recycle_local_latents":
                return XscLocalLatentsSeqFeat(mode_key="x_recycle", **kwargs)

            elif f == "x1_bb_angles":
                return BackboneTorsionAnglesSeqFeat(**kwargs)
            elif f == "x1_bond_angles":
                return BackboneBondAnglesSeqFeat(**kwargs)
            elif f == "x1_sidechain_angles":
                return OpenfoldSideChainAnglesSeqFeat(**kwargs)

            elif f == "z_latent_seq":
                return LatentVariableSeqFeat(**kwargs)

            elif f == "motif_abs_coords":
                return MotifAbsoluteCoordsSeqFeat(**kwargs)
            elif f == "motif_rel_coords":
                return MotifRelativeCoordsSeqFeat(**kwargs)
            elif f == "motif_seq":
                return MotifSequenceSeqFeat(**kwargs)
            elif f == "motif_sc_angles":
                return MotifSideChainAnglesSeqFeat(**kwargs)
            elif f == "motif_torsion_angles":
                return MotifTorsionAnglesSeqFeat(**kwargs)
            elif f == "motif_mask":
                return MotifMaskSeqFeat(**kwargs)

            elif f == "bulk_all_atom_xmotif" or f == "x_motif":
                return BulkAllAtomXmotifSeqFeat(**kwargs)

            elif f == "bulk_all_atom_xmotif":
                return BulkAllAtomXmotifSeqFeat(**kwargs)

            elif f == "stochastic_translation":
                return StochasticTranslationSeqFeat(**kwargs)

            elif f == "zero_feat_seq":
                return ZeroFeat(**kwargs)
            else:
                raise IOError(f"Sequence feature {f} not implemented.")

        elif self.mode == "pair":

            if f == "time_emb_bb_ca":
                return TimeEmbeddingPairFeat(data_mode_use="bb_ca", **kwargs)
            elif f == "time_emb_local_latents":
                return TimeEmbeddingPairFeat(data_mode_use="local_latents", **kwargs)

            elif f == "rel_seq_sep":
                return SequenceSeparationPairFeat(**kwargs)

            elif f == "xt_bb_ca_pair_dists":
                return XtBBCAPairwiseDistancesPairFeat(**kwargs)
            elif f == "x_sc_bb_ca_pair_dists":
                return XscBBCAPairwiseDistancesPairFeat(**kwargs)
            elif f == "x_recycle_bb_ca_pair_dists":
                return XscBBCAPairwiseDistancesPairFeat(mode_key="x_recycle", **kwargs)
            elif f == "ca_coors_nm_pair_dists":
                return CaCoorsNanometersPairwiseDistancesPairFeat(**kwargs)
            elif f == "optional_ca_pair_dist":
                return OptionalCaCoorsNanometersPairwiseDistancesPairFeat(**kwargs)
            elif f == "x1_bb_pair_dists_nm":
                return BackbonePairDistancesNanometerPairFeat(**kwargs)
            elif f == "x_motif_pair_dists":
                return XmotifPairwiseDistancesPairFeat(**kwargs)

            elif f == "x1_bb_pair_orientation":
                return RelativeResidueOrientationPairFeat(**kwargs)

            elif f == "chain_idx_pair":
                return ChainIdxPairFeat(**kwargs)

            else:
                raise IOError(f"Pair feature {f} not implemented.")

        else:
            raise IOError(
                f"Wrong feature mode (creator): {self.mode}. Should be 'seq' or 'pair'."
            )

    def apply_padding_mask(self, feature_tensor, mask):

        if self.mode == "seq":
            return feature_tensor * mask[..., None]
        elif self.mode == "pair":
            mask_pair = mask[:, None, :] * mask[:, :, None]
            return feature_tensor * mask_pair[..., None]
        else:
            raise IOError(
                f"Wrong feature mode (pad mask): {self.mode}. Should be 'seq' or 'pair'."
            )

    def forward(self, batch):

        if self.ret_zero:
            return self.zero_creator(batch)

        feature_tensors = []
        for fcreator in self.feat_creators:
            feature_tensors.append(fcreator(batch))

        features = torch.cat(feature_tensors, dim=-1)
        features = self.apply_padding_mask(features, batch["mask"])

        features_proc = self.ln_out(self.linear_out(features))
        return self.apply_padding_mask(features_proc, batch["mask"])


class FeatureFactoryUidxMotif(torch.nn.Module):
    def __init__(
        self,
        dim_feats_out: int,
        use_ln_out: bool,
        **kwargs,
    ):

        super().__init__()

        self.feat_creator = XmotifSeqFeatUnindexed(**kwargs)
        self.ln_out = (
            torch.nn.LayerNorm(dim_feats_out) if use_ln_out else torch.nn.Identity()
        )
        self.linear_out = torch.nn.Linear(
            self.feat_creator.get_dim(), dim_feats_out, bias=False
        )

    def forward(self, batch):

        feat, feat_mask = self.feat_creator(batch)
        feat_proc = self.ln_out(self.linear_out(feat))
        return feat_proc * feat_mask[..., None], feat_mask
