import json
import re
import numpy as np
from collections import defaultdict

def ndcg_at_k(true_answers, relevance, k):

    k = min(k, len(relevance))
    relevance = relevance[:k]    
    dcg = 0.0
    for i in range(k):
        dcg += relevance[i] / np.log2(i + 2)  
    
    L = min(len(true_answers), k)
    idcg = 0.0
    for i in range(L):
        idcg += 1.0 / np.log2(i + 2)

    if idcg == 0:
        return 0.0
    
    return dcg / idcg

def recall_at_k(predicted, true_answers, k):
    top_k = predicted[:k]
    correct = set(top_k).intersection(set(true_answers))
    return len(correct) / len(true_answers)

def clean_gold_answers(answers):
    # Remove brackets and split by semicolons
    # answer_str = answer_str.replace('[', '').replace(']', '').replace(';', '')
    # return [ans[1:-1].lower().strip() for ans in answer_str.split(',') if ans.strip()]
    return [ans.lower() for ans in answers]

def clean_candi_answers(answer_str):
    # Remove brackets and split by semicolons
    answer_str = answer_str.replace('[', '').replace(']', '').split(';')
    return [ans.lower().strip() for ans in answer_str if ans != '']

def clean_answers(answer_str):
    # Remove brackets and split by semicolons
    # answer_str = answer_str.replace('[', '').replace(']', '').split(';')
    start = answer_str.rfind('[')
    answer_str = answer_str[start:].strip()
    if ';' in answer_str:
        answer_str = answer_str.replace('[', '').replace(']', '').split(';')
    else:
        answer_str = answer_str.replace('[', '').replace(']', '').split(',')
    return [ans.lower().strip() for ans in answer_str if ans != '']

def check(file_path):
    
    total_recall_candidate = 0
    total_ndcg_candidate = 0
    total_recall_model = 0
    total_ndcg_model = 0
    num_questions = 0
    fch = open('check/ch-' + file_path[7:], 'w')

    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            id_ = data.get('id', 0)
            drugs = data.get('drugs', [])
            drug1, drug2 = drugs[0], drugs[1]
            correct_answers = drugs[2]
            correct_answers_list = clean_gold_answers(correct_answers)
            

            # Extract candidate answers from the question field
            question = data.get('question', '')
            candidate_start = question.find('Candidate Answers: [') + len("Candidate Answers: [")
            question_text = question[candidate_start:]
            candidate_end = question_text.find('] Related Facts') 
            if candidate_end == -1:
                candidate_end = question_text.find('] Answer')
                if candidate_end == -1: 
                    candidate_end = len(question_text)
            candidate_answers_str = question_text[:candidate_end].strip()
            candidate_answers_list = clean_candi_answers(candidate_answers_str)
            
            # Extract model's answers from the answer field
            answer = data.get('answer', '')
            # model_start = answer.find('correct answer is ') + len('correct answer is ')
            model_start = answer.rfind('side effects are') + len('side effects are')
            model_answers_str = answer[model_start:].strip()
            model_answers_list = clean_answers(model_answers_str[:-1])
            
            
            # For model's answers: assume they are the top 5 predictions
            # Calculate recall@5 and NDCG@5
            k_model = 5
            relevance_model = []
            for cand in model_answers_list:
                flag = 0
                for ans in correct_answers_list:
                    if ans in cand:
                        flag = 1
                        relevance_model.append(1)
                        break
                if flag == 0:
                    relevance_model.append(0)
            recall_model = np.sum(relevance_model[:k_model]) / len(correct_answers_list)
            # recall_model = recall_at_k(model_answers_list, correct_answers_list, k_model)
            # For NDCG, assign relevance: 1 if in correct_answers, else 0
            # relevance_model = [1 if ans in correct_answers_list else 0 for ans in model_answers_list]
            try:
                ndcg_model = ndcg_at_k(correct_answers_list, relevance_model, k_model)
            except:
                print('model',correct_answers_list, model_answers_list, relevance_model)
                exit()
            # Accumulate metrics
            total_recall_model += recall_model
            total_ndcg_model += ndcg_model
            num_questions += 1

            data = {  
                    'id': id_,
                    'R@5': f"{recall_model:.4f}",
                    'N@5': f"{ndcg_model:.4f}",
                    'gold_answers': correct_answers_list,  
                    'answer': model_answers_list,
                    'candidate': candidate_answers_list
                    }    
            fch.write(json.dumps(data) + '\n') 



    # Calculate and print metrics
    if num_questions > 0:

        average_recall_candidate = total_recall_candidate / num_questions
        average_ndcg_candidate = total_ndcg_candidate / num_questions
        average_recall_model = total_recall_model / num_questions
        average_ndcg_model = total_ndcg_model / num_questions
        print(f"{file_path}  R@5: {average_recall_model:.4f}  N@5: {average_ndcg_model:.4f}  CR@5: {average_recall_candidate:.4f}  CN@5: {average_ndcg_candidate:.4f}")
        print('Number of questions:', num_questions)
        result_str = f"R@5: {average_recall_model:.4f}  N@5: {average_ndcg_model:.4f}  CR@5: {average_recall_candidate:.4f}  CN@5: {average_ndcg_candidate:.4f}"   
        fch.write(result_str)
        fch.close()
    else:
        print("No valid questions processed.")


def CheckABC(answer_str, gold_answer, candidates):
    # check abc 
    s = answer_str
    if s.startswith('Based'):
        s = s[1:]
    index_a = s.find('A.')
    index_b = s.find('B.')
    index_c = s.find('C.')
    index_d = s.find('D.')
    index_e = s.find('E.')

    if len(candidates) > 0 and 0 <= index_a and (index_b ==-1 or index_b > index_a) and (index_c == -1 or index_a < index_c) and (index_d == -1 or index_a < index_d):
        s = 'A. '  + candidates[0].lower()
    elif len(candidates) > 1 and 0 <= index_b and (index_a ==-1 or index_a > index_b) and (index_c == -1 or index_b < index_c) and (index_d == -1 or index_b < index_d):
        s = 'B. '+ candidates[1].lower()
    elif len(candidates) > 2 and 0 <= index_c and (index_a ==-1 or index_a > index_c) and (index_b == -1 or index_b > index_c) and (index_d == -1 or index_c < index_d):
        s = 'C. ' + candidates[2].lower()
    elif len(candidates) > 3 and 0 <= index_d and (index_a ==-1 or index_a > index_d) and (index_b == -1 or index_b > index_d) and (index_c == -1 or index_d < index_c):
        s = 'D. ' + candidates[3].lower()
    else:
        # print(i, a)
        s = s.lower()
    
    return int(gold_answer.lower() in s)
        

if __name__ == "__main__":

    check(file_path = 'outputfile')