# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from functools import reduce, wraps
from operator import add

import numpy as np
import torch

from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from entity import entity_constants as ec
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
    tree_map,
    tensor_tree_map,
    batched_gather,
)
from utils.funcs import (
    generate_Cbeta,
)


MSA_FEATURE_NAMES = [
    "msa",
    "deletion_matrix",
    "msa_mask",
    "msa_row_mask",
    "bert_mask",
    "true_msa",
]


def cast_to_64bit_ints(protein):
    # We keep all ints as int64
    for k, v in protein.items():
        if v.dtype == torch.int32:
            protein[k] = v.type(torch.int64)

    return protein


def make_one_hot(x, num_classes):
    x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
    x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
    return x_one_hot


def make_seq_mask(protein):
    protein["seq_mask"] = torch.ones(
        protein["token_type"].shape, dtype=torch.float32
    )
    return protein


# def make_template_mask(protein):
#     protein["template_mask"] = torch.ones(
#         protein["template_aatype"].shape[0], dtype=torch.float32
#     )
#     return protein


def curry1(f):
    """Supply all arguments but the first."""
    @wraps(f)
    def fc(*args, **kwargs):
        return lambda x: f(x, *args, **kwargs)

    return fc


# def make_all_atom_aatype(protein):
#     protein["all_atom_aatype"] = protein["aatype"]
#     return protein


# def fix_templates_aatype(protein):
#     # Map one-hot to indices
#     num_templates = protein["template_aatype"].shape[0]
#     if(num_templates > 0):
#         protein["template_aatype"] = torch.argmax(
#             protein["template_aatype"], dim=-1
#         )
#         # Map hhsearch-aatype to our aatype.
#         new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
#         new_order = torch.tensor(
#             new_order_list, dtype=torch.int64, device=protein["aatype"].device,
#         ).expand(num_templates, -1)
#         protein["template_aatype"] = torch.gather(
#             new_order, 1, index=protein["template_aatype"]
#         )

#     return protein


# def correct_msa_restypes(protein):
#     """Correct MSA restype to have the same order as rc."""
#     new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
#     new_order = torch.tensor(
#         [new_order_list] * protein["msa"].shape[1], 
#         device=protein["msa"].device,
#     ).transpose(0, 1)
#     protein["msa"] = torch.gather(new_order, 0, protein["msa"])

#     perm_matrix = np.zeros((22, 22), dtype=np.float32)
#     perm_matrix[range(len(new_order_list)), new_order_list] = 1.0

#     for k in protein:
#         if "profile" in k:
#             num_dim = protein[k].shape.as_list()[-1]
#             assert num_dim in [
#                 20,
#                 21,
#                 22,
#             ], "num_dim for %s out of expected range: %s" % (k, num_dim)
#             protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
    
#     return protein


def squeeze_features(protein):
    """Remove singleton and repeated dimensions in protein features."""
    protein["token_type"] = torch.argmax(protein["target_feat"], dim=-1)
    for k in [
        "domain_name",
        "seq_length",
        "resolution",
        "between_segment_residues",
        "token_index",
    ]:
        if k in protein:
            final_dim = protein[k].shape[-1]
            if isinstance(final_dim, int) and final_dim == 1:
                if torch.is_tensor(protein[k]):
                    protein[k] = torch.squeeze(protein[k], dim=-1)
                else:
                    protein[k] = np.squeeze(protein[k], axis=-1)

    for k in ["seq_length"]:
        if k in protein:
            protein[k] = protein[k][0]

    return protein


# @curry1
# def randomly_replace_msa_with_unknown(protein, replace_proportion):
#     """Replace a portion of the MSA with 'X'."""
#     msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
#     x_idx = 20
#     gap_idx = 21
#     msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
#     protein["msa"] = torch.where(
#         msa_mask,
#         torch.ones_like(protein["msa"]) * x_idx,
#         protein["msa"]
#     )
#     aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion

#     protein["aatype"] = torch.where(
#         aatype_mask,
#         torch.ones_like(protein["aatype"]) * x_idx,
#         protein["aatype"],
#     )
#     return protein


# @curry1
# def sample_msa(protein, max_seq, keep_extra, seed=None):
#     """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
#     num_seq = protein["msa"].shape[0]

#     g = None
#     if seed is not None:
#         g = torch.Generator(device=protein["msa"].device)
#         g.manual_seed(seed)

#     shuffled = torch.randperm(num_seq - 1, generator=g) + 1
#     index_order = torch.cat(
#         (torch.tensor([0], device=shuffled.device), shuffled), 
#         dim=0
#     )
#     num_sel = min(max_seq, num_seq)
#     sel_seq, not_sel_seq = torch.split(
#         index_order, [num_sel, num_seq - num_sel]
#     )

#     for k in MSA_FEATURE_NAMES:
#         if k in protein:
#             if keep_extra:
#                 protein["extra_" + k] = torch.index_select(
#                     protein[k], 0, not_sel_seq
#                 )
#             protein[k] = torch.index_select(protein[k], 0, sel_seq)

#     return protein


# @curry1
# def add_distillation_flag(protein, distillation):
#     protein['is_distillation'] = distillation
#     return protein

# @curry1
# def sample_msa_distillation(protein, max_seq):
#     if(protein["is_distillation"] == 1):
#         protein = sample_msa(max_seq, keep_extra=False)(protein)
#     return protein


# @curry1
# def crop_extra_msa(protein, max_extra_msa):
#     num_seq = protein["extra_msa"].shape[0]
#     num_sel = min(max_extra_msa, num_seq)
#     select_indices = torch.randperm(num_seq)[:num_sel]
#     for k in MSA_FEATURE_NAMES:
#         if "extra_" + k in protein:
#             protein["extra_" + k] = torch.index_select(
#                 protein["extra_" + k], 0, select_indices
#             )
    
#     return protein


# def delete_extra_msa(protein):
#     for k in MSA_FEATURE_NAMES:
#         if "extra_" + k in protein:
#             del protein["extra_" + k]
#     return protein


# # Not used in inference
# @curry1
# def block_delete_msa(protein, config):
#     num_seq = protein["msa"].shape[0]
#     block_num_seq = torch.floor(
#         torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
#         * config.msa_fraction_per_block
#     ).to(torch.int32)

#     if int(block_num_seq) == 0:
#         return protein

#     if config.randomize_num_blocks:
#         nb = int(torch.randint(
#             low=0,
#             high=config.num_blocks + 1,
#             size=(1,),
#             device=protein["msa"].device,
#         )[0])
#     else:
#         nb = config.num_blocks

#     del_block_starts = torch.randint(low=1, high=num_seq, size=(nb,), device=protein["msa"].device)
#     del_blocks = del_block_starts[:, None] + torch.arange(start=0, end=block_num_seq)
#     del_blocks = torch.clip(del_blocks, 1, num_seq - 1)
#     del_indices = torch.unique(torch.reshape(del_blocks, [-1]))

#     # Make sure we keep the original sequence
#     combined = torch.cat((torch.arange(start=0, end=num_seq), del_indices)).long()
#     uniques, counts = combined.unique(return_counts=True)
#     keep_indices = uniques[counts == 1]

#     assert int(keep_indices[0]) == 0
#     for k in MSA_FEATURE_NAMES:
#         if k in protein:
#             protein[k] = torch.index_select(protein[k], 0, keep_indices)

#     return protein


# @curry1
# def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
#     weights = torch.cat(
#         [
#             torch.ones(21, device=protein["msa"].device), 
#             gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
#             torch.zeros(1, device=protein["msa"].device)
#         ],
#         0,
#     )

#     # Make agreement score as weighted Hamming distance
#     msa_one_hot = make_one_hot(protein["msa"], 23)
#     sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
#     extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
#     extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot

#     num_seq, num_res, _ = sample_one_hot.shape
#     extra_num_seq, _, _ = extra_one_hot.shape

#     # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
#     # in an optimized fashion to avoid possible memory or computation blowup.
#     agreement = torch.matmul(
#         torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
#         torch.reshape(
#             sample_one_hot * weights, [num_seq, num_res * 23]
#         ).transpose(0, 1),
#     )

#     # Assign each sequence in the extra sequences to the closest MSA sample
#     protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
#         torch.int64
#     )
    
#     return protein


def unsorted_segment_sum(data, segment_ids, num_segments):
    """
    Computes the sum along segments of a tensor. Similar to 
    tf.unsorted_segment_sum, but only supports 1-D indices.

    :param data: A tensor whose segments are to be summed.
    :param segment_ids: The 1-D segment indices tensor.
    :param num_segments: The number of segments.
    :return: A tensor of same data type as the data argument.
    """
    assert (
        len(segment_ids.shape) == 1 and
        segment_ids.shape[0] == data.shape[0]
    )
    segment_ids = segment_ids.view(
        segment_ids.shape[0], *((1,) * len(data.shape[1:]))
    )
    segment_ids = segment_ids.expand(data.shape)
    shape = [num_segments] + list(data.shape[1:])
    tensor = (
        torch.zeros(*shape, device=segment_ids.device)
        .scatter_add_(0, segment_ids, data.float())
    )
    tensor = tensor.type(data.dtype)
    return tensor


# @curry1
# def summarize_clusters(protein):
#     """Produce profile and deletion_matrix_mean within each cluster."""
#     num_seq = protein["msa"].shape[0]

#     def csum(x):
#         return unsorted_segment_sum(
#             x, protein["extra_cluster_assignment"], num_seq
#         )

#     mask = protein["extra_msa_mask"]
#     mask_counts = 1e-6 + protein["msa_mask"] + csum(mask)  # Include center

#     msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
#     msa_sum += make_one_hot(protein["msa"], 23)  # Original sequence
#     protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
#     del msa_sum

#     del_sum = csum(mask * protein["extra_deletion_matrix"])
#     del_sum += protein["deletion_matrix"]  # Original sequence
#     protein["cluster_deletion_mean"] = del_sum / mask_counts
#     del del_sum
    
#     return protein


# def make_msa_mask(protein):
#     """Mask features are all ones, but will later be zero-padded."""
#     protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
#     protein["msa_row_mask"] = torch.ones(
#         (protein["msa"].shape[0]), dtype=torch.float32
#     )
#     return protein


def pseudo_beta_fn(token_type, entity_type, all_atom_positions, all_atom_mask):
    """Create pseudo beta features."""
    is_gly = torch.eq(token_type, ec.token_type_order["GLY"])
    is_mol = torch.eq(entity_type.argmax(dim=-1), ec.entity_type_order["molecule"])

    ca_idx = ec.atom_order["CA"]
    # c_idx = ec.atom_order["C"]
    # n_idx = ec.atom_order["N"]
    cb_idx = ec.atom_order["CB"]
    mol_idx = ec.atom_order["*MolAtom"]
    pseudo_beta = torch.where(
        torch.tile(is_mol[..., None], [1] * len(is_mol.shape) + [3]),
        all_atom_positions[..., mol_idx, :],
        all_atom_positions[..., cb_idx, :]
        # generate_Cbeta(
        #     all_atom_positions[..., n_idx, :],
        #     all_atom_positions[..., ca_idx, :],
        #     all_atom_positions[..., c_idx, :],
        # )
    )
    # For GLY, we use ca position
    pseudo_beta = torch.where(
        torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
        all_atom_positions[..., ca_idx, :],
        pseudo_beta
    )

    if all_atom_mask is not None:
        pseudo_beta_mask = torch.where(
            is_mol,
            all_atom_mask[..., mol_idx],
            # all_atom_mask[..., ca_idx] * all_atom_mask[..., c_idx] * all_atom_mask[..., n_idx]
            torch.where(
                is_gly,
                all_atom_mask[..., ca_idx],
                all_atom_mask[..., cb_idx]
            )
        )
        return pseudo_beta, pseudo_beta_mask
    else:
        return pseudo_beta


@curry1
def make_pseudo_beta(protein, prefix=""):
    """Create pseudo-beta (alpha for glycine) position and mask."""
    assert prefix in [""]
    (
        protein[prefix + "pseudo_beta"],
        protein[prefix + "pseudo_beta_mask"],
    ) = pseudo_beta_fn(
        protein[prefix + "token_type"],
        protein[prefix + "entity_type"],
        protein[prefix + "all_atom_positions"],
        protein[prefix + "all_atom_mask"],
    )
    return protein


@curry1
def add_constant_field(protein, key, value):
    protein[key] = torch.tensor(value, device=protein["msa"].device)
    return protein


def shaped_categorical(probs, epsilon=1e-10):
    ds = probs.shape
    num_classes = ds[-1]
    distribution = torch.distributions.categorical.Categorical(
        torch.reshape(probs + epsilon, [-1, num_classes])
    )
    counts = distribution.sample()
    return torch.reshape(counts, ds[:-1])


def make_hhblits_profile(protein):
    """Compute the HHblits MSA profile if not already present."""
    if "hhblits_profile" in protein:
        return protein

    # Compute the profile for every residue (over all MSA sequences).
    msa_one_hot = make_one_hot(protein["msa"], 22)

    protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
    return protein


@curry1
def make_masked_msa(protein, config, replace_fraction):
    """Create data for BERT on raw MSA."""
    # Add a random amino acid uniformly.
    random_aa = torch.tensor(
        [0.05] * 20 + [0.0, 0.0], 
        dtype=torch.float32, 
        device=protein["aatype"].device
    )

    categorical_probs = (
        config.uniform_prob * random_aa
        + config.profile_prob * protein["hhblits_profile"]
        + config.same_prob * make_one_hot(protein["msa"], 22)
    )

    # Put all remaining probability on [MASK] which is a new column
    pad_shapes = list(
        reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
    )
    pad_shapes[1] = 1
    mask_prob = (
        1.0 - config.profile_prob - config.same_prob - config.uniform_prob
    )
    assert mask_prob >= 0.0

    categorical_probs = torch.nn.functional.pad(
        categorical_probs, pad_shapes, value=mask_prob
    )

    sh = protein["msa"].shape
    mask_position = torch.rand(sh) < replace_fraction

    bert_msa = shaped_categorical(categorical_probs)
    bert_msa = torch.where(mask_position, bert_msa, protein["msa"])

    # Mix real and masked MSA
    protein["bert_mask"] = mask_position.to(torch.float32)
    protein["true_msa"] = protein["msa"]
    protein["msa"] = bert_msa

    return protein


@curry1
def make_fixed_size(
    protein,
    shape_schema,
    msa_cluster_size=0,
    extra_msa_size=0,
    num_res=0,
    num_templates=0,
):
    """Guess at the MSA and sequence dimension to make fixed size."""
    pad_size_map = {
        NUM_RES: num_res,
        NUM_MSA_SEQ: msa_cluster_size,
        NUM_EXTRA_SEQ: extra_msa_size,
        NUM_TEMPLATES: num_templates,
    }

    for k, v in protein.items():
        # Don't transfer this to the accelerator.
        if k == "extra_cluster_assignment":
            continue
        shape = list(v.shape)
        schema = shape_schema[k]
        msg = "Rank mismatch between shape and shape schema for"
        assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
        pad_size = [
            pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
        ]

        padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
        padding.reverse()
        padding = list(itertools.chain(*padding))
        if padding:
            protein[k] = torch.nn.functional.pad(v, padding)
            protein[k] = torch.reshape(protein[k], pad_size)
    
    return protein


@curry1
def make_msa_feat(protein):
    """Create and concatenate MSA features."""
    # Whether there is a domain break. Always zero for chains, but keeping for
    # compatibility with domain datasets.
    has_break = torch.clip(
        protein["between_segment_residues"].to(torch.float32), 0, 1
    )
    aatype_1hot = make_one_hot(protein["aatype"], 21)

    target_feat = [
        torch.unsqueeze(has_break, dim=-1),
        aatype_1hot,  # Everyone gets the original sequence.
    ]

    msa_1hot = make_one_hot(protein["msa"], 23)
    has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
    deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
        2.0 / np.pi
    )

    msa_feat = [
        msa_1hot,
        torch.unsqueeze(has_deletion, dim=-1),
        torch.unsqueeze(deletion_value, dim=-1),
    ]

    if "cluster_profile" in protein:
        deletion_mean_value = torch.atan(
            protein["cluster_deletion_mean"] / 3.0
        ) * (2.0 / np.pi)
        msa_feat.extend(
            [
                protein["cluster_profile"],
                torch.unsqueeze(deletion_mean_value, dim=-1),
            ]
        )

    if "extra_deletion_matrix" in protein:
        protein["extra_has_deletion"] = torch.clip(
            protein["extra_deletion_matrix"], 0.0, 1.0
        )
        protein["extra_deletion_value"] = torch.atan(
            protein["extra_deletion_matrix"] / 3.0
        ) * (2.0 / np.pi)

    protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
    protein["target_feat"] = torch.cat(target_feat, dim=-1)
    return protein


@curry1
def select_feat(protein, feature_list):
    return {k: v for k, v in protein.items() if k in feature_list}


@curry1
def crop_templates(protein, max_templates):
    for k, v in protein.items():
        if k.startswith("template_"):
            protein[k] = v[:max_templates]
    return protein


def make_atom14_masks(complex):
    """Construct denser atom positions (14 dimensions instead of full)."""
    token_type_atom14_to_atomFull = []
    token_type_atomFull_to_atom14 = []
    token_type_atom14_mask = []

    for tok_type in ec.token_types:
        atom_names = ec.toktype_to_atom14_names[tok_type]
        # here we use the feature of `"" == False` in python
        token_type_atom14_to_atomFull.append(
            [(ec.atom_order[name] if name else 0) for name in atom_names]
        )
        atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
        token_type_atomFull_to_atom14.append(
            [
                (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
                for name in ec.atom_types
            ]
        )
        token_type_atom14_mask.append(
            [(1.0 if name else 0.0) for name in atom_names]
        )

    token_type_atom14_to_atomFull = torch.tensor(
        token_type_atom14_to_atomFull,
        dtype=torch.int32,
        device=complex["token_type"].device,
    )
    token_type_atomFull_to_atom14 = torch.tensor(
        token_type_atomFull_to_atom14,
        dtype=torch.int32,
        device=complex["token_type"].device,
    )
    token_type_atom14_mask = torch.tensor(
        token_type_atom14_mask,
        dtype=torch.float32,
        device=complex["token_type"].device,
    )
    token_type = complex['token_type'].to(torch.long)

    # create the mapping for (residx, atom14) --> atom37, i.e. an array
    # with shape (num_res, 14) containing the atom37 indices for this protein
    tokidx_atom14_to_atomFull = token_type_atom14_to_atomFull[token_type]
    tokidx_atom14_mask = token_type_atom14_mask[token_type]

    complex["atom14_atom_exists"] = tokidx_atom14_mask
    complex["residx_atom14_to_atomFull"] = tokidx_atom14_to_atomFull.long()

    # create the gather indices for mapping back
    tokidx_atomFull_to_atom14 = token_type_atomFull_to_atom14[token_type]
    complex["residx_atomFull_to_atom14"] = tokidx_atomFull_to_atom14.long()

    # create the corresponding mask
    toktype_atomFull_mask = torch.zeros(
        [ec.token_type_num, ec.atom_type_num],
        dtype=torch.float32, device=complex["token_type"].device
    )
    for tok_type_idx, tok_type in enumerate(ec.token_types):
        atom_names = ec.token_atoms[tok_type]
        for atom_name in atom_names:
            atom_type = ec.atom_order[atom_name]
            toktype_atomFull_mask[tok_type_idx, atom_type] = 1

    tokidx_atomFull_mask = toktype_atomFull_mask[token_type]
    complex["atomFull_atom_exists"] = tokidx_atomFull_mask

    return complex


def make_atom14_masks_np(batch):
    batch = tree_map(
        lambda n: torch.tensor(n, device="cpu"), 
        batch, 
        np.ndarray
    )
    out = make_atom14_masks(batch)
    out = tensor_tree_map(lambda t: np.array(t), out)
    return out

# this function is designed for ligand-protein setting exclusively
def centering_positions(complex):
    molAtom_idx = ec.atom_order["*MolAtom"]
    lig_mask = complex["all_atom_mask"][:, molAtom_idx]
    lig_pos = complex["all_atom_positions"][:, molAtom_idx]
    lig_center = lig_pos.sum(axis=0) / (lig_mask.sum() + 1e-5)
    complex["all_atom_positions"] -= lig_center[None, None, :]
    complex["all_atom_positions"] = complex["all_atom_positions"] * complex["all_atom_mask"][... , None]
    return complex


def make_atom14_positions(complex):
    """Constructs denser atom positions (14 dimensions instead of full)."""
    residx_atom14_mask = complex["atom14_atom_exists"]
    residx_atom14_to_atomFull = complex["residx_atom14_to_atomFull"]

    # Create a mask for known ground truth positions.
    residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
        complex["all_atom_mask"],
        residx_atom14_to_atomFull,
        dim=-1,
        no_batch_dims=len(complex["all_atom_mask"].shape[:-1]),
    )

    # Gather the ground truth positions.
    residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
        batched_gather(
            complex["all_atom_positions"],
            residx_atom14_to_atomFull,
            dim=-2,
            no_batch_dims=len(complex["all_atom_positions"].shape[:-2]),
        )
    )

    complex["atom14_atom_exists"] = residx_atom14_mask
    complex["atom14_gt_exists"] = residx_atom14_gt_mask
    complex["atom14_gt_positions"] = residx_atom14_gt_positions


    # Matrices for renaming ambiguous atoms.
    all_matrices = {
        token_type: torch.eye(
            14,
            dtype=complex["all_atom_mask"].dtype,
            device=complex["all_atom_mask"].device,
        )
        for token_type in ec.token_types
    }
    for toktype, swap in ec.token_atom_renaming_swaps.items():
        correspondences = torch.arange(
            14, device=complex["all_atom_mask"].device
        )
        for source_atom_swap, target_atom_swap in swap.items():
            source_index = ec.toktype_to_atom14_names[toktype].index(
                source_atom_swap
            )
            target_index = ec.toktype_to_atom14_names[toktype].index(
                target_atom_swap
            )
            correspondences[source_index] = target_index
            correspondences[target_index] = source_index
            renaming_matrix = complex["all_atom_mask"].new_zeros((14, 14))
            for index, correspondence in enumerate(correspondences):
                renaming_matrix[index, correspondence] = 1.0
        all_matrices[toktype] = renaming_matrix
    
    renaming_matrices = torch.stack(
        [all_matrices[toktype] for toktype in ec.token_types]
    )

    # Pick the transformation matrices for the given residue sequence
    # shape (num_res, 14, 14).
    renaming_transform = renaming_matrices[complex["token_type"]]

    # Apply it to the ground truth positions. shape (num_res, 14, 3).
    alternative_gt_positions = torch.einsum(
        "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
    )
    complex["atom14_alt_gt_positions"] = alternative_gt_positions

    # Create the mask for the alternative ground truth (differs from the
    # ground truth mask, if only one of the atoms in an ambiguous pair has a
    # ground truth position).
    alternative_gt_mask = torch.einsum(
        "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
    )
    complex["atom14_alt_gt_exists"] = alternative_gt_mask

    # Create an ambiguous atoms mask.  shape: (21, 14).
    toktype_atom14_is_ambiguous = complex["all_atom_mask"].new_zeros((ec.token_type_num, 14))
    for toktype, swap in ec.token_atom_renaming_swaps.items():
        for atom_name1, atom_name2 in swap.items():
            restype = ec.token_type_order[toktype]
            atom_idx1 = ec.toktype_to_atom14_names[toktype].index(
                atom_name1
            )
            atom_idx2 = ec.toktype_to_atom14_names[toktype].index(
                atom_name2
            )
            toktype_atom14_is_ambiguous[restype, atom_idx1] = 1
            toktype_atom14_is_ambiguous[restype, atom_idx2] = 1

    # From this create an ambiguous_mask for the given sequence.
    complex["atom14_atom_is_ambiguous"] = toktype_atom14_is_ambiguous[
        complex["token_type"]
    ]

    return complex


def rigid_from_3_points(
    p_neg_x_axis: torch.Tensor, 
    origin: torch.Tensor, 
    p_xy_plane: torch.Tensor, 
    single_atom_mask: torch.Tensor,
    eps: float = 1e-8
) -> Rigid:
    """
        Implements algorithm 21. Constructs transformations from sets of 3 
        points using the Gram-Schmidt algorithm.

        Args:
            p_neg_x_axis: [*, 3] coordinates
            origin: [*, 3] coordinates used as frame origins
            p_xy_plane: [*, 3] coordinates
            eps: Small epsilon value
        Returns:
            A transformation object of shape [*]
    """
    p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
    origin = torch.unbind(origin, dim=-1)
    p_xy_plane = torch.unbind(p_xy_plane, dim=-1)

    e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
    e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]

    denom = torch.sqrt(sum((c * c for c in e0)) + eps)
    e0 = [c / denom for c in e0]
    dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
    e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
    denom = torch.sqrt(sum((c * c for c in e1)) + eps)
    e1 = [c / denom for c in e1]
    e2 = [
        e0[1] * e1[2] - e0[2] * e1[1],
        e0[2] * e1[0] - e0[0] * e1[2],
        e0[0] * e1[1] - e0[1] * e1[0],
    ]

    rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
    rots = rots.reshape(rots.shape[:-1] + (3, 3))

    # here x and z axis for backbone will be flipped afterward
    identity_rot = torch.tensor([[-1., 0., 0.],
                                 [0., 1., 0.],
                                 [0., 0., -1.]])

    rots[single_atom_mask, :, :] = identity_rot

    rot_obj = Rotation(rot_mats=rots, quats=None)

    return Rigid(rot_obj, torch.stack(origin, dim=-1))


def atomFull_to_frames(complex, eps=1e-8):
    token_type_seq = complex["token_type"]
    all_atom_positions = complex["all_atom_positions"]
    all_atom_mask = complex["all_atom_mask"]

    batch_dims = len(token_type_seq.shape[:-1])

    toktype_rigidgroup_base_atom_names = np.full([ec.token_type_num, 8, 3], "", dtype=object)
    toktype_singlegroup_mask = np.zeros([ec.token_type_num, 8], dtype=np.bool_)
    
    for entity_type, default_groups in ec.entity_default_rigid_groups.items():
        is_cur = torch.eq(torch.tensor(ec.token_entity_type), ec.entity_type_order[entity_type])
        for rg_idx, rg_atoms in default_groups.items():
            toktype_rigidgroup_base_atom_names[is_cur, rg_idx, :] = rg_atoms
        if len(default_groups) == 1:
            rg_idx, rg_atoms = list(default_groups.items())[0]
            toktype_singlegroup_mask[is_cur, rg_idx] = 1

    for token_type_idx, token_type in enumerate(ec.token_types):
        for chi_idx in range(4):
            if ec.chi_angles_mask[token_type_idx][chi_idx]:
                names = ec.chi_angles_atoms[token_type][chi_idx]
                toktype_rigidgroup_base_atom_names[
                    token_type_idx, chi_idx + 4, :
                ] = names[1:]

    toktype_rigidgroup_mask = all_atom_mask.new_zeros(
        (*token_type_seq.shape[:-1], ec.token_type_num, 8),
    )

    for entity_type, default_groups in ec.entity_default_rigid_groups.items():
        is_cur = torch.eq(torch.tensor(ec.token_entity_type), ec.entity_type_order[entity_type])
        for rg_idx in default_groups:
            toktype_rigidgroup_mask[is_cur, rg_idx] = 1

    toktype_rigidgroup_mask[..., 4:] = all_atom_mask.new_tensor(
        ec.chi_angles_mask
    )

    lookuptable = ec.atom_order.copy()
    lookuptable[""] = 0
    lookup = np.vectorize(lambda x: lookuptable[x])
    toktype_rigidgroup_base_atomFull_idx = lookup(
        toktype_rigidgroup_base_atom_names,
    )
    toktype_rigidgroup_base_atomFull_idx = token_type_seq.new_tensor(
        toktype_rigidgroup_base_atomFull_idx,
    )
    toktype_rigidgroup_base_atomFull_idx = (
        toktype_rigidgroup_base_atomFull_idx.view(
            *((1,) * batch_dims), *toktype_rigidgroup_base_atomFull_idx.shape
        )
    )

    tokidx_rigidgroup_base_atomFull_idx = batched_gather(
        toktype_rigidgroup_base_atomFull_idx,
        token_type_seq,
        dim=-3,
        no_batch_dims=batch_dims,
    )

    tokidx_singlegroup_mask = batched_gather(
        toktype_singlegroup_mask,
        token_type_seq,
        dim=-2,
        no_batch_dims=batch_dims,
    )

    base_atom_pos = batched_gather(
        all_atom_positions,
        tokidx_rigidgroup_base_atomFull_idx,
        dim=-2,
        no_batch_dims=len(all_atom_positions.shape[:-2]),
    )

    gt_frames = rigid_from_3_points(
        p_neg_x_axis=base_atom_pos[..., 0, :],
        origin=base_atom_pos[..., 1, :],
        p_xy_plane=base_atom_pos[..., 2, :],
        single_atom_mask=tokidx_singlegroup_mask,
        eps=eps,
    )

    group_exists = batched_gather(
        toktype_rigidgroup_mask,
        token_type_seq,
        dim=-2,
        no_batch_dims=batch_dims,
    )

    gt_atoms_exist = batched_gather(
        all_atom_mask,
        tokidx_rigidgroup_base_atomFull_idx,
        dim=-1,
        no_batch_dims=len(all_atom_mask.shape[:-1]),
    )
    gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists

    rots = torch.eye(3, dtype=all_atom_mask.dtype, device=token_type_seq.device)
    rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
    rots[..., 0, 0, 0] = -1
    rots[..., 0, 2, 2] = -1
    rots = Rotation(rot_mats=rots)

    gt_frames = gt_frames.compose(Rigid(rots, None))

    toktype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
        *((1,) * batch_dims), ec.token_type_num, 8
    )
    toktype_rigidgroup_rots = torch.eye(
        3, dtype=all_atom_mask.dtype, device=token_type_seq.device
    )
    toktype_rigidgroup_rots = torch.tile(
        toktype_rigidgroup_rots,
        (*((1,) * batch_dims), ec.token_type_num, 8, 1, 1),
    )

    for token_type in ec.token_atom_renaming_swaps:
        token_type_idx = ec.token_type_order[token_type]
        chi_idx = int(sum(ec.chi_angles_mask[token_type_idx]) - 1)
        toktype_rigidgroup_is_ambiguous[..., token_type_idx, chi_idx + 4] = 1
        toktype_rigidgroup_rots[..., token_type_idx, chi_idx + 4, 1, 1] = -1
        toktype_rigidgroup_rots[..., token_type_idx, chi_idx + 4, 2, 2] = -1

    tokidx_rigidgroup_is_ambiguous = batched_gather(
        toktype_rigidgroup_is_ambiguous,
        token_type_seq,
        dim=-2,
        no_batch_dims=batch_dims,
    )

    tokidx_rigidgroup_ambiguity_rot = batched_gather(
        toktype_rigidgroup_rots,
        token_type_seq,
        dim=-4,
        no_batch_dims=batch_dims,
    )

    tokidx_rigidgroup_ambiguity_rot = Rotation(
        rot_mats=tokidx_rigidgroup_ambiguity_rot
    )
    alt_gt_frames = gt_frames.compose(
        Rigid(tokidx_rigidgroup_ambiguity_rot, None)
    )

    gt_frames_tensor = gt_frames.to_tensor_4x4()
    alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()

    complex["rigidgroups_gt_frames"] = gt_frames_tensor
    complex["rigidgroups_gt_exists"] = gt_exists
    complex["rigidgroups_group_exists"] = group_exists # not used in model
    complex["rigidgroups_group_is_ambiguous"] = tokidx_rigidgroup_is_ambiguous
    complex["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor

    return complex


def get_chi_atom_indices():
    """Returns atom indices needed to compute chi angles for all residue types.

    Returns:
      A tensor of shape [residue_types=ec.token_type_num, chis=4, atoms=4]. The residue types are
      in the order specified in ec.token_type
      at the end. For chi angles which are not defined on the token, the
      positions indices are by default set to 0.
    """
    chi_atom_indices = []
    for token_type in ec.token_types:
        residue_chi_angles = ec.chi_angles_atoms[token_type]
        atom_indices = []
        for chi_angle in residue_chi_angles:
            atom_indices.append([ec.atom_order[atom] for atom in chi_angle])
        for _ in range(4 - len(atom_indices)):
            atom_indices.append(
                [0, 0, 0, 0]
            )  # For chi angles not defined on the AA.
        chi_atom_indices.append(atom_indices)

    return chi_atom_indices


@curry1
def atomFull_to_torsion_angles(
    complex,
    prefix="",
):
    """
    Convert coordinates to torsion angles.

    This function is extremely sensitive to floating point imprecisions
    and should be run with double precision whenever possible.

    Args:
        Dict containing:
            * (prefix)aatype:
                [*, N_res] residue indices
            * (prefix)all_atom_positions:
                [*, N_res, ec.atom_type_num, 3] atom positions (in atomFull
                format)
            * (prefix)all_atom_mask:
                [*, N_res, ec.atom_type_num] atom position mask
    Returns:
        The same dictionary updated with the following features:

        "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Torsion angles
        "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Alternate torsion angles (accounting for 180-degree symmetry)
        "(prefix)torsion_angles_mask" ([*, N_res, 7])
            Torsion angles mask
    """
    token_type = complex[prefix + "token_type"]
    all_atom_positions = complex[prefix + "all_atom_positions"]
    all_atom_mask = complex[prefix + "all_atom_mask"]

    # token_type = torch.clamp(token_type, max=20)

    pad = all_atom_positions.new_zeros(
        [*all_atom_positions.shape[:-3], 1, ec.atom_type_num, 3]
    )

    # addedd a blank line before all_atom_positions
    prev_all_atom_positions = torch.cat(
        [pad, all_atom_positions[..., :-1, :, :]], dim=-3
    )

    pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, ec.atom_type_num])

    # addedd a blank line before all_atom_mask
    prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)

    pre_omega_atom_pos = torch.cat(
        [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
        dim=-2,
    )
    phi_atom_pos = torch.cat(
        [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
        dim=-2,
    )
    psi_atom_pos = torch.cat(
        [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
        dim=-2,
    )

    pre_omega_mask = torch.prod(
        prev_all_atom_mask[..., 1:3], dim=-1
    ) * torch.prod(all_atom_mask[..., :2], dim=-1)
    phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
        all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
    )
    psi_mask = (
        torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
        * all_atom_mask[..., 4]
    )

    chi_atom_indices = torch.as_tensor(
        get_chi_atom_indices(), device=token_type.device
    )

    atom_indices = chi_atom_indices[..., token_type, :, :]
    chis_atom_pos = batched_gather(
        all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
    )

    chi_angles_mask = list(ec.chi_angles_mask)
    chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)

    chis_mask = chi_angles_mask[token_type, :]

    chi_angle_atoms_mask = batched_gather(
        all_atom_mask,
        atom_indices,
        dim=-1,
        no_batch_dims=len(atom_indices.shape[:-2]),
    )
    chi_angle_atoms_mask = torch.prod(
        chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
    )
    chis_mask = chis_mask * chi_angle_atoms_mask

    torsions_atom_pos = torch.cat(
        [
            pre_omega_atom_pos[..., None, :, :],
            phi_atom_pos[..., None, :, :],
            psi_atom_pos[..., None, :, :],
            chis_atom_pos,
        ],
        dim=-3,
    )

    torsion_angles_mask = torch.cat(
        [
            pre_omega_mask[..., None],
            phi_mask[..., None],
            psi_mask[..., None],
            chis_mask,
        ],
        dim=-1,
    )

    torsion_frames = Rigid.from_3_points(
        torsions_atom_pos[..., 1, :],
        torsions_atom_pos[..., 2, :],
        torsions_atom_pos[..., 0, :],
        eps=1e-8,
    )

    fourth_atom_rel_pos = torsion_frames.invert().apply(
        torsions_atom_pos[..., 3, :]
    )

    torsion_angles_sin_cos = torch.stack(
        [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
    )

    denom = torch.sqrt(
        torch.sum(
            torch.square(torsion_angles_sin_cos),
            dim=-1,
            dtype=torsion_angles_sin_cos.dtype,
            keepdims=True,
        )
        + 1e-8
    )
    torsion_angles_sin_cos = torsion_angles_sin_cos / denom

    torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
        [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
    )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]

    chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
        ec.chi_pi_periodic,
    )[token_type, ...]

    mirror_torsion_angles = torch.cat(
        [
            all_atom_mask.new_ones(*token_type.shape, 3),
            1.0 - 2.0 * chi_is_ambiguous,
        ],
        dim=-1,
    )

    alt_torsion_angles_sin_cos = (
        torsion_angles_sin_cos * mirror_torsion_angles[..., None]
    )

    complex[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
    complex[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
    complex[prefix + "torsion_angles_mask"] = torsion_angles_mask

    return complex


def get_backbone_frames(complex):
    # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
    complex["backbone_rigid_tensor"] = complex["rigidgroups_gt_frames"][
        ..., 0, :, :
    ]
    complex["backbone_rigid_mask"] = complex["rigidgroups_gt_exists"][..., 0]

    return complex


# def get_chi_angles(complex):
#     dtype = complex["all_atom_mask"].dtype
#     complex["chi_angles_sin_cos"] = (
#         complex["torsion_angles_sin_cos"][..., 3:, :]
#     ).to(dtype)
#     complex["chi_mask"] = complex["torsion_angles_mask"][..., 3:].to(dtype)

#     return complex


@curry1
def random_crop_to_size(
    protein,
    crop_size,
    max_templates,
    shape_schema,
    subsample_templates=False,
    seed=None,
):
    """Crop randomly to `crop_size`, or keep as is if shorter than that."""
    # We want each ensemble to be cropped the same way

    g = None
    if seed is not None:
        g = torch.Generator(device=protein["seq_length"].device)
        g.manual_seed(seed)

    seq_length = protein["seq_length"]

    if "template_mask" in protein:
        num_templates = protein["template_mask"].shape[-1]
    else:
        num_templates = 0

    # No need to subsample templates if there aren't any
    subsample_templates = subsample_templates and num_templates

    num_res_crop_size = min(int(seq_length), crop_size)

    def _randint(lower, upper):
        return int(torch.randint(
                lower,
                upper + 1,
                (1,),
                device=protein["seq_length"].device,
                generator=g,
        )[0])

    if subsample_templates:
        templates_crop_start = _randint(0, num_templates)
        templates_select_indices = torch.randperm(
            num_templates, device=protein["seq_length"].device, generator=g
        )
    else:
        templates_crop_start = 0

    num_templates_crop_size = min(
        num_templates - templates_crop_start, max_templates
    )

    n = seq_length - num_res_crop_size
    if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
        right_anchor = n
    else:
        x = _randint(0, n)
        right_anchor = n - x

    num_res_crop_start = _randint(0, right_anchor)

    for k, v in protein.items():
        if k not in shape_schema or (
            "template" not in k and NUM_RES not in shape_schema[k]
        ):
            continue

        # randomly permute the templates before cropping them.
        if k.startswith("template") and subsample_templates:
            v = v[templates_select_indices]

        slices = []
        for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
            is_num_res = dim_size == NUM_RES
            if i == 0 and k.startswith("template"):
                crop_size = num_templates_crop_size
                crop_start = templates_crop_start
            else:
                crop_start = num_res_crop_start if is_num_res else 0
                crop_size = num_res_crop_size if is_num_res else dim
            slices.append(slice(crop_start, crop_start + crop_size))
        protein[k] = v[slices]

    protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
    
    return protein

CHAIN_FEATS = [
    'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors'
]
UNPADDED_FEATS = [
    't', 'rot_score_scaling', 'trans_score_scaling', 't_seq', 't_struct', 'seq_length'
]
RIGID_FEATS = [
    'rigids_0', 'init_frames'
]
PAIR_FEATS = [
    'rel_rots', 'pair_feat'
]

def pad_feats(raw_feats, max_len, use_torch=False):
    padded_feats = {
        feat_name: pad(feat, max_len, use_torch=use_torch)
        for feat_name, feat in raw_feats.items()
        if feat_name not in UNPADDED_FEATS + RIGID_FEATS
    }
    for feat_name in PAIR_FEATS:
        if feat_name in padded_feats:
            padded_feats[feat_name] = pad(padded_feats[feat_name], max_len, pad_idx=1)
    for feat_name in UNPADDED_FEATS:
        if feat_name in raw_feats:
            padded_feats[feat_name] = raw_feats[feat_name]
    for feat_name in RIGID_FEATS:
        if feat_name in raw_feats:
            padded_feats[feat_name] = pad_rigid(raw_feats[feat_name], max_len)
    return padded_feats

def pad_rigid(rigid: torch.tensor, max_len: int):
    num_rigids = rigid.shape[0]
    pad_amt = max_len - num_rigids
    pad_rigid = Rigid.identity(
        (pad_amt,), dtype=rigid.dtype, device=rigid.device, requires_grad=False)
    return torch.cat([rigid, pad_rigid.to_tensor_7()], dim=0)

def pad(x: np.ndarray, max_len: int, pad_idx=0, use_torch=False, reverse=False):
    """Right pads dimension of numpy array.

    Args:
        x: numpy like array to pad.
        max_len: desired length after padding
        pad_idx: dimension to pad.
        use_torch: use torch padding method instead of numpy.

    Returns:
        x with its pad_idx dimension padded to max_len
    """
    # Pad only the residue dimension.
    seq_len = x.shape[pad_idx]
    pad_amt = max_len - seq_len
    pad_widths = [(0, 0)] * x.ndim
    if pad_amt < 0:
        raise ValueError(f'Invalid pad amount {pad_amt}')
    if reverse:
        pad_widths[pad_idx] = (pad_amt, 0)
    else:
        pad_widths[pad_idx] = (0, pad_amt)
    if use_torch:
        return torch.pad(x, pad_widths)
    return np.pad(x, pad_widths)
