import os
import json
from tqdm import tqdm
import hashlib
import random
from cpp_demangle import demangle
import re

def extract_pure_function_name(func_name: str) -> str:
    try:
        demangled_name = demangle(func_name).strip()
        if demangled_name.endswith('const'):
            demangled_name = demangled_name[:-5].strip()
        no_args = re.sub(r'\([^()]*(?:\(.*\))?[^()]*\)$', '', demangled_name)
        parts = []
        current = []
        bracket_stack = []
        i = 0
        n = len(no_args)
        
        while i < n:
            c = no_args[i]
            if c in '(<':
                bracket_stack.append(c)
            elif c in ')>':
                if bracket_stack:
                    bracket_stack.pop()
            if i < n-1 and no_args[i:i+2] == '::' and not bracket_stack:
                parts.append(''.join(current))
                current = []
                i += 2
                continue
            current.append(c)
            i += 1
        if current:
            parts.append(''.join(current))        
        if not parts:
            return demangled_name
        last_segment = parts[-1]
        #print(last_segment)
        pure_name = re.sub(r'<([^<>]*(<.*>)?[^<>]*)>$', '', last_segment)
        #print(pure_name)
        if '' in pure_name:
            pure_name = pure_name.split(' ')[-1]
        return pure_name
    except Exception as e:
        #print(e)
        return func_name

gcov_bin_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + '/bin'
tmp_dir = '/tmp'
llvm_profdata = '/usr/lib/llvm-14/bin/llvm-profdata'
llvm_cov = '/usr/lib/llvm-14/bin/llvm-cov'

# Input: filename, line, col
# Output: source within the range
def get_source_of_region(file_name, func_name, line_st, col_st, line_ed, col_ed):
    ret_linenos, ret_lines, srd_lines = [], [], []
    if not os.path.exists(file_name):
        return ret_linenos, ret_lines, srd_lines
    with open(file_name, 'r') as f:
        for i, line in enumerate(f.readlines()):
            lineno = i + 1
            if lineno >= line_st - 2 and \
                lineno <= line_ed + 2:
                srd_lines.append(line)
            if lineno > line_ed:
                break
            if lineno < line_st:
                continue
            
            line_seg = None
            if lineno == line_st:
                if line_st == line_ed:
                    line_seg = line[col_st-1:col_ed-1]
                else:
                    line_seg = line[col_st-1:]
            elif lineno == line_ed:
                line_seg = line[:col_ed-1]
            else:
                line_seg = line
            if len(line_seg) <= 4:
                continue
            else:
                ret_linenos.append(lineno)
                ret_lines.append(line)
    return ret_linenos, ret_lines, srd_lines

# (line_st, col_st, line_ed, col_ed, func_name, file_name) -> (linenos, lines, srd_lines)
region_mp = {}

# Input: program, input_file
# Output: covered region
def get_covered_regions(prog, in_file='./a.in'):
    cov_binary = f'{gcov_bin_dir}/{prog}'
    
    in_file_hash = hashlib.md5(in_file.encode('ascii')).hexdigest()
    
    profraw = f'{tmp_dir}/{prog}.{in_file_hash}.profraw'
    profdata = f'{tmp_dir}/{prog}.{in_file_hash}.profdata'
    profjson = f'{tmp_dir}/{prog}.{in_file_hash}.json'
    for prof in [profraw, profdata, profjson]:
        if os.path.exists(prof):
            os.system(f'rm {prof} 2>/dev/null >/dev/null')
    in_file_str = in_file.replace('(', '\\(').replace(')', '\\)')
    os.system(f'LLVM_PROFILE_FILE={profraw} {cov_binary} {in_file_str} 2>/dev/null >/dev/null')
    os.system(f'{llvm_profdata} merge -output={profdata} {profraw} 2>/dev/null >/dev/null')
    os.system(f'{llvm_cov} export --instr-profile {profdata} -format=text {cov_binary} > {profjson} 2>/dev/null')

    try:
        dct = json.load(open(profjson, 'r'))
    except:
        return set()
    
    covered_regs = set()
    for fdct in dct['data'][0]['functions']:
        file_names = fdct['filenames']
        func_name = fdct['name']
        if ':' in func_name:
            #and 'Ligature' in func_name and 'would_apply' in func_name:
            #print(func_name)
            #input()
            func_name = func_name.split(':')[1]
        regions = fdct['regions']
        for region in regions:
            line_st, col_st, line_ed, col_ed, exec_cnt, file_id, _, _ = region
            if exec_cnt == 0:
                continue
            file_name = file_names[file_id]
            reg = (file_name[39:], extract_pure_function_name(func_name), line_st, col_st, line_ed, col_ed)
            if reg not in region_mp:
                linenos, lines, srd_lines = get_source_of_region(file_name, func_name, line_st, col_st, line_ed, col_ed)
                if len(linenos) == 0:
                    region_mp[reg] = None
                else:
                    region_mp[reg] = (linenos, lines, srd_lines)
            if region_mp[reg] is not None:
                covered_regs.add(reg)
    return covered_regs

def get_newly_covered_regs(prog, fuzz_queue):
    seeds = sorted([f'{fuzz_queue}/{f}' 
                    for f in os.listdir(fuzz_queue) 
                    if f.startswith('id:') and (f.endswith('+cov') or 'orig' in f)])
    start_time = os.path.getctime(seeds[0])
    fuzz_seeds = []
    for file in seeds:
        ts = int(os.path.getctime(file) - start_time)
        fuzz_seeds.append([file, ts if ts >= 0 else 0])
    cur_ts = fuzz_seeds[-1][1]
    for idx in reversed(range(len(fuzz_seeds))):
        if fuzz_seeds[idx][1] > cur_ts:
            fuzz_seeds[idx][1] = cur_ts
        else:
            cur_ts = fuzz_seeds[idx][1]
    cov_regs = set()
    n_covs = []
    for sd, ts in tqdm(fuzz_seeds):
        newly_covs = get_covered_regions(prog, in_file=sd)
        n_cov = newly_covs - cov_regs
        if len(n_cov) == 0:
            continue
        n_covs.append((n_cov, ts))
        cov_regs |= newly_covs
    
    return n_covs

def extract_target_locs(n_covs, threshold_time):
    target_locs = []
    for i, (n_cov, ts) in enumerate(n_covs):
        if ts < threshold_time:
            continue
        n_cov_lst = list(n_cov)
        random.shuffle(n_cov_lst)
        reg = None
        for cur_reg in n_cov_lst:
            line0 = region_mp[cur_reg][1][0]
            if ('#' in line0 and 'define' in line0) or \
                '//' in line0 or \
                '/*' in line0 or \
                '*/' in line0:
                continue
            else:
                reg = cur_reg
                break
        if reg is None:
            continue
        _, lines, srd_lines = region_mp[reg]
        lines = lines[:5]
        srd_lines = srd_lines[:9]
        code = ''.join(lines)
        surround_code = ''.join(srd_lines)
        file_name, func_name, line_st, col_st, line_ed, col_ed = reg
        target_loc = {
            'file_name': file_name,
            'func_name': extract_pure_function_name(func_name),
            'start_lineno': line_st,
            'start_colno': col_st,
            'end_lineno': line_ed,
            'end_colno': col_ed,
            'code': code,
            'surround_code': surround_code,
            'discovered_time': ts,
        }
        
        target_locs.append(target_loc)
    return target_locs

def ttg_gen(entry_program, fuzz_queue, threshold_time):
    n_covs = get_newly_covered_regs(entry_program, fuzz_queue)
    target_locs = extract_target_locs(n_covs, threshold_time)
    return target_locs

if __name__ == '__main__':
    entry_program = 'libpng_read_fuzzer'
    fuzz_queue = './libpng/fuzzout/queue'
    threshold_time = 3600
    
    target_locs = ttg_gen(entry_program, fuzz_queue, threshold_time)
    for loc in target_locs:
        print(loc)