import torch 
from chroma import Protein
import matplotlib.pyplot as plt
import numpy as np 

def get_distance_matrix(X):
    nodes = X[:, :, 1, :]  # [N, R, 3] - CA atoms
    distance_matrix = torch.linalg.norm(nodes[:, :, None] - nodes[:, None], dim=-1)
    return distance_matrix  # [N, R, R]

def load_true_protein(cif_path, device='cpu'):
    X, C, S = Protein.from_CIF(cif_path).to_XCS()
    X_gt = X[torch.abs(C) == 1][None]
    S_gt = S[torch.abs(C) == 1][None]
    C_gt = C[torch.abs(C) == 1][None]
    X_gt -= X_gt.mean(dim=(0, 1, 2))

    return X_gt.to(device), C_gt.to(device), S_gt.to(device)


gen_path = '40_spt_steps_steered_protein.cif'

# lowest_rmsd_gen_path = "40_spt_steps_steered_protein.cif"

X_t, _, _ = load_true_protein('7r5b.cif')
X_g, _, _ = Protein.from_CIF(gen_path).to_XCS()


gen_dist = get_distance_matrix(X_g)
true_dist = get_distance_matrix(X_t)

diff = gen_dist - true_dist
diff = diff.squeeze(0)               # [R, R]
diff_np = np.abs(diff.detach().cpu().numpy())

# Plot heatmap
plt.figure(figsize=(6, 5))
im = plt.imshow(diff_np, vmin=0, vmax=40)
plt.colorbar(im, label="Distance difference (Å)")
plt.title("Pairwise CA Distance Difference\n(Generated − True)")
plt.xlabel("Residue index")
plt.ylabel("Residue index")

plt.tight_layout()
plt.savefig("heatmap.png", dpi=300)
plt.close()