from spaghettini import quick_register

from src.learners.pvn_classification.ljn_single_base import LJNSingleBase


@quick_register
class LJNSingleFindThePlus(LJNSingleBase):
    def unpack_data_batch(self, data_batch):
        xs, ys = data_batch['img'].unsqueeze(1).float(), data_batch['label'].long()
        p_xs, v_xs = xs, xs

        other_data = dict(correct_proofs=data_batch['coords'].float())

        return p_xs, v_xs, ys, other_data
