from pyparsing import nestedExpr, Word, alphas
import pyparsing as pp

from utils import bool_token, kw_preds, not_kw_preds, var2vid, rec_sub_val, all_binary_preds
import os
import json
import re
import pickle
from functools import reduce

##########################################################################
##############                                              ##############
##############         STSL language building blocks        ##############
##############                                              ##############
##########################################################################


##########################################################################
##############                Helper Functions              ##############
##########################################################################

def process_arg_num(raw_arg, consts):
    if raw_arg[0] == '?':
        arg = var2vid[raw_arg[1:]]
    else:
        arg = -consts.index(raw_arg) - 1
    return arg

def process_arg_str(raw_arg):
    if raw_arg[0] == '?':
        arg = raw_arg[1:]
    else:
        arg = raw_arg
    return arg

def process_arg(raw_arg, consts, num=True):
    if num:
        return process_arg_num(raw_arg, consts)
    else:
        return process_arg_str(raw_arg)

def merge_dict_ls(d1, d2):
    for k, v in d2.items():
        if not k in d1:
            d1[k] = []
        d1[k] += v
    return d1


##########################################################################
##############                 STSL Language                ##############
##########################################################################

class BoolFormula():

    def __init__(self, name, args) -> None:
        self.name = name
        self.args = args
        if not len(args) == 0:
            self.v = set.union(*[arg.v for arg in args])
        else:
            self.v = set()

    def __eq__(self):
        arg_num = len(self.args)
        for ct1 in range(arg_num):
            has_match = False
            for ct2 in range(arg_num):
                if not ct1 == ct2:
                    if self.args[ct1] == self.args[ct2]:
                        has_match = True
            if not has_match:
                return False
        return True

    def collect_preds(self):
        all_preds = set()
        for arg in self.args:
            all_preds = all_preds.union(arg.collect_preds())
        return all_preds

    def collect_tuples(self):
        all_preds = []
        for arg in self.args:
            tuple = arg.collect_tuples()
            if not len(tuple) == 0:
                all_preds += (tuple)
        return all_preds

    def collect_kws(self):
        result = {'unary': set(), 'binary': set()}
        for arg in self.args:
            arg_kws = arg.collect_kws()
            [result['unary'].add(i) for i in arg_kws['unary']]
            [result['binary'].add(i) for i in arg_kws['binary']]
        return result

    def to_scl(self, expression_id, consts):

        expressions = {}
        if len(self.args) == 0:
            return expression_id, expressions

        current_eid = expression_id
        for arg in self.args:
            current_eid, arg_scl = arg.to_scl(current_eid, consts)
            expressions = merge_dict_ls(expressions, arg_scl)

        # Should only have this case
        assert self.name == "and"

        return current_eid, expressions

    def to_str(self, consts):

        expressions = {}

        expression_strs = []
        for arg in self.args:
            try:
                arg_str = arg.to_str(consts)
            except InvalidArgException:
                continue
            expression_strs.append(arg_str)

        if len(expression_strs) == 0:
            return ""

        expressions = reduce(lambda accu, elem: f"And({elem}, {accu})", reversed(expression_strs))

        # expressions = "And(" + ', '.join(expression_strs) + ')'
        # Should only have this case
        assert self.name == "and"

        return expressions


class PosPredicate():

    def __init__(self, pred, args) -> None:
        if not isinstance(args, list):
            raise InvalidArgException

        for arg in args:
            if not isinstance(arg, str):
                raise InvalidArgException

        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(True, self.name, self.args)]

    def collect_kws(self):

        if len(self.args) == 1:
            return {'unary': [self.name], 'binary': []}
        elif len(self.args) == 2:
            return {'unary': [], 'binary': [self.name]}
        else:
            # Invalid Case
            return {'unary': [], 'binary': []}

    def to_scl(self, expression_id, consts):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            if not 'positive_unary_atom' in expressions:
                expressions['positive_unary_atom'] = []
            expressions['positive_unary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0], consts),
                current_eid))


        elif len(self.args) == 2:

            if not 'positive_binary_atom' in expressions:
                expressions['positive_binary_atom'] = []
            expressions['positive_binary_atom'].append(
                (next_eid, self.name,
                process_arg(self.args[0], consts),
                process_arg(self.args[1], consts),
                current_eid))
        else:
            print (f"Warning: ignoring {self.name}(', '.join({self.args})")

        return next_eid, expressions

    def to_str(self, consts):
        wrapped_args = [f"\"{self.name}\""]

        for arg in self.args:
            arg = process_arg(arg, consts)
            if arg < 0:
                wrapped_args.append(f"Const({arg})")
            elif arg > 0:
                wrapped_args.append(f"Var({arg})")
            else:
                raise InvalidArgException()

        if len(self.args) == 1:
            logic_pred = "Unary"
        elif  len(self.args) == 2:
            logic_pred = "Binary"
        else:
            raise InvalidArgException

        return f"{logic_pred}({','.join(wrapped_args)})"

class NegPredicate():

    def __init__(self, pred, args) -> None:
        for arg in args:
            if not isinstance(arg, str):
                raise InvalidArgException

        self.name = pred.replace('-', '_')
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other):
        if not self.name == other.name:
            return False
        if not self.args == other.args:
            return False
        return True

    def collect_preds(self):
        return set([self.name])

    def collect_tuples(self):
        return [(False, self.name, self.args)]

    def to_scl(self, expression_id, consts):

        expressions = {}
        current_eid = expression_id
        next_eid = exp_gen.get()

        if len(self.args) == 1:
            expressions['negative_unary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0], consts), current_eid)]
        else:
            assert(len(self.args) == 2)
            expressions['negative_binary_atom'] = [
                (next_eid, self.name,
                process_arg(self.args[0], consts),
                process_arg(self.args[1], consts),
                current_eid)]

        return next_eid, expressions

    def collect_kws(self):

        if len(self.args) == 1:
            return {'unary': [self.name], 'binary': []}
        elif len(self.args) == 2:
            return {'unary': [], 'binary': [self.name]}
        else:
            # Invalid Case
            return {'unary': [], 'binary': []}

    def to_str(self, consts):
        wrapped_args = [f"\"{self.name}\""]
        for arg in self.args:
            arg = process_arg(arg, consts)
            if arg < 0:
                wrapped_args.append(f"Const({arg})")
            elif arg > 0:
                wrapped_args.append(f"Var({arg})")
            else:
                raise InvalidArgException()

        if len(self.args) == 1:
            logic_pred = "NegUnary"
        elif  len(self.args) == 2:
            logic_pred = "NegBinary"
        else:
            raise InvalidArgException

        return f"{logic_pred}({','.join(wrapped_args)}))"

class InEqPred():

    def __init__(self, pred, args) -> None:
        assert pred == '!='
        self.args = args
        self.v = set([arg for arg in args if '?' == arg[0]])

    def __eq__(self, other) -> bool:
        return self.args == other.args

    def collect_preds(self):
        return set()

    def collect_tuples(self):
        return []

    def to_scl(self, expression_id, consts):
        current_eid = expression_id
        next_eid = exp_gen.get()
        expressions = {}
        expressions['inequality_constraint'] = [
            (next_eid,
            process_arg(self.args[0], consts),
            process_arg(self.args[1], consts),
            current_eid)]

        return next_eid, expressions

##########################################################################
##############                                              ##############
##############             STSL language Parser             ##############
##############              (OpenPVSG Version)              ##############
##############                                              ##############
##########################################################################

##########################################################################
##############               Utility Functions              ##############
##########################################################################

class UnseenParamException(Exception):
    pass

class InvalidArgException(Exception):
    pass

class InvalidPrediction(Exception):
    pass

def replace_nested_list(l, from_words, to_words):
    new_list = []
    for e in l:
        if isinstance(e, list):
            new_list.append(replace_nested_list(e, from_words, to_words))
        elif e in from_words:
            new_list.append(to_words[from_words.index(e)])
        elif type(e) == str:
            new_list.append(e)
        else:
            print("Warning: should not be here")
    return new_list

def flatten_arg_list(l, args):
    new_list = []
    consts = set()
    pred = []

    for e in l:
        is_leaf = True
        for ei in e:
            if type(ei) == list:
                is_leaf = False
                break

        if type(e) == str and not e == ',':
            # new_list.append(e)
            pred.append(e)

        if type(e) == list:
            if len(e) == 0:
                continue

            if is_leaf:
                new_args = []
                for arg in e:
                    if not type(arg) == str:
                        raise InvalidArgException()
                    if arg == ',':
                        continue
                    if arg not in args and arg not in consts:
                        consts.add(arg)
                    new_args.append(arg)
                new_list += new_args

            else:
                try:
                    fl, cs = flatten_arg_list(e, args)
                    new_list.append(fl)
                    for c in cs:
                        consts.add(c)
                except UnseenParamException:
                    # remove the predicates that include unseen params
                    print('Remove unseen missing params')
                except InvalidArgException:
                    print ('invalid arguments')

    if not len(pred) == 0:
        new_list = [' '.join(pred)] + new_list

    return new_list, consts

##########################################################################
##############                Helper Functions              ##############
##########################################################################

class ExpGen():

    def __init__(self, ct=0):
        self.ct = ct

    def get(self):
        self.ct += 1
        return self.ct

    def reset(self):
        self.ct = 0

exp_gen = ExpGen()

def process_logic_pred(logic_ls):

    if logic_ls[0] in bool_token:
        args = []

        # Negative leaf predicates
        if logic_ls[0] == 'not' and len(logic_ls) > 1 and not logic_ls[1][0] in bool_token:
            # Not equivalent
            if logic_ls[1][0] in kw_preds:
                args = logic_ls[1][1:]
                return InEqPred(not_kw_preds[logic_ls[1][0]], args)

            # Negtive predicates
            else:
                args = logic_ls[1][1:]
                return NegPredicate(logic_ls[1][0], args)

        # Ordinary boolean formula
        else:
            for element in logic_ls[1:]:
                b = process_logic_pred(element)
                args.append(b)
            return BoolFormula(logic_ls[0], args)

    # Equvalent ?
    elif logic_ls[0] in kw_preds:
        raise Exception("No direct equlency should be used")
    elif not logic_ls[0][0].isalpha():
        raise InvalidArgException()

    # Positive predicates
    else:
        if not type(logic_ls) == list:
            raise InvalidArgException
        args = logic_ls[1:]
        return PosPredicate(logic_ls[0], args)

##########################################################################
##############           Parser for GPT Caption             ##############
##########################################################################

class GptCaption():

    def __init__(self, caption, gpt_result) -> None:

        self.caption = caption
        self.gpt_result = gpt_result
        self.consts = set()

        self.get_args()
        self.time_steps = self.process_time_step()
        self.unary_kws = set()
        self.binary_kws = set()

    def get_args(self):
        # if "back" in clause:
        #     print("here")

        if not ("sequential descriptions" in self.gpt_result or "programmatic" in self.gpt_result):

            # Check if only one single time stamp is recorded
            if  'caption' in self.gpt_result and \
                ('video_location' in self.gpt_result or 'location' in self.gpt_result) and \
                ('relations' in self.gpt_result or 'attributes' in self.gpt_result) and \
                ('event_length' in self.gpt_result or 'event length' in self.gpt_result):

                new_prog = {}
                new_prog['programmatic'] = []
                if 'relations' in self.gpt_result:
                    new_prog['programmatic'] += self.gpt_result['relations']
                if 'attributes' in self.gpt_result:
                    new_prog['programmatic'] += self.gpt_result['attributes']

                if 'video_location' in self.gpt_result:
                    new_prog['video location'] = self.gpt_result['video_location']
                elif 'location' in self.gpt_result:
                    new_prog['video location'] = self.gpt_result['location']

                if 'event_length' in self.gpt_result:
                    new_prog['duration'] = self.gpt_result['event_length']
                elif 'event length' in self.gpt_result:
                    new_prog['duration'] = self.gpt_result['event length']

                new_prog['decription'] = self.gpt_result['caption']

                assert not 'time stamps' in self.gpt_result
                self.gpt_result['time stamps'] = {}
                self.gpt_result['time stamps']['1'] = new_prog

                if not "sequential descriptions" in self.gpt_result:
                    self.gpt_result["sequential descriptions"] = [self.gpt_result['caption']]
            else:
                raise InvalidPrediction()

        if not "time stamps" in self.gpt_result:
            print("here")

        if "sequential descriptions" in self.gpt_result:
            description = self.gpt_result["sequential descriptions"]
            self.args = set()
            for sentence in description:
                sentence = sentence.replace(",", " ")
                sentence = sentence.replace("'s", " ")
                tokens = sentence.split(" ")
                for token in tokens:
                    if len(token) == 1 and token.isalpha() and not token == 'I':
                        self.args.add(token)
            self.args = list(self.args)
            self.arg_replace = ['?' + e.lower() for e in self.args]
        else:
            program = self.gpt_result["programmatic"]
            self.args = set()
            for clause in program:

                tokens = re.findall('\(.*\)', clause)
                assert len(tokens) == 1
                tokens = tokens[0]
                tokens = tokens[1: -1].split(',')
                for token in tokens:
                    if len(token) == 1 and token.isalpha():
                        self.args.add(token)
            self.args = list(self.args)
            self.arg_replace = ['?' + e.lower() for e in self.args]

    def parse_program(self, program):

        if len(program) == 0:
            return None

        new_clean_prog = []
        for p in program:
            new_program = [ i  for i in re.split(' or | and |^or\(|^and\(|;', p) if not len(i) == 0]

            if len(new_program) > 1:
                if new_program[0][0] == '(' and new_program[-1][-1] == ')':
                    new_program[0] = new_program[0][1:]
                    new_program[-1] = new_program[-1][:-1]
                print('here')
            new_clean_prog += (new_program)

        program = new_clean_prog

        clean_prog = []
        for p in program:
            new_program = re.split('\),', p)

            for pred in new_program[:-1]:
                clean_prog.append(pred + ')')

            clean_prog.append(new_program[-1])

            if 'neq' in p:
                print('here')

        program = clean_prog

        connect_programmatic = '(and(' + ')('.join(program)+ '))'
        connect_programmatic = connect_programmatic.replace('neq(', 'not( = ')
        # connect_programmatic = connect_programmatic.replace(',', ' ')
        connect_programmatic = connect_programmatic.replace(' or ', 'and')

        argument = pp.Word(pp.alphanums + pp.pyparsing_unicode.Latin1.alphas + '_' + ' '+ '-' + '\'' + '[' + ']') | pp.Word(',')
        result = nestedExpr('(',')', content=argument).parseString(connect_programmatic).asList()

        result, consts = flatten_arg_list(result, self.args)
        result = replace_nested_list(result, self.args, self.arg_replace)

        for const in consts:
            self.consts.add(const)

        if not len(result) == 1:
            raise InvalidPrediction

        assert len(result) == 1
        result = result[0]

        new_result = []
        correct_pred = True

        for predicate in result:
            if type(predicate) == list:
                # Process bool tokens
                if len(predicate) == 2 and type(predicate[1]) == list:
                    if not predicate[0] in bool_token :
                        correct_pred = False
                        print(f"Unseen predicate: {predicate}")
                if correct_pred:
                    new_result.append(predicate)
            else:
                new_result.append(predicate)

        return new_result

    def process_time_step(self):
        self.time_progs = []
        self.time_duration = []
        self.time_location = []
        if not 'time stamps' in self.gpt_result:
            raise InvalidPrediction()
        assert 'time stamps' in self.gpt_result

        arg2const = {}
        occurred_preds = set()
        for time, event in self.gpt_result['time stamps'].items():
            if not ('programmatic' in event and 'duration' in event and 'video location' in event):
                raise InvalidPrediction()

            program = self.parse_program(event['programmatic'])
            for i in program:
                if type(i) == list and i[0] == 'name':
                    if not len(i) == 3:
                        continue
                    arg2const[i[1]] = i[2]
                elif  type(i) == list:
                    occurred_preds.add(i[0])

        for time, event in self.gpt_result['time stamps'].items():
            assert 'programmatic' in event
            assert 'duration' in event
            assert 'video location' in event

            new_program = []
            program = self.parse_program(event['programmatic'])
            for clause in program:
                if not(type(clause) == list and clause[0] == 'name'):
                    if type(clause) == list:
                        new_clause = rec_sub_val(clause, arg2const)
                        new_program.append(new_clause)
                    else:
                        new_program.append(clause)

            time_prog = process_logic_pred(new_program)
            self.time_progs.append(time_prog)
            self.time_location.append(event['video location'])
            self.time_duration.append(event['duration'])

        self.consts = list(self.consts)

    def collect_kws(self):
        kws = {'unary': set(), 'binary': set()}

        for prog in self.time_progs:
            prog_kws = prog.collect_kws()
            [kws['unary'].add(i) for i in prog_kws['unary']]
            [kws['binary'].add(i) for i in prog_kws['binary']]

        kws['unary'] = list(kws['unary'])
        kws['binary'] = list(kws['binary'])
        return kws


    def to_scl(self):
        scl_facts = []
        scl_eids = []

        for prog in self.time_progs:
            eid, facts = prog.to_scl(0, self.consts)
            scl_eids.append(eid)
            scl_facts.append(facts)

        return scl_eids, scl_facts

    def to_str(self):
        scl_prog_strs = []

        for prog in self.time_progs:
            scl_prog_strs.append(f"Finally(Logic({prog.to_str(self.consts)}))")

        if len(self.time_progs) == 0:
            raise InvalidPrediction()

        #  Until(Finally(e1), Until(Finally(e2), Finally(e3)))
        program = reduce(lambda accu, elem: f"Until({elem}, {accu})", reversed(scl_prog_strs))
        return program

def process_gpt_spec(gpt_specs):
    all_actions = {}
    for action, spec in gpt_specs.items():

        prog = GptCaption(action, spec)
        scl_exp_ids, scl_facts = prog.to_scl()
        all_actions[action] = {}
        all_actions[action]['time_stamp_ids'] = scl_exp_ids
        all_actions[action]['time_stamp_facts'] = scl_facts

    return all_actions

if __name__ == "__main__":

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../../data/open_pvsg'))
    assert os.path.exists(data_dir)
    gpt_specs_path = os.path.join(data_dir, 'nl2spec', "videollamav2_origcap_2_gpt_cache.json")
    # action_scl_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_prog_str.json')
    # action_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_scl.json')
    action_scl_path = os.path.join(data_dir, 'nl2spec', 'gpt_specs_videollamav2_prog_str.json')

    gpt_specs = json.load(open(gpt_specs_path, 'r'))

    all_action_scl = {}
    all_action = {}
    total_spec_ct = 0
    wrong_pred_ct = 0
    wrong_arg_ct = 0

    for action, spec in gpt_specs.items():
        total_spec_ct += 1
        try:
            # if not action == "The cat gets up from the thigh of the companion.":
            #     continue
            # if not "The video captures a tender moment between a person and their fluffy cat. It begins with a close-up shot of the cat's face, which is gently stroked by a hand wearing purple gloves." in action:
            #     continue
            prog = GptCaption(action, spec)
            prog_string = prog.to_str()
            prog_kws = prog.collect_kws()

        except InvalidPrediction:
            wrong_pred_ct += 1
            continue
        except InvalidArgException:
            wrong_arg_ct += 1
            continue

        all_action[action] = prog

        # scl_exp_ids, scl_facts = prog.to_scl()
        all_action_scl[action] = {}
        # all_action_scl[action]['time_stamp_ids'] = scl_exp_ids
        # all_action_scl[action]['time_stamp_facts'] = scl_facts

        all_action_scl[action]['unary_kws'] = prog_kws['unary']
        all_action_scl[action]['binary_kws'] = prog_kws['binary']

        all_action_scl[action]['prog'] = prog_string
        all_action_scl[action]['consts'] = prog.consts
        all_action_scl[action]['args'] = prog.args
        all_action_scl[action]['caption'] = action
        all_action_scl[action]['duration'] = [ts_info['duration'] for ts_info in spec['time stamps'].values()]
        all_action_scl[action]['video location'] = [ts_info['video location'] for ts_info in spec['time stamps'].values()]

    json.dump(all_action_scl, open(action_scl_path, 'w'))
    print('here')