# ../POPE/llava_qa/answer/I4_7b_cfg0.0.jsonl
import json
import os
# load
# set args
import argparse
def get_question_file(answer_file):
    # question_dir = answer_dir.replace('answer', 'question')

    model_size_ls = ['7b','13b']

    for model_size in model_size_ls:
        if model_size in answer_file:
            question_file = answer_file.split('_'+model_size)[0] + '.json'
            break
    else:
        if '_cfg' in answer_file:
            question_file = answer_file.split('_cfg')[0] + '.json'
    if question_file is None:
        raise ValueError('question_file is None')
    return question_file

def qid2img_id(answer_ls, answer_file):
    answer_dir = '../POPE/llava_qa/answer'

    question_dir = '../POPE/llava_qa/question'
    # if question_file is None:
    #     question_file = get_question_file(answer_file)
    question_file = "I3.json"
    ## load data
    with open(os.path.join(question_dir, question_file), 'r') as f:
        question_ls = f.readlines()
    question_ls = [json.loads(question) for question in question_ls][0]
    
    with open(os.path.join(answer_dir, answer_file), 'r') as f:
        answer_ls = f.readlines()

    answer_ls = match_qa_image_id(answer_ls, question_ls)
    return answer_ls

def load_question_ls():    
    question_dir = '../POPE/llava_qa/question'
    # if question_file is None:
    #     question_file = get_question_file(answer_file)
    question_file = "I3.json"
    ## load data
    with open(os.path.join(question_dir, question_file), 'r') as f:
        question_ls = f.readlines()
    question_ls = [json.loads(question) for question in question_ls][0]
    return question_ls

def match_qa_image_id(question_ls, answer_ls):
    ## add 'image_id' to each answer according to 'question_id'
    for i, answer in enumerate(answer_ls):
        question_id = answer['question_id']
        if question_ls[i] == question_id:
            answer['image_id'] = question_ls[i]['image']
        else:
            for question in question_ls:
                if question['id'] == question_id:
                    answer['image_id'] = question['image']
                    break
    return answer_ls
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--answer_dir', type=str, default='../POPE/llava_qa/answer')
    parser.add_argument('--question_dir', type=str, default='../POPE/llava_qa/question')
    parser.add_argument('--save_dir', type=str, default='../POPE/llava_qa/answer_v3')
    parser.add_argument('--question_file', type=str, default='I3.json')
    parser.add_argument('--answer_file', type=str, default='I3_7b_cfg0.0.jsonl')
    args = parser.parse_args()

    if args.question_file is None:
        args.question_file = get_question_file(args.answer_file)
        
    ## load data
    with open(os.path.join(args.question_dir, args.question_file), 'r') as f:
        question_ls = f.readlines()
    with open(os.path.join(args.answer_dir, args.answer_file), 'r') as f:
        answer_ls = f.readlines()
    question_ls = [json.loads(question) for question in question_ls][0]
    answer_ls = [json.loads(answer) for answer in answer_ls]

    ## add 'image_id' to each answer according to 'question_id'
    for i, answer in enumerate(answer_ls):
        question_id = answer['question_id']
        if question_ls[i] == question_id:
            answer['image_id'] = question_ls[i]['image']
        else:
            for question in question_ls:
                if question['id'] == question_id:
                    answer['image_id'] = question['image']
                    break

    ## save
    # make dir
    os.makedirs(args.save_dir, exist_ok=True)
    save_path = os.path.join(args.save_dir, args.answer_file.split('.')[0]+'_v2.jsonl')
    with open(save_path, 'w') as f:
        for answer in answer_ls:
            json.dump(answer, f)
            f.write('\n')
    print('save to {}'.format(save_path))