import json
import os
from typing import List, Dict, Any, Callable
import pandas as pd
import wandb
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Import evaluation utility functions
from eval_utils import evaluate_numeric, evaluate_smiles, evaluate_string


class Evaluator:
    """Evaluator class, responsible only for evaluation-related functions. Recall is only calculated at the end, not recorded during intermediate steps."""
    
    def __init__(self, 
                 wandb_project: str,
                 model_type: str,
                 model_name: str,
                 numeric_tolerance: float = 0.5,
                 max_workers: int = 16):
        """Initialize the evaluator"""
        self.wandb_project = wandb_project
        self.model_type = model_type
        self.model_name = model_name
        self.numeric_tolerance = numeric_tolerance
        self.max_workers = max_workers  # Number of threads
        self.wandb_run = None
        
        # Used to store raw Recall data, only used for final calculation
        self.recall_raw_data = []
        
    def init_wandb(self, config: Dict[str, Any] = None) -> None:
        """Initialize WandB run"""
        run_name = f"{self.model_name}-{self.model_type}"
        self.wandb_run = wandb.init(
            project=self.wandb_project,
            name=run_name,
            config={
                **(config or {}),
                "model_name": self.model_name,
                "model_type": self.model_type,
                "numeric_tolerance": self.numeric_tolerance,
                "max_workers": self.max_workers,  # Record thread count configuration
            }
        )
    
    def evaluate_answer(self, 
                       gt_answer: str, 
                       pred_answer: str, 
                       qa_type: str, 
                       input_type: str) -> float:
        """Evaluate the score of a single answer"""
        gt_answer = str(gt_answer).strip()
        pred_answer = str(pred_answer).strip()
        
        if qa_type == 'mass_prediction':
            return evaluate_numeric(gt_answer, pred_answer, tolerance=1.0)
        elif qa_type == 'condition_prediction':
            return evaluate_numeric(gt_answer, pred_answer, tolerance=self.numeric_tolerance)
        elif qa_type == 'name_conversion' and input_type == 'iupac':
            return evaluate_smiles(gt_answer, pred_answer)
        else:
            return evaluate_string(gt_answer, pred_answer)
    
    def _calculate_recall(self, row: pd.Series, retrieved_suffixes: List[str]) -> Dict[str, int]:
        """Calculate Recall@1 and Recall@5 metrics"""
        # Determine question association identifier (rxn_id takes priority, then mol_id)
        target_id = None
        if pd.notna(row['rxn_id']) and str(row['rxn_id']).strip():
            target_id = str(int(row['rxn_id'])).strip()
        elif pd.notna(row['mol_id']) and str(row['mol_id']).strip():
            target_id = str(int(row['mol_id'])).strip()
        
        if not target_id:
            return {"recall_1": None, "recall_5": None, "target_id": None}
            
        return {
            "target_id": target_id,
            "recall_1": 1 if target_id in retrieved_suffixes[:1] else 0,
            "recall_5": 1 if target_id in retrieved_suffixes[:5] else 0
        }
    
    def _process_single_row(self, row: pd.Series, prediction_function: Callable, include_recall: bool,** kwargs) -> Dict[str, Any]:
        """Internal function to process a single row of data, called by multiple threads"""
        try:
            # Call inference function to get prediction result
            pred_result = prediction_function(row, **kwargs)
            
            # Evaluate score
            score = self.evaluate_answer(
                gt_answer=row['answer'],
                pred_answer=pred_result["answer_short"],
                qa_type=row['qa_type'],
                input_type=row['input_type']
            )
            
            # Build result dictionary
            result = {
                'rxn_id': row['rxn_id'],
                'mol_id': row['mol_id'],
                'question': row['question'],
                'gt_answer': str(row['answer']).strip(),
                f'{self.model_type}_answer': pred_result["answer"],
                f'{self.model_type}_answer_short': str(pred_result["answer_short"]),
                'qa_type': row['qa_type'].strip(),
                'input_type': row['input_type'].strip(),
                'num_few_shot': pred_result.get('num_few_shot', 0),
                'few_shot_available': pred_result.get('few_shot_available', False),
                'score': score
            }
            
            # If needed, calculate and add Recall-related fields
            if include_recall:
                retrieved_suffixes = pred_result.get('retrieved_suffixes', '').split(',')
                recall_info = self._calculate_recall(row, retrieved_suffixes)
                
                # Save raw Recall data to list for final calculation
                if recall_info['recall_1'] is not None:
                    self.recall_raw_data.append({
                        'qa_type': row['qa_type'].strip(),
                        'input_type': row['input_type'].strip(),
                        'recall_1': recall_info['recall_1'],
                        'recall_5': recall_info['recall_5']
                    })
                
                result.update({
                    'target_id': recall_info['target_id'],
                    'retrieved_suffixes': pred_result.get('retrieved_suffixes'),
                    'recall_1': recall_info['recall_1'],
                    'recall_5': recall_info['recall_5'],
                    'retrieve_k': kwargs.get('retrieve_k'),
                    'retrieve_scoring': kwargs.get('retrieve_scoring')
                })
            
            # Log to WandB - only log Score-related data, not intermediate Recall values
            if self.wandb_run:
                qa_type = row['qa_type'].strip()
                input_type = row['input_type'].strip()
                log_data = {
                    "score": score,
                    f"score_qa_{qa_type}": score,
                    f"score_input_{input_type}": score,
                    f"score_{qa_type}_{input_type}": score
                }
                self.wandb_run.log(log_data)
            
            return result
                
        except Exception as e:
            error_msg = str(e)[:200]
            print(f"Error processing sample rxn_id={row['rxn_id']}, mol_id={row['mol_id']}: {error_msg}")
            result = {
                'rxn_id': row['rxn_id'],
                'mol_id': row['mol_id'],
                'question': row['question'],
                'gt_answer': str(row['answer']).strip(),
                f'{self.model_type}_answer': None,
                f'{self.model_type}_answer_short': None,
                'qa_type': row['qa_type'].strip(),
                'input_type': row['input_type'].strip(),
                'num_few_shot': kwargs.get('num_few_shot', 0),
                'few_shot_available': False,
                'score': None,
                'error': error_msg
            }
            
            if include_recall:
                result.update({
                    'target_id': None,
                    'retrieved_suffixes': None,
                    'recall_1': None,
                    'recall_5': None
                })
            
            return result
    
    def process_qa(self, 
                  qa_df: pd.DataFrame,
                  prediction_function: Callable,
                  output_path: str = None,
                  include_recall: bool = False,** kwargs) -> pd.DataFrame:
        """
        Process QA evaluation, adapting to different inference methods through prediction_function, using multi-threading for acceleration
        
        Args:
            qa_df: QA dataset
            prediction_function: Inference function that returns a dictionary containing the answer
            output_path: Result output path
            include_recall: Whether to calculate Recall metrics
            **kwargs: Additional arguments passed to prediction_function
        """
        results = []
        
        # Process using thread pool
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all tasks
            futures = [
                executor.submit(
                    self._process_single_row, 
                    row, 
                    prediction_function, 
                    include_recall,** kwargs
                ) 
                for _, row in qa_df.iterrows()
            ]
            
            # Monitor progress and collect results
            for future in tqdm(
                as_completed(futures), 
                total=len(futures), 
                desc=f"Processing {self.model_type} QA"
            ):
                results.append(future.result())
        
        # Save results
        result_df = pd.DataFrame(results)
        if output_path:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            result_df.to_csv(output_path, sep='\t', index=False, encoding='utf-8')
            print(f"Evaluation results saved to: {output_path}")
        
        return result_df
    
    # The following methods remain unchanged
    def calculate_metrics(self, result_df: pd.DataFrame, include_recall: bool = False) -> Dict[str, Any]:
        """Calculate evaluation metrics, Recall is only calculated here"""
        model_prefix = self.model_type
        valid_df = result_df[result_df['score'].notna()].copy()
        total_samples = len(result_df)
        valid_samples = len(valid_df)
        
        metrics = {
            'total_samples': total_samples,
            'valid_samples': valid_samples,
            'invalid_samples': total_samples - valid_samples,
            'valid_rate': round(valid_samples / total_samples * 100, 2) if total_samples > 0 else 0.0,
            # Score metrics
            'overall_average_score': round(valid_df['score'].mean(), 4) if valid_samples > 0 else None,
            'overall_score_std': round(valid_df['score'].std(), 4) if valid_samples > 0 else None,
            # Classification metrics
            'qa_type_averages': {},
            'input_type_averages': {},
            'combined_type_averages': {}
        }
        
        # If Recall metrics need to be calculated (mainly for RAG)
        if include_recall and self.recall_raw_data:
            # Convert to DataFrame for grouped calculation
            recall_df = pd.DataFrame(self.recall_raw_data)
            recall_valid_samples = len(recall_df)
            
            metrics.update({
                'recall_valid_samples': recall_valid_samples,
                'overall_recall_1': round(recall_df['recall_1'].mean(), 4) if recall_valid_samples > 0 else None,
                'overall_recall_5': round(recall_df['recall_5'].mean(), 4) if recall_valid_samples > 0 else None,
            })
        else:
            recall_df = pd.DataFrame()
        
        # Statistics by qa_type
        if 'qa_type' in valid_df.columns:
            score_type_avg = valid_df.groupby('qa_type').agg({
                'score': ['mean', 'count']
            }).round(4)
            score_type_avg.columns = ['average_score', 'sample_count']
            
            for qa_type in score_type_avg.index:
                type_data = {
                    'average_score': score_type_avg.loc[qa_type, 'average_score'],
                    'sample_count': score_type_avg.loc[qa_type, 'sample_count']
                }
                
                # If needed, add Recall metrics (calculated based on raw data)
                if include_recall and not recall_df.empty and qa_type in recall_df['qa_type'].values:
                    recall_type = recall_df[recall_df['qa_type'] == qa_type]
                    type_data.update({
                        'recall_1': round(recall_type['recall_1'].mean(), 4),
                        'recall_5': round(recall_type['recall_5'].mean(), 4)
                    })
                
                metrics['qa_type_averages'][qa_type] = type_data
        
        # Statistics by input_type
        if 'input_type' in valid_df.columns:
            score_input_avg = valid_df.groupby('input_type').agg({
                'score': ['mean', 'count']
            }).round(4)
            score_input_avg.columns = ['average_score', 'sample_count']
            
            for input_type in score_input_avg.index:
                input_data = {
                    'average_score': score_input_avg.loc[input_type, 'average_score'],
                    'sample_count': score_input_avg.loc[input_type, 'sample_count']
                }
                
                # If needed, add Recall metrics (calculated based on raw data)
                if include_recall and not recall_df.empty and input_type in recall_df['input_type'].values:
                    recall_input = recall_df[recall_df['input_type'] == input_type]
                    input_data.update({
                        'recall_1': round(recall_input['recall_1'].mean(), 4),
                        'recall_5': round(recall_input['recall_5'].mean(), 4)
                    })
                
                metrics['input_type_averages'][input_type] = input_data
        
        # Statistics by (qa_type, input_type) combination
        if 'qa_type' in valid_df.columns and 'input_type' in valid_df.columns:
            score_combined_avg = valid_df.groupby(['qa_type', 'input_type']).agg({
                'score': ['mean', 'count']
            }).round(4)
            score_combined_avg.columns = ['average_score', 'sample_count']
            
            for (qa_type, input_type) in score_combined_avg.index:
                combined_data = {
                    'average_score': score_combined_avg.loc[(qa_type, input_type), 'average_score'],
                    'sample_count': score_combined_avg.loc[(qa_type, input_type), 'sample_count']
                }
                
                # If needed, add Recall metrics (calculated based on raw data)
                if include_recall and not recall_df.empty:
                    mask = (recall_df['qa_type'] == qa_type) & (recall_df['input_type'] == input_type)
                    if mask.any():
                        recall_combined = recall_df[mask]
                        combined_data.update({
                            'recall_1': round(recall_combined['recall_1'].mean(), 4),
                            'recall_5': round(recall_combined['recall_5'].mean(), 4)
                        })
                
                metrics['combined_type_averages'][(qa_type, input_type)] = combined_data
        
        return metrics
    
    def log_metrics_to_wandb(self, metrics: Dict[str, Any], include_recall: bool = False) -> None:
        """Log metrics to WandB"""
        if not self.wandb_run:
            return
        
        # Log overall metrics
        self.wandb_run.log({
            "overall_average_score": metrics["overall_average_score"],
            "valid_rate": metrics["valid_rate"],
            "valid_samples": metrics["valid_samples"],
            "total_samples": metrics["total_samples"]
        })
        
        # If needed, log Recall metrics (final results only)
        if include_recall:
            self.wandb_run.log({
                "overall_recall_1": metrics["overall_recall_1"],
                "overall_recall_5": metrics["overall_recall_5"]
            })
        
        # Log metrics by qa_type classification
        for qa_type, data in metrics["qa_type_averages"].items():
            log_data = {
                f"average_score_qa_{qa_type}": data["average_score"],
                f"sample_count_qa_{qa_type}": data["sample_count"]
            }
            
            if include_recall and "recall_1" in data:
                log_data.update({
                    f"recall_1_qa_{qa_type}": data["recall_1"],
                    f"recall_5_qa_{qa_type}": data["recall_5"]
                })
            
            self.wandb_run.log(log_data)
        
        # Log metrics by input_type classification
        for input_type, data in metrics["input_type_averages"].items():
            log_data = {
                f"average_score_input_{input_type}": data["average_score"],
                f"sample_count_input_{input_type}": data["sample_count"]
            }
            
            if include_recall and "recall_1" in data:
                log_data.update({
                    f"recall_1_input_{input_type}": data["recall_1"],
                    f"recall_5_input_{input_type}": data["recall_5"]
                })
            
            self.wandb_run.log(log_data)
        
        # Log metrics by combined type classification
        for (qa_type, input_type), data in metrics["combined_type_averages"].items():
            log_data = {
                f"average_score_{qa_type}_{input_type}": data["average_score"],
                f"sample_count_{qa_type}_{input_type}": data["sample_count"]
            }
            
            if include_recall and "recall_1" in data:
                log_data.update({
                    f"recall_1_{qa_type}_{input_type}": data["recall_1"],
                    f"recall_5_{qa_type}_{input_type}": data["recall_5"]
                })
            
            self.wandb_run.log(log_data)
    
    def print_metrics(self, metrics: Dict[str, Any], include_recall: bool = False) -> None:
        """Format and print evaluation metrics"""
        print("\n" + "="*80)
        print(f"                 {self.model_type.upper()} EVALUATION REPORT")
        print("="*80)
        
        # Basic statistics
        print(f"\n1. Basic Statistics:")
        print(f"   - Total samples: {metrics['total_samples']}")
        print(f"   - Valid scoring samples: {metrics['valid_samples']}")
        if include_recall:
            print(f"   - Recall-calculable samples: {metrics['recall_valid_samples']}")
        print(f"   - Valid scoring rate: {metrics['valid_rate']}%")
        
        # Overall metrics
        print(f"\n2. Overall Metrics:")
        if metrics['overall_average_score'] is not None:
            print(f"   - Average answer score: {metrics['overall_average_score']} (±{metrics['overall_score_std']})")
        if include_recall and metrics['overall_recall_1'] is not None:
            print(f"   - Recall@1: {metrics['overall_recall_1']:.4f}")
            print(f"   - Recall@5: {metrics['overall_recall_5']:.4f}")
        
        # Classification by question type
        print(f"\n3. Classification by Question Type (qa_type):")
        if metrics['qa_type_averages']:
            for qa_type, data in metrics['qa_type_averages'].items():
                print(f"   - {qa_type}:")
                print(f"     * Average score: {data['average_score']} (Sample count: {data['sample_count']})")
                if include_recall and "recall_1" in data:
                    print(f"     * Recall@1: {data['recall_1']:.4f}")
                    print(f"     * Recall@5: {data['recall_5']:.4f}")
        
        # Classification by input type
        print(f"\n4. Classification by Input Type (input_type):")
        if metrics['input_type_averages']:
            for input_type, data in metrics['input_type_averages'].items():
                print(f"   - {input_type}:")
                print(f"     * Average score: {data['average_score']} (Sample count: {data['sample_count']})")
                if include_recall and "recall_1" in data:
                    print(f"     * Recall@1: {data['recall_1']:.4f}")
                    print(f"     * Recall@5: {data['recall_5']:.4f}")
        
        # Classification by (question type, input type) combination
        print(f"\n5. Classification by (Question Type, Input Type) Combination:")
        if metrics['combined_type_averages']:
            for (qa_type, input_type), data in metrics['combined_type_averages'].items():
                print(f"   - {qa_type} + {input_type}:")
                print(f"     * Average score: {data['average_score']} (Sample count: {data['sample_count']})")
                if include_recall and "recall_1" in data:
                    print(f"     * Recall@1: {data['recall_1']:.4f}")
                    print(f"     * Recall@5: {data['recall_5']:.4f}")
        
        print("\n" + "="*80)
    
    def finish_wandb(self) -> None:
        """Finish WandB run"""
        if self.wandb_run:
            self.wandb_run.finish()
