import os
import argparse
from utils import restriction_prompt, query, get_code_with_markdown


def get_new_block(part_old, code_new, model='gpt-4'):
    if model == 'gpt-4':
        mm = 'gpt-4-turbo'
    elif model == 'gpt-3.5':
        mm = 'gpt-3.5-turbo-16k'
    else:
        raise ValueError
    # query LLM
    task_prompt = 'We are facing a challenge that requires your specialized knowledge and expertise. We need to locate a corresponding segment of code, indicated as `part_new`, within a C file named `new.c` that matches semantically with a provided code snippet labeled as `part_old`. Given that `part_new`, the target code segment, originates from modifications made to `part_old`, it is essential to identify this correspondence accurately.\n'
    prompt_part_old = f'The starting point for your task involves comparing the following `part_old`: \n```\n{part_old}\n```\n'
    prompt_code_new = f'And the entire context available in the `new.c`: \n```\n{code_new}\n```\n'
    
    prompt = task_prompt + prompt_part_old + prompt_code_new + restriction_prompt
    if model in ['gpt-4', 'gpt-3.5']:
        response = query(user_content=prompt, query_model=mm, temperature=0.5)
    else:
        raise ValueError
    response_code = get_code_with_markdown(response)
    
    return response_code


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='migratin with vanilla LLMs')
    parser.add_argument('--model', type=str, help='model', default='gpt-4')
    parser.add_argument('--data', type=str, help='data path', default='./data_16k')
    args = parser.parse_args()
    model = args.model
    data_path = args.data

    save_path = f'./results_gpt/{model}'
    
    items = os.listdir(data_path)
    folders = [item for item in items if os.path.isdir(os.path.join(data_path, item))]
    folders.sort()

    for ii, item in enumerate(folders):
        print(f'================================={ii}/{len(folders)-1}=====================================')
        folder_path = os.path.join(data_path, item)
        save_folder_path = os.path.join(save_path, item)
        print(folder_path)
        if item[-2:] == '_h':
            sa = '.h'
        elif item[-2:] == '_c':
            sa = '.c'
        else:
            raise ValueError
        with open(f'{folder_path}/new{sa}', 'r') as f:
            code_new = f.read()

        folder_items = os.listdir(folder_path)
        parts = [part for part in folder_items if os.path.isdir(os.path.join(folder_path, part))]
        parts.sort()

        for jj, part in enumerate(parts):
            part_path = os.path.join(folder_path, part)
            save_part_path = os.path.join(save_folder_path, part)
            if not os.path.exists(save_part_path):
                os.makedirs(save_part_path)

            print('-----------------------------------')
            print(f'{jj}\{len(parts)-1}--{ii}: {part_path}')

            with open(f'{part_path}/part{sa}', 'r') as f:
                part_old = f.read()
            
            part_new = get_new_block(part_old=part_old, code_new=code_new, model=model)
            if part_new is None:
                continue
            
            print(f'Found part_new{sa} and save in {save_part_path}')
            with open(f'{save_part_path}/part_new{sa}', 'w') as f:
                f.write(part_new)
                
