from pyparsing import nestedExpr
from utils import bool_token, kw_preds, not_kw_preds, var2vid, const2cid, rec_sub_val, all_binary_preds
import os 
import json
import re 
import pickle


class ExpGen():

    def __init__(self, ct=0):
        self.ct = ct

    def get(self):
        self.ct += 1
        return self.ct

    def reset(self):
        self.ct = 0

exp_gen = ExpGen()
class BoolFormula():

    def __init__(self, name, args) -> None:
        self.name = name
        self.args = args
        if not len(args) == 0:
            self.v = set.union(*[arg.v for arg in args])
        else:
            self.v = set()

    def __eq__(self):
        arg_num = len(self.args)
        for ct1 in range(arg_num):
            has_match = False
            for ct2 in range(arg_num):
                if not ct1 == ct2:
                    if self.args[ct1] == self.args[ct2]:
                        has_match = True
            if not has_match:
                return False
        return True

    def collect_preds(self):
        all_preds = set()
        for arg in self.args:
            all_preds = all_preds.union(arg.collect_preds())
        return all_preds

    def collect_tuples(self):
        all_preds = []
        for arg in self.args:
            tuple = arg.collect_tuples()
            if not len(tuple) == 0:
                all_preds += (tuple)
        return all_preds

    def to_scl(self, expression_id, consts):

        expressions = {}
        if len(self.args) == 0:
            return expression_id, expressions

        current_eid = expression_id
        for arg in self.args:
            current_eid, arg_scl = arg.to_scl(current_eid, consts)
            expressions = merge_dict_ls(expressions, arg_scl)

        # Should only have this case
        assert self.name == "and"

        return current_eid, expressions

    def to_str(self, expression_id, consts):
        
        expressions = {}
        if len(self.args) == 0:
            return expression_id, ""

        current_eid = expression_id
        expression_strs = []
        for arg in self.args:
            current_eid, arg_str = arg.to_str(current_eid, consts)
            expression_strs.append(arg_str)
        
        expressions = "And(" + ', '.join(expression_strs) + ')'
        
        # Should only have this case
        assert self.name == "and"

        return current_eid, expressions
    
    
def process_arg_num(raw_arg, consts):
    if raw_arg[0] == '?':
        arg = var2vid[raw_arg[1:]]
    else:
        arg = -consts.index(raw_arg) - 1
    return arg

def process_arg_str(raw_arg):
    if raw_arg[0] == '?':
        arg = raw_arg[1:]
    else:
        arg = raw_arg
    return arg

def process_arg(raw_arg, consts, num=True):
    if num:
        return process_arg_num(raw_arg, consts)
    else:
        return process_arg_str(raw_arg)

def merge_dict_ls(d1, d2):
    for k, v in d2.items():
        if not k in d1:
            d1[k] = []
        d1[k] += v
    return d1

class PosPredicate():

    def __init__(self, pred, args) -> None:
        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(True, self.name, self.args)]

    def to_scl(self, expression_id, consts):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            if not 'positive_unary_atom' in expressions:
                expressions['positive_unary_atom'] = []
            expressions['positive_unary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0], consts), 
                current_eid))
            

        elif len(self.args) == 2:

            if not 'positive_binary_atom' in expressions:
                expressions['positive_binary_atom'] = []
            expressions['positive_binary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0], consts),
                process_arg(self.args[1], consts),
                current_eid))
        else:
            print (f"Warning: ignoring {self.name}(', '.join({self.args})")

        return next_eid, expressions

    def to_str():
        pass
    
class NegPredicate():

    def __init__(self, pred, args) -> None:
        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(False, self.name, self.args)]

    def to_scl(self, expression_id, consts):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            expressions['negative_unary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0], consts), current_eid)]
        else:
            assert(len(self.args) == 2)
            expressions['negative_binary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0], consts),
                process_arg(self.args[1], consts),
                current_eid)]

        return next_eid, expressions

class InEqPred():

    def __init__(self, pred, args) -> None:
        assert pred == '!='
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other) -> bool:
        return self.args == other.args

    def collect_preds(self):
        return set()

    def collect_tuples(self):
        return []

    def to_scl(self, expression_id, consts):
        current_eid = expression_id
        next_eid = exp_gen.get()
        expressions = {}
        expressions['inequality_constraint'] = [
            (next_eid,
            process_arg(self.args[0], consts),
            process_arg(self.args[1], consts),
            current_eid)]

        return next_eid, expressions


class Action():

    def __init__(self, name, param, precondition, effect) -> None:
        self.name = name
        self.param = param
        self.precondition = precondition
        self.effect = effect

    def collect_preds(self):
        pred_preds = self.precondition.collect_preds()
        effect_preds = self.effect.collect_preds()
        pred_preds = pred_preds.union(effect_preds)
        return pred_preds

    def collect_tuples(self):
        pred_preds = self.precondition.collect_tuples()
        effect_preds = self.effect.collect_tuples()
        pred_preds = pred_preds + effect_preds
        return pred_preds

    def to_scl(self):
        scl_facts = {}
        pred_cond_eid, precondition_scl = self.precondition.to_scl(0)
        scl_facts = merge_dict_ls(scl_facts, precondition_scl)
        effect_eid, effect_scl = self.effect.to_scl(0)
        scl_facts = merge_dict_ls(scl_facts, effect_scl)
        scl_facts['precondition'] = [tuple([pred_cond_eid])]
        scl_facts['effect'] = [tuple([effect_eid])]

        return scl_facts


class Axiom():

    def __init__(self, context, implies, variables) -> None:
        self.context = context
        self.implies = implies
        self.variables = variables

    def to_scl():
        pass

class PDDLProg():

    def __init__(self, types, consts, predicates, actions, axioms) -> None:
        self.consts = consts
        self.types = types
        self.predicates = predicates
        self.actions = {}
        for action in actions:
            self.actions[action.name] = action
        self.axioms = axioms

    def action_to_scl(self, name) -> str:
        exp_gen.reset()
        action = self.actions[name]
        scl_exp = action.to_scl()
        return scl_exp

    def to_scl(self):
        return {name: action.to_scl() for name, action in self.actions.items()}

def process_type_defs(def_ls):
    type_def = []
    current_def = {}
    for ct, token in enumerate(def_ls):

        if ct % 3 == 0:
            if not len(current_def) == 0:
                type_def.append(current_def)
            current_def = {}
            current_def['name'] = token

        if ct % 3 == 2:
            current_def['type'] = token

    if not len(current_def) == 0:
        type_def.append(current_def)

    return type_def

def string_var_transition(ls, current_id=0, prefix ='?o'):
    pred = ls[0]
    args = ls[1:]
    new_clause = [pred]
    new_vdict = {}

    for arg in args:
        if type(arg) == str:
            if not arg[0] == '?':
                new_clause.append(arg.upper())
            else:
                new_clause.append(arg)
        else:
            arg_clause = string_var_transition(arg, current_id, prefix)
            new_clause.append(arg_clause)

    return new_clause

class PDDLParser():

    def __init__(self) -> None:
        self.binary_kws = set()
        self.unary_kws = set()
        self.arg2const = set()

    def process_params(self, params_ls):
        return [param[0] for param in params_ls]

    def process_logic_pred(self, logic_ls):

        if logic_ls[0] in bool_token:
            args = []

            # Negative leaf predicates
            if logic_ls[0] == 'not' and len(logic_ls) > 1 and not logic_ls[1][0] in bool_token:
                # Not equivalent
                if logic_ls[1][0] in kw_preds:
                    args = logic_ls[1][1:]
                    return InEqPred(not_kw_preds[logic_ls[1][0]], args)

                # Negtive predicates
                else:
                    args = logic_ls[1][1:]
                    return NegPredicate(logic_ls[1][0], args)

            # Ordinary boolean formula
            else:
                for element in logic_ls[1:]:
                    b = self.process_logic_pred(element)
                    args.append(b)
                return BoolFormula(logic_ls[0], args)

        # Equvalent ?
        elif logic_ls[0] in kw_preds:
            raise Exception("No direct equlency should be used")
        elif not logic_ls[0][0].isalpha():
            raise Exception('Error: predicate should start with alphabets')
        # Positive predicates
        else:
            args = logic_ls[1:]
            if len(args) == 1:
                self.unary_kws.add(logic_ls[0])
            elif len(args) == 2 and not logic_ls[0] == 'name':
                self.binary_kws.add(logic_ls[0])
            elif logic_ls[0] == 'name':
                assert len(logic_ls) == 3
                self.arg2const[logic_ls[1]] = logic_ls[2] 
                
            return PosPredicate(logic_ls[0], args)

    def parse(self, pddl_prog: str) -> PDDLProg:
        no_comment_prog = ""
        for line in pddl_prog.split('\n'):
            # if ';' in line:
            #     print(line)
            if len(line.strip()) == 0 or line.strip()[0] == ';':
                continue
            no_comment_prog += (line)

        result = nestedExpr('(',')').parseString(no_comment_prog).asList()
        assert len(result) == 1
        result = result[0]

        prog = []
        for tokens in result:
            current_prog_dict = {}
            key = None

            if type(tokens) == str or tokens[0] == 'domain':
                continue

            for token in tokens:
                if type(token) == str and token[0] == ':':
                    current_prog_dict[token[1:]] = []
                    key = token[1:]
                else:
                    current_prog_dict[key].append(token)

            prog.append(current_prog_dict)

        predicates = [(p[0].replace('-', '_'), len([i for i in p if i[0] == '?'])) for p in prog[3]['predicates']]

        types = process_type_defs(prog[1]['types'])
        consts = process_type_defs(prog[2]['constants'])
        actions = []
        axioms = []
        v_dict = {}

        for prog_single in prog[4:]:
            if 'action' in prog_single:
                name=prog_single['action']
                assert len(name) == 1
                name = name[0]
                param=process_type_defs(prog_single['parameters'])

                preprocessed_precondition=string_var_transition(prog_single['precondition'][0])
                precondition=self.process_logic_pred(preprocessed_precondition)

                preprocessed_effect= string_var_transition(prog_single['effect'][0])
                effect=self.process_logic_pred(preprocessed_effect)

                action = Action(name, param, precondition, effect)
                actions.append(action)

            if 'axiom' in prog_single:
                context=prog_single['context']
                implies=prog_single['implies']
                variables=process_type_defs(prog_single['vars'])
                axioms.append(Axiom(context=context, implies=implies, variables=variables))

        prog = PDDLProg(types, consts, predicates, actions, axioms)
        return prog

def replace_nested_list(l, from_words, to_words):
    new_list = []
    for e in l:
        if isinstance(e, list):
            new_list.append(replace_nested_list(e, from_words, to_words))
        elif e in from_words:
            new_list.append(to_words[from_words.index(e)])
        elif type(e) == str:
            new_list.append(e)
        else:
            print("Warning: should not be here")
    return new_list
    
class UnseenParamException(Exception):
    pass

class InvalidArgException(Exception):
    pass

class InvalidPrediction(Exception):
    pass

def flatten_arg_list(l, args):
    new_list = []
    consts = set()
    pred = []
    
    for e in l:
        if type(e) == str:
            # new_list.append(e)
            pred.append(e)
        if type(e) == list:
            if len(e) == 0:
                continue
            if e[0] in args or e[0] in const2cid:
                for arg in e:
                    if not type(arg) == str:
                        raise InvalidArgException()
                    if arg not in args and arg not in consts:
                        consts.add(arg)
                new_list += e
                
            else:
                try:
                    fl, cs = flatten_arg_list(e, args)
                    new_list.append(fl)
                    for c in cs:
                        consts.add(c)
                except UnseenParamException:
                    # remove the predicates that include unseen params
                    print('Remove unseen missing params')
                except InvalidArgException:
                    print ('invalid arguments')
    
    if not len(pred) == 0:
        new_list = [' '.join(pred)] + new_list 
        
    return new_list, consts

class GptCaption():
    
    def __init__(self, caption, gpt_result) -> None:
        
        self.caption = caption
        self.gpt_result = gpt_result
        self.parser = PDDLParser()
        self.consts = set()
        
        self.get_args()
        # self.static = self.process_static()
        self.time_steps = self.process_time_step()
        self.unary_kws = set()
        self.binary_kws = set()
 
    def get_args(self):
        
        assert ("sequential descriptions" in self.gpt_result or "programmatic" in self.gpt_result)
        if "sequential descriptions" in self.gpt_result:
            description = self.gpt_result["sequential descriptions"]
            self.args = set()
            for sentence in description:
                sentence = sentence.replace(",", " ")
                tokens = sentence.split(" ")
                for token in tokens:
                    if len(token) == 1 and token.isalpha():
                        self.args.add(token)
            self.args = list(self.args)
            self.arg_replace = ['?' + e.lower() for e in self.args]
        else: 
            program = self.gpt_result["programmatic"]
            self.args = set()
            for clause in program:
                tokens = re.findall('\(.*\)', clause)
                assert len(tokens) == 1
                tokens = tokens[0]
                tokens = tokens[1: -1].split(',')
                for token in tokens:
                    if len(token) == 1 and token.isalpha():
                        self.args.add(token)
            self.args = list(self.args)
            self.arg_replace = ['?' + e.lower() for e in self.args]
            
    def parse_program(self, program):
        
        if len(program) == 0:
            return None
        
        new_clean_prog = []
        for p in program:
            new_program = [ i  for i in re.split(' or | and |^or\(|^and\(|;', p) if not len(i) == 0]
            
            if len(new_program) > 1:
                if new_program[0][0] == '(' and new_program[-1][-1] == ')':
                    new_program[0] = new_program[0][1:]
                    new_program[-1] = new_program[-1][:-1]
                print('here')
            new_clean_prog += (new_program)
        
        program = new_clean_prog
        
        clean_prog = []
        for p in program:
            new_program = re.split('\),', p)
            
            for pred in new_program[:-1]:
                clean_prog.append(pred + ')')
            
            clean_prog.append(new_program[-1])
            
            if 'neq' in p:
                print('here')
                
        program = clean_prog
            
        connect_programmatic = '(and (' + ')('.join(program)+ '))' 
        connect_programmatic = connect_programmatic.replace('neq(', 'not( = ')
        connect_programmatic = connect_programmatic.replace(',', ' ')
        connect_programmatic = connect_programmatic.replace(' or ', 'and')

        result = nestedExpr('(',')').parseString(connect_programmatic).asList()
        result, consts = flatten_arg_list(result, self.args)
        result = replace_nested_list(result, self.args, self.arg_replace)
        for const in consts:
            self.consts.add(const)
        
        if not len(result) == 1:
            raise InvalidPrediction
        assert len(result) == 1
        result = result[0]
        
        new_result = []
        correct_pred = True
        
        for predicate in result:
            if type(predicate) == list:
                # Process bool tokens
                if len(predicate) == 2 and type(predicate[1]) == list:
                    if not predicate[0] in bool_token :
                        correct_pred = False
                        print(f"Unseen predicate: {predicate}")
                if correct_pred:
                    new_result.append(predicate)
            else: 
                new_result.append(predicate)
                        
        return new_result
        
    def process_time_step(self):
        self.time_progs = []
        self.time_duration = []
        self.time_location = []
        if not 'time stamps' in self.gpt_result:
            raise InvalidPrediction()
        assert 'time stamps' in self.gpt_result
        
        arg2const = {}
        occurred_preds = set()
        for time, event in self.gpt_result['time stamps'].items():
            if not ('programmatic' in event and 'duration' in event and 'video location' in event):
                raise InvalidPrediction()
            
            program = self.parse_program(event['programmatic'])
            for i in program:
                if type(i) == list and i[0] == 'name':
                    arg2const[i[1]] = i[2]
                elif  type(i) == list:
                    occurred_preds.add(i[0])
                    
        for time, event in self.gpt_result['time stamps'].items():
            assert 'programmatic' in event
            assert 'duration' in event
            assert 'video location' in event
            
            # program = []
            # if time == '1':
                # add static requirement into precondition
                # program += self.gpt_result['programmatic version']
            
            # program += event['programmatic']
            
            new_program = []
            program = self.parse_program(event['programmatic'])
            for clause in program:
                if not(type(clause) == list and clause[0] == 'name'):
                    if type(clause) == list:
                        new_clause = rec_sub_val(clause, arg2const)
                        # new_clause = [arg2const[token] if token in arg2const else token for token in clause ]
                        new_program.append(new_clause)
                    else: 
                        new_program.append(clause)
                    
            time_prog = self.parser.process_logic_pred(new_program)
            self.time_progs.append(time_prog)
            self.time_location.append(event['video location'])
            self.time_duration.append(event['duration'])
        
        self.consts = list(self.consts)
        
    def to_scl(self):
        scl_facts = []
        scl_eids = []

        for prog in self.time_progs:
            eid, facts = prog.to_scl(0, self.consts)
            scl_eids.append(eid)
            scl_facts.append(facts)
        
        return scl_eids, scl_facts

def process_gpt_spec(gpt_specs):
    all_actions = {}
    for action, spec in gpt_specs.items():
    
        prog = GptCaption(action, spec)
        scl_exp_ids, scl_facts = prog.to_scl()
        all_actions[action] = {}
        all_actions[action]['time_stamp_ids'] = scl_exp_ids
        all_actions[action]['time_stamp_facts'] = scl_facts
        
    return all_actions

if __name__ == "__main__":
    
    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../data/open_pvsg'))
    assert os.path.exists(data_dir)
    gpt_specs_path = os.path.join(data_dir, 'nl2spec', "open_pvsg_v2_gpt4_cache.json")
    action_scl_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_scl.json')
    action_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs.pkl')

    gpt_specs = json.load(open(gpt_specs_path, 'r'))

    all_action_scl = {}
    all_action = {}

    for action, spec in gpt_specs.items():
        try:
        # if not action == "An adult is helping the baby to stand up.":
        #     continue
            prog = GptCaption(action, spec)
        except InvalidPrediction:
            continue
        
        all_action[action] = prog
        
        scl_exp_ids, scl_facts = prog.to_scl()
        all_action_scl[action] = {}
        all_action_scl[action]['time_stamp_ids'] = scl_exp_ids
        all_action_scl[action]['time_stamp_facts'] = scl_facts
        all_action_scl[action]['consts'] = prog.consts
        all_action_scl[action]['args'] = prog.args
        all_action_scl[action]['unary_kws'] = list(prog.parser.unary_kws)
        all_action_scl[action]['binary_kws'] = list(prog.parser.binary_kws)
        all_action_scl[action]['caption'] = action
        all_action_scl[action]['duration'] = [ts_info['duration'] for ts_info in spec['time stamps'].values()]
        all_action_scl[action]['video location'] = [ts_info['video location'] for ts_info in spec['time stamps'].values()]
    
    json.dump(all_action_scl, open(action_scl_path, 'w'))
    pickle.dump(all_action, open(action_path, 'wb') )
    print('here')