import json
from argparse import ArgumentParser
from tqdm import tqdm


parser = ArgumentParser()
parser.add_argument('--ans_file', type=str, default='pope_results/llava_based_pope')
parser.add_argument('--out_file', type=str, default='pope_results/llava_based_pope_result')
parser.add_argument('--model', type=str, default='llava')
args = parser.parse_args()

ans_file = f'{args.ans_file}.json'
label_file = f'{args.ans_file}_label.json'

answers = [json.loads(q) for q in open(ans_file, 'r')]
label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
image_paths = []
image_ids = []
category_names = []
bboxes = []
questions = []

for idx, answer in enumerate(answers):
    text = answer['answer']
    image_paths.append(answer['image_path'])
    image_ids.append(int(answer['image_path'].split('.')[0].split('_')[-1]))
    questions.append(answer['question'])

    if text is None:
        answer['answer'] = 'wrong'
    else:
        if text.find('.') != -1:
            text = text.split('.')[0]

        text = text.replace(',', '')
        words = text.split(' ')
        if 'No' in words or 'not' in words or 'no' in words:
            answer['answer'] = 'no'
        else:
            answer['answer'] = 'yes'

for i in range(len(label_list)):
    if label_list[i] == 'no':
        label_list[i] = 0
    else:
        label_list[i] = 1

pred_list = []
for answer in answers:
    if answer['answer'] == 'wrong':
        pred_list.append(-1)
    elif answer['answer'] == 'no':
        pred_list.append(0)
    else:
        pred_list.append(1)

pos = 1
neg = 0
yes_ratio = pred_list.count(1) / len(pred_list)

TP, TN, FP, FN, no_answer = 0, 0, 0, 0, 0
for pred, label in tqdm(zip(pred_list, label_list), total=len(pred_list)):
    if pred == -1:
        no_answer += 1
    elif pred == pos and label == pos:
        TP += 1
    elif pred == pos and label == neg:
        FP += 1
    elif pred == neg and label == neg:
        TN += 1
    elif pred == neg and label == pos:
        FN += 1

print('TP\tFP\tTN\tFN\tNOANSWER')
print('{}\t{}\t{}\t{}\t{}'.format(TP, FP, TN, FN, no_answer))

precision = float(TP) / float(TP + FP)
recall = float(TP) / float(TP + FN)
f1 = 2*precision*recall / (precision + recall)
acc = (TP + TN) / (TP + TN + FP + FN + no_answer)
print('Accuracy: {}'.format(acc))
print('Precision: {}'.format(precision))
print('Recall: {}'.format(recall))
print('F1 score: {}'.format(f1))
print('Yes ratio: {}'.format(yes_ratio))

s = f'Accuracy: {acc}\nPrecision: {precision}\nRecall: {recall}\nF1 score: {f1}\nYes ratio: {yes_ratio}\n'

with open(f"{args.out_file}.txt", "w", encoding="utf-8") as file:
    file.write(s)
