import torch


from chroma.data.protein import Protein



def pairwise_reward(x):

    Ca_elements = x[:, :, 1, :].squeeze(0)
    filtered_pairwise_distances = compute_distances_and_filter(Ca_elements)
    
    # Get device from input tensor x
    noisy_filtered_distances = get_noisy_distances(device=x.device)
    
    assert filtered_pairwise_distances.device == filtered_pairwise_distances.device
    # print(f"filtered_pairwise_distances device: {filtered_pairwise_distances.device}")
    # print(f"noisy_filtered_distances device: {filtered_pairwise_distances.device}")
    
    mse = torch.mean((filtered_pairwise_distances - noisy_filtered_distances) ** 2)

    # Negative MSE for reward (minimizing MSE = maximizing negative MSE)
    return -mse



def compute_distances_and_filter(x):
    D = torch.cdist(x, x)  # No need for .to(x.device) here, cdist preserves device

    # Create mask for distances >= 6
    mask = D >= 6
    
    # Zero out distances that satisfy condition
    D[mask] = 0
    
    return D


def get_noisy_distances(device): 
    X, _, _ = Protein.from_CIF('../reference_proteins/7r5b.cif').to_XCS()
    X_Ca = X[:, :, 1, :].squeeze(0).to(device)
    true_filtered_distances = compute_distances_and_filter(X_Ca)

    true_filtered_distances += 0.5 * torch.randn(
        size=true_filtered_distances.size(),
        device=true_filtered_distances.device
    )
    return true_filtered_distances.abs()


X, C, S = torch.load("test_backbone.pt", map_location="cpu")

reward = []
for i in range(2):
    reward.append(pairwise_reward(X[i]))



