import spacy
from typing import Dict, List, Union
import re
import string
from collections import Counter
import numpy as np
from rouge_score import rouge_scorer
import time
from test_chatgpt import ChatGPTAPI


class Metric:
    def __init__(self, chatbot=ChatGPTAPI(model_name="gpt-3.5-turbo")) -> None:
        self.spacy_nlp = spacy.load('en_core_web_sm')
        use_model = 'spacy'
        self.nlp = eval(f'self.{use_model}_nlp')
        self.chatbot = chatbot
        self.acc_prompt = "In the following task, you are given a Question, a model Prediction for the Question, and a Ground-truth Answer to the Question. You should decide whether the model Prediction implies the Ground-truth Answer. Question:{0}\n\n Prediction:{1}\n\n Ground-truth Answer:{2}\n\n Does the Prediction imply the Ground-truth Answer? Output Yes or No. " 
        self.active_entity_logger = Metric_Logger('active_entity_score' ,4)
        self.original_entity_logger = Metric_Logger('original_entity_score', 4)

        self.active_f1_logger = Metric_Logger('active_f1_score', 3)
        self.original_f1_logger = Metric_Logger('original_f1_score', 3)

        self.active_rouge1_logger = Metric_Logger('active_rouge1_score', 4)
        self.original_rouge1_logger = Metric_Logger('original_rouge1_score', 4)

        self.active_rouge2_logger = Metric_Logger('active_rouge2_score', 4)
        self.original_rouge2_logger = Metric_Logger('original_rouge2_score', 4)

        self.active_rougeL_logger = Metric_Logger('active_rougeL_score', 4)
        self.original_rougeL_logger = Metric_Logger('original_rougeL_score', 4)

        self.active_em_logger = Metric_Logger('active_exact_match_score', 2)
        self.original_em_logger = Metric_Logger('original_exact_match_score', 2)

        self.active_prompt_acc_logger = count()
        self.original_prompt_acc_logger = count()


    def get_ner(self, text):
        doc = self.nlp(text)
        return list(doc.ents) 
    
    def entity_f1_score(
        self,
        prediction: str,
        ground_truth: Union[str, List[str]],
        ground_truth_id: str = None,
        debug: bool = False,
    ):
        if type(ground_truth) is str:
            ground_truth = [ground_truth]
        p = r = f1 = num_ent = 0
        for gold in ground_truth:
            pred_ents: List[str] = [self.normalize_answer(ent.text) for ent in self.get_ner(prediction)]
            gold_ents: List[str] = [self.normalize_answer(ent.text) for ent in self.get_ner(gold)]
            common_ents = Counter(pred_ents) & Counter(gold_ents)
            num_common_ents: int = sum(common_ents.values())
            if debug:
                print('PP', prediction)
                print('GG', gold)
                print('P', pred_ents)
                print('G', gold_ents)
                print('C', common_ents)
            _p = (num_common_ents / len(pred_ents)) if len(pred_ents) else 1
            _r = (num_common_ents / len(gold_ents)) if len(gold_ents) else 1
            assert _p <= 1 and _r <= 1
            _f1 = (2 * _p * _r) / ((_p + _r) or 1)
            p, r, f1 = max(p, _p), max(r, _r), max(f1, _f1)
            num_ent += len(gold_ents)
        num_ent /= len(ground_truth)
        return {'ent_f1': f1, 'ent_precision': p, 'ent_recall': r, 'num_ent': num_ent}
    
    def exact_match_score(
        self,
        prediction: str,
        ground_truth: str,
        ground_truth_id: str = None
    ):
        ground_truths = {ground_truth}
        if ground_truth_id:
            ground_truths.update(self.get_all_alias(ground_truth_id))
        correct = np.max([int(self.normalize_answer(prediction) == self.normalize_answer(gt)) for gt in ground_truths])
        return {'correct': correct, 'incorrect': 1 - correct}
    
    def f1_score(
        self,
        prediction: str,
        ground_truth: str,
        ground_truth_id: str = None
    ):
        ground_truths = {ground_truth}

        final_metric = {'f1': 0, 'precision': 0, 'recall': 0}
        for ground_truth in ground_truths:
            normalized_prediction = self.normalize_answer(prediction)
            normalized_ground_truth = self.normalize_answer(ground_truth)

            prediction_tokens = normalized_prediction.split()
            ground_truth_tokens = normalized_ground_truth.split()
            common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
            num_same = sum(common.values())

            precision = 1.0 * num_same / len(prediction_tokens)
            recall = 1.0 * num_same / len(ground_truth_tokens)
            if num_same == 0:
                f1 = 0
            else:
                f1 = (2 * precision * recall) / (precision + recall)
            final_metric = {'f1': f1, 'precision': precision, 'recall': recall}
        return final_metric
    
    def normalize_answer(self, s):
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)
        def white_space_fix(text):
            return ' '.join(text.split())
        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)
        def lower(text):
            return text.lower()
        return white_space_fix(remove_articles(remove_punc(lower(s))))
    
    def rouge_based_metric(self, prediction, ground_truth):
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        scores = scorer.score(ground_truth, prediction)
        res = {}
        for name, score in scores.items():
            precision, recall, fmeasure = score
            f1_score = 2 * (precision * recall) / (precision + recall + 1e-10)
            res[name] = {'rouge_f1': f1_score, 'rouge_precision': precision, 'rouge_recall': recall, 'fmeasure': fmeasure}
        
        return res
    
    def acc_evaluate(self, gt, original_ans, active_ans, question):
        original_result = self.chatbot.ask(self.acc_prompt.format(question, original_ans, gt)).lower()
        active_result = self.chatbot.ask(self.acc_prompt.format(question, active_ans, gt)).lower()
        if "yes" in active_result:
            active_cnt = 1
        else:
            active_cnt = 0
        
        if "yes" in original_result:
            original_cnt = 1
        else:
            original_cnt = 0
        return active_cnt, original_cnt
    def calculate_score_for_one_method(self, ans, gt, question):
        original_entity_score = self.entity_f1_score(ans, gt)
        original_f1_score = self.f1_score(ans, gt)
        original_rouges_score = self.rouge_based_metric(ans, gt)
        original_em_score = self.exact_match_score(ans, gt)
        
        active_cnt, original_cnt = self.acc_evaluate(gt, ans, ans, question)
        self.original_em_logger.update(original_em_score)
        self.original_prompt_acc_logger.update(original_cnt)
        self.original_entity_logger.update(original_entity_score)
        self.original_f1_logger.update(original_f1_score)

        self.original_rouge1_logger.update(original_rouges_score['rouge1'])
        self.original_rouge2_logger.update(original_rouges_score['rouge2'])
        self.original_rougeL_logger.update(original_rouges_score['rougeL'])

        return self.get_original_summary()
    
    def calculate_all_scores(self, active_ans, orginal_ans, gt, question):
        active_entity_score = self.entity_f1_score(active_ans, gt)
        original_entity_score = self.entity_f1_score(orginal_ans, gt)

        active_f1_score = self.f1_score(active_ans, gt)
        original_f1_score = self.f1_score(orginal_ans, gt)

        active_rouge_score = self.rouge_based_metric(active_ans, gt)
        original_rouges_score = self.rouge_based_metric(orginal_ans, gt)

        active_em_score  = self.exact_match_score(active_ans, gt)
        original_em_score = self.exact_match_score(orginal_ans, gt)
        
        active_cnt, original_cnt = self.acc_evaluate(gt, orginal_ans, active_ans, question)

        self.active_em_logger.update(active_em_score)
        self.original_em_logger.update(original_em_score)

        self.active_prompt_acc_logger.update(active_cnt)
        self.original_prompt_acc_logger.update(original_cnt)

        self.active_entity_logger.update(active_entity_score)
        self.original_entity_logger.update(original_entity_score)

        self.active_f1_logger.update(active_f1_score)
        self.original_f1_logger.update(original_f1_score)

        self.active_rouge1_logger.update(active_rouge_score['rouge1'])
        self.original_rouge1_logger.update(original_rouges_score['rouge1'])

        self.active_rouge2_logger.update(active_rouge_score['rouge2'])
        self.original_rouge2_logger.update(original_rouges_score['rouge2'])

        self.active_rougeL_logger.update(active_rouge_score['rougeL'])
        self.original_rougeL_logger.update(original_rouges_score['rougeL'])

        return self.get_all_summary()

    def get_all_summary(self):
        return [self.active_entity_logger.summary(), self.original_entity_logger.summary(), \
                self.active_f1_logger.summary(), self.original_f1_logger.summary(), \
                self.active_rouge1_logger.summary(), self.original_rouge1_logger.summary(), \
                self.active_rouge2_logger.summary(), self.original_rouge2_logger.summary(), \
                self.active_rougeL_logger.summary(), self.original_rougeL_logger.summary(), \
                self.active_em_logger.summary(), self.original_em_logger.summary(), \
                self.active_prompt_acc_logger.summary(), self.original_prompt_acc_logger.summary() ]
    
    def get_em_summary(self):
        return [
                self.active_em_logger.summary(), self.original_em_logger.summary()]
    def get_original_summary(self):
        return [self.original_entity_logger.summary(), \
                self.original_f1_logger.summary(), \
                self.original_rouge1_logger.summary(), \
                self.original_rouge2_logger.summary(), \
                self.original_rougeL_logger.summary(), \
                self.original_em_logger.summary(), \
                self.original_prompt_acc_logger.summary() ]
        

class Metric_Logger():
    def __init__(self, desc, field_count):
        self.cur_time = time.time()
        self.all_count = 0
        self.score: Dict[str, float] = {}
        self.avg: Dict[str, float] = {}
        self.description = desc

    def update(self, score: dict):
        self.all_count += 1
        for field, field_score in score.items():
            if field not in self.score:  # Use 'not in' to check if the key does not exist
                self.score[field] = field_score
                self.avg[field] = 1.0 * self.score[field] / self.all_count
            else:
                self.score[field] += field_score
                self.avg[field] = 1.0 * self.score[field] / self.all_count

        # self.avg = 1.0 * self.score / self.all_count

    def summary(self):
        return dict(
            description = self.description,
            total_number = self.all_count,
            avg_score = self.avg,
            time_consuming = (time.time() - self.cur_time) / 60
        )
    
class count():

    def __init__(self):
        self.cur_time = time.time()
        self.all_count = 0
        self.score = 0
        self.avg = 0

    def update(self, score):
        self.score += score
        self.all_count += 1
        if (self.score == 0):
            self.avg = 0
        else:
            self.avg = 1.0 * self.score / self.all_count

    def summary(self):

        return dict(
            total_number = self.all_count,
            avg_score = self.avg,
            time_consuming = (time.time() - self.cur_time) / 60
        )