import torch

from filling_strategies import filling

def get_reconstructed_signal(data, feature_mask, args, fp):
    x = data.x.clone()
    x[~feature_mask] = float('nan')
    filled_features = filling(data.edge_index, x, feature_mask, args, fp) 
    x_reconstructed = torch.where(feature_mask, data.x, filled_features)
    
    return x_reconstructed