import hashlib
import os
import tempfile
import json

def exec_python_code(python_code):
    import sys
    orig_stdout = sys.stdout
    orig_stderr = sys.stderr
    orig_stdin = sys.stdin
    
    sys.stdout = open('/dev/null', 'wb')
    sys.stderr = open('/dev/null', 'wb')
    sys.stdin = open('/dev/null', 'rb')
    
    os.chdir('/tmp')
    try:
        exec(python_code, locals(), locals())
        gen_func = locals()['generate_buf']
        buf = gen_func()
    except:
        buf = None
        
    sys.stdout = orig_stdout
    sys.stderr = orig_stderr
    sys.stdin = orig_stdin
    
    return buf

def exec_with_timeout(python_code, timeout=60):
    from concurrent.futures import ProcessPoolExecutor, TimeoutError
    with ProcessPoolExecutor() as executor:  
        future = executor.submit(exec_python_code, python_code)  
        try:  
            result = future.result(timeout=timeout)
            return result
        except TimeoutError:  
            return None

gcov_bin_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + '/bin'
tmp_dir = '/tmp'
llvm_profdata = '/usr/lib/llvm-14/bin/llvm-profdata'
llvm_cov = '/usr/lib/llvm-14/bin/llvm-cov'
def get_covered_regions(prog, in_file='./a.in'):
    cov_binary = f'{gcov_bin_dir}/{prog}'
    in_file_hash = hashlib.md5(in_file.encode('ascii')).hexdigest()
    profraw = f'{tmp_dir}/{prog}.{in_file_hash}.profraw'
    profdata = f'{tmp_dir}/{prog}.{in_file_hash}.profdata'
    profjson = f'{tmp_dir}/{prog}.{in_file_hash}.json'
    for prof in [profraw, profdata, profjson]:
        if os.path.exists(prof):
            os.system(f'rm {prof} 2>/dev/null >/dev/null')
    in_file_str = in_file.replace('(', '\\(').replace(')', '\\)')
    os.system(f'LLVM_PROFILE_FILE={profraw} {cov_binary} {in_file_str} 2>/dev/null >/dev/null')
    os.system(f'{llvm_profdata} merge -output={profdata} {profraw} 2>/dev/null >/dev/null')
    os.system(f'{llvm_cov} export --instr-profile {profdata} -format=text {cov_binary} > {profjson} 2>/dev/null')

    try:
        dct = json.load(open(profjson, 'r'))
    except:
        return set()
    
    covered_regs = set()
    for fdct in dct['data'][0]['functions']:
        file_names = fdct['filenames']
        func_name = fdct['name']
        if ':' in func_name:
            func_name = func_name.split(':')[1]
        regions = fdct['regions']
        for region in regions:
            line_st, col_st, line_ed, col_ed, exec_cnt, file_id, _, _ = region
            if exec_cnt == 0:
                continue
            file_name = file_names[file_id]
            #reg = (line_st, col_st, line_ed, col_ed, func_name, file_name)
            reg = (file_name[39:], line_st, col_st, line_ed, col_ed)
            covered_regs.add(reg)
    return covered_regs

def verify_result(problem, python_code):
    target_location = problem['target_location']
    region = (target_location['file_name'],
              target_location['start_lineno'],
              target_location['start_colno'],
              target_location['end_lineno'],
              target_location['end_colno'],)
    prog = problem['entry_program']

    buf = exec_with_timeout(python_code, timeout=60)
    if buf is None:
        return False
    tmp_file = tempfile.NamedTemporaryFile(delete=False)
    buf_file = tmp_file.name
    with open(buf_file, 'wb') as f:
        f.write(buf)
    regs = get_covered_regions(prog, buf_file)
    return region in regs
    
if __name__ == '__main__':
    problem = {"entry_program": "libpng_read_fuzzer", "project": "libpng", "repo_url": "https://github.com/glennrp/libpng.git", "commit": "cd0ea2a7f53b603d3d9b5b891c779c430047b39a", "problem_id": 234, "target_location": {"file_name": "libpng/libpng/pngrutil.c", "func_name": "png_handle_sCAL", "start_lineno": 2396, "start_colno": 4, "end_lineno": 2400, "end_colno": 5, "code": "      png_crc_finish(png_ptr, length);\n      png_chunk_benign_error(png_ptr, \"duplicate\");\n      return;\n", "surround_code": "\n   else if (info_ptr != NULL && (info_ptr->valid & PNG_INFO_sCAL) != 0)\n   {\n      png_crc_finish(png_ptr, length);\n      png_chunk_benign_error(png_ptr, \"duplicate\");\n      return;\n   }\n\n", "discovered_time": 4798}}
    code_correct = "import zlib\n\ndef generate_buf():\n    # PNG signature\n    png_sig = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A])\n    \n    # IHDR chunk (1x1 RGB)\n    ihdr_data = bytes([\n        0x00, 0x00, 0x00, 0x01,  # width\n        0x00, 0x00, 0x00, 0x01,  # height\n        0x08,  # bit depth\n        0x02,  # color type (RGB)\n        0x00,  # compression\n        0x00,  # filter\n        0x00   # interlace\n    ])\n    ihdr_type = b'IHDR'\n    ihdr_crc = zlib.crc32(ihdr_type + ihdr_data).to_bytes(4, 'big')\n    ihdr_chunk = bytes([0x00, 0x00, 0x00, 0x0D]) + ihdr_type + ihdr_data + ihdr_crc\n    \n    # First valid sCAL chunk\n    scal1_data = bytes([0x01]) + b'1.0\\x001.0'\n    scal1_type = b'sCAL'\n    scal1_crc = zlib.crc32(scal1_type + scal1_data).to_bytes(4, 'big')\n    scal1_chunk = bytes([0x00, 0x00, 0x00, 0x08]) + scal1_type + scal1_data + scal1_crc\n    \n    # Second sCAL chunk (duplicate)\n    scal2_data = bytes([0x01, 0x00, 0x00, 0x00])\n    scal2_type = b'sCAL'\n    scal2_crc = zlib.crc32(scal2_type + scal2_data).to_bytes(4, 'big')\n    scal2_chunk = bytes([0x00, 0x00, 0x00, 0x04]) + scal2_type + scal2_data + scal2_crc\n    \n    # IDAT chunk (minimal valid zlib data for 1x1 RGB)\n    idat_data = zlib.compress(b'\\x00\\x00\\x00\\x00\\x00')\n    idat_type = b'IDAT'\n    idat_crc = zlib.crc32(idat_type + idat_data).to_bytes(4, 'big')\n    idat_chunk = len(idat_data).to_bytes(4, 'big') + idat_type + idat_data + idat_crc\n    \n    # IEND chunk\n    iend_chunk = bytes([0x00, 0x00, 0x00, 0x00]) + b'IEND' + zlib.crc32(b'IEND').to_bytes(4, 'big')\n    \n    return png_sig + ihdr_chunk + scal1_chunk + scal2_chunk + idat_chunk + iend_chunk"
    code_wrong = "import zlib\ndef generate_buf():\n    # PNG signature\n    png = bytearray(b'\\x89PNG\\r\\n\\x1a\\n')\n    \n    # IHDR chunk (1x1 RGBA)\n    ihdr_data = bytes.fromhex('00000001 00000001 08020000 00')\n    ihdr_type = b'IHDR'\n    ihdr_crc = zlib.crc32(ihdr_type + ihdr_data).to_bytes(4, 'big')\n    png.extend((len(ihdr_data)).to_bytes(4, 'big') + ihdr_type + ihdr_data + ihdr_crc)\n    \n    # First valid sCAL chunk (unit=1, width=\"1\", height=\"1\")\n    scal1_data = bytes.fromhex('0131003100')\n    scal_type = b'sCAL'\n    scal1_crc = zlib.crc32(scal_type + scal1_data).to_bytes(4, 'big')\n    png.extend((len(scal1_data)).to_bytes(4, 'big') + scal_type + scal1_data + scal1_crc)\n    \n    # Second sCAL chunk (duplicate, minimal data)\n    scal2_data = bytes.fromhex('01000000')\n    scal2_crc = zlib.crc32(scal_type + scal2_data).to_bytes(4, 'big')\n    png.extend((len(scal2_data)).to_bytes(4, 'big') + scal_type + scal2_data + scal2_crc)\n    \n    # IDAT chunk (minimal data)\n    idat_data = bytes([0])\n    idat_type = b'IDAT'\n    idat_crc = zlib.crc32(idat_type + idat_data).to_bytes(4, 'big')\n    png.extend((len(idat_data)).to_bytes(4, 'big') + idat_type + idat_data + idat_crc)\n    \n    # IEND chunk\n    iend_type = b'IEND'\n    iend_crc = zlib.crc32(iend_type).to_bytes(4, 'big')\n    png.extend((0).to_bytes(4, 'big') + iend_type + iend_crc)\n    \n    return bytes(png)"
    print(verify_result(problem, code_correct),
          verify_result(problem, code_wrong))