import torch
from models.utils.deepproblog_modules import DeepProblogModel
from utils.args import *
from utils.conf import get_device
from models.utils.utils_problog import *
from utils.losses import *



def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Learning via'
                                        'Concept Extractor .')
    add_management_args(parser)
    add_experiment_args(parser)
    return parser

class KANDPreProcess(DeepProblogModel):
    NAME = 'kandpreprocess'
    '''
    MNIST OPERATIONS AMONG TWO DIGITS. IT WORKS ONLY IN THIS CONFIGURATION.
    '''
    def __init__(self, encoder, n_images=2,
                 c_split=(), args=None,
                 model_dict=None,
                 n_facts=20, nr_classes=19):
        super(KANDPreProcess, self).__init__(encoder=encoder, model_dict=model_dict,
                                       n_facts=n_facts, nr_classes=nr_classes)
        
        # how many images and explicit split of concepts
        self.n_images = n_images
        self.c_split = c_split

        # Worlds-queries matrix
        # if args.task == 'base':

        # opt and device
        self.opt = None
        self.device = get_device()


    def forward(self, x):
        # Image encoding
        cs = []
        xs = torch.split(x, x.size(-1) // self.n_images, dim=-1)
        for i in range(self.n_images):
            lc = self.encoder(xs[i]) # sizes are ok
            
            cs.append(lc)

        clen = len(cs[0].shape) 

        embs = torch.stack(cs, dim=1) if clen > 1 else torch.cat(cs, dim=1)        
        
        return {'EMBS': embs}
    
    @staticmethod
    def get_loss(args=None):
        return None
        
    def start_optim(self, args):
        return
        # self.opt = torch.optim.Adam(self.parameters(), args.lr)