import json
import re
import numpy as np
from sklearn.metrics import f1_score

def jaccard_similarity(str1, str2):
    set1 = set(str1.split())
    set2 = set(str2.split())
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union


def check(file_path):
    
    total_questions = 0
    correct_answers = 0
    checkABC_list = []
    correct_candidates_k = [0] * 5  # For k=1 to 5
    with open('data/DB_all_rel.json', 'r') as f:
        relation_vocab = json.load(f)
        # 提取'name'字段，并将键转换为整数
        labmap = {v.get('name', 'Unknown'): int(k) for k, v in relation_vocab.items()}
    
    labels = []
    preds = []
    cand_preds = []
    hits = []
    fch = open('check/ch-' + file_path[10:], 'w')

    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id', 0)
            drugs = data.get('drugs', [])
            question_str = data.get('question', '')
            answer_str = data.get('answer', '')

            drug1, drug2, template_string = drugs
            gold_answer = template_string.replace('#Drug1', drug1).replace('#Drug2', drug2)
            gold_answer = gold_answer[:-1]  # delete .
            labels.append(labmap[template_string])

            start = question_str.rfind('Candidate Answers')
            # print(start)
            if start != -1:
                question_str =  question_str[start:]
            # print(question_str)
            end = question_str.rfind('Related Facts')
            if end != -1:
                question_str = question_str[:end]
            sentences = question_str.split('. ')
            candidates = sentences[:3]
            candidates = [sentence.strip().lower() for sentence in candidates]


            # Check if gold_answer is in answer_str (case-insensitive)
            if gold_answer.lower() in answer_str.lower():
                correct_answers += 1
                hits.append(1)
                hit = 1
            else:
                hits.append(0)
                hit = 0

            # Check if gold_answer is among top k candidates
            hitA = 0
            for k in range(1, 6):
                if k <= len(candidates):
                    top_k_candidates = candidates[:k]
                    if any(gold_answer.lower() in cand for cand in top_k_candidates):
                        correct_candidates_k[k-1] += 1
                        if k == 1:
                            hitA = 1
            sim = 0
            try:
                for k, v in labmap.items():
                    sim_ = jaccard_similarity(k.lower(), candidates[0])
                    if sim_ > sim:
                        sim = sim_
                        cand_pred = v
                cand_preds.append(cand_pred)
            except:
                cand_preds.append(0)
            total_questions += 1

            answer_str = extract_answer(answer_str)
            if answer_str == 'correct answer is ***.' or answer_str == 'correct answer is ***':
                check = 2
            else:
                # checkABC = CheckABC(answer_str, gold_answer, candidates)
                check = int(gold_answer.lower() in answer_str.lower())
                checkABC_list.append(check)

            if check == 1:
                preds.append(labmap[template_string])
            else:
                # 遍历labmap，寻找与answer_str最相似的那一个（使用jaccard_similarity）
                sim = 0
                for k, v in labmap.items():
                    sim_ = jaccard_similarity(k, answer_str)
                    if sim_ > sim:
                        sim = sim_
                        pred = v
                preds.append(pred)
        

            data = {  
                    'id': id_,
                    'Hit': hit,
                    'HitA': hitA,
                    'acc': check,
                    'gold_answers': gold_answer,  
                    'answer': answer_str  
                    }    
            fch.write(json.dumps(data) + '\n') 

    # Calculate and print metrics
    if total_questions > 0:
        hit = correct_answers / total_questions
        acc_abc = np.sum(np.array(checkABC_list)) / len(checkABC_list)
        f1 = f1_score(labels, preds, average='macro')
        print(f"{file_path}  ACC: {acc_abc:.4f} F1: {f1:.4f}")
        result_str = f"Hit : {hit:.4f}  ACC: {acc_abc:.4f}  F1: {f1:.4f}"   
        fch.write(result_str)
        fch.close()
    else:
        print("No valid questions processed.")
    print('total question:', total_questions, len(checkABC_list))

def CheckABC(answer_str, gold_answer, candidates):
    # check abc 
    s = answer_str
    if s.startswith('Based'):
        s = s[1:]
    index_a = s.find('A.')
    index_b = s.find('B.')
    index_c = s.find('C.')
    index_d = s.find('D.')
    index_e = s.find('E.')

    if len(candidates) > 0 and 0 <= index_a and (index_b ==-1 or index_b > index_a) and (index_c == -1 or index_a < index_c) and (index_d == -1 or index_a < index_d):
        s = 'A. '  + candidates[0].lower()
    elif len(candidates) > 1 and 0 <= index_b and (index_a ==-1 or index_a > index_b) and (index_c == -1 or index_b < index_c) and (index_d == -1 or index_b < index_d):
        s = 'B. '+ candidates[1].lower()
    elif len(candidates) > 2 and 0 <= index_c and (index_a ==-1 or index_a > index_c) and (index_b == -1 or index_b > index_c) and (index_d == -1 or index_c < index_d):
        s = 'C. ' + candidates[2].lower()
    elif len(candidates) > 3 and 0 <= index_d and (index_a ==-1 or index_a > index_d) and (index_b == -1 or index_b > index_d) and (index_c == -1 or index_d < index_c):
        s = 'D. ' + candidates[3].lower()
    else:
        # print(i, a)
        s = s.lower()
    
    return int(gold_answer.lower() in s)
        


def extract_answer(s):
    # match = re.search(r'correct answer is', s, re.IGNORECASE)
    match = re.search(r'the interaction is', s, re.IGNORECASE)
    if match:
        start_index = match.start()
        return s[start_index:]
    else:
        return s

if __name__ == "__main__":

    check(file_path ='output/DB-S1-d11N3p1k5-rag11-re5-same.jsonl')