from chroma import Chroma
from chroma.layers.structure.backbone import ProteinBackbone
from chroma.data.protein import Protein

device = "cuda"

chroma = Chroma()

# download protein from the pdb database 
real_protein = Protein('7r5b', device='cuda')
protein = chroma.design(real_protein)
protein.to("7r5b.cif")


X, C, S = real_protein.to_XCS(device='cuda')  # Optional
print(f'C length: {C.size()}')
print(C)


# ca_mask = real_protein.atom_names == "CA"

# X_ca = X[ca_mask]   # shape: (147, 3)

# print(f'X_ca shape: {X_ca.shape}')
# X, C, S = real_protein.to_XCS(device='cuda')  # Optional

# print(f'X shape: {X.shape}')
# protein = chroma.design(real_protein)

# real_protein = Protein('7r5b').to(device)

# protein = chroma.design(real_protein)

# protein.to("1GFP-redesign.cif")