
import os
import sys
import json
import argparse
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field, asdict
from collections import defaultdict
import re
from difflib import SequenceMatcher
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

from transformers import (
    AutoProcessor,
    Qwen2_5_VLForConditionalGeneration,
)
from peft import PeftModel


def normalize_text(text: str) -> str:
    text = text.lower().strip()
    text = re.sub(r'\s+', ' ', text)
    return text


def exact_match(pred: str, target: str) -> bool:
    return normalize_text(pred) == normalize_text(target)


def fuzzy_match_score(pred: str, target: str) -> float:
    pred_norm = normalize_text(pred)
    target_norm = normalize_text(target)
    return SequenceMatcher(None, pred_norm, target_norm).ratio()


def edit_distance(pred: str, target: str) -> int:
    pred_norm = normalize_text(pred)
    target_norm = normalize_text(target)
    
    m, n = len(pred_norm), len(target_norm)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if pred_norm[i-1] == target_norm[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
    
    return dp[m][n]


def compute_bleu(pred: str, target: str, max_n: int = 4) -> float:
    pred_tokens = normalize_text(pred).split()
    target_tokens = normalize_text(target).split()
    
    if len(pred_tokens) == 0 or len(target_tokens) == 0:
        return 0.0
    
    precisions = []
    for n in range(1, min(max_n + 1, len(pred_tokens) + 1)):
        pred_ngrams = defaultdict(int)
        target_ngrams = defaultdict(int)
        
        for i in range(len(pred_tokens) - n + 1):
            ngram = tuple(pred_tokens[i:i+n])
            pred_ngrams[ngram] += 1
        
        for i in range(len(target_tokens) - n + 1):
            ngram = tuple(target_tokens[i:i+n])
            target_ngrams[ngram] += 1
        
        clipped_count = 0
        total_count = 0
        for ngram, count in pred_ngrams.items():
            clipped_count += min(count, target_ngrams.get(ngram, 0))
            total_count += count
        
        if total_count > 0:
            precisions.append(clipped_count / total_count)
        else:
            precisions.append(0.0)
    
    if len(precisions) == 0 or all(p == 0 for p in precisions):
        return 0.0
    
    log_precision = sum(np.log(p + 1e-10) for p in precisions) / len(precisions)
    
    bp = 1.0
    if len(pred_tokens) < len(target_tokens):
        bp = np.exp(1 - len(target_tokens) / len(pred_tokens))
    
    return bp * np.exp(log_precision)


def word_level_precision_recall(pred: str, target: str) -> Tuple[float, float, float]:
    pred_words = set(normalize_text(pred).split())
    target_words = set(normalize_text(target).split())
    
    if len(pred_words) == 0:
        precision = 0.0
    else:
        precision = len(pred_words & target_words) / len(pred_words)
    
    if len(target_words) == 0:
        recall = 0.0
    else:
        recall = len(pred_words & target_words) / len(target_words)
    
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0.0
    
    return precision, recall, f1


def length_ratio(pred: str, target: str) -> float:
    pred_len = len(normalize_text(pred).split())
    target_len = len(normalize_text(target).split())
    
    if target_len == 0:
        return 1.0 if pred_len == 0 else float('inf')
    
    return pred_len / target_len


def char_length_ratio(pred: str, target: str) -> float:
    pred_len = len(normalize_text(pred))
    target_len = len(normalize_text(target))
    
    if target_len == 0:
        return 1.0 if pred_len == 0 else float('inf')
    
    return pred_len / target_len


@dataclass
class EvalMetrics:
    sample_id: int
    task_idx: int
    img_idx: int
    
    prediction: str
    ground_truth: str
    full_instruction: str
    
    exact_match: bool
    fuzzy_score: float
    edit_distance: int
    bleu_score: float
    
    word_precision: float
    word_recall: float
    word_f1: float
    
    pred_word_count: int
    gt_word_count: int
    word_length_ratio: float
    
    pred_char_count: int
    gt_char_count: int
    char_length_ratio: float
    
    is_valid_shorter: bool 


@dataclass
class AggregatedMetrics:
    total_samples: int
    
    exact_match_rate: float
    avg_fuzzy_score: float
    avg_edit_distance: float
    avg_bleu_score: float
    
    avg_word_precision: float
    avg_word_recall: float
    avg_word_f1: float
    avg_match: float 
    
    avg_word_length_ratio: float
    avg_char_length_ratio: float
    
    valid_shorter_rate: float
    shorter_when_correct_rate: float 
    
    fuzzy_score_distribution: Dict[str, float] = field(default_factory=dict)
    length_ratio_distribution: Dict[str, float] = field(default_factory=dict)


def compute_sample_metrics(
    sample_id: int,
    task_idx: int,
    img_idx: int,
    prediction: str,
    ground_truth: str,
    full_instruction: str,
    fuzzy_threshold: float = 0.9
) -> EvalMetrics:
    is_exact_match = exact_match(prediction, ground_truth)
    fuzzy = fuzzy_match_score(prediction, ground_truth)
    edit_dist = edit_distance(prediction, ground_truth)
    bleu = compute_bleu(prediction, ground_truth)
    
    precision, recall, f1 = word_level_precision_recall(prediction, ground_truth)
    
    pred_words = normalize_text(prediction).split()
    gt_words = normalize_text(ground_truth).split()
    word_ratio = length_ratio(prediction, ground_truth)
    char_ratio = char_length_ratio(prediction, ground_truth)
    
    is_correct = is_exact_match or fuzzy >= fuzzy_threshold
    is_shorter = word_ratio < 1.0
    is_valid_shorter = is_correct and is_shorter
    
    return EvalMetrics(
        sample_id=sample_id,
        task_idx=task_idx,
        img_idx=img_idx,
        prediction=prediction,
        ground_truth=ground_truth,
        full_instruction=full_instruction,
        exact_match=is_exact_match,
        fuzzy_score=fuzzy,
        edit_distance=edit_dist,
        bleu_score=bleu,
        word_precision=precision,
        word_recall=recall,
        word_f1=f1,
        pred_word_count=len(pred_words),
        gt_word_count=len(gt_words),
        word_length_ratio=word_ratio,
        pred_char_count=len(normalize_text(prediction)),
        gt_char_count=len(normalize_text(ground_truth)),
        char_length_ratio=char_ratio,
        is_valid_shorter=is_valid_shorter
    )


def aggregate_metrics(
    sample_metrics: List[EvalMetrics],
    fuzzy_threshold: float = 0.9
) -> AggregatedMetrics:
    n = len(sample_metrics)
    if n == 0:
        return AggregatedMetrics(
            total_samples=0,
            exact_match_rate=0.0,
            avg_fuzzy_score=0.0,
            avg_edit_distance=0.0,
            avg_bleu_score=0.0,
            avg_word_precision=0.0,
            avg_word_recall=0.0,
            avg_word_f1=0.0,
            avg_match=0.0,
            avg_word_length_ratio=0.0,
            avg_char_length_ratio=0.0,
            valid_shorter_rate=0.0,
            shorter_when_correct_rate=0.0
        )
    
    exact_matches = sum(1 for m in sample_metrics if m.exact_match)
    
    avg_fuzzy = sum(m.fuzzy_score for m in sample_metrics) / n
    avg_edit = sum(m.edit_distance for m in sample_metrics) / n
    avg_bleu = sum(m.bleu_score for m in sample_metrics) / n
    
    avg_precision = sum(m.word_precision for m in sample_metrics) / n
    avg_recall = sum(m.word_recall for m in sample_metrics) / n
    avg_f1 = sum(m.word_f1 for m in sample_metrics) / n
    
    match_scores = [1 if m.word_precision > 0.8 else 0 for m in sample_metrics]
    avg_match = sum(match_scores) / n
    
    avg_word_ratio = sum(m.word_length_ratio for m in sample_metrics) / n
    avg_char_ratio = sum(m.char_length_ratio for m in sample_metrics) / n
    
    valid_shorter = sum(1 for m in sample_metrics if m.is_valid_shorter)
    correct_samples = [m for m in sample_metrics if m.exact_match or m.fuzzy_score >= fuzzy_threshold]
    shorter_when_correct = sum(1 for m in correct_samples if m.word_length_ratio < 1.0)
    
    fuzzy_distribution = {
        ">=0.9": sum(1 for m in sample_metrics if m.fuzzy_score >= 0.9) / n,
        "0.7-0.9": sum(1 for m in sample_metrics if 0.7 <= m.fuzzy_score < 0.9) / n,
        "0.5-0.7": sum(1 for m in sample_metrics if 0.5 <= m.fuzzy_score < 0.7) / n,
        "<0.5": sum(1 for m in sample_metrics if m.fuzzy_score < 0.5) / n
    }
    
    length_distribution = {
        "<0.8": sum(1 for m in sample_metrics if m.word_length_ratio < 0.8) / n,
        "0.8-1.0": sum(1 for m in sample_metrics if 0.8 <= m.word_length_ratio < 1.0) / n,
        "1.0-1.2": sum(1 for m in sample_metrics if 1.0 <= m.word_length_ratio < 1.2) / n,
        ">=1.2": sum(1 for m in sample_metrics if m.word_length_ratio >= 1.2) / n
    }
    
    return AggregatedMetrics(
        total_samples=n,
        exact_match_rate=exact_matches / n,
        avg_fuzzy_score=avg_fuzzy,
        avg_edit_distance=avg_edit,
        avg_bleu_score=avg_bleu,
        avg_word_precision=avg_precision,
        avg_word_recall=avg_recall,
        avg_word_f1=avg_f1,
        avg_match=avg_match,
        avg_word_length_ratio=avg_word_ratio,
        avg_char_length_ratio=avg_char_ratio,
        valid_shorter_rate=valid_shorter / n,
        shorter_when_correct_rate=shorter_when_correct / len(correct_samples) if correct_samples else 0.0,
        fuzzy_score_distribution=fuzzy_distribution,
        length_ratio_distribution=length_distribution
    )



class Qwen2VLEvalDataset(Dataset):
    def __init__(
        self,
        anno_path: str,
        image_folder: str,
        processor,
        multi_turn: bool = False,
        max_seq_length: int = 2048,
    ):
        self.processor = processor
        self.image_folder = image_folder
        self.multi_turn = multi_turn
        self.max_seq_length = max_seq_length
        
        with open(anno_path, 'r', encoding='utf-8') as f:
            self.anno_data = json.load(f)
        
        self.samples = []
        
        if multi_turn:
            for task_idx, task in enumerate(self.anno_data):
                self.samples.append({
                    'task_idx': task_idx,
                    'mode': 'multi_turn'
                })
        else:
            for task_idx, task in enumerate(self.anno_data):
                segments = task.get('segments', [])
                for img_idx in range(len(segments)):
                    self.samples.append({
                        'task_idx': task_idx,
                        'img_idx': img_idx,
                        'mode': 'single_turn'
                    })
        
        self.system_message = (
            "You are a vision-language navigation assistant. "
            "Given a complete navigation instruction and an image observation, "
            "identify which part of the instruction corresponds to what is shown in the image."
        )
    
    def __len__(self):
        return len(self.samples)
    
    def _get_task_folder(self, task):
        scan = task.get('scan', '')
        task_id = task.get('task_id', '')
        folder_name = f"{scan}_{task_id}" if scan else str(task_id)
        return os.path.join(self.image_folder, folder_name)
    
    def _find_image(self, task_folder: str, img_idx: int) -> Optional[str]:
        for ext in ['.jpg', '.png', '.jpeg', '.JPG', '.PNG']:
            path = os.path.join(task_folder, f"{img_idx + 1:02d}{ext}")
            if os.path.exists(path):
                return path
        return None
    
    def __getitem__(self, idx) -> Dict:
        sample = self.samples[idx]
        task = self.anno_data[sample['task_idx']]
        instruction = task.get('full_instruction', '')
        segments = task.get('segments', [])
        task_folder = self._get_task_folder(task)
        
        if sample['mode'] == 'single_turn':
            img_idx = sample['img_idx']
            segment = segments[img_idx] if img_idx < len(segments) else instruction
            
            image_path = self._find_image(task_folder, img_idx)
            if image_path is None:
                raise FileNotFoundError(f"Image not found: task {sample['task_idx']}, img {img_idx}")
            
            image = Image.open(image_path).convert('RGB')
            images = [image]
            
            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": self.system_message}]
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image},
                        {"type": "text", "text": (
                            f"Complete instruction: {instruction}\n\n"
                            "Based on this image observation, which part of the instruction is most relevant? "
                            "Please output the corresponding instruction segment."
                        )}
                    ]
                }
            ]
            
            return {
                'task_idx': sample['task_idx'],
                'img_idx': img_idx,
                'messages': messages,
                'images': images,
                'ground_truth': segment,
                'full_instruction': instruction
            }
        else:
            messages_list = []
            images = []
            ground_truths = []
            
            for img_idx, segment in enumerate(segments):
                image_path = self._find_image(task_folder, img_idx)
                if image_path is None:
                    continue
                
                image = Image.open(image_path).convert('RGB')
                images.append(image)
                ground_truths.append(segment)
                
                if img_idx == 0:
                    user_text = (
                        f"Complete instruction: {instruction}\n\n"
                        "Here is the first observation. Which part corresponds to this?"
                    )
                else:
                    user_text = "Here is the next observation. Which part corresponds to this?"
                
                messages_list.append({
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image},
                        {"type": "text", "text": user_text}
                    ]
                })
            
            return {
                'task_idx': sample['task_idx'],
                'img_idx': -1,  # 多轮模式
                'messages_list': messages_list,
                'images': images,
                'ground_truths': ground_truths,
                'full_instruction': instruction
            }



class Qwen2VLInstructionEvaluator:    
    def __init__(
        self,
        model,
        processor,
        device: str = "cuda",
        max_new_tokens: int = 256,
        fuzzy_threshold: float = 0.9
    ):
        self.model = model
        self.processor = processor
        self.device = device
        self.max_new_tokens = max_new_tokens
        self.fuzzy_threshold = fuzzy_threshold
    
    def generate_prediction(
        self,
        messages: List[Dict],
        images: List[Image.Image]
    ) -> str:
        self.model.eval()
        
        with torch.no_grad():
            text = self.processor.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            
            inputs = self.processor(
                text=[text],
                images=images,
                return_tensors="pt",
                padding=True,
            )
            
            inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v 
                     for k, v in inputs.items()}
            
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=False,
                use_cache=True,
                pad_token_id=self.processor.tokenizer.pad_token_id,
            )
            
            input_len = inputs["input_ids"].shape[1]
            generated_ids = outputs[0][input_len:]
            prediction = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
            
            prediction = self._clean_prediction(prediction)
        
        return prediction.strip()
    
    def _clean_prediction(self, prediction: str) -> str:
        patterns_to_remove = [
            "assistant\n",
            "assistant:",
            "Assistant\n",
            "Assistant:",
            "<|im_start|>assistant\n",
            "<|im_start|>assistant",
            "<|im_start|>",
            "<|im_end|>",
        ]
        
        result = prediction
        for pattern in patterns_to_remove:
            if result.startswith(pattern):
                result = result[len(pattern):]
            if result.endswith(pattern):
                result = result[:-len(pattern)]
            result = result.replace(pattern, "")
        
        return result.strip()
    
    def evaluate(
        self,
        dataset: Qwen2VLEvalDataset,
        num_samples: Optional[int] = None
    ) -> Tuple[List[EvalMetrics], AggregatedMetrics]:
        if num_samples is not None:
            dataset.samples = dataset.samples[:num_samples]
        
        sample_metrics = []
        sample_id = 0
        
        for idx in tqdm(range(len(dataset)), desc="Evaluating"):
            data = dataset[idx]
            task_idx = data['task_idx']
            full_instruction = data['full_instruction']
            
            if 'messages' in data:
                img_idx = data['img_idx']
                ground_truth = data['ground_truth']
                
                prediction = self.generate_prediction(
                    messages=data['messages'],
                    images=data['images']
                )
                
                metrics = compute_sample_metrics(
                    sample_id=sample_id,
                    task_idx=task_idx,
                    img_idx=img_idx,
                    prediction=prediction,
                    ground_truth=ground_truth,
                    full_instruction=full_instruction,
                    fuzzy_threshold=self.fuzzy_threshold
                )
                
                sample_metrics.append(metrics)
                sample_id += 1
            else:
                messages_list = data['messages_list']
                ground_truths = data['ground_truths']
                images = data['images']
                
                conversation = [
                    {
                        "role": "system",
                        "content": [{"type": "text", "text": dataset.system_message}]
                    }
                ]
                
                for turn_idx, (msg, gt, img) in enumerate(zip(messages_list, ground_truths, images)):
                    conversation.append(msg)
                    
                    prediction = self.generate_prediction(
                        messages=conversation,
                        images=images[:turn_idx + 1]
                    )
                    
                    metrics = compute_sample_metrics(
                        sample_id=sample_id,
                        task_idx=task_idx,
                        img_idx=turn_idx,
                        prediction=prediction,
                        ground_truth=gt,
                        full_instruction=full_instruction,
                        fuzzy_threshold=self.fuzzy_threshold
                    )
                    
                    sample_metrics.append(metrics)
                    sample_id += 1
                    
                    conversation.append({
                        "role": "assistant",
                        "content": [{"type": "text", "text": prediction}]
                    })
        
        aggregated_metrics = aggregate_metrics(sample_metrics, self.fuzzy_threshold)
        
        return sample_metrics, aggregated_metrics


def evaluate_from_predictions(
    predictions_file: str,
    anno_path: str,
    fuzzy_threshold: float = 0.9
) -> Tuple[List[EvalMetrics], AggregatedMetrics]:
    # 加载预测
    with open(predictions_file, 'r', encoding='utf-8') as f:
        predictions = json.load(f)
    
    # 加载标注
    with open(anno_path, 'r', encoding='utf-8') as f:
        anno_data = json.load(f)
    
    sample_metrics = []
    
    for idx, pred_item in enumerate(predictions):
        task_idx = pred_item['task_idx']
        img_idx = pred_item.get('img_idx', 0)
        prediction = pred_item['prediction']
        
        task = anno_data[task_idx]
        full_instruction = task.get('full_instruction', '')
        segments = task.get('segments', [])
        ground_truth = segments[img_idx] if img_idx < len(segments) else full_instruction
        
        metrics = compute_sample_metrics(
            sample_id=idx,
            task_idx=task_idx,
            img_idx=img_idx,
            prediction=prediction,
            ground_truth=ground_truth,
            full_instruction=full_instruction,
            fuzzy_threshold=fuzzy_threshold
        )
        
        sample_metrics.append(metrics)
    
    aggregated_metrics = aggregate_metrics(sample_metrics, fuzzy_threshold)
    
    return sample_metrics, aggregated_metrics


def save_results(
    sample_metrics: List[EvalMetrics],
    aggregated_metrics: AggregatedMetrics,
    output_path: str
):
    simplified_sample_metrics = []
    for m in sample_metrics:
        simplified_sample_metrics.append({
            "sample_id": m.sample_id,
            "task_idx": m.task_idx,
            "img_idx": m.img_idx,
            "prediction": m.prediction,
            "ground_truth": m.ground_truth,
            "full_instruction": m.full_instruction,
            "word_precision": m.word_precision
        })
    
    results = {
        "timestamp": datetime.now().isoformat(),
        "aggregated_metrics": asdict(aggregated_metrics),
        "sample_metrics": simplified_sample_metrics
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    print(f"Results saved to {output_path}")


def print_summary(aggregated_metrics: AggregatedMetrics):
    print("\n" + "=" * 60)
    print("Qwen2VL Instruction Segment Evaluation Summary")
    print("=" * 60)
    
    print(f"\n Total Samples: {aggregated_metrics.total_samples}")
    
    print("\n Correctness Metrics:")
    print(f"  Exact Match Rate: {aggregated_metrics.exact_match_rate:.2%}")
    print(f"  Avg Fuzzy Score:  {aggregated_metrics.avg_fuzzy_score:.4f}")
    print(f"  Avg Edit Distance: {aggregated_metrics.avg_edit_distance:.2f}")
    print(f"  Avg BLEU Score:   {aggregated_metrics.avg_bleu_score:.4f}")
    
    print("\n Word-Level Metrics:")
    print(f"  Avg Precision: {aggregated_metrics.avg_word_precision:.4f}")
    print(f"  Avg Recall:    {aggregated_metrics.avg_word_recall:.4f}")
    print(f"  Avg F1 Score:  {aggregated_metrics.avg_word_f1:.4f}")
    print(f"  Avg Match:     {aggregated_metrics.avg_match:.4f}  (precision > 0.8)")
    
    print("\n Length Metrics:")
    print(f"  Avg Word Length Ratio: {aggregated_metrics.avg_word_length_ratio:.4f}")
    print(f"  Avg Char Length Ratio: {aggregated_metrics.avg_char_length_ratio:.4f}")
    print(f"  Valid Shorter Rate:    {aggregated_metrics.valid_shorter_rate:.2%}")
    print(f"  Shorter When Correct:  {aggregated_metrics.shorter_when_correct_rate:.2%}")
    
    print("\n Fuzzy Score Distribution:")
    for key, value in aggregated_metrics.fuzzy_score_distribution.items():
        print(f"  {key}: {value:.2%}")
    
    print("\n Length Ratio Distribution:")
    for key, value in aggregated_metrics.length_ratio_distribution.items():
        print(f"  {key}: {value:.2%}")
    
    print("\n" + "=" * 60)



def merge_shard_results(
    shard_files: List[str],
    output_path: str,
    fuzzy_threshold: float = 0.9
):
    all_sample_metrics = []
    
    for shard_file in shard_files:
        if not os.path.exists(shard_file):
            raise FileNotFoundError(f"Shard file not found: {shard_file}")
        
        with open(shard_file, 'r', encoding='utf-8') as f:
            shard_data = json.load(f)
        
        for m in shard_data.get("sample_metrics", []):
            pred = m["prediction"]
            gt = m["ground_truth"]
            
            metrics = compute_sample_metrics(
                sample_id=m["sample_id"],
                task_idx=m["task_idx"],
                img_idx=m["img_idx"],
                prediction=pred,
                ground_truth=gt,
                full_instruction=m["full_instruction"],
                fuzzy_threshold=fuzzy_threshold
            )
            all_sample_metrics.append(metrics)
    
    all_sample_metrics.sort(key=lambda x: (x.task_idx, x.img_idx))
    
    for idx, m in enumerate(all_sample_metrics):
        m.sample_id = idx
    
    aggregated_metrics = aggregate_metrics(all_sample_metrics, fuzzy_threshold)
    
    print_summary(aggregated_metrics)
    
    save_results(all_sample_metrics, aggregated_metrics, output_path)
    
    print(f"Merged {len(shard_files)} shards into {output_path}")
    print(f"Total samples: {len(all_sample_metrics)}")


def main():
    parser = argparse.ArgumentParser(description="Evaluate Qwen2.5-VL Instruction Segmentation Model")
    parser.add_argument("--mode", type=str, default="offline",
                        choices=["online", "offline", "merge"],
                        help="Evaluation mode: online (with model), offline (from predictions file), or merge (merge shard results)")
    
    parser.add_argument("--predictions_file", type=str, default=None,
                        help="Path to predictions file (for offline mode)")
    
    parser.add_argument("--anno_path", type=str, default=None,
                        help="Path to annotation file (anno.json)")
    parser.add_argument("--image_folder", type=str, default=None,
                        help="Path to image folder (for online mode)")
    
    parser.add_argument("--model_path", type=str, default=None,
                        help="Path to base model or fine-tuned model")
    parser.add_argument("--lora_path", type=str, default=None,
                        help="Path to LoRA adapter (optional, for LoRA fine-tuned models)")
    
    parser.add_argument("--multi_turn", action="store_true",
                        help="Use multi-turn conversation mode")
    parser.add_argument("--num_samples", type=int, default=None,
                        help="Number of samples to evaluate (None for all)")
    parser.add_argument("--max_new_tokens", type=int, default=256,
                        help="Maximum new tokens to generate")
    parser.add_argument("--fuzzy_threshold", type=float, default=0.9,
                        help="Threshold for fuzzy matching")
    parser.add_argument("--max_seq_length", type=int, default=2048,
                        help="Maximum sequence length")
    
    parser.add_argument("--output_path", type=str, default="./eval_results_qwen2vl.json",
                        help="Path to save evaluation results")
    parser.add_argument("--shard_id", type=int, default=0,
                        help="Current shard ID (0-indexed)")
    parser.add_argument("--num_shards", type=int, default=1,
                        help="Total number of shards (1 for single GPU)")
    parser.add_argument("--total_samples", type=int, default=None,
                        help="Total number of samples to evaluate across all shards (applied before sharding)")
    
    parser.add_argument("--shard_files", type=str, nargs="+", default=None,
                        help="List of shard result files to merge (for merge mode)")
    
    args = parser.parse_args()
    
    if args.mode == "merge":
        if args.shard_files is None or len(args.shard_files) == 0:
            raise ValueError("shard_files is required for merge mode")
        
        print(f"Running merge mode...")
        print(f"  Shard files: {args.shard_files}")
        print(f"  Output: {args.output_path}")
        
        merge_shard_results(
            shard_files=args.shard_files,
            output_path=args.output_path,
            fuzzy_threshold=args.fuzzy_threshold
        )
        return None, None
    
    elif args.mode == "offline":
        if args.predictions_file is None:
            raise ValueError("predictions_file is required for offline mode")
        if args.anno_path is None:
            raise ValueError("anno_path is required for offline mode")
        
        print(f"Running offline evaluation...")
        print(f"  Predictions: {args.predictions_file}")
        print(f"  Annotations: {args.anno_path}")
        
        sample_metrics, aggregated_metrics = evaluate_from_predictions(
            predictions_file=args.predictions_file,
            anno_path=args.anno_path,
            fuzzy_threshold=args.fuzzy_threshold
        )
    
    else:
        # 在线评估
        if args.model_path is None or args.image_folder is None:
            raise ValueError("model_path and image_folder are required for online mode")
        if args.anno_path is None:
            raise ValueError("anno_path is required for online mode")
        
        print(f"Running online evaluation...")
        print(f"  Model: {args.model_path}")
        if args.lora_path:
            print(f"  LoRA: {args.lora_path}")
        print(f"  Annotations: {args.anno_path}")
        print(f"  Images: {args.image_folder}")
        
        # 显示分片信息
        if args.num_shards > 1:
            print(f"  Shard: {args.shard_id + 1}/{args.num_shards}")
        
        print("Loading processor...")
        processor = AutoProcessor.from_pretrained(
            args.model_path,
            trust_remote_code=True,
        )
        if processor.tokenizer.pad_token_id is None:
            processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
        
        print("Loading model...")
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto",
        )
        
        if args.lora_path:
            print(f"Loading LoRA adapter from {args.lora_path}...")
            model = PeftModel.from_pretrained(model, args.lora_path)
            model = model.merge_and_unload()  
        
        model.eval()
        
        print("Creating dataset...")
        dataset = Qwen2VLEvalDataset(
            anno_path=args.anno_path,
            image_folder=args.image_folder,
            processor=processor,
            multi_turn=args.multi_turn,
            max_seq_length=args.max_seq_length,
        )
        total_samples = len(dataset)
        print(f"Total dataset size: {total_samples}")
        
        if args.total_samples is not None and args.total_samples < total_samples:
            dataset.samples = dataset.samples[:args.total_samples]
            print(f"Limited to first {args.total_samples} samples")
            total_samples = args.total_samples
        
        if args.num_shards > 1:
            samples_per_shard = (total_samples + args.num_shards - 1) // args.num_shards
            start_idx = args.shard_id * samples_per_shard
            end_idx = min(start_idx + samples_per_shard, total_samples)
            
            dataset.samples = dataset.samples[start_idx:end_idx]
            print(f"Shard {args.shard_id}: processing samples [{start_idx}, {end_idx}) = {len(dataset)} samples")
        
        evaluator = Qwen2VLInstructionEvaluator(
            model=model,
            processor=processor,
            device="cuda" if torch.cuda.is_available() else "cpu",
            max_new_tokens=args.max_new_tokens,
            fuzzy_threshold=args.fuzzy_threshold
        )
        
        sample_metrics, aggregated_metrics = evaluator.evaluate(
            dataset=dataset,
            num_samples=None  # 不再使用 num_samples，已经通过 total_samples 控制
        )
    
    print_summary(aggregated_metrics)
    
    save_results(sample_metrics, aggregated_metrics, args.output_path)
    
    return sample_metrics, aggregated_metrics


if __name__ == "__main__":
    main()

