#!/usr/bin/env python3
"""
class for translation evaluation
"""

import os
import json
import logging
import sys
import re
import evaluate
from pathlib import Path
from typing import Dict, Any, Optional, List, Tuple, Union
import subprocess
import time

# configure logging
logger = logging.getLogger(__name__)

# common evaluation metrics
# COMMON_QA_METRICS = ['exact_match', 'f1']


class GenerationEvaluator:
    """
    inference and evaluation pipeline class
    
    features:
    1. inference: check and supplement test json file content, return inference command
    2. evaluation: use Hugging Face evaluate to calculate exact-match and F1 metrics
    """
    
    def __init__(self, 
                 workspace_dir: str,
                 llamafactory_dir: str,
                 default_system_prompt: str = "You are a helpful assistant."):
        """
        initialize inference evaluation pipeline
        
        Args:
            workspace_dir: workspace directory
            llamafactory_dir: LLaMA-Factory directory
            default_system_prompt: default system prompt
        """
        self.workspace_dir = Path(workspace_dir)
        self.llamafactory_dir = Path(llamafactory_dir)
        self.default_system_prompt = default_system_prompt
        
    
    def load_template(self, template_file: str) -> Optional[Dict[str, str]]:
        """
        load template file (reference transfer2instruction_tuning.py)
        
        Args:
            template_file: template file path (.py file)
            
        Returns:
            template dictionary or None
        """
        if not os.path.exists(template_file):
            logger.error(f"模板文件不存在: {template_file}")
            return None
        
        # dynamically import template file
        sys.path.insert(0, os.path.dirname(os.path.abspath(template_file)))
        module_name = os.path.splitext(os.path.basename(template_file))[0]
        
        try:
            template_module = __import__(module_name)
            if hasattr(template_module, 'TEMPLATE'):
                logger.info(f"✅ template loaded successfully: {template_file}")
                return template_module.TEMPLATE
            else:
                logger.error(f"template file does not contain TEMPLATE variable")
                return None
        except Exception as e:
            logger.error(f"failed to load template file: {e}")
            return None
    
    def get_nested_value(self, obj: Dict[str, Any], path: str) -> str:
        """
        get nested object value, support dot separated path
        
        Args:
            obj: data object
            path: path string, e.g. "Question" or "Answer.Value"
            
        Returns:
            string value
        """
        try:
            keys = path.split('.')
            value = obj
            for key in keys:
                if isinstance(value, dict) and key in value:
                    value = value[key]
                elif isinstance(value, list) and key.isdigit():
                    value = value[int(key)]
                else:
                    return ""
            
            # if complex object, convert to string
            if isinstance(value, (list, dict)):
                return json.dumps(value, ensure_ascii=False)
            
            return str(value) if value is not None else ""
        except:
            return ""
    
    def replace_template_vars(self, template_str: str, sample: Dict[str, Any]) -> str:
        """
        replace variables in template string (reference transfer2instruction_tuning.py)
        
        Args:
            template_str: template string, e.g. "回答问题：{sample.Question}"
            sample: sample data
            
        Returns:
            replaced string
        """
        def replace_match(match):
            var_path = match.group(1)  # get sample.xxx.yyy part
            if var_path.startswith('sample.'):
                field_path = var_path[7:]  # remove 'sample.' prefix
                return self.get_nested_value(sample, field_path)
            return match.group(0)  # if not start with sample., keep original
        
        # find all {sample.xxx} patterns
        pattern = r'\{(sample\.[^}]+)\}'
        result = re.sub(pattern, replace_match, template_str)
        return result
    
    def apply_template(self, template: Dict[str, str], sample_data: Dict[str, Any]) -> Optional[Dict[str, str]]:
        """
        apply template to sample data
        
        Args:
            template: template dictionary
            sample_data: sample data
            
        Returns:
            transformed data or None
        """
        try:
            result = {}
            for key, template_str in template.items():
                if isinstance(template_str, str):
                    result[key] = self.replace_template_vars(template_str, sample_data)
                else:
                    result[key] = str(template_str)
            
            return result
        except Exception as e:
            logger.error(f"应用模板失败: {e}")
            return None
    
    def check_and_transform_with_template(self, test_json_file: str, template_file: Optional[str] = None) -> str:
        """
        use template to check and transform test JSON file content
        
        Args:
            test_json_file: test data JSON file path
            template_file: template file path (.py file), if None, use default supplement logic
            
        Returns:
            transformed JSON file path
        """
        logger.info(f"🔍 check and transform test data file: {test_json_file}")
        
        try:
            # read original data
            with open(test_json_file, 'r', encoding='utf-8') as f:
                test_data = json.load(f)
            
            # ensure data is list format
            if isinstance(test_data, dict):
                if 'Data' in test_data:
                    test_data = test_data['Data']
                else:
                    test_data = [test_data]
            
            logger.info(f"📊 original data: {len(test_data)} samples")
            
            # if template file is provided, use template to transform
            if template_file:
                template = self.load_template(template_file)
                if template:
                    return self._transform_with_template(test_data, template, test_json_file)
            
            # otherwise use default supplement logic
            return self._supplement_with_defaults(test_data, test_json_file)
            
        except Exception as e:
            logger.error(f"❌ data transformation failed: {e}")
            return test_json_file
    
    def _transform_with_template(self, test_data: List[Dict], template: Dict[str, str], original_file: str) -> str:
        """use template to transform data"""
        logger.info("🔄 use template to transform data...")
        
        converted_data = []
        for i, item in enumerate(test_data):
            try:
                result = self.apply_template(template, item)
                if result:
                    # ensure id field
                    if 'id' not in result:
                        result['id'] = f"sample_{i}"
                    converted_data.append(result)
                else:
                    logger.warning(f"skip sample {i}: template application failed")
            except Exception as e:
                logger.error(f"error processing sample {i}: {e}")
        
        # save transformed data
        original_path = Path(original_file)
        converted_file = original_path.parent / f"{original_path.stem}_template_converted{original_path.suffix}"
        
        with open(converted_file, 'w', encoding='utf-8') as f:
            json.dump(converted_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"✅ template transformation completed: {len(converted_data)} valid samples")
        logger.info(f"   - transformed file: {converted_file}")
        
        return str(converted_file)
    
    def _supplement_with_defaults(self, test_data: List[Dict], original_file: str) -> str:
        """use default logic to supplement data"""
        logger.info("🔄 use default logic to supplement data...")
        
        modifications_count = 0
        
        for i, item in enumerate(test_data):
            if not isinstance(item, dict):
                continue
            
            # check and supplement fields
            if 'system' not in item or not item['system']:
                item['system'] = self.default_system_prompt
                modifications_count += 1
            
            if 'instruction' not in item or not item['instruction']:
                item['instruction'] = "Please answer the following question."
                modifications_count += 1
            
            if 'input' not in item:
                item['input'] = ""
            
            if 'id' not in item:
                item['id'] = f"sample_{i}"
        
        # save supplemented data
        original_path = Path(original_file)
        supplemented_file = original_path.parent / f"{original_path.stem}_supplemented{original_path.suffix}"
        
        with open(supplemented_file, 'w', encoding='utf-8') as f:
            json.dump(test_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"✅ default supplement completed: {len(test_data)} samples, {modifications_count} modifications")
        logger.info(f"   - supplemented file: {supplemented_file}")
        
        return str(supplemented_file)
    
    def _resolve_test_file_path(self, test_input: str) -> str:
        """
        parse test file path: support direct file path or find through dataset_info.json
        
        Args:
            test_input: file path or dataset name
            
        Returns:
            actual file path
        """
        # if absolute path or relative path and file exists, return directly
        if os.path.exists(test_input):
            logger.info(f"✅ directly use file path: {test_input}")
            return test_input
        
        # try to find in dataset_info.json
        dataset_info_path = self.llamafactory_dir / "data" / "dataset_info.json"
        
        if not dataset_info_path.exists():
            logger.warning(f"⚠️ dataset_info.json not found: {dataset_info_path}")
            return test_input
        
        try:
            with open(dataset_info_path, 'r', encoding='utf-8') as f:
                dataset_info = json.load(f)
            
            # check if it is dataset name
            if test_input in dataset_info:
                dataset_config = dataset_info[test_input]
                
                # get file name
                if 'file_name' in dataset_config:
                    file_name = dataset_config['file_name']
                    # build full path
                    full_path = self.llamafactory_dir / "data" / file_name
                    
                    if full_path.exists():
                        logger.info(f"✅ find file from dataset_info: {test_input} -> {full_path}")
                        return str(full_path)
                    else:
                        logger.warning(f"⚠️ file not found in dataset_info: {full_path}")
                else:
                    logger.warning(f"⚠️ dataset {test_input} does not have file_name field")
            else:
                logger.info(f"📝 {test_input} not in dataset_info, try as file path")
        
        except Exception as e:
            logger.error(f"❌ failed to parse dataset_info.json: {e}")
        
        # if all above fail, try to find in data directory directly
        data_dir_path = self.llamafactory_dir / "data" / test_input
        if data_dir_path.exists():
            logger.info(f"✅ find file in data directory: {data_dir_path}")
            return str(data_dir_path)
        
        # finally return original input (maybe relative path)
        logger.warning(f"⚠️ failed to parse file path, use original input: {test_input}")
        return test_input
    

    def inference_and_evaluate(self, 
                               model_path: str, 
                               test_name: str, 
                               model_config: Dict[str, Any],
                               metrics: List[str] = ['bleu', 'sacrebleu', 'cross_entropy', 'perplexity'],
                               template_file: Optional[str] = None,
                               include_logits: bool = True) -> Dict[str, Any]:
        """
        integrate inference and evaluation: use vLLM to infer and calculate evaluation metrics directly
        
        Args:
            model_path: model path
            test_name: test dataset name
            model_config: model config dictionary
            metrics: metrics to calculate
            template_file: template file path (.py file)
            include_logits: whether to include logits for precise cross entropy calculation
            
        Returns:
            evaluation results dictionary
        """
        logger.info(f"🚀 start integrating inference and evaluation: {model_path}")
        
        try:
            # 1. execute vLLM inference
            predictions, references, logits_data = self._run_vllm_inference_with_logits(
                model_path, test_name, model_config, include_logits
            )
            
            logger.info(f"📊 inference completed, {len(predictions)} predictions")
            
            # 2. calculate evaluation metrics
            eval_results = self._compute_metrics_with_logits(
                predictions, references, logits_data, metrics
            )
            
            # 3. add metadata
            eval_results.update({
                'total_samples': len(predictions),
                'model_path': model_path,
                'test_dataset': test_name,
                'metrics_used': metrics,
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
                'include_logits': include_logits
            })
            
            # 4. print result summary
            logger.info(f"✅ evaluation completed:")
            for metric in metrics:
                if metric in eval_results:
                    value = eval_results[metric]
                    if isinstance(value, (int, float)):
                        logger.info(f"   - {metric}: {value:.4f}")
                    else:
                        logger.info(f"   - {metric}: {value}")
            
            return eval_results
            
        except Exception as e:
            logger.error(f"❌ inference and evaluation failed: {e}")
            raise
    
    def inference(self, 
                  model_path: str, 
                  test_name: str, 
                  train_name: str,
                  model_config: Dict[str, Any],
                  output_file: str,
                  template_file: Optional[str] = None) -> Tuple[str, str]:
        """
        inference function: check and supplement test data and generate inference command
        
        Args:
            model_path: model path
            test_json_file: test data JSON file path
            model_config: model config dictionary
            template_file: template file path (.py file), if provided, use template to transform data format
            
        Returns:
            Tuple[inference command, output file path]
        """
        logger.info(f"🚀 start inference preparation: {model_path}")
        
        try:
            # 1. use template to check and transform test data format
            # test_json_file to find in dataset_info.json
            test_json_file = self._resolve_test_file_path(test_name)
            # processed_test_file = self.check_and_transform_with_template(test_json_file, template_file)
            
            # 2. determine output file path
            # model_name = 'baseline' if baseline else model_config.get('name', 'model')
            # output_file = self.inference_dir / f"{test_name}_{model_name}_inference.jsonl"
            
            # 3. build inference arguments
            inference_args = self._build_inference_args(
                model_path, test_name, train_name, model_config, str(output_file)
            )
            
            # 4. build full command
            inference_cmd = f"python scripts/vllm_infer.py {' '.join(inference_args)}"
            full_cmd = f"cd {self.llamafactory_dir} && {inference_cmd}"
            
            logger.info(f"📝 inference command generated:")
            logger.info(f"   - command: {full_cmd}")
            
            return full_cmd, str(output_file)
            
        except Exception as e:
            logger.error(f"❌ inference command generation failed: {e}")
            raise
    
    def _build_inference_args(self, 
                             model_path: str, 
                             test_file: str, 
                             train_file:str,
                             model_config: Dict[str, Any], 
                             output_file: str) -> List[str]:
        """
        build inference arguments list
        
        Args:
            model_path: model path
            test_file: test file path
            model_config: model config
            output_file: output file path
            
        Returns:
            inference arguments list
        """
        # basic inference arguments
        few_shot = model_config.get('few_shot',False)
        if few_shot:
            if model_config.get('finetuning_type') == 'lora':
                # LoRA model needs to specify base model and adapter separately
                inference_args = [
                    f"--model_name_or_path {model_config['base_model']}",
                    f"--adapter_name_or_path {model_path}",
                    f"--dataset {test_file}",
                    f"--train_dataset {train_file}",
                    f"--n_shot {model_config['n_shot']}",
                    f"--save_name {output_file}"
                ]
            else:
                # full-finetuned model directly use trained model path
                inference_args = [
                    f"--model_name_or_path {model_path}",
                    f"--dataset {test_file}",
                    f"--train_dataset {train_file}",
                    f"--n_shot {model_config['n_shot']}",
                    f"--save_name {output_file}"
                ]
        else:
            if model_config.get('finetuning_type') == 'lora':
                # LoRA model needs to specify base model and adapter separately
                inference_args = [
                    f"--model_name_or_path {model_config['base_model']}",
                    f"--adapter_name_or_path {model_path}",
                    f"--dataset {test_file}",
                    f"--save_name {output_file}"
                ]
            else:
                # full-finetuned model directly use trained model path
                inference_args = [
                    f"--model_name_or_path {model_path}",
                    f"--dataset {test_file}",
                    f"--save_name {output_file}"
                ]
        
        # add optional parameters
        optional_params = {
            'template': model_config.get('template'),
            'cutoff_len': model_config.get('cutoff_len', 4096),
            'batch_size': model_config.get('inference_batch_size', 64),
            'max_samples': model_config.get('max_samples'),
            'temperature': model_config.get('temperature', 0.1),
            'top_p': model_config.get('top_p', 0.9),
            'max_new_tokens': model_config.get('max_new_tokens', 512)
        }
        
        for param, value in optional_params.items():
            if value is not None:
                inference_args.append(f"--{param} {value}")
        
        return inference_args
    
    def evaluation(self, 
                   prediction_file: str, 
                   ground_truth_file: str,
                   metrics: List[str] = ['bleu', 'sacrebleu'],
                   output_dir: Optional[str] = None) -> Dict[str, Any]:
        """
        evaluation function: use Hugging Face evaluate to calculate specified metrics
        
        Args:
            prediction_file: prediction result file path (JSONL format)
            ground_truth_file: ground truth data name
            metrics: metrics to calculate, e.g. ['exact_match', 'f1', 'bleu', 'rouge']
            output_dir: result output directory
            
        Returns:
            evaluation results dictionary
        """
        # set default metrics
        
        logger.info(f"📊 start evaluation: {prediction_file}")
        logger.info(f"📏 specified metrics: {metrics}")
        ground_truth_path = self._resolve_test_file_path(ground_truth_file)
        try:
            # 1. load prediction results and ground truth
            predictions, references = self._load_prediction_and_ground_truth(
                prediction_file, ground_truth_path
            )
            
            logger.info(f"📋 data statistics:")
            logger.info(f"   - prediction samples: {len(predictions)}")
            logger.info(f"   - ground truth samples: {len(references)}")
            
            # 2. use Hugging Face evaluate to calculate specified metrics
            eval_results = self._compute_metrics_with_evaluate(predictions, references, metrics)
            
            # 3. add extra statistics
            eval_results.update({
                'total_samples': len(predictions),
                'prediction_file': prediction_file,
                'ground_truth_file': ground_truth_file,
                'metrics_used': metrics,
                'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
            })
            
            # 4. save results
            if output_dir:
                self._save_evaluation_results(eval_results, output_dir, prediction_file)
            
            # 5. print result summary
            logger.info(f"✅ evaluation completed:")
            for metric in metrics:
                if metric in eval_results:
                    value = eval_results[metric]
                    if isinstance(value, (int, float)):
                        logger.info(f"   - {metric}: {value:.4f}")
                    else:
                        logger.info(f"   - {metric}: {value}")
            
            return eval_results
            
        except Exception as e:
            logger.error(f"❌ evaluation failed: {e}")
            raise
    
    def _load_prediction_and_ground_truth(self, 
                                         prediction_file: str, 
                                         ground_truth_file: str) -> Tuple[List[str], List[str]]:
        """
        load prediction results and ground truth
        
        Args:
            prediction_file: prediction file path
            ground_truth_file: ground truth file path
            
        Returns:
            Tuple[prediction list, ground truth list]
        """
        # load prediction results (JSONL format)
        predictions = []
        with open(prediction_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    data = json.loads(line)
                    # support multiple prediction result formats
                    pred = data.get('predict', data.get('prediction', data.get('output', '')))
                    predictions.append(str(pred).strip())
        
        # load ground truth
        references = []
        if ground_truth_file.endswith('.jsonl'):
            # JSONL format
            with open(ground_truth_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        data = json.loads(line)
                        ref = data.get('output', data.get('answer', data.get('ground_truth', '')))
                        references.append(str(ref).strip())
        else:
            # JSON format
            with open(ground_truth_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if isinstance(data, list):
                    for item in data:
                        ref = item.get('output', item.get('answer', item.get('ground_truth', '')))
                        references.append(str(ref).strip())
                else:
                    # single sample
                    ref = data.get('output', data.get('answer', data.get('ground_truth', '')))
                    references.append(str(ref).strip())
        
        # ensure prediction and label number consistent
        min_len = min(len(predictions), len(references))
        if len(predictions) != len(references):
            logger.warning(f"prediction number ({len(predictions)}) and label number ({len(references)}) inconsistent, use first {min_len} samples")
            predictions = predictions[:min_len]
            references = references[:min_len]
        
        return predictions, references
    
    def _normalize_answer(self, s: str) -> str:
        """standardize answer text (from TriviaQA)"""
        import string
        import re
        
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def handle_punc(text):
            exclude = set(string.punctuation + "".join([u"'", u"'", u"´", u"`"]))
            return ''.join(ch if ch not in exclude else ' ' for ch in text)

        def lower(text):
            return text.lower()

        def replace_underscore(text):
            return text.replace('_', ' ')

        return white_space_fix(remove_articles(handle_punc(lower(replace_underscore(s))))).strip()
    
    def _exact_match_score(self, prediction: str, ground_truth: str) -> float:
        """calculate exact match score (from TriviaQA)"""
        return float(self._normalize_answer(prediction) == self._normalize_answer(ground_truth))
    
    def _f1_score(self, prediction: str, ground_truth: str) -> float:
        """calculate F1 score (from TriviaQA)"""
        from collections import Counter
        
        prediction_tokens = self._normalize_answer(prediction).split()
        ground_truth_tokens = self._normalize_answer(ground_truth).split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0.0
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1
    
    def _compute_metrics_with_evaluate(self, predictions: List[str], references: List[str], metrics: List[str]) -> Dict[str, Any]:
        """
        calculate specified metrics, prioritize TriviaQA method for exact_match and f1
        
        Args:
            predictions: prediction result list
            references: ground truth list
            metrics: metrics to calculate
            
        Returns:
            evaluation metrics dictionary
        """
        results = {}
        for metric_name in metrics:
            try:
                logger.info(f"🔄 calculate metric: {metric_name}")
                # other metrics use Hugging Face evaluate
                try:
                    metric = evaluate.load('metrics/'+metric_name)
                    metric_result = metric.compute(predictions=predictions, references=references)
                    results[metric_name] = metric_result
                    logger.info(f"✅ {metric_name} calculation completed")
                except Exception as e:
                    logger.error(f"❌ {metric_name} loading failed: {e}")
                    results[metric_name] = None
            except Exception as e:
                logger.error(f"❌ {metric_name} calculation failed: {e}")
                results[metric_name] = None
        
        return results

    def _run_vllm_inference_with_logits(self, 
                                         model_path: str, 
                                         test_name: str, 
                                         model_config: Dict[str, Any],
                                         include_logits: bool = True) -> Tuple[List[str], List[str], Optional[List]]:
        """
        use vLLM to infer and get logits
        
        Returns:
            Tuple[predictions, references, logits_data]
        """
        import sys
        import os
        import gc
        import time
        from tqdm import tqdm
        
        # add LLaMA-Factory path
        llamafactory_src = os.path.join(os.getcwd(), "DatasetResearch", "LLaMA-Factory", "src")
        if llamafactory_src not in sys.path:
            sys.path.insert(0, llamafactory_src)
        
        # import necessary modules
        try:
            from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
            from llamafactory.extras.constants import IGNORE_INDEX
            from llamafactory.hparams import get_infer_args
            from llamafactory.model import load_tokenizer
            from transformers import Seq2SeqTrainingArguments
            
            # vLLM related imports
            from vllm import LLM, SamplingParams
            from vllm.lora.request import LoRARequest
            
        except ImportError as e:
            logger.error(f"import vLLM or LLaMA-Factory modules failed: {e}")
            raise
        
        try:
            # 1. set inference parameters
            model_args, data_args, _, generating_args = get_infer_args(
                dict(
                    model_name_or_path=model_path,
                    adapter_name_or_path=model_config.get('adapter_path'),
                    dataset=test_name,
                    dataset_dir="data",
                    template=model_config.get('template', 'default'),
                    cutoff_len=model_config.get('cutoff_len', 2048),
                    max_samples=model_config.get('max_samples'),
                    preprocessing_num_workers=16,
                    temperature=model_config.get('temperature', 0.95),
                    top_p=model_config.get('top_p', 0.7),
                    top_k=model_config.get('top_k', 50),
                    max_new_tokens=model_config.get('max_new_tokens', 1024),
                    repetition_penalty=model_config.get('repetition_penalty', 1.0),
                )
            )
            
            # 2. initialize tokenizer and template
            training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
            tokenizer_module = load_tokenizer(model_args)
            tokenizer = tokenizer_module["tokenizer"]
            template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
            template_obj.mm_plugin.expand_mm_tokens = False
            
            # 3. initialize vLLM engine
            engine_args = {
                "model": model_args.model_name_or_path,
                "trust_remote_code": True,
                "dtype": model_args.infer_dtype,
                "max_model_len": model_config.get('cutoff_len', 2048) + model_config.get('max_new_tokens', 1024),
                "tensor_parallel_size": 1,
                "disable_log_stats": True,
                "enable_lora": model_args.adapter_name_or_path is not None,
            }
            
            logger.info(f"🔧 initialize vLLM engine: {model_path}")
            llm = LLM(**engine_args)
            
            # 4. load dataset
            logger.info(f"📂 load dataset: {test_name}")
            dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
            train_dataset = dataset_module["train_dataset"]
            logger.info(f"📊 dataset size: {len(train_dataset)} samples")
            
            # 5. set sampling parameters (if need logits, use special settings)
            sampling_params = SamplingParams(
                repetition_penalty=generating_args.repetition_penalty or 1.0,
                temperature=generating_args.temperature,
                top_p=generating_args.top_p or 1.0,
                top_k=generating_args.top_k or -1,
                stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
                max_tokens=generating_args.max_new_tokens,
                skip_special_tokens=True,
                # key: enable logprobs to get logits information
                logprobs=20 if include_logits else None,  # get top-20 logprobs
            )
            
            if include_logits:
                logger.info("🔍 enable logprobs collection, for precise cross entropy calculation")
            
            # 6. set LoRA (if needed)
            lora_request = None
            if model_args.adapter_name_or_path is not None:
                lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
                logger.info(f"🔧 enable LoRA: {model_args.adapter_name_or_path[0]}")
            
            # 7. batch inference
            predictions, references, logits_data = [], [], []
            batch_size = model_config.get('batch_size', 32)
            
            logger.info(f"🚀 start batch inference, batch size: {batch_size}")
            start_time = time.time()
            
            for i in tqdm(range(0, len(train_dataset), batch_size), desc="vLLM batch inference"):
                vllm_inputs, batch_refs = [], []
                batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
                
                for j in range(len(batch["input_ids"])):
                    # process multi-modal data (if any)
                    multi_modal_data = None
                    if batch["images"][j] is not None:
                        # image processing logic - can be extended as needed
                        pass
                    elif batch["videos"][j] is not None:
                        # video processing logic - can be extended as needed
                        pass
                    elif batch["audios"][j] is not None:
                        # audio processing logic - can be extended as needed
                        pass
                    
                    vllm_inputs.append({
                        "prompt_token_ids": batch["input_ids"][j], 
                        "multi_modal_data": multi_modal_data
                    })
                    
                    # extract reference answer
                    ref_text = tokenizer.decode(
                        list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
                        skip_special_tokens=True
                    )
                    batch_refs.append(ref_text)
                
                # execute inference
                results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
                
                # extract results
                for result, ref in zip(results, batch_refs):
                    predictions.append(result.outputs[0].text)
                    references.append(ref)
                    
                    # extract logits information (if available)
                    if include_logits and hasattr(result.outputs[0], 'logprobs') and result.outputs[0].logprobs:
                        # extract token-level logprobs
                        token_logprobs = []
                        for token_logprob in result.outputs[0].logprobs:
                            if token_logprob:
                                # get most likely token's logprob
                                max_logprob = max(token_logprob.values())
                                token_logprobs.append(max_logprob)
                        logits_data.append(token_logprobs)
                    else:
                        logits_data.append([])
                
                gc.collect()
            
            end_time = time.time()
            logger.info(f"📊 vLLM inference completed: {len(predictions)} samples, time: {end_time - start_time:.2f} seconds")
            
            if include_logits:
                valid_logits = len([x for x in logits_data if x])
                logger.info(f"📊 get logits data: {valid_logits}/{len(logits_data)} valid samples")
            
            return predictions, references, logits_data if include_logits else None
            
        except Exception as e:
            logger.error(f"❌ vLLM inference failed: {e}")
            raise
    
    def _compute_metrics_with_logits(self, 
                                     predictions: List[str], 
                                     references: List[str], 
                                     logits_data: Optional[List], 
                                     metrics: List[str]) -> Dict[str, Any]:
        """
        calculate metrics using logits data, including precise cross entropy and perplexity
        """
        results = {}
        
        for metric_name in metrics:
            try:
                logger.info(f"🔄 calculate metric: {metric_name}")
                
                if metric_name == 'cross_entropy':
                    if logits_data and any(logits_data):
                        # use real logits to calculate cross entropy
                        ce_value = self._compute_cross_entropy_from_logits(logits_data, references)
                        results[metric_name] = ce_value
                        logger.info(f"✅ {metric_name} (use logits): {ce_value:.4f}")
                    else:
                        # alternative approximate method
                        ce_value = self._approximate_cross_entropy_from_text(predictions, references)
                        results[metric_name] = ce_value
                        logger.info(f"✅ {metric_name} (approximate method): {ce_value:.4f}")
                
                elif metric_name == 'perplexity':
                    if 'cross_entropy' in results:
                        try:
                            import math
                            ppl_value = math.exp(results['cross_entropy'])
                            results[metric_name] = ppl_value
                            logger.info(f"✅ {metric_name}: {ppl_value:.4f}")
                        except OverflowError:
                            results[metric_name] = float('inf')
                            logger.warning(f"⚠️ {metric_name} calculation overflow")
                    else:
                        # calculate perplexity separately
                        if logits_data and any(logits_data):
                            ce_value = self._compute_cross_entropy_from_logits(logits_data, references)
                        else:
                            ce_value = self._approximate_cross_entropy_from_text(predictions, references)
                        
                        try:
                            import math
                            ppl_value = math.exp(ce_value)
                            results[metric_name] = ppl_value
                            logger.info(f"✅ {metric_name}: {ppl_value:.4f}")
                        except OverflowError:
                            results[metric_name] = float('inf')
                            logger.warning(f"⚠️ {metric_name} calculation overflow")
                
                else:
                    # other metrics use Hugging Face evaluate
                    try:
                        import evaluate
                        metric = evaluate.load('metrics/'+metric_name)
                        metric_result = metric.compute(predictions=predictions, references=references)
                        
                        # 处理返回结果
                        if isinstance(metric_result, dict):
                            for key, value in metric_result.items():
                                if key == metric_name:
                                    results[key] = value
                                else:
                                    results[f"{metric_name}_{key}"] = value
                        else:
                            results[metric_name] = metric_result
                            
                        logger.info(f"✅ {metric_name} calculation completed")
                        
                    except Exception as eval_e:
                        logger.warning(f"⚠️ {metric_name} calculation failed: {eval_e}")
                        results[metric_name] = None
                        
            except Exception as e:
                logger.error(f"❌ {metric_name} calculation error: {e}")
                results[metric_name] = None
        
        return results
    
    def _compute_cross_entropy_from_logits(self, logits_data: List[List[float]], references: List[str]) -> float:
        """
        use real logits to calculate cross entropy
        """
        try:
            total_ce = 0.0
            total_tokens = 0
            
            for logprobs_seq in logits_data:
                if logprobs_seq:  # if there are logprobs data
                    for logprob in logprobs_seq:
                        total_ce += -logprob  # negative log likelihood
                        total_tokens += 1
            
            return total_ce / total_tokens if total_tokens > 0 else float('inf')
            
        except Exception as e:
            logger.warning(f"use logits to calculate cross entropy failed: {e}")
            return float('inf')
    
    def _approximate_cross_entropy_from_text(self, predictions: List[str], references: List[str]) -> float:
        """
        approximate method for cross entropy based on text similarity
        """
        try:
            from difflib import SequenceMatcher
            import math
            
            total_ce = 0.0
            for pred, ref in zip(predictions, references):
                similarity = SequenceMatcher(None, pred.lower(), ref.lower()).ratio()
                # convert similarity to probability, then calculate negative log
                prob = max(similarity, 0.001)  # avoid log(0)
                ce = -math.log(prob)
                total_ce += ce
            
            return total_ce / len(predictions) if predictions else float('inf')
            
        except Exception as e:
            logger.warning(f"approximate cross entropy calculation failed: {e}")
            return 5.0  # default value

    def _save_evaluation_results(self, 
                                results: Dict[str, Any], 
                                output_dir: str, 
                                prediction_file: str):
        """
        save evaluation results
        
        Args:
            results: evaluation results
            output_dir: output directory
            prediction_file: prediction file path (for generating file name)
        """
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        # generate result file name
        pred_name = Path(prediction_file).stem
        result_file = output_path / f"{pred_name}_evaluation_results.json"
        
        # save results
        with open(result_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        
        logger.info(f"📁 evaluation results saved: {result_file}")
    



# example usage function
def example_usage():
    """example usage method"""
    
    # initialize pipeline
    pipeline = GenerationEvaluator(
        workspace_dir="/path/to/workspace",
        llamafactory_dir="/path/to/LLaMA-Factory"
    )
    
    # model configuration example
    model_config = {
        'name': 'llama3_sft',
        'base_model': 'meta-llama/Llama-3-8B-Instruct',
        'finetuning_type': 'full',  # or 'lora'
        'inference_batch_size': 64,
        'temperature': 0.1,
        'max_new_tokens': 512
    }
    
    # 1. use template to convert data and generate inference command
    cmd, output_file = pipeline.inference(
        model_path="/path/to/trained/model",
        test_json_file="/path/to/test.json",
        model_config=model_config,
        template_file="evaluation/qa_template_example.py"  # use template to convert
    )
    print(f"inference command: {cmd}")
    
    # 2. no template, only default supplement (backward compatible)
    cmd, output_file = pipeline.inference(
        model_path="/path/to/trained/model",
        test_json_file="/path/to/test.json",
        model_config=model_config,
        template_file=None  # use default supplement logic
    )
    print(f"inference command (default supplement): {cmd}")
    
    # 3. execute evaluation - use default metrics (exact_match, f1)
    eval_results = pipeline.evaluation(
        prediction_file="/path/to/predictions.jsonl",
        ground_truth_file="/path/to/ground_truth.json"
    )
    print(f"evaluation results (default metrics): {eval_results}")
    
    # 4. execute evaluation - specify multiple metrics
    eval_results_multi = pipeline.evaluation(
        prediction_file="/path/to/predictions.jsonl",
        ground_truth_file="/path/to/ground_truth.json",
        metrics=['exact_match', 'f1', 'bleu', 'rouge', 'meteor']
    )
    print(f"evaluation results (multiple metrics): {eval_results_multi}")
    
    # 5. execute evaluation - only use specific metrics
    eval_results_specific = pipeline.evaluation(
        prediction_file="/path/to/predictions.jsonl",
        ground_truth_file="/path/to/ground_truth.json",
        metrics=['bertscore', 'rouge1', 'rouge2', 'rougeL']
    )
    print(f"evaluation results (specific metrics): {eval_results_specific}")



if __name__ == "__main__":
    # configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
    # run example
    example_usage()