import json
import re
#Get LLaVA Bench json
def extract_and_remove_pictures(text):

    pattern = r'\[Picture (\d+)\]'

    numbers = re.findall(pattern, text)

    numbers = [int(num) for num in numbers]

    cleaned_text = re.sub(pattern, '', text)

    return numbers, cleaned_text


in_result_json = './test_result/VEGA_result.json'

#SciGraphQA validation json path
scigraph_json = './SciGraphQA/SciCapQA-test-with-deplot/data/1_percent_as_validation-00000-of-00001-3d2fb4f89a9cda25.json'

question_path = './llava_bench/Qwen_VEGA8k_4k2/question.json'
context_path = './llava_bench/Qwen_VEGA8k_4k2/context.json'
answer1_path = './llava_bench/Qwen_VEGA8k_4k2/answer1.json'
answer2_path = './llava_bench/Qwen_VEGA8k_4k2/answer2.json'
idd = 0
with open(in_result_json,'r',encoding='utf-8') as inp, open(scigraph_json,'r',encoding='utf-8') as sci:
    inplines = inp.readlines()
    scilines = sci.readlines()
    sciids = [json.loads(x)['id'] for x in scilines]
    scifm = [json.loads(x)['first_mention'] for x in scilines]
    assert len(sciids) == len(scifm)
    for line in inplines:
        data = json.loads(line)
        try:
            image = data['image_paths'][int(data['truth_fig_idx'])]
        except:
            image = f'{idd}.png'
        question = data['question']
        fm = scifm[sciids.index(data['id'])]
        question_data = {"question_id": idd, "image": image, "text": question, "category": "llava_bench_complex"}
        context_data = {"id": idd, "image": image, "captions": [data['caption']],"first_mention":fm}
        numbers, cleaned_text = extract_and_remove_pictures(data['response'])
        if len(numbers) == 1 and numbers[0] == int(data['truth_fig_idx']) + 1:
            flag = 1
        else:
            flag = 0
        answer1_data = {"question_id": idd, "text": cleaned_text, "category": "llava_bench_complex","flag":flag}
        answer2_data = {"question_id": idd, "text": data['answer'], "category": "llava_bench_complex","flag": 1}
        idd += 1
        with open(question_path,'a',encoding='utf-8') as que:
            json.dump(question_data,que,ensure_ascii=False)
            que.write('\n')
        with open(context_path,'a',encoding='utf-8') as que:
            json.dump(context_data,que,ensure_ascii=False)
            que.write('\n')
        
        with open(answer1_path,'a',encoding='utf-8') as que:
            json.dump(answer1_data,que,ensure_ascii=False)
            que.write('\n')
            
        with open(answer2_path,'a',encoding='utf-8') as que:
            json.dump(answer2_data,que,ensure_ascii=False)
            que.write('\n')