from openai import OpenAI
import os
import json
from prompts_openpvsg_v2 import wrap_prompt, user
var_list = ["A", "B", "C", "D", "E"]
indentation_list = ["-", "+"]
client = OpenAI()

re_num = "\([0-9a-zA-Z\,\- ]+\)"
re_letter = "([a-z]+)\..*"
# re_cap_letter = "([A-Z]*)\..*"
import regex as re

def clean_cap(caption):
    current_var_id = 0
    caption = caption.strip()
    description_ls = caption.split(' ')
    new_description = []
    to_ignore = re.findall(re_num, caption)
    new_cap = caption
    for tk in to_ignore:
        new_cap = new_cap.replace(tk, '')
    new_cap = new_cap.replace('  ', ' ')
    new_cap = new_cap.replace(' .', '.')
    new_cap = new_cap.replace(' ,', ',')
    
    return new_cap
    
def caption2spec(caption_ls, cache_path):
    
    caption_ct = len(caption_ls)
    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    prompt = wrap_prompt(caption_ls, few_shot=True)

    response = client.chat.completions.create(
    model="gpt-4-0125-preview",
    response_format={ "type": "json_object" },
    temperature=0,
    messages=[
        {"role": "system", "content": user},
        {"role": "user", "content": prompt}
    ]
    )
    
    action_responses = json.loads(response.choices[0].message.content)

    if type(action_responses) == list:
        for action_dict in action_responses:
            action = action_dict['caption']
            cache[action] = action_dict

    elif type(action_responses) == dict:
        if len(action_responses) == caption_ct:
            for action_id, res in action_responses.items():
                cache[res['caption']] = res

        if len(action_responses) == 1:
            action_responses = list(action_responses.values())[0]
            if type(action_responses) == list:
                for action_dict in action_responses:
                    action = action_dict['caption']
                    cache[action] = action_dict
                    
            elif type(action_responses) == dict:
                for action_dict in action_responses.values():
                    action = action_dict['caption']
                    cache[action] = action_dict

    json.dump(cache, open(cache_path, 'w'))
    return cache
        
def action2spec(data, cache_path, batch_size, max_num = 1000, processed_orig_captions=[]):
    
    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    all_captions = []
    action_prompt_ls = []
    all_action_prompt_ls = []
    # existing_cap = []
    # processed_cap = []
    
    for vid, data_info in data.items():  
        
        # caption_ls = list(data_info.values())
        
        if len(all_captions) >= max_num:
                break
            
        for orig_cap, caption in data_info.items():
            if (vid, orig_cap) in processed_orig_captions:
                # processed_cap.append(orig_cap)
                continue
            
            # assert len(caption) == 1
            # caption = caption[0]
            
            clean_des = clean_cap(caption)
            
            if len(all_captions) >= max_num:
                break
            
            if not clean_des in cache:
                action_prompt_ls.append(clean_des)
                all_captions.append(clean_des)
            # else:
            #     existing_cap.append(clean_des)
                
            if len(action_prompt_ls) >= batch_size:
                all_action_prompt_ls.append(action_prompt_ls)
                action_prompt_ls = []
            
    if not len(action_prompt_ls) == 0:
        all_action_prompt_ls.append(action_prompt_ls)
        
    if len(all_action_prompt_ls) == 0:
        return
    
    for action_prompt_ls in all_action_prompt_ls:
        prompt = wrap_prompt(action_prompt_ls, few_shot=True)

        response = client.chat.completions.create(
        model="gpt-4-0125-preview",
        response_format={ "type": "json_object" },
        temperature=0,
        messages=[
            {"role": "system", "content": user},
            {"role": "user", "content": prompt}
        ]
        )
        
        try:
            action_responses = json.loads(response.choices[0].message.content)
        except json.decoder.JSONDecodeError:
            print("json decode error, consider lowering the batch size")
            continue

        if type(action_responses) == list:
            for action_dict in action_responses:
                action = action_dict['caption']
                cache[action] = action_dict

        elif type(action_responses) == dict:
            if len(action_responses) == batch_size:
                for action_id, res in action_responses.items():
                    cache[res['caption']] = res

            if len(action_responses) == 1:
                action_responses = list(action_responses.values())[0]
                if type(action_responses) == list:
                    for action_dict in action_responses:
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        cache[action] = action_dict
                        
                elif type(action_responses) == dict:
                    for action_dict in action_responses.values():
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        cache[action] = action_dict

        json.dump(cache, open(cache_path, 'w'))
        
        # if len(cache) > 100:
        #     break
    print('here')
    
    
def orig_action2spec(data, cache_path, batch_size, max_num = 1000, processed_orig_captions=[]):
    
    if os.path.exists(cache_path):
        cache = json.load(open(cache_path, 'r'))
    else: 
        cache = {}

    all_captions = []
    action_prompt_ls = []
    all_action_prompt_ls = []
    # existing_cap = []
    # processed_cap = []
    
    for vid, data_info in data.items():  
        
        # caption_ls = list(data_info.values())
        
        if len(all_captions) >= max_num:
                break
            
        for orig_cap, caption in data_info.items():
            # if (vid, orig_cap) in processed_orig_captions:
                # processed_cap.append(orig_cap)
                # continue
            
            # assert len(caption) == 1
            # caption = caption[0]
            
            clean_des = clean_cap(orig_cap)
            
            if len(all_captions) >= max_num:
                break
            
            if not clean_des in cache:
                action_prompt_ls.append(clean_des)
                all_captions.append(clean_des)
            # else:
            #     existing_cap.append(clean_des)
                
            if len(action_prompt_ls) >= batch_size:
                all_action_prompt_ls.append(action_prompt_ls)
                action_prompt_ls = []
            
    if not len(action_prompt_ls) == 0:
        all_action_prompt_ls.append(action_prompt_ls)
        
    if len(all_action_prompt_ls) == 0:
        return
    
    for action_prompt_ls in all_action_prompt_ls:
        prompt = wrap_prompt(action_prompt_ls, few_shot=True)

        response = client.chat.completions.create(
        model="gpt-4-0125-preview",
        response_format={ "type": "json_object" },
        temperature=0,
        messages=[
            {"role": "system", "content": user},
            {"role": "user", "content": prompt}
        ]
        )
        
        try:
            action_responses = json.loads(response.choices[0].message.content)
        except json.decoder.JSONDecodeError:
            print("json decode error, consider lowering the batch size")
            continue

        if type(action_responses) == list:
            for action_dict in action_responses:
                action = action_dict['caption']
                cache[action] = action_dict

        elif type(action_responses) == dict:
            if len(action_responses) == batch_size:
                for action_id, res in action_responses.items():
                    cache[res['caption']] = res

            if len(action_responses) == 1:
                action_responses = list(action_responses.values())[0]
                if type(action_responses) == list:
                    for action_dict in action_responses:
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        cache[action] = action_dict
                        
                elif type(action_responses) == dict:
                    for action_dict in action_responses.values():
                        if not 'caption' in action_dict:
                            continue
                        action = action_dict['caption']
                        cache[action] = action_dict

        json.dump(cache, open(cache_path, 'w'))
        
        # if len(cache) > 100:
        #     break
    print('here')

        
if __name__ == "__main__":
    
    dataset = "activity_net"
    cache_file_name = f"{dataset}_orig_tmp_v2_gpt4_cache.json"
    # data_file_name = 'pvsg.json'
    data_file_name = 'videollamav2_caption_30000.json'
    processed_file_names = f"videollamav2_caption_4000.json"

    batch_size = 10
    max_num = 10000

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    processed_file_path = os.path.join(data_dir, processed_file_names)

    with open(data_path, 'r') as f:
        data = json.load(open(data_path, 'r'))

    processed_orig_captions = []
    if os.path.exists(processed_file_path):
        orig2gpt = json.load(open(processed_file_path, 'r'))
        for vid, caption_mapping in orig2gpt.items():
            [processed_orig_captions.append((vid, k)) for k in caption_mapping.keys()]
            
    # See video id in anno['split'].
    # data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    # data = json.load(open(data_path, 'r'))
    # Each video has info of video_id, meta, objects, relations, captions, qa_pairs, and summary.
    
    # action2spec(data, cache_path, batch_size, max_num=max_num, processed_orig_captions=processed_orig_captions)
    orig_action2spec(data, cache_path, batch_size, max_num=max_num, processed_orig_captions=processed_orig_captions)

    # cache = json.load(open(cache_path, 'r'))
    # for action, response in cache.items():
    #     process_response(response)
    print("end")