from sklearn.metrics import classification_report, confusion_matrix
import json
import argparse
import os
import re
import string
from tqdm import tqdm
from collections import Counter

def print_evaluation_results(predictions, gt_labels, num_of_classes=3):
    if num_of_classes == 3:
        target_names = ['refutes', 'supports', 'not enough info']
        label_map = {'refutes': 0, 'supports': 1, 'not enough info': 2}
        labels = [label_map[e] for e in gt_labels]
        predictions = [label_map[e] for e in predictions]
        macro_report = classification_report(labels, predictions, target_names=target_names, digits=4)
        confusion_report = confusion_matrix(labels, predictions)
        return macro_report, confusion_report
    elif num_of_classes == 2:
        target_names = ['refutes', 'supports']
        label_map = {'refutes': 0, 'supports': 1}
        labels = [label_map[e] for e in gt_labels]
        predictions = [label_map[e] for e in predictions]
        macro_report = classification_report(labels, predictions, target_names=target_names, digits=4)
        confusion_report = confusion_matrix(labels, predictions)
        return macro_report, confusion_report

def evaluate_hover_by_hops(args, result_file):
    with open(result_file, 'r') as f:
        results = json.load(f)

    with open(os.path.join(args.FV_data_path, args.dataset_name, 'claims', 'dev.json'), 'r') as f:
        dataset = json.load(f)
    
    id_num_hops_map = {sample['id']:sample['num_hops'] for sample in dataset}

    predictions = {'2_hop': [], '3_hop': [], '4_hop': []}
    gt_labels = {'2_hop': [], '3_hop': [], '4_hop': []}
    for sample in results:
        key = f"{id_num_hops_map[sample['id']]}_hop"
        gt_labels[key].append(sample['gold'].strip())
        predictions[key].append(sample['prediction'].strip())
    
    for key in predictions:
        print(key)
        print_evaluation_results(predictions[key], gt_labels[key], num_of_classes=2)
        print()



def evaluate_feverous(result_file, paradigm ='fc'):
    with open(result_file, 'r') as f:
        results = json.load(f)

    
    predictions = []
    gt_labels = []
    if paradigm == 'qa':
        for sample in results:
            gt_labels.append(sample['label'].strip())
            gpt_response = sample['prediction'].strip().lower()
            if gpt_response.find('true') != -1:
                predictions.append('supports')
            else:
                predictions.append('refutes')
    else:
        pattern = r"<response>(.*)</response>"
        for sample in results:
            gt_labels.append(sample['label'].strip())
            gpt_response = sample['prediction'].strip().lower()
            gpt_response_regrex = re.findall(pattern, gpt_response, flags=re.DOTALL)[0].strip()
            if gpt_response_regrex.find('support') != -1:
                predictions.append('supports')
            else:
                predictions.append('refutes')

    macro_report, confusion_report = print_evaluation_results(predictions, gt_labels, num_of_classes=2)
    return macro_report, confusion_report

def evaluate_hover(result_file, paradigm ='fc'):
    with open(result_file, 'r') as f:
        results = json.load(f)

    id_num_hops_map = {sample['id']:sample['num_hops'] for sample in results}

    predictions = {'2_hop': [], '3_hop': [], '4_hop': []}
    gt_labels = {'2_hop': [], '3_hop': [], '4_hop': []}
    if paradigm == 'qa':
        for sample in results:
            key = f"{id_num_hops_map[sample['id']]}_hop"
            gt_labels[key].append(sample['label'].strip())
            gpt_response = sample['prediction'].strip().lower()
            if gpt_response.find('true') != -1:
                predictions[key].append('supports')
            else:
                predictions[key].append('refutes')
    else:
        pattern = r"<response>(.*)</response>"

        for sample in tqdm(results):
            key = f"{id_num_hops_map[sample['id']]}_hop"
            gt_labels[key].append(sample['label'].strip())
            gpt_response = sample['prediction'].strip().lower()
            gpt_response_regrex = re.findall(pattern, gpt_response, flags=re.DOTALL)[0].strip() 
            if gpt_response_regrex.find('support') != -1:
                predictions[key].append('supports')
            else:
                predictions[key].append('refutes')


    macro_reports = {}
    for key in predictions:
        # print(key)
        macro_report, confusion_report = print_evaluation_results(predictions[key], gt_labels[key], num_of_classes=2)
        macro_reports[key] = macro_report

    return macro_reports

def evaluate_strategyqa(result_file, paradigm ='fc'):
    with open(result_file, 'r') as f:
        results = json.load(f)

    
    predictions = []
    gt_labels = []
    if paradigm != 'qa':
        for sample in results:
            labels = str(sample['label']).strip().lower()
            if labels == 'false':
                gt_labels.append('refutes')
            else:
                gt_labels.append('supports')

            gpt_response = sample['prediction'].strip().lower()
            if gpt_response.find('false') != -1:
                predictions.append('refutes')
            else:
                predictions.append('supports')
            
    else:
        pattern = r"<answer>(.*)</answer>"
        for sample in results:
            print(sample['id'])
            labels = str(sample['label']).strip().lower()
            if labels == 'false':
                gt_labels.append('refutes')
            else:
                gt_labels.append('supports')


            gpt_response = sample['prediction'].strip().lower()
            try:
                gpt_response_regrex = re.findall(pattern, gpt_response, flags=re.DOTALL)[0].strip()
            except:
                gpt_response_regrex = gpt_response
            if gpt_response_regrex.find('true') != -1 or gpt_response_regrex.find('yes') != -1:
                predictions.append('supports')
            else:
                predictions.append('refutes')

    macro_report, confusion_report = print_evaluation_results(predictions, gt_labels, num_of_classes=2)
    return macro_report, confusion_report

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def evaluate_hotpotqa(result_file, paradigm ='fc'):
    with open(result_file, 'r') as f:
        results = json.load(f)


    if paradigm == 'qa':
        metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
        'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
        'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
        for sample in results:
            pattern = r"<answer>(.*)</answer>"
            ground_truth = sample['label'].strip()
            gpt_response = sample['prediction'].strip().lower()
            print(sample['id'])
            gpt_response_regrex = re.findall(pattern, gpt_response, flags=re.DOTALL)[0].strip()
            print(gpt_response_regrex)
            

            em = exact_match_score(gpt_response_regrex, ground_truth)

            normalized_prediction = normalize_answer(gpt_response_regrex)
            normalized_ground_truth = normalize_answer(ground_truth)
            ZERO_METRIC = (0, 0, 0)

            if normalized_prediction in ['yes', 'no', 'noanswer', 'unknown'] and normalized_prediction != normalized_ground_truth:
                precision = 0
                recall = 0
                f1 = 0
            elif normalized_ground_truth in ['yes', 'no', 'noanswer','unknown'] and normalized_prediction != normalized_ground_truth:
                # return ZERO_METRIC
                precision = 0
                recall = 0
                f1 = 0
            else: 
                prediction_tokens = normalized_prediction.split()
                ground_truth_tokens = normalized_ground_truth.split()
                common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
                num_same = sum(common.values())
                if num_same == 0:
                    # return ZERO_METRIC
                    precision = 0
                    recall = 0
                    f1 = 0
                else:
                    precision = 1.0 * num_same / len(prediction_tokens)
                    recall = 1.0 * num_same / len(ground_truth_tokens)
                    f1 = (2 * precision * recall) / (precision + recall)

            metrics['em'] += float(em)
            metrics['f1'] += f1
            metrics['prec'] += precision
            metrics['recall'] += recall


        N = len(results)
        for k in metrics.keys():
            metrics[k] /= N
        print(metrics)
        return  metrics
               
    else:
        gt_labels = []
        predictions = []
        pattern = r"<response>(.*)</response>"
        for sample in results:
            gt_labels.append(sample['label'].strip())
            gpt_response = sample['prediction'].strip().lower()
            gpt_response_regrex = re.findall(pattern, gpt_response, flags=re.DOTALL)[0].strip()
            if gpt_response_regrex.find('support') != -1:
                predictions.append('supports')
            else:
                predictions.append('refutes')

        macro_report, confusion_report = print_evaluation_results(predictions, gt_labels, num_of_classes=2)
        return macro_report, confusion_report

def parse_args():
    parser = argparse.ArgumentParser()
    # dataset args
    parser.add_argument('--dataset_name', type=str)
    parser.add_argument('--FV_data_path', type=str, default = '')
    parser.add_argument('--result_file', type=str)
    parser.add_argument('--paradigm', type=str)
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    if args.dataset_name == 'FEVEROUS' or args.dataset_name == 'SCIFACT':
        macro_report, confusion_report = evaluate_feverous(args.result_file, args.paradigm)
        print(len(macro_report.split('\n')))
        print(macro_report)
        print(confusion_report)
    elif args.dataset_name == 'HOVER':
        macro_reports = evaluate_hover(args.result_file, args.paradigm)
        for key, macro_report in macro_reports.items():
            print(key)
            print(macro_report)
    elif args.dataset_name == 'HOTPOTQA' or args.dataset_name == 'QANGAROO' or args.dataset_name == 'MuSiQue' or args.dataset_name == 'QANGAROO' or args.dataset_name == '2WIKIMULTIHOPQA':
        macro_reports = evaluate_hotpotqa(args.result_file, args.paradigm)
        print(macro_reports)
        for key, macro_report in macro_reports.items():
            print(key)
            print(macro_report)
    elif args.dataset_name == 'STRATEGYQA':
        macro_report, confusion_report = evaluate_strategyqa(args.result_file, args.paradigm)
        print(len(macro_report.split('\n')))
        print(macro_report)
        print(confusion_report)
    else:
        raise NotImplementedError
    
    
