"""
MMMU Evaluation Script

This script provides evaluation on the MMMU benchmark using either:
1. Simple mode: QwenVL model directly answers questions about images
2. Multi-stage mode: VLM describes image, then LLM answers question
3. LLM mode: LLM only answers questions without image access (baseline)

Usage:
    # Simple mode (QwenVL only)
    python MMMU.py --vlm_model /path/to/vlm --mode simple
    
    # Multi-stage mode (VLM+LLM)
    python MMMU.py --vlm_model /path/to/vlm --llm_model /path/to/llm --mode multi-stage
    
    # LLM mode (LLM only, baseline)
    python MMMU.py --llm_model /path/to/llm --mode llm
"""

import argparse
import os
import json
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Any, List, Optional
from datasets import load_dataset
from datetime import datetime
from collections import defaultdict
import tempfile
from PIL import Image
import re
import sys
import csv
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from pathlib import Path as _Path

from rosetta.utils.multi_stage import MultiStageInference


# MMMU subjects - total 30 subjects across 6 disciplines
MMMU_SUBJECTS = {
    "Art_and_Design": [
        "Art", "Art_Theory", "Design", "Music"
    ],
    "Business": [
        "Accounting", "Economics", "Finance", "Manage", "Marketing"
    ],
    "Science": [
        "Biology", "Chemistry", "Geography", "Math", "Physics"
    ],
    "Health_and_Medicine": [
        "Basic_Medical_Science", "Clinical_Medicine", "Diagnostics_and_Laboratory_Medicine",
        "Pharmacy", "Public_Health"
    ],
    "Humanities_and_Social_Science": [
        "History", "Literature", "Sociology", "Psychology"
    ],
    "Tech_and_Engineering": [
        "Agriculture", "Architecture_and_Engineering", "Computer_Science",
        "Electronics", "Energy_and_Power", "Materials", "Mechanical_Engineering"
    ]
}

# Flatten subjects list
ALL_MMMU_SUBJECTS = []
for discipline, subjects in MMMU_SUBJECTS.items():
    ALL_MMMU_SUBJECTS.extend(subjects)


class MMMUEvaluator:
    """Evaluator for MMMU benchmark."""
    
    def __init__(
        self,
        vlm_model_path: Optional[str] = None,
        llm_model_path: Optional[str] = None,
        mode: str = "simple",
        output_dir: str = "outputs/mmmu",
        device: str = "cuda",
        max_new_tokens: int = 4096,
        run_name: Optional[str] = None,
        update_summary: bool = True,
        load_models: bool = True
    ):
        """
        Initialize MMMU evaluator.
        
        Args:
            vlm_model_path: Path to VLM model (required for simple/multi-stage modes)
            llm_model_path: Path to LLM model (required for multi-stage/llm modes)
            mode: Evaluation mode - "simple", "multi-stage", or "llm"
            output_dir: Directory for outputs
            device: Device to use
            max_new_tokens: Maximum number of new tokens to generate
            run_name: Optional run name to construct subfolder under output_dir (mmmu_<mode>_<run_name>)
            update_summary: Whether to update summary.json incrementally (disable in parallel workers)
            load_models: Whether to load models immediately (False for coordinator-only runs)
        """
        self.vlm_model_path = vlm_model_path
        self.llm_model_path = llm_model_path
        self.mode = mode
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.device = device
        self.max_new_tokens = max_new_tokens
        self.update_summary = update_summary
        
        # Initialize results tracking
        self.all_results = {}
        self.discipline_results = defaultdict(list)
        self.skipped_examples = []
        self.skipped_counts = defaultdict(int)
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        # Create run directory from run_name or timestamp, grouped by mode subfolder
        folder_name = f"mmmu_{self.mode}_{run_name}" if run_name else f"mmmu_{self.mode}_{self.timestamp}"
        mode_subdir = "llm" if self.mode == "llm" else ("simple" if self.mode == "simple" else "multi-stage")
        self.run_dir = self.output_dir / mode_subdir / folder_name
        self.run_dir.mkdir(parents=True, exist_ok=True)
        print(f"Run directory: {self.run_dir}")
        
        # Validate mode
        if mode not in ["simple", "multi-stage", "llm"]:
            raise ValueError(f"Invalid mode: {mode}. Must be 'simple', 'multi-stage', or 'llm'")
        
        if mode == "multi-stage" and not llm_model_path:
            raise ValueError("LLM model path required for multi-stage mode")
        
        if mode == "llm" and not llm_model_path:
            raise ValueError("LLM model path required for llm mode")
        
        if mode in ["simple", "multi-stage"] and not vlm_model_path:
            raise ValueError("VLM model path required for simple and multi-stage modes")
        
        # Load models based on mode (optional)
        if load_models:
            self._load_models()

    def _save_run_config(self, run_config: Dict[str, Any]):
        """Save the run configuration into the run directory."""
        try:
            cfg_file = self.run_dir / "config.json"
            with open(cfg_file, "w") as f:
                json.dump(run_config, f, indent=2)
            print(f"Run config saved to {cfg_file}")
        except Exception as e:
            print(f"Warning: failed to save run config: {e}")
    
    def _load_models(self):
        """Load models based on mode."""
        if self.mode == "multi-stage":
            print(f"Loading multi-stage pipeline...")
            print(f"  VLM: {self.vlm_model_path}")
            print(f"  LLM: {self.llm_model_path}")
            self.pipeline = MultiStageInference(
                vlm_model_path=self.vlm_model_path,
                llm_model_path=self.llm_model_path,
                device=self.device,
                max_new_tokens=self.max_new_tokens
            )
        elif self.mode == "llm":
            print(f"Loading LLM model: {self.llm_model_path}")
            from transformers import AutoModelForCausalLM, AutoTokenizer
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                self.llm_model_path,
                torch_dtype=torch.bfloat16,
                device_map={"": self.device},
            )
            self.llm_tokenizer = AutoTokenizer.from_pretrained(self.llm_model_path)
            if self.llm_tokenizer.pad_token is None:
                self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
        else:  # simple mode
            print(f"Loading QwenVL model: {self.vlm_model_path}")
            self.vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                self.vlm_model_path,
                torch_dtype=torch.bfloat16,
                device_map={"": self.device},
            )
            self.vlm_model.generation_config.do_sample = False
            self.vlm_model.generation_config.temperature = None
            self.vlm_model.generation_config.top_p = None
            self.vlm_model.generation_config.top_k = None
            self.vlm_processor = AutoProcessor.from_pretrained(self.vlm_model_path)
    
    def format_mmmu_prompt(self, question: str, options: List[str]) -> str:
        """
        Format MMMU question into standard prompt.
        
        Args:
            question: Question text (may contain <image 1> placeholder)
            options: List of answer options
            
        Returns:
            Formatted prompt
        """
        # Replace <image 1> with <image> for consistency
        question = question.replace("<image 1>", "<image>")
        
        # Build options
        option_text = ""
        for i, opt in enumerate(options):
            option_text += f"{chr(65+i)}. {opt}\n"
        
        # Use the exact format you specified
        prompt = (
            f"{question}\n"
            f"{option_text}"
            "Answer the preceding multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of options. Think step by step before answering."
        )
        
        return prompt

    def _normalize_options(self, options_field: Any) -> List[str]:
        """Normalize options field into a list of option strings in A.. order."""
        if options_field is None:
            return []
        
        # MMMU stores options as string representation of Python list
        if isinstance(options_field, str):
            try:
                import ast
                parsed = ast.literal_eval(options_field)
                if isinstance(parsed, list):
                    return [str(item) for item in parsed]
            except (ValueError, SyntaxError):
                pass
            # Fallback: treat as single option
            return [options_field]
        
        # Already a list of strings
        if isinstance(options_field, list):
            return [str(item) for item in options_field]
        
        # Dict keyed by letters or indices
        if isinstance(options_field, dict):
            # Try A, B, C, D
            ordered: List[str] = []
            keys_letter = [k for k in ["A", "B", "C", "D", "E", "F"] if k in options_field]
            if keys_letter:
                for k in keys_letter:
                    ordered.append(str(options_field[k]))
                return ordered
            # Try numeric keys as strings
            keys_num = [k for k in ["0", "1", "2", "3", "4", "5"] if k in options_field]
            if keys_num:
                for k in keys_num:
                    ordered.append(str(options_field[k]))
                return ordered
            # Fallback: sort keys and use values
            for k in sorted(options_field.keys()):
                ordered.append(str(options_field[k]))
            return ordered
        
        # Fallback to string
        return [str(options_field)]

    def _normalize_answer_letter(self, answer_field: Any, num_options: int) -> Optional[str]:
        """Normalize answer into an uppercase letter within the available options.
        Accepts only alphabetic answers A..(A+num_options-1)."""
        if answer_field is None:
            return None
        
        # MMMU stores answer as string (e.g., 'B', 'C')
        if isinstance(answer_field, str):
            s = answer_field.strip().upper()
            allowed = [chr(65+i) for i in range(num_options)]
            if s in allowed:
                return s
            # Patterns like 'Option A'
            m = re.search(r"([A-Z])", s)
            if m and m.group(1) in allowed:
                return m.group(1)
            return None
        
        # Handle integer indices (fallback)
        if isinstance(answer_field, int):
            if 0 <= answer_field < num_options:
                return chr(65 + answer_field)
            return None
        
        return None
    
    def get_simple_answer(self, image_path: str, prompt: str) -> str:
        """
        Get answer directly from VLM (simple mode).
        
        Args:
            image_path: Path to image
            prompt: Formatted question prompt
            
        Returns:
            Generated answer
        """
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image_path},
                {"type": "text", "text": prompt}
            ]
        }]
        
        text = self.vlm_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.vlm_processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.device)
        
        with torch.inference_mode():
            outputs = self.vlm_model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=False,
            )
        
        generated_ids = outputs[:, inputs["input_ids"].shape[-1]:]
        answer = self.vlm_processor.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        
        return answer
    
    def get_llm_answer(self, prompt: str) -> str:
        """
        Get answer directly from LLM (LLM-only mode).
        
        Args:
            prompt: Formatted question prompt (without image)
            
        Returns:
            Generated answer
        """
        # Construct messages for chat template
        messages = [{"role": "user", "content": prompt}]
        
        # Apply chat template
        template_kwargs = {'enable_thinking': False}
        inputs = self.llm_tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
            **template_kwargs
        )
        inputs = inputs.to(self.device)
        
        with torch.inference_mode():
            outputs = self.llm_model.generate(
                inputs,
                max_new_tokens=self.max_new_tokens,
                do_sample=False,
                temperature=None,
                top_p=None,
                top_k=None,
                pad_token_id=self.llm_tokenizer.eos_token_id
            )
        
        # Decode only the new tokens
        generated_ids = outputs[:, inputs.shape[-1]:]
        answer = self.llm_tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
        
        return answer
    
    def get_multi_stage_answer(self, image_path: str, question: str, options: List[str]) -> Dict[str, str]:
        """
        Get answer using multi-stage pipeline.
        
        Args:
            image_path: Path to image
            question: Question text
            options: Answer options
            
        Returns:
            Dictionary with description and answer
        """
        # Format the prompt for answering
        formatted_prompt = self.format_mmmu_prompt(question, options)
        
        # Use multi-stage pipeline
        result = self.pipeline.process(
            image_path=image_path,
            question=formatted_prompt,
            description_prompt="Describe this image for future question answering."
        )
        
        return result
    
    def extract_answer_letter(self, text: str, allowed_letters: Optional[str] = None) -> Optional[str]:
        """
        Extract answer letter from generated text.
        
        Args:
            text: Generated text
            allowed_letters: Allowed answer letters (e.g., "ABCD")
            
        Returns:
            Answer letter (A, B, C, D) or None
        """
        # Determine allowed letters dynamically
        if allowed_letters is None:
            allowed_letters = "ABCD"
        char_class = "[" + re.escape(allowed_letters) + "]"
        
        # Try different patterns, prioritizing "Answer: X" format
        patterns = [
            rf"Answer:\s*({char_class})",  # "Answer: X" pattern (primary)
            rf"answer:\s*({char_class})",  # lowercase version
            rf"^({char_class})[\s\.\)]",  # Letter at start
            rf"The answer is\s*({char_class})",
            rf"the answer is\s*({char_class})",
            rf"\n({char_class})[\s\.\)]",  # Letter after newline
        ]
        
        for pattern in patterns:
            match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
            if match:
                letter = match.group(1).upper()
                if letter in allowed_letters:
                    return letter
        
        # Last resort: find first allowed capital letter occurrence
        for letter in allowed_letters:
            if letter in text:
                return letter
        
        return None
    
    def process_mmmu_images(self, example: Dict[str, Any]) -> List[str]:
        """
        Process MMMU example to extract image paths.
        Only use image_1 (single image constraint).
        
        Args:
            example: MMMU dataset example
            
        Returns:
            List of temporary image file paths (max 1)
        """
        image_paths = []
        
        # Only use image_1 (single image constraint)
        img1 = example.get("image_1", None)
        if img1 is None:
            return []
        
        # Convert image_1 to usable path
        if isinstance(img1, Image.Image):
            with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
                img1.save(f.name)
                image_paths.append(f.name)
        elif isinstance(img1, str):
            image_paths.append(img1)
        
        return image_paths
    
    def _save_subject_results(self, subject_results: Dict[str, Any]):
        """Save results for a single subject incrementally."""
        subject = subject_results["subject"]
        
        # Save individual subject results as JSON
        subject_file = self.run_dir / f"{subject}_results.json"
        with open(subject_file, "w") as f:
            json.dump(subject_results, f, indent=2)
        
        # Save detailed results as CSV
        csv_file = self.run_dir / f"{subject}_detailed.csv"
        fieldnames = [
            'id', 'subject', 'question', 'options', 'correct_answer', 'predicted_answer', 
            'is_correct', 'mode', 'generated_text', 'image_description'
        ]
        
        with open(csv_file, 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for result in subject_results["results"]:
                # Convert options list to string for CSV
                csv_result = result.copy()
                csv_result['options'] = ' | '.join(result['options']) if result['options'] else ''
                writer.writerow(csv_result)
        
        print(f"Subject {subject} results saved to {subject_file} and {csv_file}")
    
    def _update_summary(self):
        """Update and save overall summary after each subject."""
        # Calculate overall and discipline accuracies
        total_correct = sum(r["correct"] for r in self.all_results.values())
        total_count = sum(r["total"] for r in self.all_results.values())
        overall_accuracy = total_correct / total_count if total_count > 0 else 0.0
        
        discipline_accuracies = {}
        for discipline, results in self.discipline_results.items():
            disc_correct = sum(r["correct"] for r in results)
            disc_total = sum(r["total"] for r in results)
            discipline_accuracies[discipline] = disc_correct / disc_total if disc_total > 0 else 0.0
        
        # Prepare summary
        summary = {
            "model": {
                "vlm": self.vlm_model_path,
                "llm": self.llm_model_path if self.mode == "multi-stage" else None,
                "mode": self.mode
            },
            "overall_accuracy": overall_accuracy,
            "discipline_accuracies": discipline_accuracies,
            "subject_accuracies": {s: r["accuracy"] for s, r in self.all_results.items()},
            "total_correct": total_correct,
            "total_questions": total_count,
            "timestamp": datetime.now().isoformat(),
            "subjects_completed": list(self.all_results.keys())
        }
        
        # Attach skipped info
        if hasattr(self, "skipped_counts") and hasattr(self, "skipped_examples"):
            summary["skipped_counts"] = dict(self.skipped_counts)
            summary["skipped_examples"] = self.skipped_examples[:50]
        
        # Save updated summary under run subfolder
        summary_file = self.run_dir / f"summary.json"
        with open(summary_file, "w") as f:
            json.dump(summary, f, indent=2)
        
        return summary
    
    def _load_results_from_run_dir(self):
        """Load all per-subject results from run_dir into memory (for final summary)."""
        self.all_results = {}
        self.discipline_results = defaultdict(list)
        for json_file in sorted(self.run_dir.glob("*_results.json")):
            try:
                with open(json_file, "r") as f:
                    subject_results = json.load(f)
            except Exception:
                continue
            subject = subject_results.get("subject", json_file.stem.replace("_results", ""))
            self.all_results[subject] = subject_results
            for discipline, subj_list in MMMU_SUBJECTS.items():
                if subject in subj_list:
                    self.discipline_results[discipline].append(subject_results)
                    break
    
    def evaluate_subject(
        self,
        subject: str,
        split: str = "validation",
        limit: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        Evaluate on a single MMMU subject.
        
        Args:
            subject: Subject name
            split: Dataset split (validation/test)
            limit: Limit number of examples
            
        Returns:
            Evaluation results
        """
        print(f"\nEvaluating subject: {subject} (mode: {self.mode})")
        
        # Load dataset
        dataset = load_dataset("MMMU/MMMU", subject, split=split)
        
        results = []
        correct = 0
        total = 0
        
        # Process examples
        examples_to_eval = dataset if limit is None else dataset.select(range(min(limit, len(dataset))))
        
        for example in tqdm(examples_to_eval, desc=f"Evaluating {subject}"):
            try:
                # Extract question and options (normalize formats)
                question = example.get("question", "")
                raw_options = example.get("options")
                options = self._normalize_options(raw_options)
                if not options:
                    msg = f"empty or invalid options"
                    self.skipped_counts[msg] += 1
                    self.skipped_examples.append({
                        "id": example.get("id", "unknown"),
                        "reason": msg
                    })
                    print(f"Skipping example {example.get('id', 'unknown')}: {msg}")
                    continue
                correct_letter = self._normalize_answer_letter(example.get("answer"), len(options))
                if correct_letter is None:
                    msg = f"unrecognized answer {example.get('answer')}"
                    self.skipped_counts[msg] += 1
                    self.skipped_examples.append({
                        "id": example.get("id", "unknown"),
                        "reason": msg
                    })
                    print(f"Skipping example {example.get('id', 'unknown')}: {msg}")
                    continue
                
                # Get answer based on mode
                if self.mode == "llm":
                    # LLM mode: no image processing needed
                    prompt = self.format_mmmu_prompt(question, options)
                    answer_text = self.get_llm_answer(prompt)
                    description = None
                else:
                    # Process images for VLM modes
                    image_paths = self.process_mmmu_images(example)
                    if not image_paths:
                        print(f"Warning: No images found for example {example.get('id', 'unknown')}")
                        continue
                    
                    # Use first image (can be extended for multi-image support)
                    image_path = image_paths[0]
                    
                    if self.mode == "multi-stage":
                        result_dict = self.get_multi_stage_answer(image_path, question, options)
                        answer_text = result_dict["answer"]
                        description = result_dict["description"]
                    else:  # simple mode
                        prompt = self.format_mmmu_prompt(question, options)
                        answer_text = self.get_simple_answer(image_path, prompt)
                        description = None
                
                # Extract answer letter with dynamic allowed set
                allowed_letters = "".join([chr(65+i) for i in range(len(options))])
                predicted_letter = self.extract_answer_letter(answer_text, allowed_letters)
                is_correct = predicted_letter == correct_letter
                
                if is_correct:
                    correct += 1
                total += 1
                
                # Store result
                result = {
                    "id": example.get("id", f"{subject}_{total}"),
                    "subject": subject,
                    "question": question,
                    "options": options,
                    "correct_answer": correct_letter,
                    "predicted_answer": predicted_letter,
                    "is_correct": is_correct,
                    "mode": self.mode,
                    "generated_text": answer_text,
                    "image_description": description if self.mode == "multi-stage" else None,
                }
                results.append(result)
                
                # Clean up temp image files (only for VLM modes)
                if self.mode != "llm" and 'image_paths' in locals():
                    for img_path in image_paths:
                        if os.path.exists(img_path):
                            os.remove(img_path)
                        
            except Exception as e:
                import traceback
                print(f"Error processing example in {subject}: {e}")
                print(f"Traceback:\n{traceback.format_exc()}")
                continue
        
        accuracy = correct / total if total > 0 else 0.0
        
        print(f"Subject {subject} completed: {correct}/{total} correct (accuracy: {accuracy:.2%})")
        
        subject_results = {
            "subject": subject,
            "accuracy": accuracy,
            "correct": correct,
            "total": total,
            "results": results
        }
        
        # Save results incrementally
        self._save_subject_results(subject_results)
        
        # Update overall tracking
        self.all_results[subject] = subject_results
        
        # Group by discipline
        for discipline, subj_list in MMMU_SUBJECTS.items():
            if subject in subj_list:
                self.discipline_results[discipline].append(subject_results)
                break
        
        # Update and save summary (only if enabled to avoid parallel write races)
        if self.update_summary:
            summary = self._update_summary()
            print(f"Updated overall accuracy: {summary['overall_accuracy']:.2%} ({summary['total_correct']}/{summary['total_questions']})")
        
        return subject_results
    
    def evaluate(
        self,
        subjects: Optional[List[str]] = None,
        split: str = "validation",
        limit: Optional[int] = None,
        save_results: bool = True,
        gpus: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Evaluate on multiple MMMU subjects.
        
        Args:
            subjects: List of subjects (None for all)
            split: Dataset split
            limit: Limit per subject
            save_results: Whether to save results
            
        Returns:
            Overall evaluation results
        """
        if subjects is None:
            subjects = ALL_MMMU_SUBJECTS
        
        # Filter to valid subjects and show overall progress bar for ETA
        valid_subjects = [s for s in subjects if s in ALL_MMMU_SUBJECTS]
        for s in subjects:
            if s not in ALL_MMMU_SUBJECTS:
                print(f"Warning: {s} not in MMMU subjects, skipping")
        
        overall_desc = f"Subjects ({self.mode})"
        for subject in tqdm(valid_subjects, desc=overall_desc, unit="subject"):
            # Results are automatically saved incrementally in evaluate_subject
            self.evaluate_subject(subject, split, limit)
        
        # In parallel mode (gpus not None), this method should only be called inside worker.
        # The main process will handle combining results.
        if save_results and gpus is None:
            self._save_combined_csv()
        
        # Get final summary (recompute from disk to be robust to interruptions/parallelism)
        self._load_results_from_run_dir()
        final_summary = self._update_summary()
        
        # Print final summary
        print("\n" + "="*60)
        print("MMMU Evaluation Results - FINAL")
        print("="*60)
        print(f"Mode: {self.mode}")
        print(f"Overall Accuracy: {final_summary['overall_accuracy']:.2%} ({final_summary['total_correct']}/{final_summary['total_questions']})")
        print("\nDiscipline Accuracies:")
        for disc, acc in final_summary['discipline_accuracies'].items():
            print(f"  {disc}: {acc:.2%}")
        print("\nSubject Accuracies:")
        for subj, acc in final_summary["subject_accuracies"].items():
            print(f"  {subj}: {acc:.2%}")
        
        return final_summary
    
    def _save_combined_csv(self):
        """Save all results combined into a single CSV file."""
        combined_csv_file = self.run_dir / f"all_results.csv"
        fieldnames = [
            'id', 'subject', 'question', 'options', 'correct_answer', 'predicted_answer', 
            'is_correct', 'mode', 'generated_text', 'image_description'
        ]
        
        with open(combined_csv_file, 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            
            # Load any existing per-subject files (supports interrupted runs)
            for json_file in sorted(self.run_dir.glob("*_results.json")):
                try:
                    with open(json_file, "r") as f:
                        subject_results = json.load(f)
                except Exception:
                    continue
                for result in subject_results.get("results", []):
                    # Convert options list to string for CSV
                    csv_result = result.copy()
                    csv_result['options'] = ' | '.join(result['options']) if result['options'] else ''
                    writer.writerow(csv_result)
        
        print(f"Combined results saved to {combined_csv_file}")
    
    def _save_results(self, summary: Dict[str, Any], detailed_results: Dict[str, Any]):
        """Legacy method - results are now saved incrementally."""
        print("Note: Results are now saved incrementally. This method is deprecated.")


# ========================== Multiprocessing worker ===========================
def worker_process(
    gpu_id: int,
    shard: List[str],
    vlm_model: Optional[str],
    llm_model: Optional[str],
    mode: str,
    output_dir: str,
    max_new_tokens: int,
    run_dir: str,
    split: str,
    limit: Optional[int]
):
    device = f"cuda:{gpu_id}"
    worker = MMMUEvaluator(
        vlm_model_path=vlm_model,
        llm_model_path=llm_model,
        mode=mode,
        output_dir=output_dir,
        max_new_tokens=max_new_tokens,
        run_name=None,
        device=device,
        update_summary=False,
        load_models=True
    )
    for subject in shard:
        worker.evaluate_subject(subject, split=split, limit=limit)

def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description="MMMU Evaluation")
    parser.add_argument(
        "--vlm_model",
        type=str,
        default="Qwen/Qwen2.5-VL-3B-Instruct",
        help="Path to VLM model (required for simple/multi-stage modes)"
    )
    parser.add_argument(
        "--llm_model",
        type=str,
        default="Qwen/Qwen3-4B",
        help="Path to LLM model (required for multi-stage/llm modes)"
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="llm",
        choices=["simple", "multi-stage", "llm"],
        help="Evaluation mode: simple (VLM only), multi-stage (VLM+LLM), or llm (LLM only, baseline)"
    )
    parser.add_argument(
        "--subjects",
        nargs="+",
        default=None,
        help="Subjects to evaluate (default: all)"
    )
    parser.add_argument(
        "--split",
        type=str,
        default="validation",
        choices=["validation", "test"],
        help="Dataset split"
    )
    parser.add_argument(
        "--limit",
        type=int,
        default=None,
        help="Limit examples per subject"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="outputs/mmmu",
        help="Output directory"
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=4096,
        help="Maximum number of new tokens to generate"
    )
    parser.add_argument(
        "--run_name",
        type=str,
        default=None,
        help="Optional run name; per-run folder becomes mmmu_<mode>_<run_name> under output_dir."
    )
    parser.add_argument(
        "--gpus",
        type=str,
        default=None,
        help="Comma-separated GPU IDs to parallelize subjects (e.g., '0,1,2'). If unset, runs single-process."
    )
    parser.add_argument(
        "--combine_from",
        type=str,
        default=None,
        help="Path to an existing run folder (e.g., outputs/mmmu/simple/mmmu_simple_<name>) to merge per-subject files into all_results.csv and summary.json, then exit."
    )
    
    args = parser.parse_args()
    
    # Combine-only mode: merge existing results and exit
    if args.combine_from:
        existing_run_dir = _Path(args.combine_from)
        
        # Create a minimal evaluator without loading models or creating new directories
        combiner = MMMUEvaluator(
            vlm_model_path=args.vlm_model,
            llm_model_path=args.llm_model,
            mode=args.mode,
            output_dir=args.output_dir,
            max_new_tokens=args.max_new_tokens,
            run_name=None,
            update_summary=True,
            load_models=False
        )
        # Override the run_dir to use the existing one
        combiner.run_dir = existing_run_dir
        
        # Save a minimal config to document combination
        combiner._save_run_config({
            "mode": args.mode,
            "combine_from": args.combine_from,
            "timestamp": datetime.now().isoformat()
        })
        combiner._save_combined_csv()
        combiner._load_results_from_run_dir()
        combiner._update_summary()
        return

    # Parallel or single-process execution
    if args.gpus:
        # Spawn one worker per GPU with a shard of subjects
        import torch.multiprocessing as mp
        gpu_ids = [int(x) for x in args.gpus.split(",") if x.strip() != ""]
        if not gpu_ids:
            raise ValueError("--gpus provided but no valid GPU IDs parsed")
        
        # Prepare subjects
        if args.subjects is None:
            all_subjects = ALL_MMMU_SUBJECTS
        else:
            all_subjects = [s for s in args.subjects if s in ALL_MMMU_SUBJECTS]
        if not all_subjects:
            print("No valid subjects to evaluate.")
            return
        
        # Make or reuse run_dir early without loading models
        tmp_eval = MMMUEvaluator(
            vlm_model_path=args.vlm_model if args.mode != "llm" else None,
            llm_model_path=args.llm_model if args.mode in ["multi-stage", "llm"] else None,
            mode=args.mode,
            output_dir=args.output_dir,
            max_new_tokens=args.max_new_tokens,
            run_name=args.run_name,
            update_summary=False,  # avoid summary writes in workers
            load_models=False
        )
        run_dir = str(tmp_eval.run_dir)
        # Save run config
        tmp_eval._save_run_config({
            "mode": args.mode,
            "vlm_model": args.vlm_model,
            "llm_model": args.llm_model,
            "subjects": args.subjects,
            "split": args.split,
            "limit": args.limit,
            "output_dir": args.output_dir,
            "run_name": args.run_name,
            "gpus": args.gpus,
            "max_new_tokens": args.max_new_tokens
        })
        
        # Partition subjects across GPUs
        subject_shards = [all_subjects[i::len(gpu_ids)] for i in range(len(gpu_ids))]
        
        processes = []
        for gpu_id, shard in zip(gpu_ids, subject_shards):
            p = mp.Process(target=worker_process, args=(
                gpu_id,
                shard,
                args.vlm_model if args.mode != "llm" else None,
                args.llm_model if args.mode in ["multi-stage", "llm"] else None,
                args.mode,
                args.output_dir,
                args.max_new_tokens,
                run_dir,
                args.split,
                args.limit
            ))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()
        
        # Combine after all workers finish
        combiner = MMMUEvaluator(
            vlm_model_path=args.vlm_model if args.mode != "llm" else None,
            llm_model_path=args.llm_model if args.mode in ["multi-stage", "llm"] else None,
            mode=args.mode,
            output_dir=args.output_dir,
            max_new_tokens=args.max_new_tokens,
            run_name=args.run_name,
            update_summary=True,
            load_models=False
        )
        # Use the same run directory that was created initially
        combiner.run_dir = tmp_eval.run_dir
        combiner._save_combined_csv()
        combiner._load_results_from_run_dir()
        combiner._update_summary()
    else:
        # Single-process
        evaluator = MMMUEvaluator(
            vlm_model_path=args.vlm_model if args.mode != "llm" else None,
            llm_model_path=args.llm_model if args.mode in ["multi-stage", "llm"] else None,
            mode=args.mode,
            output_dir=args.output_dir,
            max_new_tokens=args.max_new_tokens,
            run_name=args.run_name
        )
        # Save run config
        evaluator._save_run_config({
            "mode": args.mode,
            "vlm_model": args.vlm_model,
            "llm_model": args.llm_model,
            "subjects": args.subjects,
            "split": args.split,
            "limit": args.limit,
            "output_dir": args.output_dir,
            "run_name": args.run_name,
            "gpus": args.gpus,
            "max_new_tokens": args.max_new_tokens
        })
        evaluator.evaluate(
            subjects=args.subjects,
            split=args.split,
            limit=args.limit,
            save_results=True
        )


if __name__ == "__main__":
    import torch.multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    main()