import os, argparse
from utils import restriction_prompt, query, get_code_with_markdown, find_context, lines_eq_start, lines_eq_end, rm_none_line
from fp_utils import build_fp, collect_all_function_call_names, collect_all_function_def_names, anchor, tokenize_c_code


def get_new_block(part_old, code_new, name_list=[], range_anchor=[], model='gpt-4', all_query_time=None):
    if model == 'gpt-4':
        mm = 'gpt-4-turbo'
    elif model == 'gpt-3.5':
        mm = 'gpt-3.5-turbo-16k'
    else:
        raise ValueError

    task_prompt = 'We are facing a challenge that requires your specialized knowledge and expertise. We need to locate a corresponding segment of code, indicated as `part_new`, within a C file named `new.c` that matches semantically with a provided code snippet labeled as `part_old`. Given that `part_new`, the target code segment, originates from modifications made to `part_old`, it is essential to identify this correspondence accurately.\n'
    prompt_part_old = f'The starting point for your task involves comparing the following `part_old`: \n```\n{part_old}\n```\n'
    prompt_code_new = f'And the entire context available in the `new.c`: \n```\n{code_new}\n```\n'
    if len(name_list) == 0:
        prompt_fuc_def = ''
    else:
        feature_list = []
        for name, types in name_list:
            _name = types
            _name.append(name)
            func_def_features = ' '.join(_name)
            feature_list.append(func_def_features)
            feature = ', '.join(feature_list)
        prompt_fuc_def = f'It appears that `part_old` encompasses the definition of the function `{feature}`. Your role is to pinpoint the matching code segment `part_new` within `new.c`. Please ensure that the identified function definitions are solely derived from `new.c`. Avoid constructing false code snippets by using the function definitions from `part_old`.\n'
    
    if len(range_anchor) == 0:
        prompt_align = ''
    else:
        prompt_align = f'To facilitate the search, you may need to align `part_new` using the initial line `{range_anchor[0]}` and the final line `{range_anchor[1]}` from `part_old`.\n'
    
    prompt = task_prompt + prompt_part_old + prompt_code_new + prompt_fuc_def + prompt_align + restriction_prompt
    response = query(user_content=prompt, query_model=mm, temperature=0.5)
    response_code = get_code_with_markdown(response)
    all_query_time += 1

    interact_list = []
    response_code_lines = response_code.splitlines()
    find_name = []
    for name, types in name_list:
        for line in response_code_lines:
            tokens = tokenize_c_code(line)
            if name in tokens:
                for t in types:
                    if t in tokens:
                        find_name.append((name, types))
                        break
    for item in name_list:
        if item not in find_name:
            interact_list.append(item)

    code_new_lines = code_new.splitlines()

    if len(interact_list) != 0:
        fp_tmp_part_new = build_fp(response_code)
        tmp_funcdef_list = collect_all_function_def_names(fp_tmp_part_new)
        tmp_list = []
        for tmp_feature, _ in tmp_funcdef_list:
            tmp_list.append(tmp_feature)

        rm_context = find_context(tmp_list, code_new, out_type='idx')
        tmp_code_new_lines = []
        for idx, line in enumerate(code_new_lines):
            for tmp_s, tmp_e in rm_context:
                if not (idx>=tmp_s and idx<=tmp_e):
                    tmp_code_new_lines.append(line)
        tmp_code_new = '\n'.join(tmp_code_new_lines)
        tmp_response_code = response_code
        query_time = 0
        while len(interact_list) > 0 and query_time <= 2:
            _prompt_code_new = f'Here is the code of `new.c`: \n```\n{tmp_code_new}\n```\n'
            _prompt = task_prompt + prompt_part_old + _prompt_code_new + prompt_fuc_def + prompt_align + restriction_prompt

            response = query(user_content=_prompt, query_model=mm, temperature=0.4)
            response_code = get_code_with_markdown(response)
            all_query_time += 1
            if response_code is None:
                response_code = tmp_response_code
                print('Back to last')
                break

            interact_list = []
            response_code_lines = response_code.splitlines()
            find_name = []
            for name, types in name_list:
                for line in response_code_lines:
                    if name in line:
                        for t in types:
                            if t in line:
                                find_name.append((name, types))
                                break
            for item in name_list:
                if item not in find_name:
                    interact_list.append(item)
            if len(interact_list) == 0:
                break
            else:
                fp_tmp_part_new = build_fp(response_code)
                tmp_funcdef_list = collect_all_function_def_names(fp_tmp_part_new)
                tmp_list = []
                for tmp_feature, _ in tmp_funcdef_list:
                    tmp_list.append(tmp_feature)
                rm_context = find_context(tmp_list, code_new, out_type='idx')
                _tmp_code_new_lines = []
                for idx, line in enumerate(tmp_code_new_lines):
                    for tmp_s, tmp_e in rm_context:
                        if not (idx>=tmp_s and idx<=tmp_e):
                            _tmp_code_new_lines.append(line)
                tmp_code_new = '\n'.join(_tmp_code_new_lines)
                tmp_response_code = response_code

            query_time += 1

    response_code_lines = response_code.splitlines()
    start = -1
    end = -1
    start_add = 0
    end_add = 0
    
    if '}' in response_code_lines[-1] and '{' not in response_code_lines[-1]:
        if '}' in response_code_lines[-2]:
            end_line = response_code_lines[-3]
            end_line_up = response_code_lines[-4]
            add_lo = 2
            if end_line == ' */':
                end_add = 1
                end_line = response_code_lines[-4]
                end_line_up = response_code_lines[-3]
        else:
            end_line = response_code_lines[-2]
            end_line_up = response_code_lines[-3]
            add_lo = 1
            if end_line == ' */':
                end_add = 1
                end_line = response_code_lines[-3]
                end_line_up = response_code_lines[-4]
        end_type = 'in'
    else:
        end_line = response_code_lines[-1]
        end_line_up = response_code_lines[-2]
        end_type = 'out'
        add_lo = 0
        if end_line == ' */':
            end_add = 1
            end_line = response_code_lines[-2]
            end_line_up = response_code_lines[-3]
    
    if response_code_lines[0] == '/*':
        start_line = response_code_lines[1]
        start_add = -1
    else:
        start_line = response_code_lines[0]

    for idx, line in enumerate(code_new_lines):
        if lines_eq_start(line, start_line):
            start = idx
        else:
            if start != -1:
                if idx + add_lo < len(code_new_lines):
                    if lines_eq_end(line, end_line, code_new_lines[idx-1], end_line_up, end_type, code_new_lines[idx + add_lo]):
                        if end_type == 'in':
                            end = idx + add_lo
                            break
                        elif end_type == 'out':
                            end = idx
                            break

    if start != -1:
        if end != -1:
            response_code = '\n'.join(code_new_lines[start+start_add:end+1+end_add])
            return response_code, True, all_query_time
        else:
            print('Not Found 1')
            return response_code, False, all_query_time
    else:
        print('Not Found 2')
        return response_code, False, all_query_time


def add_extra_code(part_new, fp_part_old, code_new):
    fp_part_new = build_fp(part_new)
    part_old_funccall = set(collect_all_function_call_names(fp_part_old))
    part_new_funccall = set(collect_all_function_call_names(fp_part_new))
    new_funccall_list = list(part_new_funccall - part_old_funccall)
    add_context = find_context(new_funccall_list, code_new)
    if len(add_context) != 0:
        print('add extra code...')
        add_context_string = '\n'.join(add_context) 
        part_new = f'{add_context_string}\n\n{part_new}'
    
    return part_new


def more_align(part_new, part_old):
    part_new_lines = part_new.splitlines()
    part_old_lines = rm_none_line(part_old.splitlines())
    start = -1
    end = -1

    for idx, line in enumerate(part_new_lines):
        if lines_eq_start(line, part_old_lines[0]):
            start = idx
            break
    if start == -1:
        return part_new

    if part_old_lines[-1] == '}':
        for idx, line in enumerate(reversed(part_new_lines)):
            if line == '}':
                end = len(part_new_lines) - idx - 1
                break
    else:
        for idx, line in enumerate(reversed(part_new_lines)):
            if lines_eq_start(line, part_old_lines[-1]):
                end = len(part_new_lines) - idx - 1
                break
    if end == -1:
        return part_new
    
    _part_new = '\n'.join(part_new_lines[start:end+1])

    return _part_new


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='migratin with miggpt LLMs')
    parser.add_argument('--model', type=str, help='model', default='gpt-4')
    parser.add_argument('--data', type=str, help='data path', default='./data_16k')
    args = parser.parse_args()
    model = args.model
    data_path = args.data

    save_path = f'./results_miggpt/{model}'
    
    items = os.listdir(data_path)
    folders = [item for item in items if os.path.isdir(os.path.join(data_path, item))]
    folders.sort()

    for ii, item in enumerate(folders):
        print(f'================================={ii}/{len(folders)-1}=====================================')
        folder_path = os.path.join(data_path, item)
        save_folder_path = os.path.join(save_path, item)
        print(folder_path)
        if item[-2:] == '_h':
            sa = '.h'
        elif item[-2:] == '_c':
            sa = '.c'
        else:
            raise ValueError
        with open(f'{folder_path}/new{sa}', 'r') as f:
            code_new = f.read()

        folder_items = os.listdir(folder_path)
        parts = [part for part in folder_items if os.path.isdir(os.path.join(folder_path, part))]
        parts.sort()

        for jj, part in enumerate(parts):
            all_query_time = 0
            part_path = os.path.join(folder_path, part)
            save_part_path = os.path.join(save_folder_path, part)
            if not os.path.exists(save_part_path):
                os.makedirs(save_part_path)

            print('-----------------------------------')
            print(f'{jj}\{len(parts)-1}--{ii}: {part_path}')

            with open(f'{part_path}/part{sa}', 'r') as f:
                part_old = f.read()
            
            fp_part_old = build_fp(part_old)
            funcdef_list = collect_all_function_def_names(fp_part_old)
            part_anchor = anchor(fp_part_old)
            part_new, found, all_query_time = get_new_block(part_old=part_old, code_new=code_new, name_list=funcdef_list, 
                                                        range_anchor=part_anchor, model=model, all_query_time=all_query_time)
            
            qt = 0
            while(not found) and (qt<2):
               qt += 1
               part_new, found, all_query_time = get_new_block(part_old=part_old, code_new=code_new, name_list=funcdef_list, 
                                                        range_anchor=part_anchor, model=model, all_query_time=all_query_time)
            
            part_new = more_align(part_new, part_old)
            part_new = add_extra_code(part_new, fp_part_old, code_new)
            print(f'Found part_new{sa} and save in {part_path}')
            with open(f'{save_part_path}/part_new{sa}', 'w') as f:
                f.write(part_new)

            with open(f'{save_part_path}/query_time.txt', 'w') as f:
                f.write(f'qyery_time:{all_query_time}')