from abc import ABC, abstractmethod 
from utils.logic.nli import NLI
from utils.logic.fol import *
import re, copy, statistics
import heapq

def convert_camel_case_to_spaces(s):
    return re.sub('([a-z])([A-Z])', r'\1 \2', s)



class Prioritizer(ABC):
    def __init__(self):
        pass
    
    @abstractmethod
    def __call__(self, recipe_premise, query_hypothesis):
        pass

    def _grammar2str(self, grammar):
        flatenned_list = flatten_and(grammar)
        # extract the predicate names from the grammar
        processed_predicates = ["not " + str(x.arg.name) if isinstance(x, Not) else str(x.name) for x in flatenned_list]
        # replace underscores with spaces
        processed_predicates = [x.replace("_", " ") for x in processed_predicates]
        # convert camel case to spaces
        processed_predicates = [convert_camel_case_to_spaces(x) for x in processed_predicates]
        return processed_predicates
            
    def _str2grammar(self, list_str):
        # s: ~spicy
        # check if s starts with ~
        #conjuncts = [Not(Predicate(s[1:], Variable('x'))) if s.startswith('~') else Predicate(s, Variable('x')) for s in list_str]
        # no longer appending not, but using the LLM for negations
        conjuncts = [Predicate(s, Variable('x')) for s in list_str]
        ands = and_list(conjuncts)
        return ands

class Pure_prioritizer(Prioritizer):
    def __init__(self):
        pass
    
    def __call__(self, premise, goal):
        """
        calculate the score of a clause based on its match or entailment with the goal
        """
        if not isinstance(premise, str):
            premise_str = self._grammar2str(premise)[0]
            goal_str = self._grammar2str(goal)[0]
        else:
            premise_str = premise
            goal_str = goal
        nli = NLI()
        score = nli(premise_str, goal_str)
        return score



class GD_prioritizer(Prioritizer):
    def __init__(self):
        pass
    
    def __call__(self, kb, goal, input_theta, k=2 ,prioritization='nli'):
        self.prioritization = prioritization
        self.nli = NLI()
        clause_scores = []
        original_clauses = []
        clause_thetas = []
        for clause in kb:
            #theta = copy.deepcopy(input_theta)
            theta = {}
            clause_score, clause_theta = self._score_clause(clause, goal, theta)
            if clause_score[0] is None:
                continue
            clause_scores.append(clause_score)
            original_clauses.append(clause)
            clause_thetas.append(clause_theta)
        

        #return clause_scores, original_clauses, clause_thetas
        return self._get_topk_clauses(clause_scores, original_clauses, clause_thetas, k)
    
    def _get_topk_clauses(self, clause_scores, original_clauses, clause_thetas, k):
        scores_list = []
        for score in clause_scores:
            scores_list.append(score[1])
        sorted_tuples = sorted(enumerate(scores_list), key=lambda x: (x[1][0], x[1][1]), reverse=True)
        top_k_indices = [index for index, _ in sorted_tuples[:k]]

        clause_scores_topk = []; original_clauses_topk = []; clause_thetas_topk = []
        for ind in top_k_indices:
            clause_scores_topk.append(clause_scores[ind])
            original_clauses_topk.append(original_clauses[ind])
            clause_thetas_topk.append(clause_thetas[ind])

        return clause_scores_topk, original_clauses_topk, clause_thetas_topk
        

    
    
    def _score_clause(self, clause, goal, theta):
        """
        calculate the score of a clause based on its match or entailment with the goal
        """


        if isinstance(goal, Not):
            unsigned_goal = goal.arg
            no_goal_args = len(goal.arg.args)
        else:
            if isinstance(goal, Predicate):
                unsigned_goal = goal
                no_goal_args = len(goal.args)
        goal_str = self._grammar2str(goal)[0]

        if isinstance(clause, Predicate):     
            # two literals with different number of arguments cannot be unified
            if len(clause.args) != no_goal_args:
                return (None, None, None), theta       
            
            # if all arguments of the clause are constants, they have to be the same as the goal arguments
            if all(isinstance(arg, Constant) for arg in clause.args) and all(isinstance(arg, Constant) for arg in unsigned_goal.args):
                for i, arg in enumerate(clause.args):
                    if str(arg) != str(unsigned_goal.args[i]):
                        return (None, None, None), theta

            type_score, child_types = self.get_types_score(clause, unsigned_goal)
            if str(clause.name) == str(unsigned_goal.name):
                if isinstance(goal, Not):
                    theta = unify(clause.args, goal.arg.args, theta)
                    child_types = self.update_child_types(child_types, theta)
                    if theta is None:
                        return (None, None, None), {}
                    clause = substitute(theta, clause)
                    return (clause, (type_score,-2), child_types), theta
                elif isinstance(goal, Predicate):
                    theta = unify(clause.args, goal.args, theta)
                    child_types = self.update_child_types(child_types, theta)
                    if theta is None:
                        return (None, None, None), {}
                    clause = substitute(theta, clause)
                    return (clause, (type_score,2), child_types), theta
            else:
                pred_score = self.nli(str(clause.name), goal_str)
                return (clause, (type_score, pred_score), child_types), theta
            

        elif isinstance(clause, Not):
            if len(clause.arg.args) != no_goal_args:
                return (None, None, None), theta
            
            # if all arguments of the clause are constants, they have to be the same as the goal arguments
            if all(isinstance(arg, Constant) for arg in clause.arg.args) and all(isinstance(arg, Constant) for arg in unsigned_goal.args):
                for i, arg in enumerate(clause.args):
                    if str(arg) != str(goal.args[i]):
                        return (None, None, None), theta

            type_score, child_types = self.get_types_score(clause.arg, unsigned_goal)
            if str(clause.arg.name) == str(unsigned_goal.name):
                if isinstance(goal, Not):
                    theta = unify(clause.arg.args, goal.arg.args, theta)
                    child_types = self.update_child_types(child_types, theta)
                    if theta is None:
                        return (None, None, None), {}
                    clause = substitute(theta, clause)
                    return (clause, (type_score,2), child_types), theta
                elif isinstance(goal, Predicate):
                    theta = unify(clause.arg.args, goal.args, theta)
                    child_types = self.update_child_types(child_types, theta)
                    clause = substitute(theta, clause)
                    return (clause, (type_score,-2), child_types), theta
            else:
                negated_clause = self._grammar2str(self.negate_preds(clause.arg))[0]
                pred_score = self.nli(negated_clause, goal_str)
                return (clause, (type_score, pred_score), child_types), theta
                
        elif isinstance(clause, ForAll):
            if isinstance(clause.body.rhs, Predicate):
                if len(clause.body.rhs.args) != no_goal_args:
                    return (None, None, None), {}
                type_score, child_types = self.get_types_score(clause.body.rhs, unsigned_goal)
                if isinstance(goal, Not):
                    theta = unify(clause.body.rhs.args, goal.arg.args, theta)
                    child_types = self.update_child_types(child_types, theta)
                    if theta is None:
                        return (None, None, None), {}
                    clause = substitute(theta, clause)
                    if str(clause.body.rhs.name) == str(unsigned_goal.name):
                        return (clause, (type_score, -1), child_types), theta
                    else:
                        rhs_str = self._grammar2str(clause.body.rhs)[0]
                        pred_score = self.nli(rhs_str, goal_str)
                        return (clause, (type_score, pred_score), child_types), theta
                elif isinstance(goal, Predicate):
                    try:
                        theta = unify(clause.body.rhs.args, goal.args, theta)
                    except:
                        theta = None
                    if theta is None:
                        return (None, None, None), {}
                    child_types = self.update_child_types(child_types, theta)
                    clause = substitute(theta, clause)
                    if str(clause.body.rhs.name) == str(unsigned_goal.name):
                        return (clause, (type_score, 1), child_types), theta
                    else:
                        rhs_str = self._grammar2str(clause.body.rhs)[0]
                        pred_score = self.nli(rhs_str, goal_str)
                        return (clause, (type_score, pred_score), child_types), theta





            elif isinstance(clause.body.rhs, Not):
                if len(clause.body.rhs.arg.args) != len(goal.args):
                    return (None, None, None), theta
                type_score, child_types = self.get_types_score(clause.arg, unsigned_goal)

                if isinstance(goal, Not):
                    theta = unify(clause.body.rhs.arg.args, goal.arg.args, theta)
                    child_types = self.update_child_types(child_types, theta)
                    if theta is None:
                        return (None, None, None), {}
                    clause = substitute(theta, clause)
                    if str(clause.body.rhs.arg.name) == str(unsigned_goal.name):
                        return (clause, (type_score, 1), child_types), theta
                    else:
                        negated_rhs = self._grammar2str(self.negate_preds(clause.body.rhs.arg))[0]
                        pred_score = self.nli(negated_rhs, goal_str)
                        return (clause, (type_score, pred_score), child_types), theta
                    
                elif isinstance(goal, Predicate):
                    theta = unify(clause.body.rhs.arg.args, goal.args, theta)
                    child_types = self.update_child_types(child_types, theta)
                    if theta is None:
                        return (None, None, None), {}
                    clause = substitute(theta, clause)
                    if str(clause.body.rhs.arg.name) == str(unsigned_goal.name):
                        return (clause, (type_score, -1), child_types), theta
                    else:
                        negated_rhs = self._grammar2str(self.negate_preds(clause.body.rhs.arg))[0]
                        pred_score = self.nli(negated_rhs, goal_str)
                        return (clause, (type_score, pred_score), child_types), theta

                

    def update_child_types(self, child_types, theta):
        if theta is None:
            return child_types
        for key, value in child_types.items():
            if key in theta:
                child_types[key] = theta[key][1]
        return child_types


    def get_types_score(self,clause_rhs, goal):

        clause_scores = []
        child_types = {}
        goal_types = [arg.type for arg in goal.args]
        goal_args = [arg for arg in goal.args]
        for i, arg in enumerate(list(clause_rhs.args)):
            clause_rhs_type = arg.type
            if clause_rhs_type == goal_types[i]:
                clause_scores.append(1)
                child_types[arg.name] = goal_types[i]
            else:
                # the score of cluase type entailing the goal type
                score_1 = self.nli(clause_rhs_type, goal_types[i])
                # the score of goal type entailing the clause type
                score_2 = self.nli(goal_types[i], clause_rhs_type)
                if score_1 > score_2:
                    clause_scores.append(score_1)
                    child_types[arg.name] = clause_rhs_type
                else:
                    clause_scores.append(score_2)
                    if isinstance(arg, Variable) and isinstance(goal_args[i], Variable):
                        child_types[arg.name] = goal_types[i]
                
        type_score = statistics.mean(clause_scores)
        return type_score, child_types

            
    
    def negate_preds(self, literal):
        """
        return the negation of a literal
        """
        if isinstance(literal, Predicate):
            return Not(literal)
        elif isinstance(literal, Not):
            return literal.arg
        else:
            raise ValueError("literal is neither a Predicate nor a Not")