from openai import OpenAI
import os
import json
from prompts_openpvsg_v2 import wrap_prompt, user
var_list = ["A", "B", "C", "D", "E"]
indentation_list = ["-", "+"]
client = OpenAI()

re_num = "\([0-9a-zA-Z\,\- ]+\)"
re_letter = "([a-z]+)\..*"
# re_cap_letter = "([A-Z]*)\..*"
import regex as re

def clean_cap(caption):
    current_var_id = 0
    description_ls = caption.split(' ')
    new_description = []
    to_ignore = re.findall(re_num, caption)
    new_cap = caption
    for tk in to_ignore:
        new_cap = new_cap.replace(tk, '')
    new_cap = new_cap.replace('  ', ' ')
    new_cap = new_cap.replace(' .', '.')
    new_cap = new_cap.replace(' ,', ',')
    new_cap = new_cap.strip()
    
    return new_cap
    
def caption2spec(caption_ls, cache_path):
    
    caption_ct = len(caption_ls)
    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    prompt = wrap_prompt(caption_ls, few_shot=True)

    response = client.chat.completions.create(
    model="gpt-4-0125-preview",
    response_format={ "type": "json_object" },
    temperature=0,
    messages=[
        {"role": "system", "content": user},
        {"role": "user", "content": prompt}
    ]
    )
    
    action_responses = json.loads(response.choices[0].message.content)

    if type(action_responses) == list:
        for action_dict in action_responses:
            action = action_dict['caption']
            cache[action] = action_dict

    elif type(action_responses) == dict:
        if len(action_responses) == caption_ct:
            for action_id, res in action_responses.items():
                cache[res['caption']] = res

        if len(action_responses) == 1:
            action_responses = list(action_responses.values())[0]
            if type(action_responses) == list:
                for action_dict in action_responses:
                    action = action_dict['caption']
                    cache[action] = action_dict
                    
            elif type(action_responses) == dict:
                for action_dict in action_responses.values():
                    action = action_dict['caption']
                    cache[action] = action_dict

    json.dump(cache, open(cache_path, 'w'))
    return cache
        
def action2spec(data, cache_path, batch_size, processed_captions=[]):
    
    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    action_prompt_ls = []
    all_action_prompt_ls = []
    gen2orig = {} 
    
    for vid, data_info in data.items():  
        for caption_orig, caption_vllama in data_info.items():
            clean_des = clean_cap(caption_vllama)
            clean_des_orig = clean_cap(caption_orig)
            gen2orig[clean_des] = clean_des_orig
            
            if not clean_des in cache and clean_des in processed_captions:
                action_prompt_ls.append(clean_des)
            
            if len(action_prompt_ls) >= batch_size:
                all_action_prompt_ls.append(action_prompt_ls)
                action_prompt_ls = []
            
    if not len(action_prompt_ls) == 0:
        all_action_prompt_ls.append(action_prompt_ls)
        
    if len(all_action_prompt_ls) == 0:
        return
    
    for action_prompt_ls in all_action_prompt_ls:
        gen_cap_ls = []
        
        for clean_des in action_prompt_ls:
            gen_cap_ls.append(clean_des)
            
        prompt = wrap_prompt(gen_cap_ls, few_shot=True)

        response = client.chat.completions.create(
        model="gpt-4-0125-preview",
        response_format={ "type": "json_object" },
        temperature=0,
        messages=[
            {"role": "system", "content": user},
            {"role": "user", "content": prompt}
        ]
        )
        
        action_responses = json.loads(response.choices[0].message.content)

        if type(action_responses) == list:
            for action_dict in action_responses:
                action = action_dict['caption']
                # orig_action = gen2orig[action]
                cache[action] = action_dict

        elif type(action_responses) == dict:
            if len(action_responses) == batch_size:
                for action_id, res in action_responses.items():
                    action = res['caption']
                    # orig_action = gen2orig[action]
                    cache[action] = res

            if len(action_responses) == 1:
                action_responses = list(action_responses.values())[0]
                if type(action_responses) == list:
                    for action_dict in action_responses:
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        # orig_action = gen2orig[action]
                        cache[action] = action_dict
                        
                elif type(action_responses) == dict:
                    for action_dict in action_responses.values():
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        # orig_action = gen2orig[action]
                        cache[action] = action_dict

        json.dump(cache, open(cache_path, 'w'))
        
    print('here')

        
if __name__ == "__main__":
    
    dataset = "open_pvsg"
    # cache_file_name = f"{dataset}_vllamav2_simple_gpt4_cache.json"
    # data_file_name = 'videollamav2_simple_caption.json'
    cache_file_name = f"{dataset}_vllamav2_gpt4_cache.json"
    data_file_name = 'videollamav2_caption_10000.json'
    processed_file_names = f"orig2gpt.json"
    batch_size = 8

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    processed_file_path = os.path.join(data_dir, processed_file_names)
    
    with open(data_path, 'r') as f:
        data = json.load(f)

    processed_captions = set()
    if os.path.exists(processed_file_path):
        orig2gpt = json.load(open(processed_file_path, 'r'))
        for vid, caption_mapping in orig2gpt:
            processed_captions.add(caption_mapping.values())
        
    # Each video has info of video_id, meta, objects, relations, captions, qa_pairs, and summary.
    action2spec(data, cache_path, batch_size, processed_file_names, processed_captions=processed_captions)
    # cache = json.load(open(cache_path, 'r'))
    # for action, response in cache.items():
    #     process_response(response)
    print("end")