# main file for our agent
import os, sys, json, re, random, copy
random.seed(98)
from abc import ABC, abstractmethod 
import numpy as np
from agent.llm.llm_actions import *
from utils.logic.fol import *
from utils.logic.nli import NLI
from utils.logger import logger
from utils.logic.queue import PriorityQueue, Node
from utils.wikidata_types import *

class Agent(ABC):
    def __init__(self):
        pass 
    @abstractmethod
    def __call__(self):
        pass 

    def _query2fol(self, question, llm_name, log):
        # convert query to FOL 
        if llm_name == 'gpt3.5':
            prompts_dir = 'agent/llm/llm_prompts/chat/Query2FOL.yaml'
        elif llm_name == 'gemini':
            prompts_dir = 'agent/llm/llm_prompts/completion/Query2FOL.yaml'
        query_fol_str, response = Query2FOL(prompts_dir, llm_name)(QUESTION=question)
        log("response", response)
        predicate = query_fol_str.split('(')[0].strip()
        args = query_fol_str.split('(')[1].split(')')[0].split(',')
        query_fol = Predicate(predicate, [Constant(arg.strip()) for arg in args])
        return query_fol
    
    def _get_type_axiom(self,question, query_fol, llm_name, log):
        query_fol_str = str(query_fol)

        # get the typed query
        if llm_name == 'gpt3.5':
            prompts_dir = 'agent/llm/llm_prompts/chat/GetTypeAxiom.yaml'
        elif llm_name == 'gemini':
            prompts_dir = 'agent/llm/llm_prompts/completion/GetTypeAxiom.yaml'
        typed_query_str, response = GetTypeAxiom(prompts_dir, llm_name)(QUESTION=question, FOL_QUESTION=query_fol_str)
        
        #typed_query_str = "requires-electricity(x) | x: electronic device"
        #log("response", response)
        typed_query = parse_predicate_string(typed_query_str)
        return typed_query

    

    @abstractmethod
    def prove(self):
        pass 




# class Typed_Resolution(Agent):
#     def __init__(self):
#         super().__init__()
#         self.nli = NLI()
    
#     def __call__(self, query, llm_name, log, max_steps):
#         answer = self._prove(query,llm_name,  log=log, max_steps=max_steps)
#         return str(answer)
    
#     def _prove(self, query, llm_name, log=None, max_steps=10):
#         # convert query to FOL
#         query_fol = self._query2fol(query, llm_name, log)

#         pos_query = query_fol
#         neg_query = self._negate_preds(pos_query)
#         log("negated_query", neg_query)

#         typed_axiom = self._get_type_axiom(query, query_fol, llm_name, log)
#         substitutions, scores = self._unify(typed_axiom, neg_query)

#         answer = all(a >= b for a, b in zip(scores, scores_neg))

#         return answer






class Typed_Resolution_synthetic(Agent):
    def __init__(self):
        super().__init__()
        self.nli = NLI()
        
    def __call__(self, query, llm_name, log, max_steps, typed_constants='True'):

        answer = self.prove(query,llm_name,  log=log, max_steps=max_steps, typed_constants=typed_constants)
        return str(answer)
    

    def prove(self, query, llm_name, log=None, max_steps=10, typed_constants='True'):

        # # convert query to FOL
        query_fol = self._query2fol(query, llm_name, log)
        if typed_constants == 'True':
            self._set_const_types(query_fol)

        pos_query = query_fol
        neg_query = self._negate_preds(pos_query)
        log("negated_query", neg_query)




        typed_axiom = self._get_type_axiom(query, query_fol, llm_name, log)

        ##typed_axiom = Predicate('can_hold', [Variable('x', 'physical container'), Variable('y', 'physical object')])
        typed_axiom_neg = self._negate_types(typed_axiom)

        ##neg_query = Not(Predicate('can_hold', [Constant('basket'), Constant('bing')]))
    
        substitutions, scores = self._unify(typed_axiom, neg_query, typed_constants)
        log("substitutions", substitutions)

        substitutions_neg, scores_neg = self._unify(typed_axiom_neg, neg_query, typed_constants)
        log("substitutions_neg", substitutions_neg)

        answer = all(a >= b for a, b in zip(scores, scores_neg))

        return answer

    
    
    def _negate_types(self, clause):
        
        if isinstance(clause, Predicate):
            neg_args = []
            for arg in clause.args[0]:
                if isinstance(arg, Variable):
                    if arg.type is not None:
                        neg_args.append(Variable(arg.name, 'not '+arg.type))
                elif isinstance(arg, Constant):
                    neg_args.append(Constant('not ' + arg.name))
            clause_neg_typed = Predicate(clause.name, neg_args)
        elif isinstance(clause, Not):
            inner_clause = clause.arg
            neg_args = []
            for arg in inner_clause.args[0]:
                if isinstance(arg, Variable):
                    if arg.type is not None:
                        neg_args.append(Variable(arg.name, 'not '+arg.type))
                elif isinstance(arg, Constant):
                    neg_args.append(Constant('not ' + arg.name))
            clause_neg_typed = Not(Predicate(inner_clause.name, neg_args))

        return clause_neg_typed

                    

    
    def _unify(self, pos, neg, typed_constants):

        # if isinstance(literal1, Predicate) and isinstance(literal2, Predicate) or isinstance(literal1, Not) and isinstance(literal2, Not):
        #     return None, None
        # elif isinstance(literal1, Predicate) and isinstance(literal2, Not):
        #     pos = literal1; neg = literal2.arg
        # elif isinstance(literal1, Not) and isinstance(literal2, Predicate):
        #     pos = literal2; neg = literal1.arg


        # the dict of substitutions with the corresponding scores and types
        substitutions = {}
        if len(pos.args[0]) != len(neg.arg.args[0]):
            return None

        for i in range(len(pos.args[0])):
            # Case 1: unifying constant with constant
            # TODO: for now just exact sting matches, but need to allow co-references
            if isinstance(pos.args[0][i], Constant) and isinstance(neg.arg.args[0][i], Constant):
                if pos.args[0][i].name != neg.arg.args[0][i].name:
                    return None
                else:
                    substitutions[pos.args[0][i]] = (neg.arg.args[0][i], neg.arg.args[0][i], 1)
                    substitutions[neg.arg.args[0][i]] = (pos.args[0][i], pos.args[0][i], 1)
                
            # Case 2: unifying variable with constant    
            elif isinstance(pos.args[0][i], Constant) and isinstance(neg.arg.args[0][i], Variable):
                neg_type = neg.arg.args[0][i].type
                if typed_constants == 'True':
                    pos_types = pos.args[0][i].type
                    type_score = 0
                    for pos_type in pos_types:
                        type_score1 = self.nli(pos_type, neg_type)['scores'][0]; type_score2 = self.nli(neg_type, pos_type)['scores'][0]
                        if type_score1 > type_score2 and type_score1 > type_score:
                            substituted_type = pos_type
                            type_score = type_score1
                        elif type_score2 > type_score1 and type_score2 > type_score:
                            substituted_type = neg_type
                            type_score = type_score2
                
                else:
                    pos_type = pos.args[0][i].name
                
                    neg_type = neg.arg.args[0][i].type
                    type_score1 = self.nli(pos_type, neg_type)['scores'][0]; type_score2 = self.nli(neg_type, pos_type)['scores'][0]
                    if type_score1 > type_score2:
                        substituted_type = pos_type
                        type_score = type_score1
                    else:
                        substituted_type = neg_type
                        type_score = type_score2
                substitutions[neg.arg.args[0][i]] = (pos.args[0][i], substituted_type, type_score)
                substitutions[pos.args[0][i]] = (pos.args[0][i], substituted_type, type_score)
            elif isinstance(pos.args[0][i], Variable) and isinstance(neg.arg.args[0][i], Constant):
                pos_type = pos.args[0][i].type
                if typed_constants == 'True':
                    neg_types = neg.arg.args[0][i].type
                    type_score = 0
                    for neg_type in neg_types:
                        type_score1 = self.nli(pos_type, neg_type)['scores'][0]; type_score2 = self.nli(neg_type, pos_type)['scores'][0]
                        if type_score1 >= type_score2 and type_score1 > type_score:
                            substituted_type = pos_type
                            type_score = type_score1
                        elif type_score2 > type_score1 and type_score2 > type_score:
                            substituted_type = neg_type
                            type_score = type_score2
                
                else:
                    neg_type = neg.arg.args[0][i].name
                    type_score1 = self.nli(pos_type, neg_type)['scores'][0]; type_score2 = self.nli(neg_type, pos_type)['scores'][0]
                    # variable being the more specific type
                    if type_score1 > type_score2:
                        substituted_type = pos_type
                        type_score = type_score1
                    else:
                        substituted_type = neg_type
                        type_score = type_score2

                substitutions[pos.args[0][i]] = (neg.arg.args[0][i], substituted_type, type_score)
                substitutions[neg.arg.args[0][i]] = (neg.arg.args[0][i] , substituted_type, type_score)
            

            # Case 3: unifying variable with variable
            elif isinstance(pos.args[0][i], Variable) and isinstance(neg.arg.args[0][i], Variable):
                pos_type = pos.args[0][i].type
                neg_type = neg.arg.args[0][i].type
                type_score1 = self.nli(pos_type, neg_type)['scores'][0]; type_score2 = self.nli(neg_type, pos_type)['scores'][0]
                if type_score1 > type_score2:
                    substituted_type = pos_type
                    type_score = type_score1
                else:
                    substituted_type = neg_type
                    type_score = type_score2
                substitutions[pos.args[0][i]] = (neg.arg.args[0][i], substituted_type, type_score)
                substitutions[neg.arg.args[0][i]] = (pos.args[0][i], substituted_type, type_score)

            
            

        scores = [substitution[2] for substitution in substitutions.values()]
        return substitutions, scores


    def _set_const_types(self, query_fol):
        if isinstance(query_fol, Predicate):
            for const in query_fol.args[0]:
                const_name = const.name
                const_types = get_wikidata_types(const_name)
                if const_types:
                    const.type = const_types
                else:
                    const.type = const.name
            
        elif isinstance(query_fol, Not):
            inner_clause = query_fol.arg
            for const in inner_clause.args[0]:
                const_name = const.name
                const_types = get_wikidata_types(const_name)
                if const_types:
                    const.type = const_types
                else:
                    const.type = const.name

    




