from pyparsing import nestedExpr
from common import bool_token, kw_preds, not_kw_preds, var2vid, const2cid

class ExpGen():

    def __init__(self, ct=0):
        self.ct = ct

    def get(self):
        self.ct += 1
        return self.ct

    def reset(self):
        self.ct = 0

exp_gen = ExpGen()
class BoolFormula():

    def __init__(self, name, args) -> None:
        self.name = name
        self.args = args
        if not len(args) == 0:
            self.v = set.union(*[arg.v for arg in args])
        else:
            self.v = set()

    def __eq__(self):
        arg_num = len(self.args)
        for ct1 in range(arg_num):
            has_match = False
            for ct2 in range(arg_num):
                if not ct1 == ct2:
                    if self.args[ct1] == self.args[ct2]:
                        has_match = True
            if not has_match:
                return False
        return True

    def collect_preds(self):
        all_preds = set()
        for arg in self.args:
            all_preds = all_preds.union(arg.collect_preds())
        return all_preds

    def collect_tuples(self):
        all_preds = []
        for arg in self.args:
            tuple = arg.collect_tuples()
            if not len(tuple) == 0:
                all_preds += (tuple)
        return all_preds

    def to_scl(self, expression_id):

        expressions = {}
        if len(self.args) == 0:
            return expression_id, expressions

        current_eid = expression_id
        for arg in self.args:
            current_eid, arg_scl = arg.to_scl(current_eid)
            expressions = merge_dict_ls(expressions, arg_scl)

        # Should only have this case
        assert self.name == "and"

        return current_eid, expressions

def process_arg_num(raw_arg):
    if raw_arg[0] == '?':
        arg = var2vid[raw_arg[1:]]
    else:
        arg = const2cid[raw_arg]
    return arg

def process_arg_str(raw_arg):
    if raw_arg[0] == '?':
        arg = raw_arg[1:]
    else:
        arg = raw_arg
    return arg

def process_arg(raw_arg, num=True):
    if num:
        return process_arg_num(raw_arg)
    else:
        return process_arg_str(raw_arg)

def merge_dict_ls(d1, d2):
    for k, v in d2.items():
        if not k in d1:
            d1[k] = []
        d1[k] += v
    return d1

class PosPredicate():

    def __init__(self, pred, args) -> None:
        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(True, self.name, self.args)]

    def to_scl(self, expression_id):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            if not 'positive_unary_atom' in expressions:
                expressions['positive_unary_atom'] = []
            expressions['positive_unary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0]), current_eid))

        else:
            assert(len(self.args) == 2)

            if not 'positive_binary_atom' in expressions:
                expressions['positive_binary_atom'] = []
            expressions['positive_binary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0]),
                process_arg(self.args[1]),
                current_eid))

        return next_eid, expressions

class NegPredicate():

    def __init__(self, pred, args) -> None:
        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(False, self.name, self.args)]

    def to_scl(self, expression_id):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            expressions['negative_unary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0]), current_eid)]
        else:
            assert(len(self.args) == 2)
            expressions['negative_binary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0]),
                process_arg(self.args[1]),
                current_eid)]

        return next_eid, expressions

class InEqPred():

    def __init__(self, pred, args) -> None:
        assert pred == '!='
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other) -> bool:
        return self.args == other.args

    def collect_preds(self):
        return set()

    def collect_tuples(self):
        return []

    def to_scl(self, expression_id):
        current_eid = expression_id
        next_eid = exp_gen.get()
        expressions = {}
        expressions['inequality_constraint'] = [
            (next_eid,
            process_arg(self.args[0]),
            process_arg(self.args[1]),
            current_eid)]

        return next_eid, expressions

class Action():

    def __init__(self, name, param, precondition, effect) -> None:
        self.name = name
        self.param = param
        self.precondition = precondition
        self.effect = effect

    def collect_preds(self):
        pred_preds = self.precondition.collect_preds()
        effect_preds = self.effect.collect_preds()
        pred_preds = pred_preds.union(effect_preds)
        return pred_preds

    def collect_tuples(self):
        pred_preds = self.precondition.collect_tuples()
        effect_preds = self.effect.collect_tuples()
        pred_preds = pred_preds + effect_preds
        return pred_preds

    def to_scl(self):
        scl_facts = {}
        pred_cond_eid, precondition_scl = self.precondition.to_scl(0)
        scl_facts = merge_dict_ls(scl_facts, precondition_scl)
        effect_eid, effect_scl = self.effect.to_scl(0)
        scl_facts = merge_dict_ls(scl_facts, effect_scl)
        scl_facts['precondition'] = [tuple([pred_cond_eid])]
        scl_facts['effect'] = [tuple([effect_eid])]

        return scl_facts


class Axiom():

    def __init__(self, context, implies, variables) -> None:
        self.context = context
        self.implies = implies
        self.variables = variables

    def to_scl():
        pass

class PDDLProg():

    def __init__(self, types, consts, predicates, actions, axioms) -> None:
        self.consts = consts
        self.types = types
        self.predicates = predicates
        self.actions = {}
        for action in actions:
            self.actions[action.name] = action
        self.axioms = axioms

    def action_to_scl(self, name) -> str:
        exp_gen.reset()
        action = self.actions[name]
        scl_exp =  action.to_scl()
        return scl_exp

    def to_scl(self):
        return {name: action.to_scl() for name, action in self.actions.items()}

def process_type_defs(def_ls):
    type_def = []
    current_def = {}
    for ct, token in enumerate(def_ls):

        if ct % 3 == 0:
            if not len(current_def) == 0:
                type_def.append(current_def)
            current_def = {}
            current_def['name'] = token

        if ct % 3 == 2:
            current_def['type'] = token

    if not len(current_def) == 0:
        type_def.append(current_def)

    return type_def

def string_var_transition(ls, current_id=0, prefix ='?o'):
    pred = ls[0]
    args = ls[1:]
    new_clause = [pred]
    new_vdict = {}

    for arg in args:
        if type(arg) == str:
            if not arg[0] == '?':
                new_clause.append(arg.upper())
            else:
                new_clause.append(arg)
        else:
            arg_clause = string_var_transition(arg, current_id, prefix)
            new_clause.append(arg_clause)

    return new_clause

class PDDLParser():

    def __init__(self) -> None:
        pass

    def process_params(self, params_ls):
        return [param[0] for param in params_ls]

    def process_logic_pred(self, logic_ls):

        if logic_ls[0] in bool_token:
            args = []

            # Negative leaf predicates
            if logic_ls[0] == 'not' and len(logic_ls) > 1 and not logic_ls[1][0] in bool_token:
                # Not equivalent
                if logic_ls[1][0] in kw_preds:
                    args = logic_ls[1][1:]
                    return InEqPred(not_kw_preds[logic_ls[1][0]], args)

                # Negtive predicates
                else:
                    args = logic_ls[1][1:]
                    return NegPredicate(logic_ls[1][0], args)

            # Ordinary boolean formula
            else:
                for element in logic_ls[1:]:
                    b = self.process_logic_pred(element)
                    args.append(b)
                return BoolFormula(logic_ls[0], args)

        # Equvalent ?
        elif logic_ls[0] in kw_preds:
            raise Exception("No direct equlency should be used")
        elif not logic_ls[0][0].isalpha():
            raise Exception('Error: predicate should start with alphabets')
        # Positive predicates
        else:
            args = logic_ls[1:]
            return PosPredicate(logic_ls[0], args)

    def parse(self, pddl_prog: str) -> PDDLProg:
        no_comment_prog = ""
        for line in pddl_prog.split('\n'):
            # if ';' in line:
            #     print(line)
            if len(line.strip()) == 0 or line.strip()[0] == ';':
                continue
            no_comment_prog += (line)

        result = nestedExpr('(',')').parseString(no_comment_prog).asList()
        assert len(result) == 1
        result = result[0]

        prog = []
        for tokens in result:
            current_prog_dict = {}
            key = None

            if type(tokens) == str or tokens[0] == 'domain':
                continue

            for token in tokens:
                if type(token) == str and token[0] == ':':
                    current_prog_dict[token[1:]] = []
                    key = token[1:]
                else:
                    current_prog_dict[key].append(token)

            prog.append(current_prog_dict)

        predicates = [(p[0].replace('-', '_'), len([i for i in p if i[0] == '?'])) for p in prog[3]['predicates']]

        types = process_type_defs(prog[1]['types'])
        consts = process_type_defs(prog[2]['constants'])
        actions = []
        axioms = []
        v_dict = {}

        for prog_single in prog[4:]:
            if 'action' in prog_single:
                name=prog_single['action']
                assert len(name) == 1
                name = name[0]
                param=process_type_defs(prog_single['parameters'])

                preprocessed_precondition=string_var_transition(prog_single['precondition'][0])
                precondition=self.process_logic_pred(preprocessed_precondition)

                preprocessed_effect= string_var_transition(prog_single['effect'][0])
                effect=self.process_logic_pred(preprocessed_effect)

                action = Action(name, param, precondition, effect)
                actions.append(action)

            if 'axiom' in prog_single:
                context=prog_single['context']
                implies=prog_single['implies']
                variables=process_type_defs(prog_single['vars'])
                axioms.append(Axiom(context=context, implies=implies, variables=variables))

        prog = PDDLProg(types, consts, predicates, actions, axioms)
        return prog

if __name__ == "__main__":
    test_pddl = '''
    (define (domain twentybn)
	(:requirements :strips :typing :equality :negative-preconditions :quantified-preconditions :conditional-effects :domain-axioms :derived-predicates)
	(:types
		location - object
		sth - object
		void - object
	)
	(:constants
		hand - object
	)
	(:predicates
		; static properties
		(is-bendable ?a - sth)
		; TODO: Add not fluid to preconditions
		(is-fluid ?a - sth)
		(is-holdable ?a - sth)
		(is-rigid ?a - sth)
		(is-spreadable ?a - sth)
		(is-tearable ?a - sth)

		; mutable properties
		(above ?a - sth ?b - sth)
		(attached ?a - sth ?b - sth)
		(behind ?a - sth ?b - sth)
		(broken ?a - sth)
		(close ?a - sth)
		(closed ?a - sth)
		(deformed ?a - sth)
		(empty ?a - sth)
		(far ?a - sth)
		(fits ?a - sth ?b - sth)
		(folded ?a - sth)
		(full ?a - sth)
		(has-hole ?a - sth)
		(high ?a - sth)
		(in ?a - sth ?b - object)
		(infront ?a - sth ?b - sth)
		(left ?a - sth)
		(low ?a - sth)
		(nextto ?a - sth ?b - sth)
		(on ?a - sth ?b - sth)
		(onsurface ?a - sth)
		(open ?a - sth)
		(right ?a - sth)
		(stacked ?a - sth)
		(stretched ?a - sth)
		(torn ?a - sth)
		(touching ?a - object ?b - object)
		(twisted ?a - sth)
		(under ?a - sth ?b - sth)
		(upright ?a - sth)
		(visible ?a - object)
	)

	; 0 Approaching something with your camera
	(:action approach
		:parameters (?a - sth)
		:precondition (and
			(not (close ?a))
			(visible ?a)
			(not (visible hand))
		)
		:effect (close ?a)
	)

    ; 1 Attaching something to something
	(:action attach
		:parameters (?a - sth ?b - sth)
		:precondition (and
			(not (= ?a ?b))
			(not (far ?b))
			(in ?a hand)
			(not (touching ?a ?b))
			(visible ?b)
		)
		:effect (and
			(attached ?a ?b)
			(not (touching ?a hand))
		)
	))
    '''


    parser = PDDLParser()
    prog = parser.parse(test_pddl)
    print(prog.action_to_scl("approach"))
    print('here')
