import torch
from chroma.data.protein import Protein
from chroma import Chroma



def load_true_protein(cif_path, device='cpu'):
    X, C, S = Protein.from_CIF(cif_path).to_XCS()
    X_gt = X[torch.abs(C) == 1][None]
    S_gt = S[torch.abs(C) == 1][None]
    C_gt = C[torch.abs(C) == 1][None]
    X_gt -= X_gt.mean(dim=(0, 1, 2))

    return X_gt.to(device), C_gt.to(device), S_gt.to(device)




def get_distance_matrix(X):
    nodes = X[:, :, 1, :]  #  - CA atoms
    distance_matrix = torch.linalg.norm(nodes[:, :, None] - nodes[:, None], dim=-1)
    return distance_matrix  






class ProperPairwiseReward:
    def __init__(self, cif_path, device="cuda"):
        # torch.manual_seed(seed)
        self.device = device
        print(f'CIF path: {cif_path}')
        X_ref, C_ref, S_ref = load_true_protein(cif_path, device=device)
        self.D_ref = get_distance_matrix(X_ref)   
        self.mask = torch.ones(size=self.D_ref.size()).to("cuda") # all ca pairwise distances
        self.Y = self.mask * self.D_ref

    def __call__(self, X):
        D = get_distance_matrix(X)
        X_masked = self.mask * D
        l2 = ((self.Y - X_masked)**2).mean(dim=(1,2))
        normalized = - l2 / 80
        return normalized 