
from chroma import Chroma
from chroma.data.protein import Protein
import torch

chroma = Chroma()




true_protein = Protein.from_CIF("7r5b.cif", canonicalize=True)
# gen_protein  = Protein.from_CIF("steered_protein.cif", canonicalize=True)
# better_gen   = Protein.from_CIF("fewer_spt_steps_steered_protein.cif", canonicalize=True)
better_gen = Protein.from_CIF("30_spt_steps_steered_protein.cif", canonicalize=True)

X_t, C_t, S_t = true_protein.to_XCS(all_atom=False)



# X_g, C_g, S_g = gen_protein.to_XCS(all_atom=False)
X_g_better, C_g_better, S_g_better = better_gen.to_XCS(all_atom=False)


def masked_backbone(protein: Protein, mask: torch.Tensor) -> Protein:
    X, C, S = protein.to_XCS(all_atom=False)
    X = X[:, mask[0]]
    C = C[:, mask[0]]
    S = S[:, mask[0]]
    return Protein.from_XCS(X, C, S)

mask = (C_t > 0) & (C_g_better > 0)   # shape (1, N)

true_bb = masked_backbone(true_protein, mask)
# gen_bb  = masked_backbone(gen_protein,  mask)
better_gen_bb  = masked_backbone(better_gen,  mask)

# true_bb.to_PDB("true_backbone.pdb")
# gen_bb.to_PDB("gen_backbone.pdb")
better_gen_bb.to_PDB("30_step_SPT.pdb")


# true_protein.to_PDB("true_backbone.pdb")
# gen_protein.to_PDB("gen_backbone.pdb")
# better_gen.to_PDB("better_gen_backbone.pdb")

