import torch
from torch.utils.data import Dataset
import numpy as np

class MNIST_Addition(Dataset):

    def __init__(self, dataset, examples, flat_for_spn):
        self.data = list()
        self.dataset = dataset
        self.flat_for_spn = flat_for_spn
        
        with open(examples) as f:
            for line in f:
                line = line.strip().split(' ')
                self.data.append(tuple([int(i) for i in line]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        i1, i2, l = self.data[index]
        
        if self.flat_for_spn:
            return self.dataset[i1][0].flatten(), self.dataset[i2][0].flatten(), l
        else: 
            return self.dataset[i1][0], self.dataset[i2][0], l



def get_data_and_query_list(train_dataset):

    dataList = []
    queryList = []
    for i1, i2, l in train_dataset:
        dataList.append({'i1': i1.unsqueeze(0), 'i2': i2.unsqueeze(0)})
        queryList.append(':- not addition(i1, i2, {}).'.format(l))

    dataList= np.array(dataList)
    queryList= np.array(queryList)
    
    return dataList, queryList