import json, os, random, argparse, time, re, itertools
from openai import OpenAI
from tqdm import tqdm
from prompt_templates import *
import nltk
from nltk import pos_tag, word_tokenize, RegexpParser
from nltk.corpus import stopwords


client = OpenAI(api_key='<KEY>')
client.api_key = os.environ.get('OPENAI_API_KEY')
client.base_url = 'api_url'

task_instructs = [
    "Based on the following captions describing keyframes of a video, answer the next question.",
    "Given the video captions listed below, respond to the upcoming question.",
    "Taking into account the sequence of captions of a video, please answer the following question.",
    "Referencing the video captions that are shown below, please address the next query.",
    "After reviewing the captions from the video in sequential order, please proceed to answer the subsequent question.",
    "Using the video captions displayed in a sequential manner as reference, respond to the forthcoming question.",
    "With the video captions presented in their chronological sequence below, address the upcoming question.",
    "Based on the sequence of captions for a video, answer the question that follows.",
    "Referring to the video captions listed in sequential order below, please reply to the subsequent question."
]

answer_prompts = {
    "multi-choice": [
        "Only give the best option:\n",
        "Provide only the top choice:\n",
        "Offer just the best selection:\n",
        "Deliver only the best choice:\n",
        "Return just the best option:\n",
        "Show only the top selection:\n"
    ],
}

def get_nouns(sentence):
    words = word_tokenize(sentence)
    tagged_words = pos_tag(words)
    nouns = [word for word, pos in tagged_words if pos in ['NN', 'NNS', 'NNP', 'NNPS'] if word.lower() not in ['image', 'images']]
    return set(nouns)

def check_common_nouns(nouns1, nouns2):
    common_noun_phrases = nouns1.intersection(nouns2)
    return common_noun_phrases

def extract_noun_phrases(sentence, random_sample=True):
    # Tokenize the sentence
    words = word_tokenize(sentence)
    
    # POS tagging
    tagged_words = pos_tag(words)
    
    # Define a chunk grammar for noun phrases
    grammar = "NP: {<DT>?<JJ>*<NN>+}"
    
    # Create a chunk parser
    cp = RegexpParser(grammar)
    
    # Parse the sentence
    tree = cp.parse(tagged_words)
    
    # Extract noun phrases
    noun_phrases = []
    for subtree in tree.subtrees():
        if subtree.label() == 'NP':
            np = " ".join(word for word, tag in subtree.leaves())
            noun_phrases.append(np)
    noun_phrases = [np for np in noun_phrases if (not 'image' in np) and (not 'close' in np) and len(np.split())>1]
    if not noun_phrases:
        return False
    if random_sample:
        return random.choice(noun_phrases)
    else:
        return noun_phrases

def noun_phrase_to_sentence(noun_phrase):
    templates = ["[np] is shown.", "we can see [np].", "[np] appears.", "[np] is visible.", "[np] is present."]
    template = random.choice(templates)
    return template.replace('[np]', noun_phrase)

def load_json(json_file):
    with open(json_file, 'r') as f:
        datas = json.load(f)
    return datas

def save_json(datas, json_file):
    with open(json_file, 'w') as f:
        datas = json.dump(datas, f, indent=4)

def combine_captions(args, captions, combined_cap_path):
    ids = list(range(len(captions)))
    combined_captions = {}
    combined_ids = {i: [] for i in [2, 3, 4, 5, 6]}
    for num_cap in combined_ids:
        sample = set(random.sample(ids, k=num_cap))
        while len(combined_ids[num_cap])<args.num_combine:
            sample = set(random.sample(ids, k=num_cap))
            if sample not in combined_ids[num_cap]:
                combined_ids[num_cap].append(sample)
        combined_captions.update({'_'.join([str(id) for id in ids]): [captions[id] for id in ids] for ids in combined_ids[num_cap]})
    combined_captions_first_sent = {id: [cap.split('.')[0]+'.' for cap in caps] for id, caps in combined_captions.items()}
    save_json(combined_captions, combined_cap_path)
    save_json(combined_captions_first_sent, combined_cap_path.replace('.json', '')+'-first-sent.json')
    return combined_captions

def get_llm_output(prompt, model='gpt-4-turbo-2024-04-09', max_tokens=1024, temperature=0.7, response_format='json_object'):

    completion = client.chat.completions.create(
        model=model,
        # top_p=0.1,
        messages=[
            {
                "role": "system",
                "content": "You are an AI assistant for question generation."
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        temperature=temperature,
        max_tokens=max_tokens,
        response_format={"type": response_format}
    )
    if response_format=='json_object':
        response = json.loads(completion.choices[0].message.content)
    else:
        response = completion.choices[0].message.content
    return response

def generate_order_qa(all_captions, combined_cap_path):
    """
        Generate order-related text qa from captions, using gpt-4
    """
    output_path = combined_cap_path.replace('.json', '')+'-qa.json'
    if os.path.exists(output_path):
        qa_datas = load_json(output_path)
    else:
        qa_datas = {}
    if len(qa_datas)>=len(all_captions):
        return qa_datas
    # shuffle the caption examples
    all_captions = list(all_captions.items())
    random.shuffle(all_captions)
    all_captions = dict(all_captions)
    for id, captions in tqdm(all_captions.items()):
        if id in qa_datas:
            continue
        qa_data = {}
        qa_data['captions'] = captions
        prompt = ""
        for i in range(len(captions)):
            prompt += f"{i+1}. {captions[i]}\n"
        prompt = generate_qa_prompt_order + prompt + '\nQuestions and Answers:\n'
        qa_data.update(get_llm_output(prompt))
        qa_datas[id] = qa_data
        save_json(qa_datas, output_path)
        time.sleep(3)
    return qa_datas

def generate_attribute_qa(all_captions, combined_cap_path, maxtry=10):
    """
        Generate attribute-related text qa from captions, using gpt-4
    """
    output_path = combined_cap_path.replace('.json', '')+'-qa.json'
    if os.path.exists(output_path):
        qa_datas = load_json(output_path)
    else:
        qa_datas = {}
    
    for id, captions in tqdm(all_captions.items()):
        if id in qa_datas:
            continue

        # shuffle the order of captions to avoid bias
        for dim in captions:
            random.shuffle(captions[dim])

        qa_data = {}
        qa_data['captions'] = captions
        prompt = ""
        for i, cap_pair in enumerate(captions.values()):
            prompt += f"Caption Pair {i+1}:\n1.{cap_pair[0]}\n2.{cap_pair[1]}\n\n"
        prompt = generate_qa_prompt_attribute + prompt + '\nQuestions and Answers:\n'
        while True:
            try:
                qa_data.update(get_llm_output(prompt))
                qa_datas[id] = qa_data
                save_json(qa_datas, output_path)
                break
            except:
                if maxtry<=0:
                    qa_datas[id] = None
                    save_json(qa_datas, output_path)
                    break
                maxtry -= 1
                print(f"Not success! {maxtry} retries remaining...")
                time.sleep(10)
    return qa_datas

def generate_refer_qa(all_captions, output_path, maxtry=10):
    """
        Generate temporal-referring text qa from captions, using gpt-4
    """
    if os.path.exists(output_path):
        qa_datas = load_json(output_path)
    else:
        qa_datas = {}
    
    for id, captions in tqdm(all_captions.items()):
        if id in qa_datas:
            continue

        # shuffle the order of captions to avoid bias
        for dim in captions:
            random.shuffle(captions[dim])

        qa_data = {}
        qa_data['captions'] = captions
        prompt = ""
        
        qtype = 'negation' if random.random()<0.3 else 'normal'
        if qtype=='negation':
            for i, cap_pair in enumerate(captions.values()):
                prefix = cap_pair[0].split(' a ')[0] + ' a '
                sufixes = [' a '.join(cap.split(' a ')[1:]).strip('.') for cap in cap_pair]
                prompt += f"###Caption Group {i+1}:\n"
                for j, cap in enumerate(cap_pair):
                    cap_pair[j] = f"{prefix}{sufixes[j]} and a {sufixes[(j+1)%len(sufixes)]}"
                    prompt += f"Caption {j+1}. {cap_pair[j]}\n"
                prompt += "\n"
            prompt = generate_qa_prompt_refer_negation + prompt + '\nNegation Questions and Answers:\n'
        else:
            for i, cap_pair in enumerate(captions.values()):
                prompt += f"###Caption Group {i+1}:\nCaption 1.{cap_pair[0]}\nCaption 2.{cap_pair[1]}\nCaption 3.{cap_pair[2]}\n\n"
            prompt = generate_qa_prompt_refer + prompt + '\nQuestions and Answers:\n'
        while True:
            try:
                qa_data.update(get_llm_output(prompt))
                qa_datas[id] = qa_data
                save_json(qa_datas, output_path)
                break
            except:
                if maxtry<=0:
                    qa_datas[id] = None
                    save_json(qa_datas, output_path)
                    break
                maxtry -= 1
                print(f"Not success! {maxtry} retries remaining...")
                time.sleep(10)
    return qa_datas

def generate_attribute_captions(captions, output_path, maxtry=10):
    """
        Generate attribute-related captions, using gpt-4
    """
    if os.path.exists(output_path):
        combined_captions = load_json(output_path)
    else:
        combined_captions = {}
    if len(combined_captions)==len(captions):
        return combined_captions
    
    for id, caption in tqdm(captions.items()):
        if id in combined_captions:
            continue
        while True:
            try:
                prompt = generate_caption_prompt_attribute + caption + '\nEnriched Captions:\n'
                combined_captions[id] = get_llm_output(prompt, temperature=0.5)['captions']
                save_json(combined_captions, output_path)
                # time.sleep(3)
                break
            except:
                if maxtry<=0:
                    combined_captions[id] = None
                    break
                maxtry -= 1
                print(f"Not success! {maxtry} retries remaining...")
                time.sleep(10)
    return combined_captions

def generate_counterfactual_captions(captions, output_path, maxtry=10):
    """
        Generate counterfactual captions by modifying the original caption, using gpt-4
    """
    if os.path.exists(output_path):
        combined_captions = load_json(output_path)
    else:
        combined_captions = {}
    if len(combined_captions)==len(captions):
        return combined_captions
    
    for id, caption in tqdm(captions.items()):
        if id in combined_captions:
            continue
        while True:
            try:
                prompt = generate_caption_prompt_counterfactual + caption + '\nModified Captions:\n'
                combined_captions[id] = get_llm_output(prompt, temperature=0.5)['captions']
                save_json(combined_captions, output_path)
                time.sleep(2)
                break
            except:
                if maxtry<=0:
                    combined_captions[id] = None
                    break
                maxtry -= 1
                print(f"Not success! {maxtry} retries remaining...")
                time.sleep(10)
    return combined_captions
    
capid_prefix = [
    [f"{i}. " for i in range(1, 200)],
    [f"Caption {i}: " for i in range(1, 200)],
    [f"Keyframe {i}: " for i in range(1, 200)],
]
def process_order_qa_dataset(qa_datas, caption_pool, output_path, aspect, num_val=500, num_distract=100):
    """
        Convert generated order text qa to unified format
    """
    train_qas, train_qas_long = [], []  # train_qas: qa data in the training format
    for id, qa_data in tqdm(qa_datas.items()):
        caption_str = '\n'.join(qa_data['captions'])
        long_caption_str = sample_distract_captions(get_nouns(' '.join(qa_data['captions'])), qa_data['captions'], caption_pool, random.randint(num_distract-50, num_distract+50))
        for qa in qa_datas[id]['qas']:
            random.shuffle(qa['options'])
            opt_str = ""
            for opt, char in zip(qa['options'], ['(A) ', '(B) ', '(C) ', '(D) ', '(E) ', '(F) '][:len(qa['options'])]):
                if qa['answer']==opt:
                    mc_a = char + opt
                opt_str += char + opt + '\n'
            mc_q = f"{qa['question']}\n{opt_str}"
            instruct = random.choice(task_instructs)
            ans_prompt_mc = random.choice(answer_prompts['multi-choice'])
            train_qas.append([{"from": "human", "value": f"{instruct}\n\nCaptions:\n{caption_str}\nQuestion: {mc_q}\n{ans_prompt_mc}"}, {"from": "gpt", "value": mc_a}])
            if any([str_ in qa['options'][0] for str_ in [',', 'then', '->', 'follow']]):
                train_qas_long.append([{"from": "human", "value": f"{instruct}\n\nCaptions:\n{long_caption_str}\nQuestion: {mc_q}\n{ans_prompt_mc}"}, {"from": "gpt", "value": mc_a}])
    random.shuffle(train_qas)
    random.shuffle(train_qas_long)
    train_qas, val_qas = train_qas[:-num_val], train_qas[-num_val:]
    train_qas_long, val_qas_long = train_qas_long[:-num_val], train_qas_long[-num_val:]
    print(len(train_qas), len(val_qas), len(train_qas_long), len(val_qas_long))
    save_json(train_qas_long, f"{output_path}/temp_qa_{aspect}_train_long_4x.json")
    save_json(train_qas_long[:15000], f"{output_path}/temp_qa_{aspect}_train_long_4x_15k.json")
    save_json(val_qas_long, f"{output_path}/temp_qa_{aspect}_val_long_4x.json")
    return train_qas

def sample_distract_captions(question_nouns, captions, caption_pool, num_caps, insert_style='random', add_anchor=False):
    distract_captions = []
    if add_anchor:
        identifier_phrase_count = {}
        while True:
            identifier_cap = random.choice(list(caption_pool.values()))
            identifier_phrase = extract_noun_phrases(identifier_cap['caption'].lower())
            if (not check_common_nouns(question_nouns, identifier_cap['nouns'])) and identifier_phrase:
                question_nouns = question_nouns.union(identifier_cap['nouns'])     # ensure that the identifier phrae does not appear in the distractor captions
                identifier_sent = noun_phrase_to_sentence(identifier_phrase)
                break
    for id, cap in caption_pool.items():
        if len(distract_captions)>=num_caps:
            break
        common_noun_phrases = check_common_nouns(question_nouns, cap['nouns'])
        if not common_noun_phrases:
            distract_captions.append(cap['caption'])
    
    # Insert identifier phrase into random distractor captions
    if add_anchor:
        insert_ids = random.sample(list(range(len(distract_captions))), random.randint(1, 3))
        identifier_phrase_count['total'] = len(insert_ids) + len(captions)
        for id in insert_ids:
            distract_captions[id] = distract_captions[id] + ' ' + identifier_sent if random.random()<0.5 else identifier_sent + ' ' + distract_captions[id]

    if insert_style=='random':
        insert_style = random.choice(['together', 'separate'])
    if insert_style=='determined':
        inds = random.randint(0, int(0.1*len(distract_captions))), random.randint(int(0.4*len(distract_captions)), int(0.6*len(distract_captions))), random.randint(int(0.9*len(distract_captions)), len(distract_captions))
        for ind, cap in zip(inds, captions):
            distract_captions.insert(ind, cap)
            
    elif insert_style=='together':
        ind = random.choice(list(range(len(distract_captions))))
        for i, cap in enumerate(captions):
            distract_captions.insert(ind+i, captions[i])
    elif insert_style=='separate':
        insert_inds = sorted(random.sample(range(len(distract_captions) + 1), len(captions)))
        for cid, insert_ind in enumerate(insert_inds):
            if not captions[cid].endswith('.'):
                captions[cid] += '.'
            # Insert identifier phrase into relevant captions
            if add_anchor:
                captions[cid] = captions[cid] + ' ' + identifier_sent if random.random()<0.5 else identifier_sent + ' ' + captions[cid]
                identifier_phrase_count[cid] = ' '.join(distract_captions[:insert_ind]).count(identifier_sent)+1
            distract_captions.insert(insert_ind, captions[cid])
    distract_captions = '\n'.join(distract_captions)
    if add_anchor:
        return distract_captions, identifier_phrase, identifier_phrase_count
    else:
        return distract_captions

def process_attribute_qa_dataset(qa_datas, caption_pool, output_path, aspect, num_val=500, num_distract=100):
    """
        Convert generated attribute text qa to unified format
    """
    train_qas = []  # train_qas: qa data in the training format
    train_qas_long = []
    for id, example in tqdm(qa_datas.items()):
        for caps, qa in zip(example['captions'].values(), example['qas'].values()):

            caption_str = '\n'.join(caps)

            long_caption_str = sample_distract_captions(get_nouns(qa['question']), caps, caption_pool, random.randint(num_distract-50, num_distract+50))

            random.shuffle(qa['options'])
            opt_str = ""
            for opt, char in zip(qa['options'], ['(A) ', '(B) ', '(C) ', '(D) ', '(E) ', '(F) '][:len(qa['options'])]):
                if qa['answer']==opt:
                    mc_a = char + opt
                opt_str += char + opt + '\n'
            mc_q = f"{qa['question']}\n{opt_str}"
            instruct = random.choice(task_instructs)
            ans_prompt_mc = random.choice(answer_prompts['multi-choice'])
            train_qas.append([{"from": "human", "value": f"{instruct}\n\nCaptions:\n{caption_str}\nQuestion: {mc_q}\n{ans_prompt_mc}"}, {"from": "gpt", "value": mc_a}])
            train_qas_long.append([{"from": "human", "value": f"{instruct}\n\nCaptions:\n{long_caption_str}\nQuestion: {mc_q}\n{ans_prompt_mc}"}, {"from": "gpt", "value": mc_a}])
    random.shuffle(train_qas)
    random.shuffle(train_qas_long)
    train_qas, val_qas = train_qas[:-num_val], train_qas[-num_val:]
    train_qas_long, val_qas_long = train_qas_long[:-num_val], train_qas_long[-num_val:]
    print(len(train_qas), len(val_qas))
    save_json(train_qas_long, f"{output_path}/temp_qa_{aspect}_train_long_2x.json")
    save_json(train_qas_long[:15000], f"{output_path}/temp_qa_{aspect}_train_long_2x_15k.json")
    save_json(val_qas_long, f"{output_path}/temp_qa_{aspect}_val_long_2x.json")
    return train_qas

pos_strs = {
    "begin": ["at the begin of the video", "at the beginning of the video", "at the start of the video", "when the video begings", "close to thee begin of the video", "close to the start of the video", "close to the begining of the video"],
    "middle": ["at the middle of the video", "at the middle part of the video", "in the midst of the video"],
    "end": ["at the end of the video", "at the final of the video", "when the video comes to end", "close to the end of the video", "close to the final of the video"],
    1: [
        "when [anchor] appears the first time", "when [anchor] first appears", "when [anchor] first appears in the video", "when [anchor] is seen for the first time", "at the moment [anchor] first appears", "upon the initial appearance of [anchor]",
        "when [anchor] appears the 1st time", "when [anchor] is seen for the 1st time", "upon the 1st appearance of [anchor]"
        ],
    2: [
        "when [anchor] appears the second time", "when [anchor] is seen for the second time", "at the moment [anchor] appears the second time", "upon the second appearance of [anchor]",
        "when [anchor] appears the 2nd time", "when [anchor] is seen for the 2nd time", "at the moment [anchor] appears the 2nd time", "upon the 2nd appearance of [anchor]"
        ],
    3: [
        "when [anchor] appears the third time", "when [anchor] is seen for the third time", "at the moment [anchor] appears the third time", "upon the third appearance of [anchor]",
        "when [anchor] appears the 3rd time", "when [anchor] is seen for the 3rd time", "at the moment [anchor] appears the 3rd time", "upon the 3rd appearance of [anchor]"
        ],
    4: [
        "when [anchor] appears the fourth time", "when [anchor] is seen for the fourth time", "at the moment [anchor] appears the fourth time", "upon the fourth appearance of [anchor]",
        "when [anchor] appears the 4th time", "when [anchor] is seen for the 4th time", "at the moment [anchor] appears the 4th time", "upon the 4th appearance of [anchor]"
        ],
    5: [
        "when [anchor] appears the fifth time", "when [anchor] is seen for the fifth time", "at the moment [anchor] appears the fifth time", "upon the fifth appearance of [anchor]",
        "when [anchor] appears the 5th time", "when [anchor] is seen for the 5th time", "at the moment [anchor] appears the 5th time", "upon the 5th appearance of [anchor]",
        ],
    6: [
        "when [anchor] appears the sixth time", "when [anchor] is seen for the sixth time", "at the moment [anchor] appears the sixth time", "upon the sixth appearance of [anchor]",
        "when [anchor] appears the 6th time", "when [anchor] is seen for the 6th time", "at the moment [anchor] appears the 6th time", "upon the 6th appearance of [anchor]",
        ],
    "last": [
        "when [anchor] appears the last time", "when [anchor] is seen for the last time", "at the moment [anchor] appears the last time", "upon the last appearance of [anchor]",
        ],
}
pos_str_list = []
for pos in ['begin', 'middle', 'end']:
    pos_str_list += pos_strs[pos]

def process_refer_qa_dataset(qa_datas, caption_pool, output_path, aspect, refer_type, num_val=500, num_distract=100):
    """
        Convert temporal referring text qa to unified format
            refer_type: 'begin_end' or 'identifier'
    """
    train_qas = []  # train_qas: qa data in the training format
    for id, example in tqdm(qa_datas.items()):
        for caps, qa in zip(example['captions'].values(), example['qas'].values()):

            new_qas = [{'caption': cap, 'answer': ans} for cap, ans in zip(caps, qa['answers'].values())]
            random.shuffle(new_qas)
            caps = [item['caption'] for item in new_qas]
            if refer_type=='anchor':
                long_caption_str, anchor_phrase, anchor_count = sample_distract_captions(get_nouns(' '.join(caps)), caps, caption_pool, random.randint(num_distract-50, num_distract+50), add_anchor=True, insert_style='separate')
            elif refer_type=='begin_end':
                long_caption_str = sample_distract_captions(get_nouns(' '.join(caps)), caps, caption_pool, random.randint(num_distract-50, num_distract+50), insert_style='determined')

            random.shuffle(qa['options'])
            opt_str = ""
            for opt, char in zip(qa['options'], ['(A) ', '(B) ', '(C) ', '(D) ', '(E) ', '(F) '][:len(qa['options'])]):
                for item in new_qas:
                    if item['answer']==opt:
                        item['answer_str'] = char + opt
                opt_str += char + opt + '\n'

            instruct = random.choice(task_instructs)
            ans_prompt_mc = random.choice(answer_prompts['multi-choice'])
            if refer_type=='begin_end':
                positions = ['begin', 'middle', 'end']
            elif refer_type=='anchor':
                positions = [anchor_count[i] for i in range(len(caps))]
                positions[-1] = 'last' if positions[-1]==anchor_count['total'] and random.random()<0.5 else positions[-1]
            for item, pos in zip(new_qas, positions):
                pos_str = random.choice(pos_strs[pos])
                if refer_type=='anchor':
                    pos_str = pos_str.replace('[anchor]', anchor_phrase)
                item['question'] = f"{qa['question'].replace('?', ' '+pos_str+'?')}\n{opt_str}"
                if not 'answer_str' in item:
                    print("answer not in option")
                    continue
                train_qas.append([{"from": "human", "value": f"{instruct}\n\nCaptions:\n{long_caption_str}\nQuestion: {item['question']}\n{ans_prompt_mc}"}, {"from": "gpt", "value": item['answer_str']}])
    random.shuffle(train_qas)
    train_qas, val_qas = train_qas[:-num_val], train_qas[-num_val:]
    print(len(train_qas), len(val_qas))
    save_json(train_qas, f"{output_path}/temp_qa_{aspect}_{refer_type}_train.json")
    save_json(val_qas, f"{output_path}/temp_qa_{aspect}_{refer_type}_val.json")
    return train_qas

def statement2dict(question, answers, new_statements):
    dict_statement = {}
    for ans, statement in zip(answers, new_statements.split('\n')):
        key = f"{question.strip('?')}-{ans}"
        dict_statement[key] = statement.split(': ')[1].strip()
    return dict_statement

any2temp_questions = {
    "prefix": [
       "In which part of the video ",
       "During which part of the video ",
       "During which segment of the video ",
       "Where in the video ",
       "At which moment in the video ",
       "In what section of the video"
    ],
    "suffix": [
       "can we observe [statement]?",
       "does [statement] happen?",
       "can [statement] be ovserved?",
       "the statement \"[statement]\" is true?",
    ]
}

def refer_any2temp(output_path, refer_type, num_val=500):
    src_qas = load_json(f"{output_path}/temp_qa_refer.json")
    train_qa_path, val_qa_path = f"{output_path}/temp_qa_refer_{refer_type}_temp2any_train.json", f"{output_path}/temp_qa_refer_{refer_type}_temp2any_val.json"
    src_qas_long = load_json(train_qa_path)
    src_qas_long += load_json(val_qa_path)

    output_qa_path, output_statement_path = f"{output_path}/temp_qa_refer_{refer_type}_any2temp.json", f"{output_path}/temp_qa_refer_statements.json"
    if os.path.exists(output_qa_path):
        any2temp_qas = load_json(output_qa_path)
    else:
        any2temp_qas = {}
    if os.path.exists(output_statement_path):
        statements = load_json(output_statement_path)
    else:
        statements = {}
    for id, sample in tqdm(src_qas.items()):
        for qa in sample['qas'].values():
            question, answers = qa['question'], list(qa['answers'].values())
            if all([f"{question.strip('?')}-{ans}" in statements for ans in answers]):
                continue
            prompt = any2temp_prompt + f"Question: {question}\n"
            for aid, ans in enumerate(answers):
                prompt += f"Answer {aid+1}: {ans}\n"
            maxtry = 10
            while True:
                try:
                    new_statements = get_llm_output(prompt, response_format='text', model='gpt-3.5-turbo-0125')
                    new_statements = statement2dict(question, answers, new_statements)
                    statements.update(new_statements)
                    time.sleep(1)
                    break
                except:
                    if maxtry<=0:
                        break
                    maxtry -= 1
                    print(f"Not success! {maxtry} retries remaining...")
                    time.sleep(10)
    save_json(statements, output_statement_path)

    for sid, sample in enumerate(tqdm(src_qas_long)):
        if str(sid) in any2temp_qas:
            continue
        context, question = sample[0]['value'].split('\nQuestion: ')
        ans_prompt = sample[0]['value'].split('\n\n')[-1]
        question = question.split('\n')[0]
        answer = sample[1]['value'].split(') ')[1]

        for pos_str in pos_str_list:
            if pos_str in question:
                key = f"{question.split(pos_str)[0].strip()}-{answer}"
                statement = statements[key]
                break

        opt_str, options = "", ['begin', 'middle', 'end']
        for pos in options:
            for pos_str in pos_strs[pos]:
                if pos_str in question:
                    new_answer = pos
                    break
        random.shuffle(options)
        for opt, char in zip(options, ['(A) ', '(B) ', '(C) ', '(D) ', '(E) ', '(F) '][:len(options)]):
            cur_opt_str = random.choice(pos_strs[opt])
            if new_answer==opt:
                answer_str = char + cur_opt_str
            opt_str += char + cur_opt_str + '\n'

        prefix, suffix = random.choice(any2temp_questions['prefix']), random.choice(any2temp_questions["suffix"]).replace('[statement]', statement.lower())
        new_question = f"{context}\nQuestion: {prefix}{suffix}\n{opt_str}\n{ans_prompt}"
        new_sample = [{'from': 'human', 'value': new_question}, {'from': 'gpt', 'value': answer_str}]
        any2temp_qas[sid] = new_sample
    save_json(any2temp_qas, f"{output_path}/temp_qa_refer_{refer_type}_any2temp.json")
    
    new_samples_train, new_samples_val = list(any2temp_qas.values())[:-num_val], list(any2temp_qas.values())[-num_val:]
    output_train_path, output_val_path = f"{output_path}/temp_qa_refer_{refer_type}_any2temp_train.json", f"{output_path}/temp_qa_refer_{refer_type}_any2temp_val.json"
    save_json(new_samples_train, output_train_path)
    save_json(new_samples_val, output_val_path)

def build_caption_pool(captions, num_caps=500000, noun_type='word'):
    """
        noun_type: 'word' or 'phrase'
    """
    caption_pool = {}
    captions = [cap.split('.')[0]+'.' for cap in src_captions]
    random.shuffle(captions)
    captions = captions[:num_caps]
    for id, cap in enumerate(tqdm(captions)):
        nouns = get_nouns(cap) if noun_type=='word' else extract_noun_phrases(cap, random_sample=False)
        if not nouns:
            continue
        caption_pool[id] = {'caption': cap, 'nouns': nouns}
    return caption_pool
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()     
    parser.add_argument('--src_path', default='path_to/LLaVA-ReCap-558K.json')
    parser.add_argument('--tgt_path', default='outputs')
    parser.add_argument('--aspect', default='order_template', choices=['order', 'attribute', 'refer', 'order_template'])
    parser.add_argument('--num_distract', default=100, type=int, help="number of distractor captions for building long-context")
    parser.add_argument('--num_combine', default=2000, type=int, help="number of caption combinations")
    args = parser.parse_args()

    random.seed(42)
    # combine image captions as context
    combined_cap_path = f"{args.tgt_path}/{os.path.basename(args.src_path).replace('.json', '')}-Combine-{args.aspect}.json"
    src_datas = load_json(args.src_path)
    src_captions = [d['conversations'][1]['value'] for d in src_datas]
    caption_pool = build_caption_pool(src_captions, noun_type='phrase' if args.aspect=='order_template' else 'word') 
    if args.aspect=='order_gpt':
        combined_cap_first_sent_path = combined_cap_path.replace('.json', '')+'-first-sent.json'
        if not os.path.exists(combined_cap_path):
            combined_captions = combine_captions(args, src_captions, combined_cap_path)
        else:
            combined_captions = load_json(combined_cap_first_sent_path)
    
        qa_datas = generate_order_qa(combined_captions, combined_cap_first_sent_path)
        process_order_qa_dataset(qa_datas, caption_pool, args.tgt_path, args.aspect, num_distract=args.num_distract)
    elif args.aspect=='order_template':
        output_type = 'phrase'  # choose from 'prefix', 'sentence', 'phrase'
        for output_type in ['prefix', 'phrase', 'sentence']:
            num_distract, num_samples, num_val = 200, 30000, 500
            order_template_setting = {
                "num_tgt_caps": [3, 4, 5, 6],
                "what_to_order": ['sentence', 'phrase'],
                "prefixs": [
                    ['(a)', '(b)', '(c)', '(d)', '(e)', '(f)'],
                    ['(1)', '(2)', '(3)', '(4)', '(5)', '(6)'],
                    ['①', '②', '③', '④', '⑤', '⑥']
                ],
                "opt_chars": [
                    ['(A) ', '(B) ', '(C) ', '(D) ', '(E) ', '(F) '],
                    ['A. ', 'B. ', 'C. ', 'D. ', 'E. ', 'F. '],
                ],
                "opt_connector": [' -> ', ' --> ', ','],
                "shuf_cap_connector": ['\n', ' ', '; '],
                "question": {
                    "sentence": [
                        "Arrange the following captions in the correct chronological order as they appear in the video.",
                        "Place the following captions in the order they are shown in the video.",
                        "Sort the following captions into their chronological order as per the video.",
                        "Rearrange the following captions to match the order they appear in the video.",
                        "Reorder the following captions according to the above video.",
                        "Order the following captions in the same sequence as they unfold in the video.",
                        "List the following captions in the chronological progression they appear in the video.",
                    ],
                    "prefix_phrase": [
                        "Sort these events from the video by their chronological order.",
                        "Organize the listed items from the video according to their time sequence: ",
                        "Arrange the following events according to the correct order that they were shown in the video: ",
                        "In which order are the following items appear in the video?",
                        "What is the chronological order of the following events according to the above video?",
                        "According to the video, in which order do the following events happen?"
                    ]
                }
            }

            samples = []
            for _ in tqdm(range(num_samples)):
                instruct = random.choice(task_instructs)
                num_tgt_cap = random.choice(order_template_setting['num_tgt_caps'])
                tgt_caps = random.sample(list(caption_pool.values()), num_tgt_cap)
                tgt_cap_nouns = []
                for cap in tgt_caps:
                    tgt_cap_nouns += cap['nouns']
                cap_str = sample_distract_captions(set(tgt_cap_nouns), [cap['caption'] for cap in tgt_caps], caption_pool, 
                                                random.randint(num_distract-50, num_distract+50), insert_style='separate')
                if output_type in ['prefix', 'phrase']:
                    question = random.choice(order_template_setting['question']['prefix_phrase'])
                    ans_prompt = random.choice(answer_prompts['multi-choice'])
                    tgt_caps = [(i, cap) for i, cap in enumerate(tgt_caps)]
                    random.shuffle(tgt_caps)
                    if output_type=='prefix':
                        opt_connector = random.choice(order_template_setting['opt_connector']+['', ' '])
                        prefixs, shuf_cap_connector, what_to_order = random.choice(order_template_setting['prefixs']), random.choice(order_template_setting['shuf_cap_connector']), random.choice(order_template_setting['what_to_order'])
                        correct_order = {}
                        shuffled_cap_str = ""
                        for prefix, cap in zip(prefixs, tgt_caps):
                            identifier = cap[1]['caption'] if what_to_order=='sentence' else random.choice(cap[1]['nouns'])
                            shuffled_cap_str += f"{prefix} {identifier}{shuf_cap_connector}"
                            correct_order[cap[0]] = prefix
                    elif output_type=='phrase':
                        opt_connector = random.choice(order_template_setting['opt_connector'])
                        correct_order = {}
                        shuffled_cap_str = ""
                        for cap in tgt_caps:
                            identifier = random.choice(cap[1]['nouns'])
                            shuffled_cap_str += f"{identifier}\n"
                            correct_order[cap[0]] = identifier

                    correct_order = list(dict(sorted(correct_order.items())).values())
                    permutations = list(itertools.permutations(correct_order))
                    random.shuffle(permutations)
                    answer = opt_connector.join(correct_order)
                    options = [answer] + [opt_connector.join(permut) for permut in permutations[:3]]
                    random.shuffle(options)
                    opt_str = ""
                    for opt, char in zip(options, random.choice(order_template_setting['opt_chars'])[:len(options)]):
                        if answer==opt:
                            answer_str = char + opt
                        opt_str += char + opt + '\n'
                    sample = [{"from": "human", "value": f"{instruct}\n\nCaptions:\n{cap_str}\nQuestion: {question}\n{shuffled_cap_str}\n{opt_str}\n{ans_prompt}"}, 
                            {"from": "gpt", "value": answer_str}]
                elif output_type=='sentence':
                    question = random.choice(order_template_setting['question']['sentence'])
                    answer_str = '\n'.join([cap['caption'] for cap in tgt_caps])
                    random.shuffle(tgt_caps)
                    shuffled_cap_str = "\n".join([cap['caption'] for cap in tgt_caps])
                    sample = [{"from": "human", "value": f"{instruct}\n\nCaptions:\n{cap_str}\n\nQuestion: {question}\n{shuffled_cap_str}"}, 
                            {"from": "gpt", "value": answer_str}]
                samples.append(sample)
            samples_train, samples_val = samples[:-num_val], samples[-num_val:]
            save_json(samples_train, f"{args.tgt_path}/temp_qa_order_shuffle_{output_type}_train.json")
            save_json(samples_val, f"{args.tgt_path}/temp_qa_order_shuffle_{output_type}_val.json")
    elif args.aspect=='attribute':
        combined_captions = generate_attribute_captions(src_captions, combined_cap_path)
        qa_datas = generate_attribute_qa(combined_captions, combined_cap_path)
        process_attribute_qa_dataset(qa_datas, caption_pool, args.tgt_path, args.aspect, num_distract=args.num_distract)
    elif args.aspect=='refer':
        output_caption_path, output_qa_path = f"{args.tgt_path}/counterfactual_captions.json", f"{args.tgt_path}/temp_qa_refer.json"
        combined_captions = generate_counterfactual_captions(src_captions, output_path=output_caption_path)
        qa_datas = generate_refer_qa(combined_captions, output_qa_path)
        process_refer_qa_dataset(qa_datas, caption_pool, args.tgt_path, args.aspect, num_distract=args.num_distract, refer_type='begin_end')
        refer_any2temp(output_path=args.tgt_path, refer_type='begin_end')