from chroma import Chroma
from chroma.layers.structure.backbone import ProteinBackbone
from chroma.data.protein import Protein
# Mean_Flow_Steering/steering_chroma/chroma/chroma/data/protein.py
import torch 
# device = 
chroma = Chroma()





# lengths = [100, 80]
device = "cuda"


B = 1                      # number of proteins
lengths = [147]             # single chain
num_residues = sum(lengths)

X = torch.randn(B, num_residues, 4, 3, device=device) 

# chain mapping (shared across batch)
C_single = torch.cat(
    [torch.full((l,), i, device=device) for i, l in enumerate(lengths)]
)
C = C_single.unsqueeze(0).repeat(B, 1)

S = torch.zeros_like(C)

protein_init = Protein.from_XCS(X, C, S)


# num_residues = sum(lengths)
# X = torch.randn(1, num_residues, 4, 3, device=device) 

# # Create chain mapping
# C = torch.cat([torch.full((l,), i, device=device) for i, l in enumerate(lengths)]).unsqueeze(0)

# # Create protein
# protein_init = Protein.from_XCS(X, C, torch.zeros_like(C))


# real_protein = Protein('7r5b')

# # Sample with Chroma
# protein = chroma.sample(sde_func="ode", protein_init=real_protein)
# protein.to("my_sample.cif")

