import pickle
import re
import sys
import time

import clingo
import torch
from torch.nn.utils import clip_grad_norm_
import numpy as np
import torch.nn as nn
import time
import utils

from mvpp import MVPP
from sklearn.metrics import confusion_matrix

from tqdm import tqdm

from EinsumNetwork import Graph, EinsumNetwork

#from multiprocessing import Pool
from pathos.multiprocessing import ProcessingPool as Pool
from itertools import repeat




def compute_gradients_splitwise(networkOutput_split, query_batch_split, mvpp, n, normalProbs, alpha, dmvpp, method, opt):
        """
        Computes the gradients, stable models and P(Q) for part of the batch.
        
        @param networkOutput_split:
        @param query_batch_split:
        @param mvpp:
        @param n:
        @param normalProbs:
        @param alpha:
        @param dmvpp:
        @param method:
        @param opt:
        :return:returns the gradients, the stable models and the probability P(Q)
        """
        
        
        
        query_batch_split = query_batch_split.tolist()
        
        #create a list to store the gradients into
        gradient_batch_list_split = []
        
        #create a list to store the stable models into
        model_batch_list_split = []
        
        #create a list to store p(Q) into
        prob_q_batch_list_split = []
        
        #iterate over all queries
        for bidx, query in enumerate(query_batch_split):
            
            # Step 2: if alpha is less than 1, we compute the semantic gradients (default: alpha=0)
            if alpha < 1:

                # Step 2.1: replace the parameters in the MVPP program with network outputs
                #iterate over all rules
                for ruleIdx in range(mvpp['networkPrRuleNum']):
                    
                    #get the network outputs for the current element in the batch and put it to the correct rule
                    dmvpp.parameters[ruleIdx] = [networkOutput_split[m][t][bidx][i*n[m]+j] for (m, i, t, j) in mvpp['networkProb'][ruleIdx]]

                    if len(dmvpp.parameters[ruleIdx]) == 1:
                        dmvpp.parameters[ruleIdx] = [dmvpp.parameters[ruleIdx][0], 1-dmvpp.parameters[ruleIdx][0]]

                # Step 2.2: replace the parameters for normal prob. rules in the MVPP program with updated probabilities
                if normalProbs:
                    for ruleIdx, probs in enumerate(normalProbs):
                        dmvpp.parameters[mvpp['networkPrRuleNum']+ruleIdx] = probs


            dmvpp.normalize_probs()

            check = False
            
            if method == 'exact': #default exact
                # models = dmvpp.find_k_SM_under_query(query, k=0)#k=0 find all stable models
                # self.stableModels.append(models)
                # gradients = dmvpp.mvppLearn(models)
                gradients, models = dmvpp.gradients_one_query(query, opt=opt)
            elif method == 'slot':
                models = dmvpp.find_one_most_probable_SM_under_query_noWC(query)
                gradients = dmvpp.mvppLearn(models)
            elif method == 'sampling':
                models = dmvpp.sample_query(query, num=10)
                gradients = dmvpp.mvppLearn(models)
            elif method == 'network_prediction':
                models = dmvpp.find_one_most_probable_SM_under_query_noWC()
                check = SLASH.satisfy(models[0], mvpp['program_asp'] + query)
                gradients = dmvpp.mvppLearn(models) if check else -dmvpp.mvppLearn(models)
                if check:
                    continue
            elif method == 'penalty':
                models = dmvpp.find_all_SM_under_query()
                models_noSM = [model for model in models if not SLASH.satisfy(model, mvpp['program_asp'] + queryList[idx])]
                gradients = - dmvpp.mvppLearn(models_noSM)
            else:
                print('Error: the method \'%s\' should be either \'exact\' or \'sampling\'', method)
            
            prob_q = dmvpp.sum_probability_for_stable_models(models)
            
            model_batch_list_split.append(models)
            gradient_batch_list_split.append(gradients)
            prob_q_batch_list_split.append(prob_q)
        
        return gradient_batch_list_split, model_batch_list_split, prob_q_batch_list_split




class SLASH(object):
    def __init__(self, dprogram, networkMapping, optimizers, gpu=True):

        """
        @param dprogram: a string for a NeurASP program
        @param networkMapping: a dictionary maps network names to neural networks modules
        @param optimizers: a dictionary maps network names to their optimizers
        
        @param gpu: a Boolean denoting whether the user wants to use GPU for training and testing
        """
        self.device = torch.device('cuda' if torch.cuda.is_available() and gpu else 'cpu')

        self.dprogram = dprogram
        self.const = {} # the mapping from c to v for rule #const c=v.
        self.n = {} # the mapping from network name to an integer n denoting the domain size; n would be 1 or N (>=3); note that n=2 in theory is implemented as n=1
        self.e = {} # the mapping from network name to an integer e
        self.domain = {} # the mapping from network name to the domain of the predicate in that network atom
        self.normalProbs = None # record the probabilities from normal prob rules
        self.networkOutputs = {}
        self.networkGradients = {}
        self.networkTypes = {}
        if gpu==True:
            self.networkMapping = {key : nn.DataParallel(networkMapping[key].to(self.device)) for key in networkMapping}
        else:
            self.networkMapping = networkMapping
        self.optimizers = optimizers
        # self.mvpp is a dictionary consisting of 4 keys: 
        # 1. 'program': a string denoting an MVPP program where the probabilistic rules generated from network are followed by other rules;
        # 2. 'networkProb': a list of lists of tuples, each tuple is of the form (model, i ,term, j)
        # 3. 'atom': a list of list of atoms, where each list of atoms is corresponding to a prob. rule
        # 4. 'networkPrRuleNum': an integer denoting the number of probabilistic rules generated from network
        self.mvpp = {'networkProb': [], 'atom': [], 'networkPrRuleNum': 0, 'program': ''}
        self.mvpp['program'], self.mvpp['program_pr'], self.mvpp['program_asp'] = self.parse(query='')
        self.stableModels = [] # a list of stable models, where each stable model is a list
        self.prob_q = [] # a list of probabilites for each query in the batch


    def constReplacement(self, t):
        """ Return a string obtained from t by replacing all c with v if '#const c=v.' is present

        @param t: a string, which is a term representing an input to a neural network
        """
        t = t.split(',')
        t = [self.const[i.strip()] if i.strip() in self.const else i.strip() for i in t]
        return ','.join(t)

    def networkAtom2MVPPrules(self, networkAtom):
        """
        @param networkAtom: a string of a neural atom
        @param countIdx: a Boolean value denoting whether we count the index for the value of m(t, i)[j]
        """
    
        # STEP 1: obtain all information
        regex = '^(network|pc|tabnet)\((.+)\((.+)\),\((.+)\)\)$'
        out = re.search(regex, networkAtom)        
        
        network_type = out.group(1)
        m = out.group(2)
        e, t = out.group(3).split(',', 1) # in case t is of the form t1,...,tk, we only split on the first comma
        domain = out.group(4).split(',')
        
        
        self.networkTypes[m] = network_type

        
        t = self.constReplacement(t)
        # check the value of e
        e = e.strip()
        e = int(self.constReplacement(e))
        n = len(domain)
        if n == 2:
            n = 1
        self.n[m] = n
        self.e[m] = e

        self.domain[m] = domain
        if m not in self.networkOutputs:
            self.networkOutputs[m] = {}
        if t not in self.networkOutputs[m]:
            self.networkOutputs[m][t] = None

        # STEP 2: generate MVPP rules
        mvppRules = []

        # we have different translations when n = 2 (i.e., n = 1 in implementation) or when n > 2
        if n == 1:
            for i in range(e):
                rule = '@0.0 {}({}, {}, {}); @0.0 {}({}, {}, {}).'.format(m, i, t, domain[0], m, i, t, domain[1])
                prob = [tuple((m, i, t, 0))]
                atoms = ['{}({}, {}, {})'.format(m, i, t, domain[0]), '{}({}, {}, {})'.format(m, i, t, domain[1])]
                mvppRules.append(rule)
                self.mvpp['networkProb'].append(prob)
                self.mvpp['atom'].append(atoms)
                self.mvpp['networkPrRuleNum'] += 1

        elif n > 2:
            for i in range(e):
                rule = ''
                prob = []
                atoms = []
                for j in range(n):
                    atom = '{}({}, {}, {})'.format(m, i, t, domain[j])
                    rule += '@0.0 {}({}, {}, {}); '.format(m, i, t, domain[j])
                    prob.append(tuple((m, i, t, j)))
                    atoms.append(atom)
                mvppRules.append(rule[:-2]+'.')
                self.mvpp['networkProb'].append(prob)
                self.mvpp['atom'].append(atoms)
                self.mvpp['networkPrRuleNum'] += 1
        else:
            print('Error: the number of element in the domain %s is less than 2' % domain)
        return mvppRules


    def parse(self, query=''):
        dprogram = self.dprogram + query
        # 1. Obtain all const definitions c for each rule #const c=v.
        regex = '#const\s+(.+)=(.+).'
        out = re.search(regex, dprogram)
        if out:
            self.const[out.group(1).strip()] = out.group(2).strip()
            
        # 2. Generate prob. rules for grounded network atoms
        clingo_control = clingo.Control(["--warn=none"])
        
        # 2.1 remove weak constraints and comments
        program = re.sub(r'\n:~ .+\.[ \t]*\[.+\]', '\n', dprogram)
        program = re.sub(r'\n%[^\n]*', '\n', program)
        
        # 2.2 replace [] with ()
        program = program.replace('[', '(').replace(']', ')')
        

        # 2.3 use MVPP package to parse prob. rules and obtain ASP counter-part
        mvpp = MVPP(program)
        if mvpp.parameters and not self.normalProbs:
            self.normalProbs = mvpp.parameters
        pi_prime = mvpp.pi_prime

        
        # 2.4 use clingo to generate all grounded network atoms and turn them into prob. rules
        clingo_control.add("base", [], pi_prime)
        clingo_control.ground([("base", [])])
        symbols = [atom.symbol for atom in clingo_control.symbolic_atoms]
                
        mvppRules = [self.networkAtom2MVPPrules(str(atom)) for atom in symbols if (atom.name == 'network' or atom.name =='pc' or atom.name =='tabnet')]
        mvppRules = [rule for rules in mvppRules for rule in rules]
        
        
        # 3. obtain the ASP part in the original NeurASP program
        #lines = [line.strip() for line in dprogram.split('\n') if line and not (line.startswith('network(') or line.startswith('pc('))]
        lines = [line.strip() for line in dprogram.split('\n') if line and not (re.match("^\s*pc\(", line) or re.match("^\s*network\(", line) or re.match("^\s*tabnet\(", line))]

        return '\n'.join(mvppRules + lines), '\n'.join(mvppRules), '\n'.join(lines)


    @staticmethod
    def satisfy(model, asp):
        """
        Return True if model satisfies the asp program; False otherwise
        @param model: a stable model in the form of a list of atoms, where each atom is a string
        @param asp: an ASP program (constraints) in the form of a string
        """
        asp_with_facts = asp + '\n'
        for atom in model:
            asp_with_facts += atom + '.\n'
        clingo_control = clingo.Control(['--warn=none'])
        clingo_control.add('base', [], asp_with_facts)
        clingo_control.ground([('base', [])])
        result = clingo_control.solve()
        if str(result) == 'SAT':
            return True
        return False

        
    def infer(self, dataDic, query='', mvpp='', postProcessing=None):
        """
        @param dataDic: a dictionary that maps terms to tensors/np-arrays
        @param query: a string which is a set of constraints denoting a query
        @param mvpp: an MVPP program used in inference
        """

        mvppRules = ''
        facts = ''

        # Step 1: get the output of each neural network
        for m in self.networkOutputs:
            self.networkMapping[m].eval()
            for t in self.networkOutputs[m]:
                # if dataDic maps t to tuple (dataTensor, {'m': labelTensor})
                if isinstance(dataDic[t], tuple):
                    dataTensor = dataDic[t][0]
                # if dataDic maps t to dataTensor directly
                else:
                    dataTensor = dataDic[t]
                
                self.networkOutputs[m][t] = self.networkMapping[m](dataTensor).view(-1).tolist()

        # Step 3: turn the network outputs (from usual classification neurual networks) into a set of MVPP probabilistic rules
        for ruleIdx in range(self.mvpp['networkPrRuleNum']):
            probs = [self.networkOutputs[m][t][i*self.n[m]+j] for (m, i, t, j) in self.mvpp['networkProb'][ruleIdx]]
            if len(probs) == 1:
                mvppRules += '@{} {}; @{} {}.\n'.format(probs[0], self.mvpp['atom'][ruleIdx][0], 1 - probs[0], self.mvpp['atom'][ruleIdx][1])
            else:
                tmp = ''
                for atomIdx, prob in enumerate(probs):
                    tmp += '@{} {}; '.format(prob, self.mvpp['atom'][ruleIdx][atomIdx])
                mvppRules += tmp[:-2] + '.\n'

        # Step 3: find an optimal SM under query
        dmvpp = MVPP(facts + mvppRules + mvpp)
        return dmvpp.find_one_most_probable_SM_under_query_noWC(query=query)


 
    def learn(self, dataList, queryList, epoch, alpha=0, lossFunc='cross', method='exact', lr=0.01, opt=False, storeSM=False, smPickle=None, accEpoch=0, batchSize=1, use_em=False, train_slot=False,  slot_net=None, p_num=1, marginalisation_masks=None):
        
        """
        @param dataList: a list of dictionaries, where each dictionary maps terms to either a tensor/np-array or a tuple (tensor/np-array, {'m': labelTensor})
        @param queryList: a list of strings, where each string is a set of constraints denoting query
        @param epoch: an integer denoting the number of epochs
        @param alpha: a real number between 0 and 1 denoting the weight of cross entropy loss; (1-alpha) is the weight of semantic loss
        @param lossFunc: a string in {'cross'} or a loss function object in pytorch
        @param method: a string in {'exact', 'sampling'} denoting whether the gradients are computed exactly or by sampling
        @param lr: a real number between 0 and 1 denoting the learning rate for the probabilities in probabilistic rules
        @param opt: stands for optimal -> if true we can select optimal stable models
        @param storeSM: boolean indicating to store or not store the stable models for later usage
        @param smPickle: path to store/load the stable models if storeSM is true
        @param batchSize: a positive integer denoting the batch size, i.e., how many data instances do we use to update the network parameters for once
        @param p_num: a positive integer denoting the number of processor cores to be used during the training
        @param marginalisation_masks: a list entailing one marginalisation mask for each batch of dataList
        """
        
        assert len(dataList) == len(queryList), 'Error: the length of dataList does not equal to the length of queryList'
        assert alpha >= 0 and alpha <= 1, 'Error: the value of alpha should be within [0, 1]'
        assert p_num >= 1 and isinstance(p_num, int), 'Error: the number of processors used should greater equals one and a natural number'

        # if the pickle file for stable models is given, we will either read all stable models from it or
        # store all newly generated stable models in that pickle file in case the pickle file cannot be loaded
        savePickle = False
        if smPickle is not None:
            storeSM = True
            try:
                with open(smPickle, 'rb') as fp:
                    self.stableModels = pickle.load(fp)
            except Exception:
                savePickle = True


        # get the mvpp program by self.mvpp
        if method == 'network_prediction':
            dmvpp = MVPP(self.mvpp['program_pr'])
        elif method == 'penalty':
            dmvpp = MVPP(self.mvpp['program_pr'])
        else:
            dmvpp = MVPP(self.mvpp['program'])
            

        # we train all neural network models
        for m in self.networkMapping:
            self.networkMapping[m].train() #torch training mode
            self.networkMapping[m].module.train() #torch training mode

        
        forward_time = 0.0 
        asp_time = 0.0
        backward_time = 0.0
                       
        # we train for 'epoch' times of epochs. Learning for multiple epochs can also be done in an outer loop by specifying epoch = 1
        for epochIdx in range(epoch):
        
                
            #create random batches of the training data
            train_N = len(dataList)
            idx_batches = torch.randperm(train_N).split(batchSize)

            total_loss = 0
                    
            
            #iterate over batches
            for i, idx in tqdm(enumerate(idx_batches), total=len(idx_batches)):

                start_time = time.time()
                #if len(dataList) ==1:
                #    dataList = [dataList]
                
                #we have a list of hashmaps but we want a hashmap of lists such that we can forward the data batchwise. 
                #generate batches of the form dict {im1 : im1_batch, im2: im2_batch}
                data_batch = {k: [dic[k] for dic in dataList[idx]] for k in dataList[idx][0]}
                query_batch = queryList[idx]
                
                # If we have marginalisation masks, than we have to pick one for the batch
                if marginalisation_masks is not None:
                    marg_mask = marginalisation_masks[i]
                else:
                    marg_mask = None
                
                #STEP 0: APPLY SLOT ATTENTION TO TRANSFORM IMAGE TO SLOTS
                #we have a map which is : im: im_data
                #we want a map which is : s1: slot1_data, s2: slot2_data, s3: slot3_data
                if slot_net is not None:
                    if use_em or train_slot == False:
                        slot_net.eval()
                        with torch.no_grad():
                            dataTensor_after_slot = slot_net(torch.cat(data_batch['im']).to(self.device)) #forward the image
                            
                    #only train the slot module if train_slot is true and we dont use EM
                    else: 
                        slot_net.train()
                        dataTensor_after_slot = slot_net(torch.cat(data_batch['im']).to(self.device)) #forward the image
                    
                    #add the slot outputs to the data batch
                    for slot_num in range(slot_net.n_slots):
                        key = 's'+str(slot_num+1)
                        data_batch[key] = dataTensor_after_slot[:,slot_num,:]
                            
                                       
                    
                # data is a dictionary. we need to edit its key if the key contains a defined const c
                # where c is defined in rule #const c=v.
                #for key in dataList[idx]:
                #    dataList[idx][self.constReplacement(key)] = dataList[idx].pop(key)
                            
                
                # Step 1: get the output of each neural network and initialize the gradients
                networkOutput = {}
                networkLLOutput = {}
                
                # If any of network's type is 'tabnet', create the additional dictionary
                # to keep the track of the sparsity loss
                if any(self.networkTypes[m]=='tabnet' for m in self.networkOutputs):
                    networkSparsityLossOutputs = {} 
                
                for m in self.networkOutputs:
                                        
                    networkOutput[m] = {}
                    networkLLOutput[m] = {}
                    if  self.networkTypes[m]=='tabnet':
                        networkSparsityLossOutputs[m] = {}
                    
                    for t in self.networkOutputs[m]: #iterate over all inputs to be forwarded through the neural net
                        
                                                
                        labelTensor = None
                        # if data maps t to tuple (dataTensor, {'m': labelTensor})
                        
                        if isinstance(data_batch[t], tuple):
                            dataTensor = data_batch[t][0]
                            if m in data_batch[t][1]:
                                labelTensor = data[t][1][m]
                        # if data maps t to dataTensor directly                            
                        else:
                            dataTensor = data_batch[t]
                            
                        
                        #we have a list of data elements but want a Tensor of the form [batchsize,...]
                        if isinstance(dataTensor, list):
                            dataTensor = torch.stack(dataTensor).squeeze(dim=1)
                        
                        
                        #one forward pass to get the outputs
                        if self.networkTypes[m] == 'pc':
                            networkOutput[m][t], networkLLOutput[m][t],_ = self.networkMapping[m].forward(dataTensor.to(self.device), ll_out=True, marg_idx=marg_mask)
                        elif self.networkTypes[m] == 'network':
                            networkOutput[m][t] = self.networkMapping[m].forward(dataTensor.to(self.device))
                        elif self.networkTypes[m] == 'tabnet':
                            networkOutput[m][t], networkSparsityLossOutputs[m][t] = self.networkMapping[m].forward(dataTensor.to(self.device), sparsity_loss=True)

                        
                        #store the outputs of the neural networks as a class variable
                        self.networkOutputs[m][t] = networkOutput[m][t] #this is of shape [first batch entry, second batch entry,...]
                                  
                        
                        # if alpha is greater than 0 and the labelTensor is given in dataList, we compute the network gradients
                        if alpha > 0 and labelTensor is not None:
                            if isinstance(lossFunc, str):
                                if lossFunc == 'cross':
                                    criterion = torch.nn.NLLLoss()
                                    loss = alpha * criterion(torch.log(networkOutput[m][t].view(-1, self.n[m])), labelTensor.long().view(-1))
                            else:
                                criterion = lossFunc
                                loss = alpha * criterion(networkOutput[m][t].view(-1, self.n[m]), labelTensor)
                            loss.backward(retain_graph=True)
                 

                step1= time.time()
                forward_time += step1- start_time
                            
                # Step 2: compute the gradients
                #iterate over queries to get the gradients for each oberservation

                
                #partition dictionary for different processors
                
                splits = np.arange(0, p_num)
                partition = int(len(idx) / p_num)
                partition_mod = len(idx) % p_num 
                partition = [partition]*p_num
                partition[-1]+= partition_mod

                query_batch_split = np.split(query_batch,np.cumsum(partition))[:-1]



                #create an empty dictionary with nested structure {procces_id: {m: {t:[bs_split, num_classes_of_property]}}
                split_networkoutputs = {}
                for s in splits:
                    if s not in split_networkoutputs:
                        split_networkoutputs[s] = {}

                        for m in networkOutput:
                            if m not in split_networkoutputs[s]:
                                split_networkoutputs[s][m] = {}

                            for t in networkOutput[m]:
                                if t not in split_networkoutputs[s][m]:
                                    split_networkoutputs[s][m][t] =  None 


                for m in networkOutput:
                    for t in networkOutput[m]:
                        split = torch.split(networkOutput[m][t].detach().cpu(),partition, dim=0)
                        for sidx, s in enumerate(split):
                            split_networkoutputs[sidx][m][t] = s.detach().cpu()


                gradient_batch_list= []

                #run wmc function on p_num processors and collect all the gradients
                with Pool(p_num) as p:
                    gradient_batch_list_splits, model_batch_list_splits, prob_q_batch_list_splits= zip(*p.map(
                        compute_gradients_splitwise,
                        split_networkoutputs.values(), query_batch_split,
                            repeat(self.mvpp), repeat(self.n), repeat(self.normalProbs) ,repeat(alpha), repeat(dmvpp), repeat(method), repeat(opt)))
                    
                '''
                with Pool(p_num) as p:
                    gradient_batch_list_splits, model_batch_list_splits, prob_q_batch_list_splits= zip(*p.starmap(
                        compute_gradients_splitwise,
                        zip(split_networkoutputs.values(), query_batch_split,
                            repeat(self.mvpp), repeat(self.n), repeat(self.normalProbs) ,repeat(alpha), repeat(dmvpp), repeat(method), repeat(opt))))
                '''   
                
                
                #concatenate the splits obtained from each process
                gradient_batch_list = np.concatenate(gradient_batch_list_splits)
                model_batch_list = np.concatenate(model_batch_list_splits)
                prob_q_batch_list = np.concatenate(prob_q_batch_list_splits)
                               
                #store the gradients, the stable models and p(Q) of the last batch processed
                self.networkGradients = gradient_batch_list
                self.stableModels = model_batch_list
                self.prob_q = prob_q_batch_list
                
                result = []

                
                # Step 3: update parameters in neural networks
                 

                step2 = time.time()
                asp_time += step2 - step1
                
                #iterate over the batch
                #gradient_batch_list has entries of the form [batch_size, num_anotated_disjunctions, output_size of predicate]
                for bidx, grad_elem in enumerate(gradient_batch_list):
                   
                    #iterate over all neural nets m and literals t
                    for midx, m in enumerate(networkOutput):
                        for tidx,t in enumerate(networkOutput[m]):

                            #networkOutput[m][t][bidx] has shape [3] r,g,b or rectangle, triangle, square
                            #gradients has shape [6,3]
                            #networkOutput has shape [2, 3, bs, 3]
                            gidx = (midx * len(networkOutput[m]) + tidx)
                            #print("gidx:", gidx, grad_elem[gidx], networkOutput[m][t][bidx])
                            #gradients = torch.stack(torch.Tensor(grad_elem[gidx])).to(self.device)
                            
                            #print(grad_elem[gidx], print(type(grad_elem[gidx])))
                            gradients = torch.Tensor(grad_elem[gidx].astype(float)).to(self.device)
                            
                            if self.networkTypes[m] == 'tabnet':
                                # print(networkSparsityLossOutputs[m][t])
                                # print(networkOutput[m][t][bidx])
                                # print(gradients)
                                # result.append((networkOutput[m][t][bidx] * gradients))
                                result.append((networkOutput[m][t][bidx] * gradients + networkSparsityLossOutputs[m][t] * torch.ones_like(gradients)))
                            else:
                                result.append((networkOutput[m][t][bidx] * gradients))
            
                
                #stack together all products, take the negative mean and then do one backward pass
                result_stacked = torch.cat([result[i] for i in range(len(result))], dim=0)
                # print(result_stacked.shape)
                if use_em:
                    result_ll = result_stacked.mean()    
                    result_ll.backward()
                    
                    for midx, m in enumerate(networkLLOutput):
                        #iterate over all neural literals t
                        self.networkMapping[m].module.em_process_batch()
                    
                else:
                    result_nll = -result_stacked.mean()
                    # print(result_nll.shape)
                    
                    #reset optimizers
                    for midx, m in enumerate(self.optimizers):
                        self.optimizers[m].zero_grad() 
                    
                    #backward pass
                    result_nll.backward(retain_graph=True)
                    
                    #apply gradients with each optimizer
                    for midx, m in enumerate(self.optimizers):
                        if self.networkTypes[m] == 'tabnet':
                            clip_grad_norm_(self.networkMapping[m].parameters(), 1)
                        self.optimizers[m].step()
                    
            
            
                last_step = time.time()
                backward_time += last_step - step2
                        
            
                # Step 4: if alpha is less than 1, we update probabilities in normal prob. rules
                if alpha < 1:
                    if self.normalProbs:
                        gradientsNormal = gradients[self.mvpp['networkPrRuleNum']:].tolist()
                        for ruleIdx, ruleGradients in enumerate(gradientsNormal):
                            ruleIdxMVPP = self.mvpp['networkPrRuleNum']+ruleIdx
                            for atomIdx, b in enumerate(dmvpp.learnable[ruleIdxMVPP]):
                                if b == True:
                                    dmvpp.parameters[ruleIdxMVPP][atomIdx] += lr * gradientsNormal[ruleIdx][atomIdx]
                        dmvpp.normalize_probs()
                        self.normalProbs = dmvpp.parameters[self.mvpp['networkPrRuleNum']:]

                # Step 5: show training accuracy
                if accEpoch !=0 and (idx+1) % accEpoch == 0:
                    print('Training accuracy at interation {}:'.format(idx+1))
                    self.testConstraint(dataList, queryList, [self.mvpp['program']])
                

            #em update
            if use_em:
                for midx, m in enumerate(networkLLOutput):
                    self.networkMapping[m].module.em_update()    

        
            print("forward time: ", forward_time)
            print("asp time:", asp_time)
            print("backward time: ", backward_time)

            

    def testNetwork(self, network, testLoader, ret_confusion=False):
        """
        Return a real number in [0,100] denoting accuracy
        @network is the name of the neural network or probabilisitc circuit to check the accuracy. 
        @testLoader is the input and output pairs.
        """
        self.networkMapping[network].eval()
        # check if total prediction is correct
        correct = 0
        total = 0
        # check if each single prediction is correct
        singleCorrect = 0
        singleTotal = 0
        
        #list to collect targets and predictions for confusion matrix
        y_target = []
        y_pred = []
        with torch.no_grad():
            for data, target in testLoader:
                                
                output = self.networkMapping[network](data.to(self.device))
                if self.n[network] > 2 :
                    pred = output.argmax(dim=-1, keepdim=True) # get the index of the max log-probability
                    target = target.to(self.device).view_as(pred)
                    
                    correctionMatrix = (target.int() == pred.int()).view(target.shape[0], -1)
                    y_target = np.concatenate( (y_target, target.int().flatten().cpu() ))
                    y_pred = np.concatenate( (y_pred , pred.int().flatten().cpu()) )
                    
                    
                    correct += correctionMatrix.all(1).sum().item()
                    total += target.shape[0]
                    singleCorrect += correctionMatrix.sum().item()
                    singleTotal += target.numel()
                else: 
                    pred = np.array([int(i[0]<0.5) for i in output.tolist()])
                    target = target.numpy()
                    
                    #y_target.append(target)
                    #y_pred.append(pred.int())
                    
                    correct += (pred.reshape(target.shape) == target).sum()
                    total += len(pred)
        accuracy = correct / total

        if self.n[network] > 2:
            singleAccuracy = singleCorrect / singleTotal
        else:
            singleAccuracy = 0
        
        #print(correct,"/", total, "=", correct/total)
        #print(singleCorrect,"/", singleTotal, "=", singleCorrect/singleTotal)

        
        if ret_confusion:
            confusionMatrix = confusion_matrix(np.array(y_target), np.array(y_pred))
            return accuracy, singleAccuracy, confusionMatrix

        return accuracy, singleAccuracy
    
    # We interprete the most probable stable model(s) as the prediction of the inference mode
    # and check the accuracy of the inference mode by checking whether the query is satisfied by the prediction
    def testInferenceResults(self, dataList, queryList):
        """ Return a real number in [0,1] denoting the accuracy
        @param dataList: a list of dictionaries, where each dictionary maps terms to tensors/np-arrays
        @param queryList: a list of strings, where each string is a set of constraints denoting a query
        """
        assert len(dataList) == len(queryList), 'Error: the length of dataList does not equal to the length of queryList'

        correct = 0
        for dataIdx, data in enumerate(dataList):
            models = self.infer(data, query=':- mistake.', mvpp=self.mvpp['program_asp'])
            for model in models:
                if self.satisfy(model, queryList[dataIdx]):
                    correct += 1
                    break
        accuracy = 100. * correct / len(dataList)
        return accuracy


    def testConstraint(self, dataList, queryList, mvppList):
        """
        @param dataList: a list of dictionaries, where each dictionary maps terms to tensors/np-arrays
        @param queryList: a list of strings, where each string is a set of constraints denoting a query
        @param mvppList: a list of MVPP programs (each is a string)
        """
        assert len(dataList) == len(queryList), 'Error: the length of dataList does not equal to the length of queryList'

        # we evaluate all nerual networks
        for func in self.networkMapping:
            self.networkMapping[func].eval()

        # we count the correct prediction for each mvpp program
        count = [0]*len(mvppList)

        for dataIdx, data in enumerate(dataList):
            # data is a dictionary. we need to edit its key if the key contains a defined const c
            # where c is defined in rule #const c=v.
            for key in data:
                data[self.constReplacement(key)] = data.pop(key)

            # Step 1: get the output of each neural network
            for model in self.networkOutputs:
                for t in self.networkOutputs[model]:
                    self.networkOutputs[model][t] = self.networkMapping[model](data[t].to(self.device)).view(-1).tolist()

            # Step 2: turn the network outputs into a set of ASP facts
            aspFacts = ''
            for ruleIdx in range(self.mvpp['networkPrRuleNum']):
                probs = [self.networkOutputs[m][t][i*self.n[model]+j] for (m, i, t, j) in self.mvpp['networkProb'][ruleIdx]]
                if len(probs) == 1:
                    atomIdx = int(probs[0] < 0.5) # t is of index 0 and f is of index 1
                else:
                    atomIdx = probs.index(max(probs))
                aspFacts += self.mvpp['atom'][ruleIdx][atomIdx] + '.\n'

            # Step 3: check whether each MVPP program is satisfied
            for programIdx, program in enumerate(mvppList):
                # if the program has weak constraints
                if re.search(r':~.+\.[ \t]*\[.+\]', program) or re.search(r':~.+\.[ \t]*\[.+\]', queryList[dataIdx]):
                    choiceRules = ''
                    for ruleIdx in range(self.mvpp['networkPrRuleNum']):
                        choiceRules += '1{' + '; '.join(self.mvpp['atom'][ruleIdx]) + '}1.\n'
                    mvpp = MVPP(program+choiceRules)
                    models = mvpp.find_all_opt_SM_under_query_WC(query=queryList[dataIdx])
                    models = [set(model) for model in models] # each model is a set of atoms
                    targetAtoms = aspFacts.split('.\n')
                    targetAtoms = set([atom.strip().replace(' ','') for atom in targetAtoms if atom.strip()])
                    if any(targetAtoms.issubset(model) for model in models):
                        count[programIdx] += 1
                else:
                    mvpp = MVPP(aspFacts + program)
                    if mvpp.find_one_SM_under_query(query=queryList[dataIdx]):
                        count[programIdx] += 1
        for programIdx, program in enumerate(mvppList):
            print('The accuracy for constraint {} is {}'.format(programIdx+1, float(count[programIdx])/len(dataList)))

            
            
    
    def forward_slot_attention_pipeline(self, slot_net, data_batch):
        """
        Makes one forward pass trough the slot attention pipeline to obtain the probabilities/log likelihoods for all classes for each object. 
        The pipeline includes  the SlotAttention module followed by probabilisitc circuits for probabilites for the discrete properties.
        @param slot_net: The SlotAttention module
        @param data_batch: The data batch to be forwarded
        """
        with torch.no_grad():

            if len(data_batch) ==1:
                data_batch = [data_batch]

            #transform list of hashmaps to hashmap of lists for batchwise forwarding
            data_batch = {k: [dic[k] for dic in data_batch] for k in data_batch[0]}
            
            
            probabilities = {}  # map to store all output probabilities(posterior)
            ll = {} #map to store all log likelihoods
            slot_map = {} #map to store all slot module outputs
        
            #forward img to get slots
            dataTensor_after_slot = slot_net(torch.cat(data_batch['im']).to(self.device))
            
            #dataTensor_after_slot has shape [bs, num_slots, slot_vector_length]
            _, num_slots ,_ = dataTensor_after_slot.shape 

            
            for sdx in range(0, num_slots):
                slot_map["s"+str(sdx)] = dataTensor_after_slot[:,sdx,:]
            

            #iterate over all slots and forward them through all nets (shape + color + ... )
            for key in slot_map:
                probabilities[key] = {}
                ll[key] = {}
                
                for network in self.networkMapping:
                    posterior, out, _ = self.networkMapping[network].forward(slot_map[key], ll_out=True)
                    probabilities[key][network] = posterior
                    ll[key][network] = out

        return probabilities, ll
