from openai import OpenAI
import os
import json
from prompts_20bn_v2 import wrap_prompt
var_list = ["A", "B", "C"]
indentation_list = ["-", "+"]
client = OpenAI()

re_num = "([0-9]+)\..*"
re_letter = "([a-z]+)\..*"
# re_cap_letter = "([A-Z]*)\..*"
import regex as re

def clean_des(action, description):
    current_var_id = 0
    description_ls = description.split(' ')
    new_description = []
    for description_token in description_ls:
        if description_token == "[something]":
            new_description.append(var_list[current_var_id])
            current_var_id += 1
        else: 
            new_description.append(description_token)
    action = action + f"({','.join(var_list[:current_var_id])})"
    return action, ' '.join(new_description)
    
def action2spec(actions, cache_path, batch_size = 5):

    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    all_action_prompt_ls = []
    action_prompt_ls = []
    for action, action_des in actions.items():
        
        if action in cache:
            continue
        else:
            action_prompt_ls.append((action, clean_des(action, action_des)))
        
        if len(action_prompt_ls) >= batch_size:
            all_action_prompt_ls.append(action_prompt_ls)
            action_prompt_ls = []
    
    if not len(action_prompt_ls) == 0:
        all_action_prompt_ls.append(action_prompt_ls)
        
    if len(all_action_prompt_ls) == 0:
        return
    
    
    for action_prompt_ls in all_action_prompt_ls:
        prompt = wrap_prompt(action_prompt_ls, few_shot=True)

        response = client.chat.completions.create(
        model="gpt-4-0125-preview",
        response_format={ "type": "json_object" },
        temperature=0,
        messages=[
            {"role": "system", "content": "You are a super user in logic programming. "},
            {"role": "user", "content": prompt}
        ]
        )
        
        action_responses = json.loads(response.choices[0].message.content)

        if type(action_responses) == list:
            for action_dict in action_responses:
                action = action_dict['action']
                cache[action] = action_dict

        elif type(action_responses) == dict:
            if len(action_responses) == batch_size:
                for action, res in action_responses.items():
                    cache[action] = res

            if len(action_responses) == 1:
                action_responses = list(action_responses.values())[0]
                assert type(action_responses) == list
                for action_dict in action_responses:
                    action = action_dict['action']
                    cache[action] = action_dict

        json.dump(cache, open(cache_path, 'w'))
        
    print('here')


if __name__ == "__main__":
    
    dataset = "20bn"
    cache_file_name = f"{dataset}_gpt_cache.json"
    action_file_name = f"template_mapping.json"
    batch_size = 10

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../../data/20bn/nl2spec"))
    assert (os.path.exists(data_dir))

    cache_path = os.path.join(data_dir, cache_file_name)
    action_path = os.path.join(data_dir, action_file_name)
    actions = json.load(open(action_path, 'r'))

    action2spec(actions, cache_path, batch_size)
    # cache = json.load(open(cache_path, 'r'))
    # for action, response in cache.items():
    #     process_response(response)
    print("end")