from spaghettini import quick_register

from src.learners.pvn_classification.ljn_double_base import LJNDoubleBase


@quick_register
class LJNDoubleFindThePlus(LJNDoubleBase):
    def unpack_data_batch(self, data_batch):
        xs, ys = data_batch['img'].unsqueeze(1).float(), data_batch['label'].long()
        p_xs, v_xs = xs, xs
        causal_feature = data_batch['coords'].float()

        return p_xs, v_xs, ys, causal_feature

    # def task_specific_logging(self, metric_logs, **kwargs):
    #     verifier = kwargs["verifier"]
    #     ys_true = kwargs["ys_true"]
    #     ys_train = kwargs["ys_train"]
    #     ys_aux = kwargs["ys_aux"]
    #     net_idx = kwargs["net_idx"]
    #     model_logs = kwargs["model_logs"]
    #     prepend_key = kwargs["prepend_key"]
    #
    #     FindThePlusCallback.log(logger=self.logger, verifier=verifier, ys_true=ys_true, ys_train=ys_train,
    #                             ys_aux=ys_aux, net_idx=net_idx, global_step=self.global_step, model_logs=model_logs,
    #                             prepend_key=prepend_key)
    #
    #     return metric_logs
