from idautils import *
from idaapi import *
from idc import *
from idc_bc695 import *
import ida_typeinf
import ida_nalt
import ida_auto
import sys
import idaapi
import idautils
import json

filename = GetInputFile()
decom_buf = list()

idaapi.auto_wait()

#idaapi.load_plugin('hexrays')
#idaapi.load_plugin('hexx64')
idaapi.load_plugin('hexarm')
if not idaapi.init_hexrays_plugin():
    print('Unable to load Hex-rays')


orig_name = filename.replace('strip_', '')

with open(orig_name + '_kb.json', 'r') as infile:
    kb = json.load(infile)


callees = dict()
for function_ea in idautils.Functions():
    f_name = idc.get_func_name(function_ea)
    for ref_ea in idautils.CodeRefsTo(function_ea, 0):
        caller_name = idc.get_func_name(ref_ea)
        callees[str(caller_name)] = callees.get(str(caller_name), set())
        callees[str(caller_name)].add(function_ea)
        #callees[str(caller_name)].add(str(f_name))

output = {}
for function_ea in idautils.Functions():
    var_cnt = 1
    type_cnt = 1
    func_cnt = 1

    orig_var = []
    orig_type = []
    orig_func = []
    temp = {}
    answer = {}
    orig_f_name = idc.get_func_name(function_ea)

    if str(function_ea) in kb:
        idaapi.set_name(function_ea, 'FUNC' + str(func_cnt), idaapi.SN_FORCE)
        answer['FUNC' + str(func_cnt)] = kb[str(function_ea)]['name']
        func_cnt += 1
    else:
        print ('no func addr', function_ea)
        continue


    if orig_f_name in callees:
        for callee in callees[orig_f_name]:
            name = idc.get_func_name(callee)
            if callee in kb[str(function_ea)]['caller']:
                idaapi.set_name(callee, 'FUNC' + str(func_cnt), idaapi.SN_FORCE)
                func_cnt += 1
                answer['FUNC' + str(func_cnt)] = kb[str(function_ea)]['caller'][callee]
                orig_func.append(name)


            else:
                print ('no function name', function_ea, callee)


    f = get_func(function_ea)
    try:
        cfunc = decompile(f)
        vuu = open_pseudocode(function_ea, 0)
        if cfunc is None:
            continue

        for v in cfunc.arguments:
            loc = v.location
            reg = print_vdloc(loc, 1)
            if reg in kb[str(function_ea)]['var']:
                var_name, var_type = kb[str(function_ea)]['var'][reg]
                if var_name == "" or var_type == "":
                    continue
            else:
                print ('no argument', loc)
                continue
            orig_var.append(v.name)
            orig_type.append(v.type())
            answer['VAR' + str(var_cnt)] = var_name
            answer['TYPE' + str(type_cnt)] = var_type
            v.name = 'VAR' + str(var_cnt)
            tinfo = create_typedef('TYPE' + str(type_cnt))
            var_cnt += 1
            type_cnt += 1
            tt = make_pointer(tinfo)
            v.set_lvar_type(tt)
            v.set_user_name()
            v.set_user_type()


        for v in cfunc.lvars:
            if v in cfunc.arguments:
                continue
            loc = v.location

            try:
                loc = int(print_vdloc(loc, 1).replace('^',''), 16)
            except Exception as e:
                continue

            if str(loc) in kb[str(function_ea)]['var']:
                var_name, var_type = kb[str(function_ea)]['var'][str(loc)]
                if var_name == "" or var_type == "":
                    continue
            else:
                print ('no variable', function_ea, loc)
                continue
            orig_var.append(v.name)
            orig_type.append(v.type())
            answer['VAR' + str(var_cnt)] = var_name
            answer['TYPE' + str(type_cnt)] = var_type
            v.name = 'VAR' + str(var_cnt)
            tinfo = create_typedef('TYPE' + str(type_cnt))
            var_cnt += 1
            type_cnt += 1
            tt = make_pointer(tinfo)
            v.set_lvar_type(tt)
            v.set_user_name()
            v.set_user_type()

    except Exception as e:
        print (e)
        idaapi.set_name(function_ea, orig_f_name, idaapi.SN_FORCE)
        if orig_f_name in callees:
            idx=0
            for callee in callees[orig_f_name]:
                if callee not in kb[str(function_ea)]['caller']:
                    continue
                name = idc.get_func_name(callee)
                idaapi.set_name(callee, orig_func[idx], idaapi.SN_FORCE)
                idx += 1

        for v in cfunc.lvars:
            v.clr_user_name()
            v.clr_user_type()
        print ('error')
        continue

    lines = []
    sv = cfunc.get_pseudocode()


    for sline in sv:
        line = tag_remove(sline.line)
        if len(lines) == 0:
            if kb[str(function_ea)]['type'] != '?':
                print ('prev', line)
                func_type = kb[str(function_ea)]['type']
                answer['TYPE' + str(type_cnt)] = func_type
                comps = line.split()
                final_idx = -1
                idx=0
                for comp in comps:
                    if 'fastcall' in comp or 'FUNC1' in comp:
                        final_idx = idx
                        break
                    idx += 1

                if final_idx == -1:
                    print ('something wrong', line)
                else:
                    while True:
                        if final_idx == 1:
                            break
                        comps.pop(0)
                        final_idx -= 1

                    comps[0] = 'TYPE' + str(type_cnt)
                line = ' '.join(comps)


        lines.append(line)

    decomp = '\n'.join(lines)

    idaapi.set_name(function_ea, orig_f_name, idaapi.SN_FORCE)

    if orig_f_name in callees:
        idx=0
        for callee in callees[orig_f_name]:
            if callee not in kb[str(function_ea)]['caller']:
                continue
            name = idc.get_func_name(callee)
            idaapi.set_name(callee, orig_func[idx], idaapi.SN_FORCE)
            idx += 1

    for v in cfunc.lvars:
        v.clr_user_name()
        v.clr_user_type()

    cfunc.reset()
    ida_funcs.del_func(function_ea)
    ida_funcs.add_func(function_ea)
    ida_auto.plan_and_wait(function_ea, function_ea + 1)
    temp['funcbody'] = decomp
    temp['answer'] = answer

    output[orig_f_name] = temp

#json_object = json.dumps(output, indent=2)
print (filename)
with open(filename + '.json', 'w') as outfile:
    json.dump(output, outfile, indent=2)

print ('Done')
ida_pro.qexit(0)
