import numpy as np
import re
import os
import json
from sklearn.metrics import precision_score, recall_score, f1_score

class QAEvaluator:
    def __init__(self, predictions_path):
        with open(predictions_path, 'r', encoding='utf-8') as f:
            self.predictions = json.load(f)
        
    def _normalize_tags(self, tags):

        return [str(tag).strip().lower() for tag in tags] if tags else []

    def evaluate_basic(self):
        results = []

        global_tp = 0
        global_fp = 0
        global_fn = 0
        
        for pred in [p for p in self.predictions if p.get('type') == 'Basic verification']:

            if 'ground_truth' not in pred or 'predicted_tags' not in pred:
                continue
                
            gt_tags = self._normalize_tags(pred.get('ground_truth', []))
            pred_tags = self._normalize_tags(pred.get('predicted_tags', []))
            

            tp = sum(1 for tag in gt_tags if tag in pred_tags)
            fp = sum(1 for tag in pred_tags if tag not in gt_tags)
            fn = sum(1 for tag in gt_tags if tag not in pred_tags)
            

            global_tp += tp
            global_fp += fp
            global_fn += fn
            

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
            

            config_correct = int(
                pred.get('predicted_config', '').lower() in 
                pred.get('question', '').lower()
            ) if 'question' in pred else 0
            
            top3_cover = int(all(tag in pred_tags[:3] for tag in gt_tags))
            
            results.append({
                'config_correct': config_correct,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'top3_cover': top3_cover
            })
        

        micro_precision = global_tp / (global_tp + global_fp) if (global_tp + global_fp) > 0 else 0.0
        micro_recall = global_tp / (global_tp + global_fn) if (global_tp + global_fn) > 0 else 0.0
        micro_f1 = 2 * micro_precision * micro_recall / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0
        
        return {
            'config_accuracy': np.mean([r['config_correct'] for r in results]) if results else 0.0,
            'tag_precision': np.mean([r['precision'] for r in results]) if results else 0.0,
            'tag_recall': np.mean([r['recall'] for r in results]) if results else 0.0,
            'tag_f1': micro_f1,  
            'top3_accuracy': np.mean([r['top3_cover'] for r in results]) if results else 0.0
        }
        
    def evaluate_multi_choice(self):
        results = []

        global_tp = 0
        global_fp = 0
        global_fn = 0
        
        for pred in [p for p in self.predictions if p.get('type') == 'Multi-Label']:

            gt_options = set(str(opt).lower() for opt in pred.get('ground_truth', []))
            pred_options = set(str(opt).lower() for opt in pred.get('predicted_config', []))
            

            exact_match = int(pred_options == gt_options)
            

            tp = len(pred_options & gt_options)
            fp = len(pred_options - gt_options)
            fn = len(gt_options - pred_options)
            

            global_tp += tp
            global_fp += fp
            global_fn += fn

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
            
            results.append({
                'exact_match': exact_match,
                'precision': precision,
                'recall': recall,
                'f1': f1
            })
        
        micro_precision = global_tp / (global_tp + global_fp) if (global_tp + global_fp) > 0 else 0.0
        micro_recall = global_tp / (global_tp + global_fn) if (global_tp + global_fn) > 0 else 0.0
        micro_f1 = 2 * micro_precision * micro_recall / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0
        
        return {
            'exact_match_rate': np.mean([r['exact_match'] for r in results]),
            'precision': np.mean([r['precision'] for r in results]),
            'recall': np.mean([r['recall'] for r in results]),
            'f1_score': micro_f1  
        }

    def _check_supporting_tags(self, prediction):
        
        question_tags = set(re.findall(r'\*(.*?)\*', prediction['question']))
        predicted_tags = set(tag.lower() for tag in prediction.get('supporting_tags', []))
        return int(predicted_tags.issubset(question_tags))
    
    def generate_basic_report(self):
        basic = [p for p in self.predictions if p.get('type') == 'Basic verification']
        report = {
            'summary': {
                'total_questions': len(self.predictions),
                'basic_count': len(basic)
            },
            'basic_metrics': self.evaluate_basic() if basic else None
        }
        return report
    
    def generate_choice_report(self):
        multi = [p for p in self.predictions if p.get('type') == 'Multi-Label']
        multi_metrics = self.evaluate_multi_choice() if multi else None
        # 4. 生成报告
        report = {
            'summary': {
                'total_questions': len(self.predictions),
                'multi_count': len(multi),
                'multi_f1': multi_metrics['f1_score'] if multi_metrics else None,
            },
            'detailed_metrics': {
                'multi': multi_metrics,
            }
        }
        
        return report