from .semantic_loss.py3psdd import PSddManager
from .semantic_loss import SemanticLoss
from pysdd.sdd import Vtree, SddManager
from typing import List
import torch
import os
import time

class WMC():

    def __init__(self, eps, n_variables=2):
        self.n_variables = n_variables
        self.CONSTRAINT2FUNCTION = {
            "implication": self.implication_constraint,
            "negation": self.negation_constraint,
            "inverse_implication": self.inverse_implication_constraint,
            "all": self.all_constraints
        }
        self.eps = eps

    @staticmethod
    def need_negations(constraint):
        """ Checks constraint type to split cases by design """
        return constraint in ["negation", "all"]

    def implication_constraint(self, l1:object, l2:object):
        """ Given sdd literals, return the formula """
        return (-l1 | l2)

    def inverse_implication_constraint(self, l1:object, l2:object):
        """ Given sdd literals, return the formula """
        return (l2 | -l1)

    def negation_constraint(self, l1:object, l2:object, l1_not:object, l2_not:object):
        """ Given sdd literals, return the formula """
        return ((l1 | l1_not) & (-(l1 & l1_not))) & ((l2 | l2_not) & (-(l2 & l2_not)))
        
    def all_constraints(self, l1:object, l2:object, l1_not:object, l2_not:object):
        """ Combined the constraints above"""
        return (
            self.implication_constraint(l1, l2) &
            self.negation_constraint(l1, l2, l1_not, l2_not) &
            self.inverse_implication_constraint(l1, l2)
        )

    def get_psdd(self, literals:List[int], ground_labels:List[int], constraint:str):
        """ Compile the constraint formula for the solver to compute the valid worlds
            literals:       List[int]                               from the constraint, signs for the literals
            ground_labels:  dict                                    ground truth assignments to constraint literals
            constraint:     str                                     constraint type
        """
        if ground_labels is None: ground_labels = [-1, -1]
        # Checking for existance of cached vtrees
        if not os.path.exists("cache"):
            os.makedirs("cache")
        cachename = f"{'_'.join(map(str, literals.tolist()))}_{'_'.join(map(str, list(ground_labels)))}"
        sdd_cachename = os.path.join("cache", f"{cachename}.sdd")
        vtree_cachename = os.path.join("cache", f"{cachename}.vtree")
        # Load or build them
        if os.path.isfile(sdd_cachename) and os.path.isfile(vtree_cachename):
            return sdd_cachename, vtree_cachename
        else:
            t1 = time.time()
            # Setup SDD and variables
            if WMC.need_negations(constraint): self.n_variables = 4
            else: self.n_variables = 2
            a_symbol, b_symbol = literals
            # Compile formula
            vtree = Vtree(var_count=self.n_variables, var_order=list(range(1, self.n_variables+1)), vtree_type="balanced")
            sdd = SddManager.from_vtree(vtree)
            if WMC.need_negations(constraint): a, b, not_a, not_b = sdd.vars
            else: a, b = sdd.vars
            # Apply constraint negations if present in formula
            if a_symbol == 0: a = -a
            if b_symbol == 0: b = -b
            # Get constraint formula
            if WMC.need_negations(constraint): 
                if not_a is None or not_b is None: raise Exception("Invalid not_A or not_B probability tensor.") 
                formula = self.CONSTRAINT2FUNCTION[constraint](l1=a, l2=b, l1_not=not_a, l2_not=not_b)
            else: formula = self.CONSTRAINT2FUNCTION[constraint](l1=a, l2=b)
            # Apply grounding if provided
            if ground_labels:
                ant, cons = ground_labels
                # since a,b have been flipped before (symbol == 0), 
                # re-flip to take the grounding
                if ant == 0: 
                    if a_symbol == 0: formula = formula & a
                    else: formula = formula & -a
                if ant == 1: 
                    if a_symbol == 0: formula = formula & -a
                    else: formula = formula & a
                if cons == 0: 
                    if b_symbol == 0: formula = formula & b
                    else: formula = formula & -b
                if cons == 1: 
                    if b_symbol == 0: formula = formula & -b
                    else: formula = formula & b
            sdd.save(sdd_cachename.encode('utf-8'), formula)
            vtree.save(vtree_cachename.encode('utf-8'))
        return sdd_cachename, vtree_cachename
    
    def sl(self, p1:object, p2:object, batch_symbols:List[List[int]], batch_facts:List[List[int]], constraint:str, p1_not:object=None, p2_not:object=None):
        """ Compute the semantic loss from probabilities, constraint literals signs, constraint type 
            p:          (B, 2) tensor of batch probabilities for vars in a formula
            symbols:    (B, 2) tensor of symbols for vars in a formula  
            facts:      (B, 2) tensor of ground facts assignments for vars in a formula  
        """
        # select from batch_p of: B, 4, 1
        # where 4: p1, p2, p1_not, p2_not
        if WMC.need_negations(constraint):
            p_batch = torch.concat((p1, p2, p1_not, p2_not), dim=1) # B, 4, 1
        else: # normal ordering for the other
            p_batch = torch.concat((p1, p2), dim=1) # B, 2, 1
        # Apply sample-wise semantic loss and concatenate in a batch
        batch_loss = None
        for idx in range(p_batch.shape[0]):
            # formula signs and grounded literals
            s = batch_symbols[idx] # 2, 1
            g = None if batch_facts is None else batch_facts[idx] # B, 2, 1
            sdd_filepath, vtree_filepath = self.get_psdd(literals=s, ground_labels=g, constraint=constraint)
            # adding 1-p, formatting probabilities
            p = p_batch[idx].unsqueeze(0).unsqueeze(-1) + self.eps # 1, N, 1 
            # loss relying on tmp sdd compiling & loading
            sl = SemanticLoss(sdd_filepath, vtree_filepath)
            loss = sl(probabilities=p)
            if batch_loss is None: batch_loss = loss.unsqueeze(0)
            else: batch_loss = torch.cat((batch_loss, loss.unsqueeze(0)), dim=0)
        return batch_loss.sum()