from pyparsing import nestedExpr
from utils import bool_token, kw_preds, not_kw_preds, var2vid, rec_sub_val, all_binary_preds
import os 
import json
import re 
import pickle
from functools import reduce

##########################################################################
##############                                              ##############
##############         STSL language building blocks        ##############
##############                                              ##############
##########################################################################

   
########################################################################## 
##############                Helper Functions              ##############
##########################################################################

def process_arg_num(raw_arg, consts):
    if raw_arg[0] == '?':
        arg = var2vid[raw_arg[1:]]
    else:
        arg = -consts.index(raw_arg) - 1
    return arg

def process_arg_str(raw_arg):
    if raw_arg[0] == '?':
        arg = raw_arg[1:]
    else:
        arg = raw_arg
    return arg

def process_arg(raw_arg, consts, num=True):
    if num:
        return process_arg_num(raw_arg, consts)
    else:
        return process_arg_str(raw_arg)

def merge_dict_ls(d1, d2):
    for k, v in d2.items():
        if not k in d1:
            d1[k] = []
        d1[k] += v
    return d1

   
########################################################################## 
##############                 STSL Language                ##############
##########################################################################

class BoolFormula():

    def __init__(self, name, args) -> None:
        self.name = name
        self.args = args
        if not len(args) == 0:
            self.v = set.union(*[arg.v for arg in args])
        else:
            self.v = set()

    def __eq__(self):
        arg_num = len(self.args)
        for ct1 in range(arg_num):
            has_match = False
            for ct2 in range(arg_num):
                if not ct1 == ct2:
                    if self.args[ct1] == self.args[ct2]:
                        has_match = True
            if not has_match:
                return False
        return True

    def collect_preds(self):
        all_preds = set()
        for arg in self.args:
            all_preds = all_preds.union(arg.collect_preds())
        return all_preds

    def collect_tuples(self):
        all_preds = []
        for arg in self.args:
            tuple = arg.collect_tuples()
            if not len(tuple) == 0:
                all_preds += (tuple)
        return all_preds

    def collect_kws(self):
        result = {'unary': set(), 'binary': set()}
        for arg in self.args:
            arg_kws = arg.collect_kws()
            [result['unary'].add(i) for i in arg_kws['unary']]
            [result['binary'].add(i) for i in arg_kws['binary']]
        return result
        
    def to_scl(self, expression_id, consts):

        expressions = {}
        if len(self.args) == 0:
            return expression_id, expressions

        current_eid = expression_id
        for arg in self.args:
            current_eid, arg_scl = arg.to_scl(current_eid, consts)
            expressions = merge_dict_ls(expressions, arg_scl)

        # Should only have this case
        assert self.name == "and"

        return current_eid, expressions

    def to_str(self, consts):
        
        expressions = {}

        expression_strs = []
        for arg in self.args:
            arg_str = arg.to_str(consts)
            expression_strs.append(arg_str)
        
        if len(expression_strs) == 0:
            return ""
        
        expressions = reduce(lambda accu, elem: f"And({elem}, {accu})", reversed(expression_strs))

        # expressions = "And(" + ', '.join(expression_strs) + ')'
        # Should only have this case
        assert self.name == "and"

        return expressions
    
    
class PosPredicate():

    def __init__(self, pred, args) -> None:
        if not isinstance(args, list):
            raise InvalidArgException
        
        for arg in args:
            if not isinstance(arg, str):
                raise InvalidArgException
            
        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(True, self.name, self.args)]

    def collect_kws(self):
        
        if len(self.args) == 1:
            return {'unary': [self.name], 'binary': []}
        elif len(self.args) == 2:
            return {'unary': [], 'binary': [self.name]}
        else:
            # Invalid Case
            return {'unary': [], 'binary': []}
        
    def to_scl(self, expression_id, consts):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            if not 'positive_unary_atom' in expressions:
                expressions['positive_unary_atom'] = []
            expressions['positive_unary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0], consts), 
                current_eid))
            

        elif len(self.args) == 2:

            if not 'positive_binary_atom' in expressions:
                expressions['positive_binary_atom'] = []
            expressions['positive_binary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0], consts),
                process_arg(self.args[1], consts),
                current_eid))
        else:
            print (f"Warning: ignoring {self.name}(', '.join({self.args})")

        return next_eid, expressions

    def to_str(self, consts):
        wrapped_args = [f"\"{self.name}\""]
        
        for arg in self.args:
            arg = process_arg(arg, consts)
            if arg < 0:
                wrapped_args.append(f"Const({arg})")
            elif arg > 0:
                wrapped_args.append(f"Var({arg})")
            else:
                raise InvalidArgException()
        
        if len(self.args) == 1:
            logic_pred = "Unary"
        elif  len(self.args) == 2:
            logic_pred = "Binary"
        else:
            raise InvalidArgException
        
        return f"{logic_pred}({','.join(wrapped_args)})"
    
class NegPredicate():

    def __init__(self, pred, args) -> None:
        for arg in args:
            if not isinstance(arg, str):
                raise InvalidArgException
            
        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(False, self.name, self.args)]

    def to_scl(self, expression_id, consts):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            expressions['negative_unary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0], consts), current_eid)]
        else:
            assert(len(self.args) == 2)
            expressions['negative_binary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0], consts),
                process_arg(self.args[1], consts),
                current_eid)]

        return next_eid, expressions

    def collect_kws(self):
        
        if len(self.args) == 1:
            return {'unary': [self.name], 'binary': []}
        elif len(self.args) == 2:
            return {'unary': [], 'binary': [self.name]}
        else:
            # Invalid Case
            return {'unary': [], 'binary': []}
        
    def to_str(self, consts):
        wrapped_args = [f"\"{self.name}\""]
        for arg in self.args:
            arg = process_arg(arg, consts)
            if arg < 0:
                wrapped_args.append(f"Const({arg})")
            elif arg > 0:
                wrapped_args.append(f"Var({arg})")
            else:
                raise InvalidArgException()
        
        if len(self.args) == 1:
            logic_pred = "NegUnary"
        elif  len(self.args) == 2:
            logic_pred = "NegBinary"
        else:
            raise InvalidArgException
        
        return f"{logic_pred}({','.join(wrapped_args)}))"

class InEqPred():

    def __init__(self, pred, args) -> None:
        assert pred == '!='
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other) -> bool:
        return self.args == other.args

    def collect_preds(self):
        return set()

    def collect_tuples(self):
        return []

    def to_scl(self, expression_id, consts):
        current_eid = expression_id
        next_eid = exp_gen.get()
        expressions = {}
        expressions['inequality_constraint'] = [
            (next_eid,
            process_arg(self.args[0], consts),
            process_arg(self.args[1], consts),
            current_eid)]

        return next_eid, expressions
    
##########################################################################
##############                                              ##############
##############             STSL language Parser             ##############
##############              (OpenPVSG Version)              ##############
##############                                              ##############
##########################################################################

##########################################################################
##############               Utility Functions              ##############
##########################################################################

class UnseenParamException(Exception):
    pass

class InvalidArgException(Exception):
    pass

class InvalidPrediction(Exception):
    pass

def replace_nested_list(l, from_words, to_words):
    new_list = []
    for e in l:
        if isinstance(e, list):
            new_list.append(replace_nested_list(e, from_words, to_words))
        elif e in from_words:
            new_list.append(to_words[from_words.index(e)])
        elif type(e) == str:
            new_list.append(e)
        else:
            print("Warning: should not be here")
    return new_list

def flatten_arg_list(l, args):
    new_list = []
    consts = set()
    pred = []
    
    
    for e in l:
        
        is_leaf = True
        for ei in e:
            if type(ei) == list:
                is_leaf = False
                break
            
        if type(e) == str:
            # new_list.append(e)
            pred.append(e)
        if type(e) == list:
            if len(e) == 0:
                continue
            
            if is_leaf:
                for arg in e:
                    if not type(arg) == str:
                        raise InvalidArgException()
                    if arg not in args and arg not in consts:
                        consts.add(arg)
                new_list += e
                
            else:
                try:
                    fl, cs = flatten_arg_list(e, args)
                    new_list.append(fl)
                    for c in cs:
                        consts.add(c)
                except UnseenParamException:
                    # remove the predicates that include unseen params
                    print('Remove unseen missing params')
                except InvalidArgException:
                    print ('invalid arguments')
    
    if not len(pred) == 0:
        new_list = [' '.join(pred)] + new_list 
        
    return new_list, consts
   
########################################################################## 
##############                Helper Functions              ##############
##########################################################################

class ExpGen():

    def __init__(self, ct=0):
        self.ct = ct

    def get(self):
        self.ct += 1
        return self.ct

    def reset(self):
        self.ct = 0

exp_gen = ExpGen()

def process_logic_pred(logic_ls):

    if logic_ls[0] in bool_token:
        args = []

        # Negative leaf predicates
        if logic_ls[0] == 'not' and len(logic_ls) > 1 and not logic_ls[1][0] in bool_token:
            # Not equivalent
            if logic_ls[1][0] in kw_preds:
                args = logic_ls[1][1:]
                return InEqPred(not_kw_preds[logic_ls[1][0]], args)

            # Negtive predicates
            else:
                args = logic_ls[1][1:]
                return NegPredicate(logic_ls[1][0], args)

        # Ordinary boolean formula
        else:
            for element in logic_ls[1:]:
                b = process_logic_pred(element)
                args.append(b)
            return BoolFormula(logic_ls[0], args)

    # Equvalent ?
    elif logic_ls[0] in kw_preds:
        raise Exception("No direct equlency should be used")
    elif not logic_ls[0][0].isalpha():
        raise Exception('Error: predicate should start with alphabets')
    
    # Positive predicates
    else:
        if not type(logic_ls) == list:
            raise InvalidArgException
        args = logic_ls[1:]
        return PosPredicate(logic_ls[0], args)

##########################################################################
##############           Parser for GPT Caption             ##############
##########################################################################

class GptCaption():
    
    def __init__(self, caption, gpt_result) -> None:
        
        self.caption = caption
        self.gpt_result = gpt_result
        self.consts = set()
        
        self.get_args()
        self.time_steps = self.process_time_step()
        self.unary_kws = set()
        self.binary_kws = set()
 
    def get_args(self):
        
        if not ("sequential descriptions" in self.gpt_result or "programmatic" in self.gpt_result):
            raise InvalidPrediction()
        
        if "sequential descriptions" in self.gpt_result:
            description = self.gpt_result["sequential descriptions"]
            self.args = set()
            for sentence in description:
                sentence = sentence.replace(",", " ")
                tokens = sentence.split(" ")
                for token in tokens:
                    if len(token) == 1 and token.isalpha() and not token == 'I':
                        self.args.add(token)
            self.args = list(self.args)
            self.arg_replace = ['?' + e.lower() for e in self.args]
        else: 
            program = self.gpt_result["programmatic"]
            self.args = set()
            for clause in program:
                tokens = re.findall('\(.*\)', clause)
                assert len(tokens) == 1
                tokens = tokens[0]
                tokens = tokens[1: -1].split(',')
                for token in tokens:
                    if len(token) == 1 and token.isalpha():
                        self.args.add(token)
            self.args = list(self.args)
            self.arg_replace = ['?' + e.lower() for e in self.args]
            
    def parse_program(self, program):
        
        if len(program) == 0:
            return None
        
        new_clean_prog = []
        for p in program:
            new_program = [ i  for i in re.split(' or | and |^or\(|^and\(|;', p) if not len(i) == 0]
            
            if len(new_program) > 1:
                if new_program[0][0] == '(' and new_program[-1][-1] == ')':
                    new_program[0] = new_program[0][1:]
                    new_program[-1] = new_program[-1][:-1]
                print('here')
            new_clean_prog += (new_program)
        
        program = new_clean_prog
        
        clean_prog = []
        for p in program:
            new_program = re.split('\),', p)
            
            for pred in new_program[:-1]:
                clean_prog.append(pred + ')')
            
            clean_prog.append(new_program[-1])
            
            if 'neq' in p:
                print('here')
                
        program = clean_prog
            
        connect_programmatic = '(and (' + ')('.join(program)+ '))' 
        connect_programmatic = connect_programmatic.replace('neq(', 'not( = ')
        connect_programmatic = connect_programmatic.replace(',', ' ')
        connect_programmatic = connect_programmatic.replace(' or ', 'and')

        result = nestedExpr('(',')').parseString(connect_programmatic).asList()
        result, consts = flatten_arg_list(result, self.args)
        result = replace_nested_list(result, self.args, self.arg_replace)
        for const in consts:
            self.consts.add(const)
        
        if not len(result) == 1:
            raise InvalidPrediction
        assert len(result) == 1
        result = result[0]
        
        new_result = []
        correct_pred = True
        
        for predicate in result:
            if type(predicate) == list:
                # Process bool tokens
                if len(predicate) == 2 and type(predicate[1]) == list:
                    if not predicate[0] in bool_token :
                        correct_pred = False
                        print(f"Unseen predicate: {predicate}")
                if correct_pred:
                    new_result.append(predicate)
            else: 
                new_result.append(predicate)
                        
        return new_result
        
    def process_time_step(self):
        self.time_progs = []
        self.time_duration = []
        self.time_location = []
        if not 'time stamps' in self.gpt_result:
            raise InvalidPrediction()
        assert 'time stamps' in self.gpt_result
        
        arg2const = {}
        occurred_preds = set()
        for time, event in self.gpt_result['time stamps'].items():
            if not ('programmatic' in event and 'duration' in event and 'video location' in event):
                raise InvalidPrediction()
            
            program = self.parse_program(event['programmatic'])
            for i in program:
                if type(i) == list and i[0] == 'name':
                    if not len(i) == 3:
                        continue
                    arg2const[i[1]] = i[2]
                elif  type(i) == list:
                    occurred_preds.add(i[0])
                    
        for time, event in self.gpt_result['time stamps'].items():
            assert 'programmatic' in event
            assert 'duration' in event
            assert 'video location' in event
            
            new_program = []
            program = self.parse_program(event['programmatic'])
            for clause in program:
                if not(type(clause) == list and clause[0] == 'name'):
                    if type(clause) == list:
                        new_clause = rec_sub_val(clause, arg2const)
                        new_program.append(new_clause)
                    else: 
                        new_program.append(clause)
                    
            time_prog = process_logic_pred(new_program)
            self.time_progs.append(time_prog)
            self.time_location.append(event['video location'])
            self.time_duration.append(event['duration'])
        
        self.consts = list(self.consts)
    
    def collect_kws(self):
        kws = {'unary': set(), 'binary': set()}
        
        for prog in self.time_progs:
            prog_kws = prog.collect_kws()
            [kws['unary'].add(i) for i in prog_kws['unary']]
            [kws['binary'].add(i) for i in prog_kws['binary']]
        
        kws['unary'] = list(kws['unary'])
        kws['binary'] = list(kws['binary'])
        return kws

        
    def to_scl(self):
        scl_facts = []
        scl_eids = []

        for prog in self.time_progs:
            eid, facts = prog.to_scl(0, self.consts)
            scl_eids.append(eid)
            scl_facts.append(facts)
        
        return scl_eids, scl_facts
    
    def to_str(self):
        scl_prog_strs = []
        
        if len(self.time_progs) == 0:
            raise InvalidPrediction
        
        for prog in self.time_progs:
            scl_prog_strs.append(f"Finally(Logic({prog.to_str(self.consts)}))")
        
        #  Until(Finally(e1), Until(Finally(e2), Finally(e3)))
        program = reduce(lambda accu, elem: f"Until({elem}, {accu})", reversed(scl_prog_strs))
        return program

def process_gpt_spec(gpt_specs):
    all_actions = {}
    for action, spec in gpt_specs.items():
    
        prog = GptCaption(action, spec)
        scl_exp_ids, scl_facts = prog.to_scl()
        all_actions[action] = {}
        all_actions[action]['time_stamp_ids'] = scl_exp_ids
        all_actions[action]['time_stamp_facts'] = scl_facts
        
    return all_actions

if __name__ == "__main__":
    
    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../data/open_pvsg'))
    assert os.path.exists(data_dir)

    gpt_specs_path = os.path.join(data_dir, 'nl2spec', "videollamav2_origcap_2_gpt_cache.json")
    # action_scl_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_prog_str.json')
    # action_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_scl.json')
    action_scl_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_videollamav2_prog_str.json')

    gpt_specs = json.load(open(gpt_specs_path, 'r'))

    all_action_scl = {}
    all_action = {}

    for action, spec in gpt_specs.items():
        try:
            # if not action == "I moved a chair and walked to the other side to clean the railing.":
            #     continue
            prog = GptCaption(action, spec)
            prog_string = prog.to_str()
            prog_kws = prog.collect_kws()
            
        except InvalidPrediction:
            continue
        except InvalidArgException:
            continue
        
        all_action[action] = prog
        
        
        # scl_exp_ids, scl_facts = prog.to_scl()
        all_action_scl[action] = {}
        # all_action_scl[action]['time_stamp_ids'] = scl_exp_ids
        # all_action_scl[action]['time_stamp_facts'] = scl_facts
        
        all_action_scl[action]['unary_kws'] = prog_kws['unary']
        all_action_scl[action]['binary_kws'] = prog_kws['binary']

        all_action_scl[action]['prog'] = prog_string
        all_action_scl[action]['consts'] = prog.consts
        all_action_scl[action]['args'] = prog.args
        all_action_scl[action]['caption'] = action
        all_action_scl[action]['duration'] = [ts_info['duration'] for ts_info in spec['time stamps'].values()]
        all_action_scl[action]['video location'] = [ts_info['video location'] for ts_info in spec['time stamps'].values()]
    
    json.dump(all_action_scl, open(action_scl_path, 'w'))
    # pickle.dump(all_action, open(action_path, 'wb') )
    print('here')