from chroma.layers.structure.rmsd import CrossRMSD
from chroma import Protein
import torch 


def setup_ground_truth(cif_path = '7pzt.cif'):
    X_true = Protein.from_CIF(cif_path)
    X_t, C_t, S_t = X_true.to_XCS()
    X_t = X_t[torch.abs(C_t) ==1][None]
    S_t = S_t[torch.abs(C_t) ==1][None]
    C_t = C_t[torch.abs(C_t) ==1][None]
    X_t -= X_t.mean(dim=(0, 1, 2))
    return X_t, C_t, S_t




# current cif_path='FK_protein_20_inf_2.CIF'

new_protein = 'FK_protein_30_7pzt.CIF'

def setup_generated_protein(cif_path='FK_protein_30_inf_2.CIF'): 
    X_gen = Protein.from_CIF(cif_path)
    X_g, C_g, S_g = X_gen.to_XCS()
    return X_g, C_g, S_g

def calculate_CA_RMSD(X_t, X_g):
    mask_t = (C_t[0] == 1)
    mask_g = (C_g[0] == 1)

    X_t_ca = X_t[0, mask_t, 1, :]   # [N, 3]
    X_g_ca = X_g[0, mask_g, 1, :]   # [N, 3]
    rmsd_ca, _ = CrossRMSD().pairedRMSD(
        X_g_ca.cpu().reshape(1, -1, 3),
        X_t_ca.cpu().reshape(1, -1, 3),
        compute_alignment=True,
    )
    return rmsd_ca


def calculate_all_atom_RMSD(X_t, X_g):

    rmsd, _ = CrossRMSD().pairedRMSD(
        X_g.cpu().reshape(1, -1, 3),
        X_t.cpu().reshape(1, -1, 3),
        compute_alignment=True,
    )
    return rmsd


def ca_rmsd_masked(X_t, X_g, C_gt):
    # Ensure batch dim


    mask = (C_gt[0] == 1)  # [R]

    ca_t = X_t[0, mask, :, :].reshape(1, -1, 3)
    ca_g = X_g[0, mask, :, :].reshape(1, -1, 3)

    rmsd, _ = CrossRMSD().pairedRMSD(ca_g.cpu(), ca_t.cpu(), compute_alignment=True)
    return rmsd




true_protein = '8ok3.cif'
X_t, C_t, S_t = setup_ground_truth(true_protein)

new_protein = 'FK_protein_50_8ok3.CIF'




# true_protein = Protein.from_CIF()
# generated_protein = Protein.from_CIF('steered_protein.cif')


#original 'FK_protein_20_inf_2.CIF'
# better_generated_protein = Protein.from_CIF('FK_protein_50_8ok3.CIF')

# gen_protein = 'FK_protein_50_8ok3.CIF'
X_g, C_g, S_g = setup_generated_protein(new_protein)

rmsd = ca_rmsd_masked(X_g=X_g, X_t=X_t, C_gt=C_t)
print(f'rmsd: {rmsd}')
# rmsd = calculate_all_atom_RMSD(X_t, X_g)
# rmsd_ca = calculate_CA_RMSD(X_t, X_g)

# print(f'RMSD (All atoms) {rmsd}')
# print(f'RMSD (CA specific) {rmsd_ca}')