import torch


class ChargeBalanceGuide:
    def __init__(self, use_prev=False):
        self.use_prev = use_prev

    def __call__(self, prev_sample, clean_sample):
        cand = prev_sample if self.use_prev else clean_sample
        return torch.abs(torch.tensor(cand.get_charge()).float().to(cand.get_positions().device))
