# -*- coding: utf-8 -*-  

import pandas as pd
import re
import json
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix

def extract_last_0_to_3(text):
    """Extract the last number between 0 and 3 from the text."""
    pattern = r'[0-3]'
    matches = re.finditer(pattern, text)
    return next((match.group() for match in reversed(list(matches))), None)

def load_data(filepath):
    """Load data from a JSONL file."""
    with open(filepath, 'r', encoding="utf8") as f:
        return [json.loads(line) for line in f]

def evaluate_predictions(data, mislist):
    """Evaluate the predictions and return metrics."""
    truescore_cnt, truemis_cnt = 0, 0
    multitruelabels, multipredicts = [], []
    truelabels, prelabels = [], []
    allcnt = 0

    for entry in data:
        allcnt += 1
        label = entry['label']
        pre_ans = entry['predict']

        goldscore = int(extract_last_0_to_3(label))
        assert goldscore is not None

        goldmis = next((mis for mis in mislist if mis in label), 'NULL')

        pre_score = int(extract_last_0_to_3(pre_ans)) if extract_last_0_to_3(pre_ans) else -1
        assert pre_score is not None

        if pre_score == goldscore:
            truescore_cnt += 1

        pre_mis = next((mis for mis in mislist if mis in pre_ans), 'NULL')

        if pre_mis == goldmis:
            truemis_cnt += 1

        multitruelabels.append(goldmis)
        multipredicts.append(pre_mis)

        truelabels.append(0 if goldmis == 'NULL' else 1)
        prelabels.append(0 if pre_mis == 'NULL' else 1)

    return allcnt, truescore_cnt, truemis_cnt, multitruelabels, multipredicts, truelabels, prelabels

def print_metrics(allcnt, truescore_cnt, truemis_cnt, multitruelabels, multipredicts, truelabels, prelabels):
    """Print various evaluation metrics."""
    print(allcnt)
    print(f"Number of correct scores: {truescore_cnt}, Score accuracy: {truescore_cnt/allcnt:.4f}")
    print(f"Number of correct error attributions: {truemis_cnt}, Error attribution accuracy: {truemis_cnt/allcnt:.4f}")

    f1_per_class = f1_score(multitruelabels, multipredicts, average=None)
    print("F1 score per class:", f1_per_class)

    f1_weighted = f1_score(multitruelabels, multipredicts, average='weighted')
    print(f"Overall weighted average F1 score: {f1_weighted:.4f}")

    print("Gold labels:", set(multitruelabels))
    print("Predicted labels:", set(multipredicts))

    conf_matrix = confusion_matrix(multitruelabels, multipredicts)
    print("Confusion Matrix:\n", conf_matrix)

    precision = precision_score(truelabels, prelabels)
    recall = recall_score(truelabels, prelabels)
    f1 = f1_score(truelabels, prelabels)

    print(f'Precision for error attribution: {precision:.4f}')
    print(f'Recall for error attribution: {recall:.4f}')
    print(f'F1 Score for error attribution: {f1:.4f}')

def main():
    filepath = r"xxx\inference_results.jsonl"
    mislist = ['Response Quality - Duplicate', 'Response Quality - Refusal to Answer', 'Response Quality - Truncation', 
               'Response Quality - Missing Answers', 'Response Quality - Noisy', 'Response Quality - Typos', 
               'Multi-turn Dialogue - Reference Error', 'Multi-turn Dialogue - Long-term Memory Loss', 
               'Reasoning Capability - Process Error', 'Reasoning Capability - Result Error', 'Safety - Safety', 
               'Creative Ability - Inappropriate Content', 'NULL', 'Knowledge Ability - Incorrect Answers', 
               'Knowledge Ability - Hallucination', 'Comprehension - Irrelevance', 'Instruction Following - Content Inconsistency', 
               'Instruction Following - Length Inconsistency', 'Instruction Following - Format Inconsistency', 
               'Other Errors']

    data = load_data(filepath)
    allcnt, truescore_cnt, truemis_cnt, multitruelabels, multipredicts, truelabels, prelabels = evaluate_predictions(data, mislist)
    print_metrics(allcnt, truescore_cnt, truemis_cnt, multitruelabels, multipredicts, truelabels, prelabels)

if __name__ == "__main__":
    main()