from chroma.layers.structure.rmsd import CrossRMSD
from chroma import Protein
import torch 
from chroma import Chroma
from chroma.data.protein import Protein
import torch

chroma = Chroma()

def setup_ground_truth(cif_path ):
    X_true = Protein.from_CIF(cif_path)
    X_t, C_t, S_t = X_true.to_XCS()
    X_t = X_t[torch.abs(C_t) ==1][None]
    S_t = S_t[torch.abs(C_t) ==1][None]
    C_t = C_t[torch.abs(C_t) ==1][None]
    X_t -= X_t.mean(dim=(0, 1, 2))
    return X_t, C_t, S_t


def setup_generated_protein(cif_path): 
    X_gen = Protein.from_CIF(cif_path)
    X_g, C_g, S_g = X_gen.to_XCS()
    return X_g, C_g, S_g



def ca_rmsd_masked(X_t, X_g, C_gt):
    # Ensure batch dim


    mask = (C_gt[0] == 1)  # [R]

    ca_t = X_t[0, mask, :].reshape(1, -1, 3)
    ca_g = X_g[0, mask, :].reshape(1, -1, 3)

    rmsd, _ = CrossRMSD().pairedRMSD(ca_g.cpu(), ca_t.cpu(), compute_alignment=True)
    return rmsd

ok3_gen_1 = "8ok3_gen_1.cif"
ok3_gen_2 = "8ok3_gen_2.cif"
ok3_gen_3 = "8ok3_gen_3.cif"

ok3_true ="8ok3.cif"

pzt_gen_1 = "7pzt_gen_1.cif"
pzt_gen_2 = "7pzt_gen_2.cif"
pzt_gen_3 = "7pzt_gen_3.cif"

pzt_true ="7pzt.cif"


r5b_gen_1 = "7r5b_gen_1.cif"
r5b_gen_2 = "7r5b_gen_2.cif"
r5b_gen_3 = "7r5b_gen_3.cif"

r5b_true ="7r5b.cif"

gen_1 = Protein.from_CIF(ok3_gen_1)
gen_2 = Protein.from_CIF(ok3_gen_2)
gen_3 = Protein.from_CIF(ok3_gen_3)


X_t, C_t, S_t = setup_ground_truth(ok3_true)
X_g_1, C_g_1, S_g_1 = setup_generated_protein(ok3_gen_1)
X_g_2, C_g_2, S_g_2 = setup_generated_protein(ok3_gen_2)
X_g_3, C_g_3, S_g_3 = setup_generated_protein(ok3_gen_3)





elbo_1 = chroma.score_backbone(gen_1)['elbo'].score
elbo_2 = chroma.score_backbone(gen_2)['elbo'].score
elbo_3 = chroma.score_backbone(gen_3)['elbo'].score

rmsd_1 = ca_rmsd_masked(X_g=X_g_1, X_t = X_t, C_gt= C_t)
rmsd_2 = ca_rmsd_masked(X_g=X_g_2, X_t = X_t, C_gt= C_t)
rmsd_3 = ca_rmsd_masked(X_g=X_g_3, X_t = X_t, C_gt= C_t)



rmsds = torch.stack([rmsd_1, rmsd_2, rmsd_3], dim=0)
print(rmsds)
elbos = torch.tensor([elbo_1, elbo_2, elbo_3], dtype=torch.float32, device=X_t.device)
mean_rmsd = rmsds.mean(dim=0)
std_rmsd  = rmsds.std(dim=0, unbiased=False)  # population std
mean_elbo = elbos.mean(dim=0)
std_elbo = elbos.std(dim=0, unbiased=False)  # population std
print("mean RMSD:", mean_rmsd)
print("std  RMSD:", std_rmsd)

print(f'mean elbo: {mean_elbo}')
print(f'std elbo: {std_elbo}')
