from openai import OpenAI
import os
import json
from prompts_openpvsg_v2 import wrap_prompt, user
var_list = ["A", "B", "C", "D", "E"]
indentation_list = ["-", "+"]
client = OpenAI()

re_num = "\([0-9a-zA-Z\,\- ]+\)"
re_letter = "([a-z]+)\..*"
# re_cap_letter = "([A-Z]*)\..*"
import regex as re

def clean_cap(caption):
    current_var_id = 0
    caption = caption.strip()
    description_ls = caption.split(' ')
    new_description = []
    to_ignore = re.findall(re_num, caption)
    new_cap = caption
    for tk in to_ignore:
        new_cap = new_cap.replace(tk, '')
    new_cap = new_cap.replace('  ', ' ')
    new_cap = new_cap.replace(' .', '.')
    new_cap = new_cap.replace(' ,', ',')
    
    return new_cap
    
def caption2spec(caption_ls, cache_path):
    
    caption_ct = len(caption_ls)
    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    prompt = wrap_prompt(caption_ls, few_shot=True)

    response = client.chat.completions.create(
    model="gpt-4-0125-preview",
    response_format={ "type": "json_object" },
    temperature=0,
    messages=[
        {"role": "system", "content": user},
        {"role": "user", "content": prompt}
    ]
    )
    
    action_responses = json.loads(response.choices[0].message.content)

    if type(action_responses) == list:
        for action_dict in action_responses:
            action = action_dict['caption']
            cache[action] = action_dict

    elif type(action_responses) == dict:
        if len(action_responses) == caption_ct:
            for action_id, res in action_responses.items():
                cache[res['caption']] = res

        if len(action_responses) == 1:
            action_responses = list(action_responses.values())[0]
            if type(action_responses) == list:
                for action_dict in action_responses:
                    action = action_dict['caption']
                    cache[action] = action_dict
                    
            elif type(action_responses) == dict:
                for action_dict in action_responses.values():
                    action = action_dict['caption']
                    cache[action] = action_dict

    json.dump(cache, open(cache_path, 'w'))
    return cache
        
def action2spec(data, cache_path, batch_size, max_num = 1000):
    
    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    all_captions = {}
    action_prompt_ls = []
    all_action_prompt_ls = []
    
    for vid, data_info in data.items():  
        caption_ls = data_info['sentences']
        time_stamps = data_info['timestamps']
        
        if len(all_captions) >= max_num:
                break
            
        for caption, time_stamp in zip(caption_ls, time_stamps):
            clean_des = clean_cap(caption)
            all_captions[clean_des] = (vid, time_stamp)
            
            if len(all_captions) >= max_num:
                break
            
            if not clean_des in cache:
                action_prompt_ls.append(clean_des)
            
            if len(action_prompt_ls) >= batch_size:
                all_action_prompt_ls.append(action_prompt_ls)
                action_prompt_ls = []
            
    if not len(action_prompt_ls) == 0:
        all_action_prompt_ls.append(action_prompt_ls)
        
    if len(all_action_prompt_ls) == 0:
        return
    
    for action_prompt_ls in all_action_prompt_ls:
        prompt = wrap_prompt(action_prompt_ls, few_shot=True)

        response = client.chat.completions.create(
        model="gpt-4-0125-preview",
        response_format={ "type": "json_object" },
        temperature=0,
        messages=[
            {"role": "system", "content": user},
            {"role": "user", "content": prompt}
        ]
        )
        
        action_responses = json.loads(response.choices[0].message.content)

        if type(action_responses) == list:
            for action_dict in action_responses:
                action = action_dict['caption']
                cache[action] = action_dict

        elif type(action_responses) == dict:
            if len(action_responses) == batch_size:
                for action_id, res in action_responses.items():
                    cache[res['caption']] = res

            if len(action_responses) == 1:
                action_responses = list(action_responses.values())[0]
                if type(action_responses) == list:
                    for action_dict in action_responses:
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        cache[action] = action_dict
                        
                elif type(action_responses) == dict:
                    for action_dict in action_responses.values():
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        cache[action] = action_dict

        json.dump(cache, open(cache_path, 'w'))
        # break
        
    print('here')

        
if __name__ == "__main__":
    
    dataset = "activity_net"
    cache_file_name = f"{dataset}_tmp_v2_gpt4_cache.json"
    # data_file_name = 'pvsg.json'
    data_file_name = 'temp.json'
    batch_size = 10

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    
    with open(data_path, 'r') as f:
        data = json.load(open(data_path, 'r'))

    # See video id in anno['split'].
    # data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    # data = json.load(open(data_path, 'r'))
    # Each video has info of video_id, meta, objects, relations, captions, qa_pairs, and summary.
    
    action2spec(data, cache_path, batch_size)
    # cache = json.load(open(cache_path, 'r'))
    # for action, response in cache.items():
    #     process_response(response)
    print("end")