import os
import argparse
import json
import os
import sys
llava_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
if llava_path not in sys.path:
    sys.path.append(llava_path)
    print(sys.path)
from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
from azfuse import File
from tqdm import tqdm

# def parse_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2")
#     parser.add_argument('--ckpt', type=str, required=True)
#     parser.add_argument('--split', type=str, required=True)
#     return parser.parse_args()


def main(src, dst, ann):

    # src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl')
    # test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl')
    test_split = ann
    # dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json')
    os.makedirs(os.path.dirname(dst), exist_ok=True)

    results = []
    error_line = 0
    with File.open(src) as f:
        lines = f.readlines()
    for line_idx, line in enumerate(lines):
        try:
            line = line.strip()
            jsonl = json.loads(line)
            if isinstance(jsonl, str):
                jsonl = json.loads(jsonl)
            results.append(jsonl)
        except:
            error_line += 1

    results = {x['question_id']: x['text'] for x in results}
    test_split = [json.loads(line) for line in File.open(test_split)]
    split_ids = set([x['question_id'] for x in test_split])

    print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')

    all_answers = []

    answer_processor = EvalAIAnswerProcessor()

    for x in tqdm(test_split):
        if x['question_id'] not in results:
            print(x['question_id'])
            all_answers.append({
                'question_id': x['question_id'],
                'answer': ''
            })
        else:
            all_answers.append({
                'question_id': x['question_id'],
                'answer': answer_processor(results[x['question_id']])
            })

    with File.open(dst, 'w') as f:
        json.dump(all_answers, f)


if __name__ == '__main__':
    from fire import Fire
    Fire(main)
    # args = parse_args()


