from chroma import Chroma
from chroma.data.protein import Protein
import torch

chroma = Chroma()


true_protein = Protein.from_CIF('8ok3.cif')
# generated_protein = Protein.from_CIF('steered_protein.cif')


#original 'FK_protein_20_inf_2.CIF'
better_generated_protein = Protein.from_CIF('FK_protein_50_7r5b.CIF')



# print(f'S true shape: {S_true.size()}')
# print(f'S Gen shape: {S_gen.size()}')

# X_true_CA = X_true[:, :, 1, :].squeeze(0)

# X_gen_CA = X_gen[:, :, 1, :].squeeze(0)

# true_pairwise = torch.cdist(X_true_CA, X_gen_CA)
# gen_pairwise = torch.cdist(X_gen_CA, X_gen_CA)




# # true_dist = torch.cdist(X_true_CA, X_true_CA)
# # gen_dist = torch.cdist(X_gen_CA, X_gen_CA)


# # print(f'X_true CA shape: {X_true_CA.size()}')
# # print(f'X_gen CA shape: {X_gen_CA.size()}')

results = chroma.score_backbone(better_generated_protein)
# print(results.keys())

# true_results = chroma.score_backbone(true_protein)


print(f"ELBO (Structural Realism) of Generated: {results['elbo'].score}")
# print(f"ELBO (Structural Realism) of 7r5b: {true_results['elbo'].score}")
# print(f"ELBO_X: {results['elbo_X'].score}")
# print(f"pair_MSE: {results['pair_mse'].score}")
# print(f"global_MSE: {results['global_mse'].score}")

# # print(f'X true shape: {X_true.size()}')
# # print(f'X Gen shape: {X_gen.size()}')