import pickle
from tokens import count_token
import os

# maximum token per one query
MAX_TOKEN = 3000
code_info_dir = os.path.dirname(os.path.abspath(__file__)) + '/cinfo'
class CodeBrowser:
    def __init__(self, entry_program):
        self.func_code_mp, self.called_by_mp, self.global_code_mp, \
            self.global_mp, self.type_code_mp, self.unit_lst, self.macro_mp =\
                pickle.load(
                    open(f'{code_info_dir}/{entry_program}.pkl', 'rb'))
    def get_func_code(self, func_name, class_name):
        ret = set()
        for (fname, _, _, cname) , (_, _, _, _, _, code) in self.func_code_mp.items():
            if func_name == fname and \
                class_name == cname:
                if count_token(code) >= MAX_TOKEN:
                    if len(ret) == 0:
                        code = code[:MAX_TOKEN * 5]
                    else:
                        continue
                ret.add(code)
                if count_token('\n'.join(list(ret))) >= MAX_TOKEN:
                    break
        if len(ret) == 0:
            for (fname, _, _, cname) , (_, _, _, _, _, code) in self.func_code_mp.items():
                if func_name == fname:
                    if count_token(code) >= MAX_TOKEN:
                        continue
                    ret.add(code)
                    if count_token('\n'.join(list(ret))) >= MAX_TOKEN:
                        break
        if len(ret) == 0:
            return None
        else:
            return '\n'.join(list(ret))
    def get_called_by(self, func_name):
        ret = ''
        if func_name in self.called_by_mp:
            for caller_code in self.called_by_mp[func_name]:
                if count_token(caller_code) >= MAX_TOKEN:
                    continue
                ret += caller_code + '\n'
                if count_token(ret) >= MAX_TOKEN:
                    break
        else:
            ''' for called_name, caller_code_st in self.called_by_mp.items():
                if func_name in called_name or \
                    called_name in func_name:
                    for caller_code in caller_code_st:
                        ret += caller_code + '\n' '''
        return ret
    def get_global_var_used(self, var_name):
        # get function code that uses var_name
        # var_name can be global or class/struct member
        ret = ''
        if var_name in self.global_mp:
            for caller_code in self.global_mp[var_name]:
                if count_token(caller_code) >= MAX_TOKEN:
                    continue
                ret += caller_code + '\n'
                if count_token(ret) >= MAX_TOKEN:
                    break
        return ret
    def get_global_var(self, var_name):
        # get definition of var_name
        # var_name can be global or enum 
        ret = ''
        if var_name in self.global_code_mp:
            for code in self.global_code_mp[var_name]:
                if count_token(code) >= MAX_TOKEN:
                    continue
                ret += code + '\n'
                if count_token(ret) >= MAX_TOKEN:
                    break
        else:
            # if not found, relax name constraint
            for vname, code_set in self.global_code_mp.items():
                if var_name in vname or\
                    vname in var_name:
                    for code in code_set:
                        if count_token(code) >= MAX_TOKEN:
                            continue
                        ret += code + '\n'
                        if count_token(ret) >= MAX_TOKEN:
                            break
                    if count_token(ret) >= MAX_TOKEN:
                        break
        if len(ret) == 0:
            return None
        else:
            return ret
    def get_type_code(self, type_name):
        # get definition of class/struct/enum using type_name
        ret = ''
        if type_name in self.type_code_mp:
            ret += '\n'.join(list(self.type_code_mp[type_name])) + '\n'
        else:
            for tyname, code_st in self.type_code_mp.items():
                if type_name in tyname or\
                    tyname in type_name:
                    ret += '\n'.join(list(code_st)) + '\n'
        if len(ret) == 0:
            if type_name in self.global_code_mp:
                for code in self.global_code_mp[type_name]:
                    if count_token(code) >= MAX_TOKEN:
                        continue
                    ret += code + '\n'
                    if count_token(ret) >= MAX_TOKEN:
                        break
        if len(ret) == 0:
            return None
        else:
            return ret
    
    def get_code_unit(self, file_path_target, lineno):
        ret = []
        for (file_path, startline, startcol, endline, endcol), (code, ty) in self.unit_lst.items():
            if os.path.abspath(file_path) == os.path.abspath(file_path_target) and \
                startline <= lineno <= endline:
                ret.append((code, ty))
        assert(len(ret) > 0)
        if len(ret) == 1:
            code = list(ret)[0][0]
            ty = list(ret)[0][1]
            if count_token(code) < MAX_TOKEN:
                return code, ty
            else:
                return code[:4 * MAX_TOKEN], ty
        else:
            ret = sorted(ret, key=lambda x: len(x[0]), reverse=True)
            for code, ty in ret:
                if count_token(code) < MAX_TOKEN:
                    return code, ty
            return code[:4 * MAX_TOKEN], ty
       
    def get_macro_def(self, macro_name):
        ret = set()     
        if macro_name in self.macro_mp:
            for code in self.macro_mp[macro_name]:
                if count_token(code) >= MAX_TOKEN:
                    continue
                ret.add(code)
                if count_token('\n'.join(list(ret))) >= MAX_TOKEN:
                    break
        if len(ret) == 0:
            return None
        else:
            return '\n'.join(list(ret))
    @staticmethod
    def simplify_name(name):
        if '::' in name:
            name = name.split('::')[-1].strip()
        if '.' in name:
            name = name.split('.')[-1].strip()
        if '->' in name:
            name = name.split('->')[-1].strip()
        if '[' in name:
            name = name.split('[')[0].strip()
        if '<' in name:
            name = name.split('<')[0].strip()
        return name
    def get_query_result(self, json_dct):
        ret = ''
        if 'func_name' in json_dct:
            for func_name in json_dct['func_name']:
                if '::' in func_name:
                    class_name, func_name = func_name.split('::')[-2:]
                else:
                    class_name = None
                func_name = CodeBrowser.simplify_name(func_name)
                func_code = self.get_func_code(func_name, class_name)
                if func_code:
                    ret += f'The code of function {func_name} is shown below.\n' + \
                            func_code + '\n'
                else:
                    ret += f'There is no function named {func_name}.\n'
                    ty_code = self.get_type_code(func_name)
                    if ty_code:
                        ret += \
                            f'There is a class/struct/enum named {func_name}, the code is shown below:\n' + \
                            ty_code + '\n'
                    mac_code = self.get_macro_def(func_name)
                    if mac_code:
                        ret += \
                            f'There is a macro named {func_name}, the code is shown below:\n' + \
                            mac_code + '\n'
        if 'called_name' in json_dct:
            for called_name in json_dct["called_name"]:
                called_name = CodeBrowser.simplify_name(called_name)
                ret += f'The code that calls {called_name} is shown below:\n'
                ret += self.get_called_by(called_name) + '\n'
        if 'global_used' in json_dct or 'class_member_used' in json_dct:
            global_used_queries = []
            if 'global_used' in json_dct:
                global_used_queries += json_dct['global_used']
            if 'class_member_used' in json_dct:
                global_used_queries += json_dct['class_member_used']
            for global_used in global_used_queries:
                global_used = CodeBrowser.simplify_name(global_used)
                ret += f'The code that used {global_used} is shown below:\n'
                ret += self.get_global_var_used(global_used)
        if 'global_def' in json_dct:
            for global_def in json_dct['global_def']:
                global_def = CodeBrowser.simplify_name(global_def)
                global_code = self.get_global_var(global_def)
                if global_code:
                    ret += f'The code of global variable {global_def} is shown below:\n' + \
                            global_code + '\n'
                else:
                    ret += f'There is no global variable named {global_def}.\n'
                    func_code = self.get_func_code(global_def, None)
                    if func_code:
                        ret += \
                            f'There is a function named {global_def}, the code is shown below:\n' + \
                            func_code + '\n'
                    mac_code = self.get_macro_def(global_def)
                    if mac_code:
                        ret += \
                            f'There is a macro named {global_def}, the code is shown below:\n' + \
                            mac_code + '\n'
                    ty_code = self.get_type_code(global_def)
                    if ty_code:
                        ret += \
                            f'There is a class/struct/enum named {global_def}, the code is shown below:\n' + \
                            ty_code + '\n'
        if 'type_name' in json_dct:
            for type_name in json_dct['type_name']:
                type_name = CodeBrowser.simplify_name(type_name)
                ty_code = self.get_type_code(type_name)
                if ty_code:
                    ret += f'The code of class/struct/enum {type_name} is shown below:\n' + \
                            ty_code + '\n'
                else:
                    ret += f'There is no class/struct/enum named {type_name}.\n'
                    func_code = self.get_func_code(type_name, None)
                    if func_code:
                        ret += \
                            f'There is a function named {type_name}, the code is shown below:\n' + \
                            func_code + '\n'
                    mac_code = self.get_macro_def(type_name)
                    if mac_code:
                        ret += \
                            f'There is a macro named {func_name}, the code is shown below:\n' + \
                            mac_code + '\n'
        return ret
        
        
if __name__ == '__main__':
    entry_program = 'libpng_read_fuzzer'
    cb = CodeBrowser(entry_program)
    res_1 = cb.get_query_result({
        'func_name': ['LLVMFuzzerTestOneInput'],
    })
    print(res_1)
    res_2 =  cb.get_query_result({
        'called_name': ['png_handle_sCAL'],
    })
    print(res_2)