import json
import re
import os
# import matplotlib.pyplot as plt
# import numpy as np
# import pandas as pd
import random
random.seed(0)


def parse_results(mcq_res):
    mcq_patterns = [
       r"Answer: ([A-D])"
    ]
    mcq_ans = None
    for p in mcq_patterns:
        m = re.search(p, mcq_res, re.DOTALL|re.IGNORECASE)
        if m:
            mcq_ans = m.group(1)
    
    return mcq_ans
low_score_samples_ids = json.load(open('data/llama3-8B-it_low_score_samples_ids.json','r'))
high_score_samples_ids = json.load(open('data/llama3-8B-it_high_score_samples_ids.json','r'))
primkg_sample_type_ids = json.load(open('primekg/primekg_multifaceteval_entry_type.json','r'))
model_names = [
                'llama3-8B-it',
                'llama3-8B-indirect-multi-10-samp-20-new-ref-single-1e-05lr-1ep',
               ]
mmlu_medicine_ids = json.load(open('data/mmlu_medicine_ids.json','r'))
# mmlu_medicine_question_relation_mapping = json.load(open('data/mmlu_medicine_question_relation_mapping.json','r'))
injected_knowledge = json.load(open('data/llama3-8B-it_low_score_samples.json','r'))
injected_knowledge_dict = {}
for entry in injected_knowledge:
    if entry[0] not in injected_knowledge_dict:
        injected_knowledge_dict[entry[0]] = {}
    rel = entry[1]
    if rel not in injected_knowledge_dict[entry[0]]:
        injected_knowledge_dict[entry[0]][rel] = []
    injected_knowledge_dict[entry[0]][rel].append(entry[2])

performance_detailed = []
performance = []
performance_dict = {}
performance_detailed_dict = {}
performance_group = []
available_models = []
for model_name in model_names:
    path = 'results/primekg_probe/{}_results.json'.format(model_name)
    scores = []
    avg_mcq_score, avg_other_mcq_score = 0,0
    avg_high_related_score, high_related_ttl = 0,0
    avg_low_related_score, low_related_ttl = 0,0
    other_ttl = 0
    hard_num = 0
    cnt = 0
    df = {'head':[], 'relation':[], 'tail':[],'acc':[],'type':[]}
    other_results = []
    mmlu_non_medicine_ttl = 0
    mmlu_non_medicine_acc = 0
    mmlu_medicine_ttl, mmlu_medicine_acc = 0,0
    mmlu_high_related_ttl, mmlu_high_related_acc = 0,0
    mmlu_low_related_ttl, mmlu_low_related_acc = 0,0
    medqa_high_related_ttl, medqa_high_related_acc = 0,0
    medqa_low_related_ttl, medqa_low_related_acc = 0,0
    arc_challenge_ttl, arc_challenge_acc = 0,0
    openbookqa_ttl, openbookqa_acc = 0,0
    commonsenseqa_ttl, commonsenseqa_acc = 0,0
    pubmedqa_ttl, pubmedqa_acc = 0,0
    # try:
    with open(path, 'r') as f:
        for i,line in enumerate(f):
            # if i not in sampled_ids:
            #     continue
            cnt += 1
            item = json.loads(line)
            relation = item[1]
            
            mcq_res_list = item[-1]
            mcq_pred_list = [parse_results(res) for res in mcq_res_list]
            mcq_ans_list = item[-2]
            
                
            mcq_score_list = [int(mcq_pred == mcq_ans) for mcq_pred, mcq_ans in zip(mcq_pred_list, mcq_ans_list)]
            mcq_score = sum(mcq_score_list) / len(mcq_score_list)


            if i in low_score_samples_ids:
                avg_mcq_score += mcq_score
                hard_num += 1
                # mcq_relation_res[relation] += mcq_score
                # ttl_relation_res[relation] += 1
            else:
                if i in high_score_samples_ids:
                    avg_other_mcq_score += mcq_score
                    other_ttl += 1
    avg_mcq_score /= hard_num
    
    avg_other_mcq_score /= other_ttl
    print('PrimeKG: {:.3f} {:.3f}'.format(avg_mcq_score, avg_other_mcq_score))
    
    for dataset in ['medqa_test','mmlu_test','arc_challenge_test','openbookqa_test','commonsenseqa_test']:
        acc,ttl = 0,0
        # if dataset == 'pubmedqa_test':
        #     print('y')
        if os.path.exists('results/{}/{}_results.json'.format(dataset, model_name)) == False:
            raise Exception('No results for {} {}'.format(dataset, model_name))
        with open('results/{}/{}_results.json'.format(dataset, model_name), 'r') as f:
            for i,line in enumerate(f):
                item = json.loads(line)
                response = item['response']
                answer = item['answer_idx']
                if dataset == 'pubmedqa_test':
                    pred = re.search(r"Answer: (yes|no|maybe)", response, re.IGNORECASE)
                else:
                    pred = re.search(r"Answer: ([A-E])", response)
                if pred == None:
                    if dataset == 'pubmedqa_test':
                        pred = random.choice(['yes','no','maybe'])
                    else:
                        pred = random.choice(['A','B','C','D'])
                else:
                    pred = pred.group(1)
                if pred.lower() == answer.lower():
                    score = 1
                else:
                    score = 0
                acc += score
                ttl += 1
                if dataset == 'mmlu_test':
                    if i not in mmlu_medicine_ids:
                        mmlu_non_medicine_acc += score
                        mmlu_non_medicine_ttl += 1
                    else:
                        mmlu_medicine_acc += score
                        mmlu_medicine_ttl += 1
                elif dataset == 'arc_challenge_test':
                    # if i in mmlu_medicine_ids:
                    arc_challenge_acc += score
                    arc_challenge_ttl += 1
                elif dataset == 'commonsenseqa_test':
                    commonsenseqa_acc += score
                    commonsenseqa_ttl += 1
        other_results.append(acc/ttl)
    print('MedQA: {:.3f} MMLU: {:.3f} non-medicine {:.3f} medicine {:.3f} ARC {:.3f} CommonsenseQA {:.3f}'.format(other_results[0], other_results[1], mmlu_non_medicine_acc/mmlu_non_medicine_ttl, mmlu_medicine_acc/mmlu_medicine_ttl, arc_challenge_acc/arc_challenge_ttl, commonsenseqa_acc/commonsenseqa_ttl))
    