from chroma import Chroma
from chroma.data.protein import Protein
import torch

chroma = Chroma()


true_protein = Protein.from_CIF('7r5b.cif')
generated_protein = Protein.from_CIF('steered_protein.cif')

better_generated_protein = Protein.from_CIF('fewer_spt_steps_steered_protein.cif')



more_step_protein = Protein.from_CIF('8ok3_gen.cif')

x, _, _ = more_step_protein.to_XCS()
print(f'more step protein shape: {x.size()}')

# print(f'S true shape: {S_true.size()}')
# print(f'S Gen shape: {S_gen.size()}')

# X_true_CA = X_true[:, :, 1, :].squeeze(0)

# X_gen_CA = X_gen[:, :, 1, :].squeeze(0)

# true_pairwise = torch.cdist(X_true_CA, X_gen_CA)
# gen_pairwise = torch.cdist(X_gen_CA, X_gen_CA)





# # true_dist = torch.cdist(X_true_CA, X_true_CA)
# # gen_dist = torch.cdist(X_gen_CA, X_gen_CA)


# # print(f'X_true CA shape: {X_true_CA.size()}')
# # print(f'X_gen CA shape: {X_gen_CA.size()}')

results = chroma.score_backbone(more_step_protein)


print(results.keys())
print(f"ELBO (Structural Realism): {results['elbo'].score}")
# print(f"ELBO_X: {results['elbo_X'].score}")
# print(f"pair_MSE: {results['pair_mse'].score}")
# print(f"global_MSE: {results['global_mse'].score}")

# # print(f'X true shape: {X_true.size()}')
# # print(f'X Gen shape: {X_gen.size()}')