import os, time
import argparse
from utils import restriction_prompt, query, get_code_with_markdown


def migration(part_old, part_old_patched, part_new, model):
    if model == 'gpt-4':
        mm = 'gpt-4-turbo'
    elif model == 'gpt-3.5':
        mm = 'gpt-3.5-turbo-16k'
    else:
        raise ValueError
    task_prompt = 'I am reaching out to you with a specialized code migration task where your expertise in Linux kernel development would be invaluable. Your assistance will help ensure the successful adaptation of existing code to the latest version of the Linux kernel. For this task, I will provide three code snippets for your consideration:\n'
    code_old_prompt = f'Code Snippet 1: The old version of the Linux kernel code snippet, which we will refer to as `part_old`.\n```\n{part_old}\n```\n'
    code_old_patched_prompt = f'Code Snippet 2: The corresponding code developed based on the old version of the Linux kernel code snippet `part_old`, referred to as `part_old_patched`.\n```\n{part_old_patched}\n```\n'
    code_new_prompt = f'Code Snippet 3: The new version of the Linux kernel code snippet, denoted as `part_new`.\n```\n{part_new}\n```\n'
    prompt = task_prompt + code_old_prompt + code_old_patched_prompt + code_new_prompt + restriction_prompt
    if model in ['gpt-4', 'gpt-3.5']:
        response = query(user_content=prompt, query_model=mm, temperature=0.4)
    else:
        raise ValueError
    response_code = get_code_with_markdown(response)

    return response_code


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='migratin with vanilla LLMs')
    parser.add_argument('--model', type=str, help='model', default='gpt-4')
    parser.add_argument('--data', type=str, help='data path', default='./data_16k')
    args = parser.parse_args()
    model = args.model
    data_path = args.data
    last_step_results = f'./results_vanilla/{model}'
    
    items = os.listdir(data_path)
    folders = [item for item in items if os.path.isdir(os.path.join(data_path, item))]
    folders.sort()

    for ii, item in enumerate(folders):
        print(f'================================={ii}/{len(folders)-1}=====================================')
        folder_path = os.path.join(data_path, item)
        results_folder_path = os.path.join(last_step_results, item)
        print(folder_path)
        if item[-2:] == '_h':
            sa = '.h'
        elif item[-2:] == '_c':
            sa = '.c'
        else:
            raise ValueError

        folder_items = os.listdir(folder_path)
        parts = [part for part in folder_items if os.path.isdir(os.path.join(folder_path, part))]
        parts.sort()

        for jj, part in enumerate(parts):
            part_path = os.path.join(folder_path, part)
            results_part_path = os.path.join(results_folder_path, part)
            print('-----------------------------------')
            print(f'{jj}\{len(parts)-1}----{ii}: {part_path}')

            with open(f'{part_path}/part{sa}', 'r') as f:
                part_old = f.read()
            with open(f'{part_path}/part_patched{sa}', 'r') as f:
                part_old_patched = f.read()
            
            if not os.path.exists(f'{results_part_path}/part_new{sa}'):
                continue
            else:
                with open(f'{results_part_path}/part_new{sa}', 'r') as f:
                    part_new = f.read()
            part_new_patched = migration(part_old, part_old_patched, part_new, model=model)

            if part_new_patched is None:
                continue

            print(f'Generated part_new_patched{sa} and save in {results_part_path}')
            with open(f'{results_part_path}/part_new_patched{sa}', 'w') as f:
                f.write(part_new_patched)
