from typing import Dict, List, Tuple#, Set
from pysdd.sdd import SddNode
from problog.sdd_formula import SDDManager
import torch
#import sys
#from cmd_args import cmd_args, CIRCUIT_APPROXIMATIONS, TWOSTAGE_TECHNIQUES
useGPU = True
#from utils import get_weight
#import os

def evaluate_disjunction_tensor(lineage:torch.tensor, softmax_predictions:torch.tensor)->torch.tensor:
    # find all positions of 1's 
    indices = torch.nonzero(lineage, as_tuple=False)
    # get the entries of those positions
    predictions = softmax_predictions[indices]
    # 1 - those entries 
    predictions =  1 - predictions
    return 1 - torch.prod(predictions) 

def evaluate_disjunction(lineage:torch.tensor, softmax_predictions:torch.tensor)->torch.tensor:
    probability = torch.tensor([1.0], requires_grad = True).cuda()
    one = torch.tensor([1.0], requires_grad = True).cuda()
    for clas, label in enumerate(lineage):
        if label == 1:
            probability = probability * (one - softmax_predictions[clas])
    return one - probability

def get_SDD_literal_to_weights_mapping(literaldict:Dict[str,int], softmax_predictions:torch.tensor) -> Dict[int,torch.tensor]:
    weights = dict()
    for atom, literal in literaldict.items():
        clas = int(atom[1:])
        weight = softmax_predictions[clas]
        weights[literal] = weight
        weights[-literal] = 1-weight
    return weights

def create_SDD(lineage:List[int])->Tuple[Dict[str,int],SddNode]:    

    manager = SDDManager()
    nodedict = {}
    literaldict = {}
    index = 1
    for clas, label in enumerate(lineage):
        if label == 1:
            atom = "C" + str(clas)
            node = manager.literal(index)
            nodedict[atom] = node
            literaldict[atom] = index
            index = index + 1 
                
    formula = None
    for clas, label in enumerate(lineage):
        atom = "C" + str(clas)
        if label == 1:
            node = nodedict[atom]
            if formula == None:
                formula = node
            else:
                formula = formula.disjoin(node)      

    return literaldict, formula
        
def evaluate_SDD(node:SddNode, weights:Dict[int,torch.tensor])->torch.tensor:
    stack = list()
    nodesToTensors = dict()
    
    stack.append(node)
    
    while len(stack) > 0:
        
        top = stack[len(stack)-1]
        if top not in nodesToTensors:
            
            if top.is_decision():                

                noTensor = False
                for element in top.elements():
                    if element[0] not in nodesToTensors:
                        stack.append(element[0])
                        noTensor = True
                    if element[1] not in nodesToTensors:
                        stack.append(element[1])
                        noTensor = True
                
                if noTensor == False:
                    if useGPU:
                        result = torch.tensor([0.0], requires_grad = True).cuda()
                    else:
                        result = torch.tensor([0.0], requires_grad = True)
                    for element in top.elements():
                        result = result + nodesToTensors[element[0]] * nodesToTensors[element[1]]
                        
                    nodesToTensors[top] = result
                    stack.pop()
                
            elif top.is_literal():
                literal = top.literal
                nodesToTensors[top] = weights[literal]    
                stack.pop()
                
            elif top.is_false():
                if useGPU:
                    nodesToTensors[top] = torch.tensor([0.0], requires_grad = True).cuda()
                else:
                    nodesToTensors[top] = torch.tensor([0.0], requires_grad = True)
                stack.pop()
                            
            elif top.is_true():
                if useGPU:
                    nodesToTensors[top] = torch.tensor([1.0], requires_grad = True).cuda()
                else:
                    nodesToTensors[top] = torch.tensor([1.0], requires_grad = True)
                stack.pop()
        else: 
            stack.pop()
    
    return nodesToTensors[node]
            