import os, argparse
from utils import restriction_prompt, query, get_code_with_markdown
from fp_utils import build_fp, get_diff_part


def migration(part_old, part_old_patched, part_new, context, model):
    if model == 'gpt-4':
        mm = 'gpt-4-turbo'
    elif model == 'gpt-3.5':
        mm = 'gpt-3.5-turbo-16k'
    else:
        raise ValueError
    task_prompt = 'I am reaching out to you with a specialized code migration task where your expertise in Linux kernel development would be invaluable. Your assistance will help ensure the successful adaptation of existing code to the latest version of the Linux kernel. For this task, I will provide three code snippets for your consideration:\n'
    code_old_prompt = f'Code Snippet 1: The old version of the Linux kernel code snippet, which we will refer to as `part_old`.\n```\n{part_old}\n```\n'
    code_old_patched_prompt = f'Code Snippet 2: The corresponding code developed based on the old version of the Linux kernel code snippet `part_old`, referred to as `part_old_patched`.\n```\n{part_old_patched}\n```\n'
    code_new_prompt = f'Code Snippet 3: The new version of the Linux kernel code snippet, denoted as `part_new`.\n```\n{part_new}\n```\n'
    analysis_prompt = f'Upon preliminary analysis, it appears that there is {len(context)} specific area within `part_old_patched` that requires modification:\n'
    for idx, (diff_string, top_string, bottom_string) in enumerate(context):
        if top_string != '':
            top_mod = f'situated after the line containing \n```\n{top_string}\n```\n'
        else:
            top_mod = ''

        if bottom_string != '':
            if top_mod != '':
                bottom_mod = f', and before the line containing \n```\n{bottom_string}\n```\n'
            else:
                bottom_mod = f'before the line containing \n```\n{bottom_string}\n```\n'
        else:
            bottom_mod = ''
        
        analysis_prompt = analysis_prompt + f'The modification {idx+1} should be made {top_mod}{bottom_mod} with the change being \n```\n{diff_string}\n```\n'
    
    query_prompt = 'It\'s likely that similar adjustments will need to be made within `part_new` to maintain functionality and compatibility. Given your extensive knowledge and experience in this field, could you kindly assist by generating the corresponding code snippet `part_new_patched` developed on `part_new`?'
    prompt = task_prompt + code_old_prompt + code_old_patched_prompt + code_new_prompt + analysis_prompt + query_prompt + restriction_prompt
    response = query(user_content=prompt, query_model=mm, temperature=0.4)
    response_code = get_code_with_markdown(response)

    return response_code


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='migratin with miggpt LLMs')
    parser.add_argument('--model', type=str, help='model', default='gpt-4')
    parser.add_argument('--data', type=str, help='data path', default='./data_16k')
    args = parser.parse_args()
    model = args.model
    data_path = args.data
    last_step_results = f'./results_miggpt/{model}'

    items = os.listdir(data_path)
    folders = [item for item in items if os.path.isdir(os.path.join(data_path, item))]
    folders.sort()

    for ii, item in enumerate(folders):
        print(f'================================={ii}/{len(folders)-1}=====================================')
        folder_path = os.path.join(data_path, item)
        results_folder_path = os.path.join(last_step_results, item)
        print(folder_path)
        if item[-2:] == '_h':
            sa = '.h'
        elif item[-2:] == '_c':
            sa = '.c'
        else:
            raise ValueError

        folder_items = os.listdir(folder_path)
        parts = [part for part in folder_items if os.path.isdir(os.path.join(folder_path, part))]
        parts.sort()

        for jj, part in enumerate(parts):
            part_path = os.path.join(folder_path, part)
            results_part_path = os.path.join(results_folder_path, part)
            print('-----------------------------------')
            print(f'{jj}\{len(parts)-1}----{ii}: {part_path}')

            with open(f'{part_path}/part{sa}', 'r') as f:
                part_old = f.read()
            with open(f'{part_path}/part_patched{sa}', 'r') as f:
                part_old_patched = f.read()
            
            if not os.path.exists(f'{results_part_path}/part_new{sa}'):
                continue
            else:
                with open(f'{results_part_path}/part_new{sa}', 'r') as f:
                    part_new = f.read()

            fp_part_old = build_fp(part_old)
            fp_part_old_patched = build_fp(part_old_patched)
            mod_context = get_diff_part(fp_part_old, fp_part_old_patched, part_old_patched)
            part_new_patched = migration(part_old, part_old_patched, part_new, mod_context, model=model)

            print(f'Generated part_new_patched{sa} and save in {results_part_path}')
            with open(f'{results_part_path}/part_new_patched{sa}', 'w') as f:
                f.write(part_new_patched)




