from chroma import Chroma
from chroma.data.protein import Protein
import torch

chroma = Chroma()




# true_protein = Protein.from_CIF("7r5b.cif", canonicalize=True)
true_protein_7pzt = Protein.from_CIF("7pzt.cif", canonicalize=True)

X_t, C_t, S_t = true_protein_7pzt.to_XCS()

def get_distance_matrix(X):
    nodes = X[:, :, 1, :]  # [N, R, 3] - CA atoms
    distance_matrix = torch.linalg.norm(nodes[:, :, None] - nodes[:, None], dim=-1)
    return distance_matrix  # [N, R, R]


D_t = get_distance_matrix(X_t)

thresh = 6.0
n_residues = D_t.shape[1]
n_close = (
    (D_t < thresh).sum().item() - n_residues
) // 2
print(f"Number of distances < {thresh} Å: {n_close}")

# fk_steering_gen   = Protein.from_CIF("log_space_fk.cif", canonicalize=True)

# X_t, C_t, S_t = true_protein.to_XCS(all_atom=False)
# # X_g, C_g, S_g = gen_protein.to_XCS(all_atom=False)
# X_g_better, C_g_better, S_g_better = fk_steering_gen.to_XCS(all_atom=False)


# def masked_backbone(protein: Protein, mask: torch.Tensor) -> Protein:
#     X, C, S = protein.to_XCS(all_atom=False)
#     X = X[:, mask[0]]
#     C = C[:, mask[0]]
#     S = S[:, mask[0]]
#     return Protein.from_XCS(X, C, S)

# mask = (C_t > 0) & (C_g_better > 0)   # shape (1, N)

# true_bb = masked_backbone(true_protein, mask)
# fk_steering_bb  = masked_backbone(fk_steering_gen,  mask)

# # true_bb.to_PDB("true_backbone.pdb")
# # gen_bb.to_PDB("gen_backbone.pdb")
# fk_steering_bb.to_PDB("low_space_fk.pdb")


