import torch 
from chroma import Protein
import matplotlib.pyplot as plt
import numpy as np 



def get_distance_matrix(X):
    nodes = X[:, :, 1, :]  # [N, R, 3] - CA atoms
    distance_matrix = torch.linalg.norm(nodes[:, :, None] - nodes[:, None], dim=-1)
    return distance_matrix  # [N, R, R]

def load_true_protein(cif_path, device='cpu'):
    X, C, S = Protein.from_CIF(cif_path).to_XCS()
    X_gt = X[torch.abs(C) == 1][None]
    S_gt = S[torch.abs(C) == 1][None]
    C_gt = C[torch.abs(C) == 1][None]
    X_gt -= X_gt.mean(dim=(0, 1, 2))

    return X_gt.to(device), C_gt.to(device), S_gt.to(device)


gen_path = 'FK_protein_30_inf_2.CIF'


X_t, _, _ = load_true_protein('7r5b.cif')
X_g, _, _ = Protein.from_CIF(gen_path).to_XCS()


gen_dist = get_distance_matrix(X_g)
true_dist = get_distance_matrix(X_t)

diff = gen_dist - true_dist
diff = diff.squeeze(0)               # [R, R]
diff_np = np.abs(diff.detach().cpu().numpy())

# Plot heatmap
font_size = 18

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

plt.figure(figsize=(6, 6))
ax = plt.gca()

im = ax.imshow(diff_np, vmin=0, vmax=40)

# --- LEFT-SIDE COLORBAR ---
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.3)

cbar = plt.colorbar(im, cax=cax)
cbar.set_label("Distance difference (Å)", fontsize=font_size)
cbar.ax.tick_params(labelsize=font_size)
cbar.ax.yaxis.set_label_position("right")
cbar.ax.yaxis.set_ticks_position("right")

# --- AXES LABELS / TICKS ---
ax.set_xlabel("Residue index", fontsize=font_size)
ax.tick_params(labelsize=font_size)

plt.tight_layout()
plt.savefig("heatmap.png", dpi=300, bbox_inches="tight")
plt.close()


