import os
import json
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--ref-files", type=str)
parser.add_argument("--res-rg-files", type=str)
parser.add_argument("--res-cd-files", type=str)
args = parser.parse_args()


# open ground truth answers & generated answers
ref_files = [json.loads(q) for q in open(os.path.expanduser(args.ref_files), "r")]
res_rg_files = [json.loads(q) for q in open(os.path.expanduser(args.res_rg_files), "r")]
res_cd_files = [json.loads(q) for q in open(os.path.expanduser(args.res_cd_files), "r")]

TP = 0
FP = 0
TN = 0
FN = 0

TP2TP = 0
TP2FN = 0
FP2FP = 0
FP2TN = 0
TN2TN = 0
TN2FP = 0
FN2FN = 0
FN2TP = 0



num_all = len(res_rg_files)
num_yes = 0

for index, line in enumerate(ref_files):
    idx = line['question_id']
    ref = line['label']
    assert idx == res_rg_files[index]['question_id']
    assert idx == res_cd_files[index]['question_id']
    res_rg = res_rg_files[index]['text']
    res_cd = res_cd_files[index]['text']

    ref = ref.lower().strip()
    res_rg = res_rg.lower().strip()
    res_cd = res_cd.lower().strip()

    if ref == 'yes':
        if 'yes' in res_rg:
            TP += 1
            if 'yes' in res_cd:
                TP2TP += 1
            else:
                TP2FN += 1
        else:
            FN += 1
            if 'yes' in res_cd:
                FN2TP += 1
            else:
                FN2FN += 1

    else:
        if 'no' in res_rg:
            TN += 1
            if 'no' in res_cd:
                TN2TN += 1
            else:
                TN2FP += 1
        else:
            FP += 1
            if 'no' in res_cd:
                FP2TN += 1
            else:
                FP2FP += 1

# report results

P2N = FP2TN + TP2FN
N2P = FN2TP + TN2FP 

print(f'P2N: {P2N}')
print(f'N2P: {N2P}')
print('--------------')

print(f'TP2TP:{TP2TP}')
print(f'TP2FN:{TP2FN}')
print('--------------')

print(f'FP2FP:{FP2FP}')
print(f'FP2TN:{FP2TN}')
print('--------------')

print(f'TN2TN:{TN2TN}')
print(f'TN2FP:{TN2FP}')
print('--------------')

print(f'FN2FN:{FN2FN}')
print(f'FN2TP:{FN2TP}')
print('--------------')



