#!/usr/bin/env python3
"""
Offline vLLM evaluation of LoRA checkpoints.

This script evaluates multiple LoRA checkpoints on specified evaluation datasets
using vLLM for fast inference. Supports both sequential and parallel evaluation.

Supports:
- Puzzle datasets (bridges, undead, galaxies, loopy, pattern) with templates
- Math datasets (AIME24, GSM8K) with <answer> tag format
- HuggingFace Hub LoRA checkpoints (--hf_lora_path)
- Baseline evaluation without LoRA (--no_lora)

Usage:
    # Puzzle evaluation with template (all checkpoints in directory)
    python tools/eval_lora_checkpoints.py \
        --base_model Qwen/Qwen2.5-7B-Instruct \
        --checkpoint_dir checkpoints/bridges_8ep \
        --eval_datasets anon-neurips26/bridges_5x5de_test200_intformat_json \
        --puzzle_templates_json '{"bridges": "prompts/bridges_intformat.txt"}' \
        --parallel --num_gpus 4

    # Single checkpoint evaluation (point directly to checkpoint folder)
    python tools/eval_lora_checkpoints.py \
        --base_model Qwen/Qwen2.5-7B-Instruct \
        --checkpoint_dir checkpoints/ds_r1_7b_lightr1_lr5e5/lora_epoch_5 \
        --eval_datasets math-ai/aime24

    # Math evaluation (AIME24)
    python tools/eval_lora_checkpoints.py \
        --base_model Qwen/Qwen2.5-7B-Instruct \
        --checkpoint_dir checkpoints/math_model \
        --eval_datasets math-ai/aime24

    # GSM8K evaluation
    python tools/eval_lora_checkpoints.py \
        --base_model Qwen/Qwen2.5-7B-Instruct \
        --checkpoint_dir checkpoints/math_model \
        --eval_datasets openai/gsm8k:main

    # HuggingFace Hub LoRA checkpoint evaluation
    python tools/eval_lora_checkpoints.py \
        --base_model Qwen/Qwen2.5-7B-Instruct \
        --hf_lora_path anon-neurips26/Qwen2.5-7B-Instruct-rsft-lightr1_only_8ep/best_checkpoint_aime24 \
        --eval_datasets math-ai/aime24

    # Baseline evaluation (no LoRA)
    python tools/eval_lora_checkpoints.py \
        --base_model Qwen/Qwen2.5-7B-Instruct \
        --no_lora \
        --eval_datasets math-ai/aime24
"""

import argparse
import json
import os
import re
import sys
from glob import glob
import multiprocessing
from multiprocessing import Process, Queue
from typing import Optional, Dict, List, Tuple, Any

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# =============================================================================
# Constants
# =============================================================================

SYSTEM_PROMPT_RSFT = (
    "A conversation between User and Assistant. The user asks a question, "
    "and the Assistant solves it step by step by reasoning. "
    "Provide the reasoning in <reasoning> reasoning here </reasoning> "
    "and the final solution within <answer> answer here </answer>"
)

SYSTEM_PROMPT_MATH = (
    "You are a helpful assistant. Solve the problem step by step. "
    "Show your reasoning and put your final answer in \\boxed{}."
)

SYSTEM_PROMPT_MATH_COT = (
    "Please reason step by step, and put your final answer within \\boxed{}."
)

# Available prompt styles
# None/empty means no system prompt (for DeepSeek-R1, Qwen3)
# 'deepseek_r1' uses the native DeepSeek R1 chat template (different special tokens)
PROMPT_STYLES = {
    'rsft': SYSTEM_PROMPT_RSFT,      # Uses <answer> tags
    'math': SYSTEM_PROMPT_MATH,       # Uses \boxed{} format
    'math_cot': SYSTEM_PROMPT_MATH_COT,  # Minimal CoT prompt with \boxed{}
    'none': None,                     # No system prompt (Qwen2.5 format without system)
    'deepseek': None,                 # Alias for 'none' - Qwen2.5 format without system
    'deepseek_r1': 'DEEPSEEK_R1_NATIVE',  # Native DeepSeek R1 format with <｜User｜> tokens
    'qwen3': 'QWEN3_THINKING',            # Qwen3 thinking mode: <|im_start|> format with <think> primer
}

PUZZLE_TYPES = ['bridges', 'undead', 'galaxies', 'loopy', 'pattern']


# =============================================================================
# Dataset Format Detection
# =============================================================================

def detect_dataset_format(example: dict) -> str:
    """Detect dataset format: 'puzzle', 'aime24', 'gsm8k', or 'unknown'."""
    if 'question' in example and 'answer' in example:
        answer = example.get('answer', '')
        if isinstance(answer, str) and '####' in answer:
            return 'gsm8k'
    
    if 'problem' in example and 'solution' in example:
        if 'id' in example or 'url' in example:
            return 'aime24'
    
    if 'puzzlename' in example:
        return 'puzzle'
    
    if 'problem' in example and 'solution' in example:
        solution = example.get('solution', '')
        if isinstance(solution, list) or (isinstance(solution, str) and solution.startswith('[')):
            return 'puzzle'
    
    return 'unknown'


def detect_puzzle_type_from_name(dataset_name: str) -> str:
    """Detect puzzle type from dataset name."""
    name_lower = dataset_name.lower()
    for puzzle_type in PUZZLE_TYPES:
        if puzzle_type in name_lower:
            return puzzle_type
    return 'unknown'


def detect_puzzle_type(example: dict, dataset_name: str) -> str:
    """Detect puzzle type from example or dataset name."""
    if 'puzzlename' in example:
        return example['puzzlename']
    return detect_puzzle_type_from_name(dataset_name)


# =============================================================================
# Puzzle Template Loading
# =============================================================================

def load_puzzle_templates(json_str: str) -> Dict[str, str]:
    """Load puzzle templates from JSON mapping. Raises ValueError if file missing."""
    if not json_str:
        return {}
    
    try:
        mapping = json.loads(json_str)
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid JSON for puzzle_templates_json: {e}")
    
    templates = {}
    for puzzle_type, template_path in mapping.items():
        if not os.path.exists(template_path):
            raise ValueError(f"Puzzle template not found for '{puzzle_type}': {template_path}")
        with open(template_path, 'r', encoding='utf-8') as f:
            templates[puzzle_type] = f.read()
        print(f"Loaded template for {puzzle_type}: {template_path}")
    
    return templates


# =============================================================================
# Prompt Construction
# =============================================================================

def format_puzzle_grid(puzzle_data: Any) -> str:
    """Format puzzle data as string. Handles list of lists or strings."""
    if isinstance(puzzle_data, list):
        return '\n'.join(','.join(str(cell) for cell in row) for row in puzzle_data)
    return str(puzzle_data)


def get_user_content(example: dict, puzzle_type: str, puzzle_templates: Dict[str, str], dataset_format: str) -> str:
    """Get user content with format-aware prompt construction."""
    if dataset_format == 'aime24':
        return str(example.get('problem', ''))
    elif dataset_format == 'gsm8k':
        return str(example.get('question', ''))
    elif dataset_format == 'puzzle':
        problem_data = example.get('problem', example.get('initial', example.get('grid', '')))
        problem_str = format_puzzle_grid(problem_data)
        
        if puzzle_type in puzzle_templates:
            template = puzzle_templates[puzzle_type]
            try:
                return template.format(problem=problem_str)
            except (KeyError, IndexError):
                try:
                    return template.format(problem_str)
                except (KeyError, IndexError):
                    return template.replace('{}', problem_str)
        else:
            print(f"Warning: No template for puzzle type '{puzzle_type}'")
            return problem_str
    else:
        return str(example.get('problem', '') or example.get('question', '') or example.get('prompt', ''))


def build_chat_prompt(user_content: str, prompt_style: str = 'rsft') -> str:
    """Build chat prompt with specified system prompt style.
    
    Args:
        user_content: The user's question/problem
        prompt_style: One of 'rsft', 'math', 'math_cot', 'none', 'deepseek', 'deepseek_r1'
    """
    system_prompt = PROMPT_STYLES.get(prompt_style, SYSTEM_PROMPT_RSFT)
    
    # DeepSeek R1 native format - uses full-width Unicode tokens
    # This matches the model's tokenizer.apply_chat_template() output
    if system_prompt == 'DEEPSEEK_R1_NATIVE':
        # DeepSeek R1 uses: <｜begin▁of▁sentence｜><｜User｜>content<｜Assistant｜><think>\n
        # Note: These are full-width Unicode characters, not regular ASCII
        # The ▁ character is U+2581 (LOWER ONE EIGHTH BLOCK)
        # Include <think>\n to match the generation prompt used during training
        result = f"<｜begin▁of▁sentence｜><｜User｜>{user_content}<｜Assistant｜><think>\n"
    # Qwen3 thinking mode - uses <|im_start|> format with <think> primer
    # Qwen3 doesn't use system prompt in thinking mode; primes with <think>\n
    elif system_prompt == 'QWEN3_THINKING':
        result = f"<|im_start|>user\n{user_content}<|im_end|>\n<|im_start|>assistant\n<think>\n"
    # For 'none' or 'deepseek' style, skip system prompt entirely (Qwen2.5 format)
    elif system_prompt is None:
        result = f"<|im_start|>user\n{user_content}<|im_end|>\n<|im_start|>assistant\n"
    else:
        result = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_content}<|im_end|>\n<|im_start|>assistant\n"
    
    return result


# =============================================================================
# Answer Extraction and Comparison
# =============================================================================

def extract_answer(response: str) -> Optional[str]:
    """Extract answer from response, trying multiple formats.
    
    Priority:
    1. <final>...</final> tags (preferred for puzzle grids, Qwen3/DeepSeek R1 puzzle training)
    2. <answer>...</answer> tags (Qwen2.5 style)
    3. \\boxed{...} format (common in math models)
    4. Incomplete <final> or <answer> tag at end of response
    """
    # Try <final>...</final> tags first (preferred for puzzle grids)
    all_finals = re.findall(r'<final>(.*?)</final>', response, re.DOTALL)
    if all_finals:
        return all_finals[-1].strip()
    
    # Fallback: <answer>...</answer> tags
    all_answers = re.findall(r'<answer>(.*?)</answer>', response, re.DOTALL)
    if all_answers:
        return all_answers[-1].strip()
    
    # Fallback: try \boxed{...} format (common in math models like DeepSeek-R1)
    # Handle nested braces by finding the last \boxed and matching braces
    boxed_matches = list(re.finditer(r'\\boxed\{', response))
    if boxed_matches:
        # Use the last \boxed occurrence
        last_match = boxed_matches[-1]
        start = last_match.end()
        # Match braces to handle nested content like \boxed{\frac{1}{2}}
        depth = 1
        pos = start
        while pos < len(response) and depth > 0:
            if response[pos] == '{':
                depth += 1
            elif response[pos] == '}':
                depth -= 1
            pos += 1
        if depth == 0:
            content = response[start:pos-1].strip()
            if content:
                return content
    
    # Fallback: last JSON code block with {"response": ...} (Qwen3 native format)
    json_blocks = re.findall(r'```json\s*(\{.*?\})\s*```', response, re.DOTALL)
    if json_blocks:
        return json_blocks[-1].strip()
    if '{"response"' in response:
        json_matches = re.findall(r'(\{"response".*?\})', response, re.DOTALL)
        if json_matches:
            return json_matches[-1].strip()

    # Fallback: incomplete <final> tag at end
    incomplete_final = re.search(r'<final>\s*(.*?)$', response, re.DOTALL)
    if incomplete_final:
        content = incomplete_final.group(1).strip()
        if content:
            return content
    
    # Fallback: incomplete <answer> tag at end
    incomplete_answer = re.search(r'<answer>\s*(.*?)$', response, re.DOTALL)
    if incomplete_answer:
        content = incomplete_answer.group(1).strip()
        if content:
            return content
    
    return None


def normalize_json_grid(grid_str: str) -> Any:
    """Normalize grid string to JSON object for comparison."""
    if not grid_str:
        return None
    grid_str = grid_str.strip()

    # Unwrap {"response": "..."} JSON wrapper (from puzzle_mode training)
    # The model often produces invalid JSON like {"response": "[["4", "8"...]]"}
    # where inner quotes break the outer JSON string, so we use regex extraction
    if grid_str.startswith('{') and '"response"' in grid_str:
        # Try valid JSON first
        try:
            obj = json.loads(grid_str)
            if isinstance(obj, dict) and 'response' in obj:
                grid_str = obj['response'].strip()
        except (json.JSONDecodeError, AttributeError):
            # Fallback: extract content from {"response": "..."} wrapper
            import re
            # Try JSON array format first: [[ ... ]]
            grid_match = re.search(r'(\[\[.*\]\])', grid_str, re.DOTALL)
            if grid_match:
                grid_str = grid_match.group(1).strip()
            else:
                # ASCII grid format: extract content between "response": " and trailing "}
                # Handles literal newlines inside JSON string (model emits \n instead of \\n)
                resp_match = re.search(r'"response"\s*:\s*"(.*)"', grid_str, re.DOTALL)
                if resp_match:
                    content = resp_match.group(1).strip()
                    # Unescape any JSON escape sequences that did make it through
                    content = content.replace('\\n', '\n').replace('\\"', '"').replace('\\\\', '\\')
                    grid_str = content.strip()

    if grid_str.startswith('['):
        try:
            return json.loads(grid_str)
        except json.JSONDecodeError:
            # Try cleaning escaped quotes (e.g. from regex extraction of escaped JSON)
            cleaned = grid_str.replace('\\"', '"')
            if cleaned != grid_str:
                try:
                    return json.loads(cleaned)
                except json.JSONDecodeError:
                    pass

    lines = grid_str.strip().split('\n')
    if lines:
        try:
            grid = []
            for line in lines:
                line = line.strip()
                if line:
                    if line.startswith('['):
                        row = json.loads(line)
                    else:
                        row = [cell.strip().strip('"\'') for cell in line.split(',')]
                    grid.append(row)
            return grid
        except (json.JSONDecodeError, ValueError):
            pass
    return grid_str


def calculate_partial_correctness(predicted: List, solution: List, problem: Any = None) -> float:
    """Calculate partial correctness for puzzle grids."""
    if not predicted or not solution or len(predicted) != len(solution):
        return 0.0
    
    problem_grid = None
    if problem is not None:
        if isinstance(problem, list):
            problem_grid = problem
        else:
            pg = normalize_json_grid(str(problem))
            if isinstance(pg, list):
                problem_grid = pg
    
    total_cells = 0
    correct_cells = 0
    
    for row_idx, (pred_row, sol_row) in enumerate(zip(predicted, solution)):
        if len(pred_row) != len(sol_row):
            return 0.0
        for col_idx, (pred_cell, sol_cell) in enumerate(zip(pred_row, sol_row)):
            if problem_grid and row_idx < len(problem_grid) and col_idx < len(problem_grid[row_idx]):
                if str(problem_grid[row_idx][col_idx]) == str(sol_cell):
                    continue
            total_cells += 1
            if str(pred_cell) == str(sol_cell):
                correct_cells += 1
    
    return correct_cells / total_cells if total_cells > 0 else 1.0


def compare_puzzle_solutions(predicted: str, ground_truth: Any, problem: Any = None) -> Tuple[bool, float]:
    """Compare puzzle solutions as JSON objects. Returns (exact_match, partial_score)."""
    if predicted is None:
        return False, 0.0
    
    pred_grid = normalize_json_grid(predicted)
    gt_grid = ground_truth if isinstance(ground_truth, list) else normalize_json_grid(str(ground_truth))
    
    exact_match = pred_grid == gt_grid
    
    partial_score = 0.0
    if isinstance(pred_grid, list) and isinstance(gt_grid, list):
        try:
            partial_score = calculate_partial_correctness(pred_grid, gt_grid, problem)
        except Exception:
            partial_score = 1.0 if exact_match else 0.0
    else:
        partial_score = 1.0 if exact_match else 0.0
    
    return exact_match, partial_score


def normalize_math_answer(answer: str) -> str:
    """Normalize math answer (handle \\boxed{}, whitespace, leading zeros)."""
    if not answer:
        return ""
    answer = answer.strip()
    boxed = re.search(r'\\boxed\{([^}]*)\}', answer)
    if boxed:
        answer = boxed.group(1).strip()
    try:
        if answer.isdigit():
            answer = str(int(answer))
    except ValueError:
        pass
    return answer


def extract_gsm8k_answer(answer_text: str) -> str:
    """Extract final answer from GSM8K format (after ####)."""
    if '####' in answer_text:
        return answer_text.split('####')[-1].strip()
    return answer_text.strip()


def compare_math_solutions(predicted: str, ground_truth: str, dataset_format: str) -> Tuple[bool, float]:
    """Compare math solutions. Returns (exact_match, partial_score)."""
    if predicted is None:
        return False, 0.0
    pred_normalized = normalize_math_answer(predicted)
    gt_normalized = extract_gsm8k_answer(ground_truth) if dataset_format == 'gsm8k' else normalize_math_answer(ground_truth)
    exact_match = pred_normalized == gt_normalized
    return exact_match, 1.0 if exact_match else 0.0


# =============================================================================
# Dataset Loading
# =============================================================================

def load_eval_dataset(
    dataset_name: str,
    max_samples: Optional[int] = None,
    puzzle_templates: Dict[str, str] = None,
    prompt_style: str = 'rsft'
) -> Tuple[List[str], List[dict], str]:
    """
    Load evaluation dataset with format-aware prompt construction.
    
    Args:
        dataset_name: HuggingFace dataset name
        max_samples: Maximum samples to load
        puzzle_templates: Dict of puzzle type -> template path
        prompt_style: Prompt style ('rsft', 'math', 'math_cot')
    
    Returns: (prompts, ground_truth_records, dataset_format)
    """
    from datasets import load_dataset
    
    puzzle_templates = puzzle_templates or {}
    
    # Handle dataset:config syntax
    if ':' in dataset_name and not dataset_name.startswith('/'):
        ds_path, ds_config = dataset_name.rsplit(':', 1)
        print(f"Loading dataset '{ds_path}' with config '{ds_config}'")
        try:
            ds = load_dataset(ds_path, ds_config, split="test")
        except ValueError:
            ds = load_dataset(ds_path, ds_config, split="train")
    else:
        try:
            ds = load_dataset(dataset_name, split="test")
        except ValueError:
            ds = load_dataset(dataset_name, split="train")
    
    if max_samples and max_samples < len(ds):
        ds = ds.select(range(max_samples))
    
    if len(ds) == 0:
        return [], [], 'unknown'
    
    dataset_format = detect_dataset_format(ds[0])
    puzzle_type = detect_puzzle_type(ds[0], dataset_name)
    
    print(f"  Dataset: {dataset_name}")
    print(f"  Format: {dataset_format}, Puzzle type: {puzzle_type}")
    print(f"  Samples: {len(ds)}, Prompt style: {prompt_style}")
    
    if dataset_format == 'puzzle' and puzzle_type not in puzzle_templates:
        raise ValueError(
            f"Puzzle dataset '{dataset_name}' (type: {puzzle_type}) requires template.\n"
            f"Use --puzzle_templates_json '{{\"<{puzzle_type}>\": \"path/to/template.txt\"}}'"
        )
    
    prompts = []
    ground_truth_records = []
    
    for sample in ds:
        user_content = get_user_content(sample, puzzle_type, puzzle_templates, dataset_format)
        prompt = build_chat_prompt(user_content, prompt_style)
        prompts.append(prompt)
        
        if dataset_format == 'gsm8k':
            solution = sample.get('answer', '')
            problem = sample.get('question', '')
        else:
            solution = sample.get('solution', sample.get('answer', ''))
            problem = sample.get('problem', sample.get('question', ''))
        
        ground_truth_records.append({
            'solution': solution, 'problem': problem,
            'format': dataset_format, 'puzzle_type': puzzle_type,
        })
    
    return prompts, ground_truth_records, dataset_format


# =============================================================================
# Evaluation Functions
# =============================================================================

def evaluate_responses(
    outputs: List, 
    ground_truth_records: List[dict], 
    dataset_format: str, 
    show_examples: int = 3,
    prompts: List[str] = None,
    return_details: bool = False,
    n_samples: int = 1
) -> dict:
    """Evaluate model outputs against ground truth.
    
    Args:
        outputs: vLLM output objects
        ground_truth_records: List of ground truth records
        dataset_format: 'puzzle', 'aime24', 'gsm8k', etc.
        show_examples: Number of examples to print
        prompts: Optional list of prompts (for dump_generations)
        return_details: If True, include detailed per-example results
        n_samples: Number of samples per prompt (for pass@k evaluation)
    
    Returns:
        dict with accuracy metrics and optionally 'details' list
    """
    correct = 0
    total = len(ground_truth_records)
    partial_scores = []
    details = [] if return_details else None
    
    for idx, (output, gt_record) in enumerate(zip(outputs, ground_truth_records)):
        solution = gt_record['solution']
        problem = gt_record.get('problem')
        fmt = gt_record['format']
        
        # For pass@k evaluation: check if ANY of the n samples is correct
        best_exact_match = False
        best_partial_score = 0.0
        best_predicted = None
        best_response = None
        num_correct_samples = 0
        
        for sample_idx in range(len(output.outputs)):
            response_text = output.outputs[sample_idx].text
            predicted = extract_answer(response_text)
            
            if fmt == 'puzzle':
                exact_match, partial_score = compare_puzzle_solutions(predicted, solution, problem)
            else:
                exact_match, partial_score = compare_math_solutions(predicted, str(solution), fmt)
            
            if exact_match:
                num_correct_samples += 1
                if not best_exact_match:  # Take first correct one
                    best_exact_match = True
                    best_predicted = predicted
                    best_response = response_text
                    best_partial_score = partial_score
            elif partial_score > best_partial_score and not best_exact_match:
                best_partial_score = partial_score
                best_predicted = predicted
                best_response = response_text
        
        # If no samples processed yet, use first one
        if best_response is None:
            best_response = output.outputs[0].text
            best_predicted = extract_answer(best_response)
            if fmt == 'puzzle':
                best_exact_match, best_partial_score = compare_puzzle_solutions(best_predicted, solution, problem)
            else:
                best_exact_match, best_partial_score = compare_math_solutions(best_predicted, str(solution), fmt)
        
        if best_exact_match:
            correct += 1
        partial_scores.append(best_partial_score)
        
        if return_details:
            detail = {
                'idx': idx,
                'problem': str(problem) if problem else None,
                'ground_truth': str(solution) if not isinstance(solution, (list, dict)) else solution,
                'response': best_response,
                'extracted_answer': best_predicted,
                'exact_match': best_exact_match,
                'partial_score': best_partial_score,
            }
            if n_samples > 1:
                detail['n_samples'] = n_samples
                detail['num_correct_samples'] = num_correct_samples
            if prompts and idx < len(prompts):
                detail['prompt'] = prompts[idx]
            details.append(detail)
        
        if idx < show_examples:
            sol_preview = str(solution)[:100] + ('...' if len(str(solution)) > 100 else '')
            pred_preview = str(best_predicted)[:100] + ('...' if best_predicted and len(best_predicted) > 100 else '')
            print(f"\n  Example {idx + 1}:")
            print(f"    Ground truth: {sol_preview}")
            print(f"    Predicted: {pred_preview}")
            if n_samples > 1:
                print(f"    Exact match: {best_exact_match}, Partial: {best_partial_score:.4f} ({num_correct_samples}/{n_samples} samples correct)")
            else:
                print(f"    Exact match: {best_exact_match}, Partial: {best_partial_score:.4f}")
    
    result = {'accuracy': correct / total if total > 0 else 0.0, 'correct': correct, 'total': total, 'format': dataset_format}
    if n_samples > 1:
        result['n_samples'] = n_samples
        result['eval_method'] = 'pass@1'
    if dataset_format == 'puzzle' and partial_scores:
        result['partial_correctness_avg'] = sum(partial_scores) / len(partial_scores)
    if return_details:
        result['details'] = details
    return result


def gpu_eval_worker(args_tuple, result_queue):
    """Persistent GPU worker: loads vLLM once, evaluates assigned (checkpoint, dataset) jobs.

    Each worker receives a list of (checkpoint_or_None, dataset_name) pairs, groups them
    by checkpoint to minimize LoRA adapter switches, and sends per-job results via queue.
    """
    gpu_id, base_model, jobs, max_new_tokens, max_samples, puzzle_templates_json, max_model_len, prompt_style, temperature, dump_generations, n_samples = args_tuple

    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest
    from collections import OrderedDict

    # Group jobs by checkpoint for efficient LoRA switching
    job_groups = OrderedDict()
    for ckpt, ds_name in jobs:
        job_groups.setdefault(ckpt, []).append(ds_name)

    has_lora = any(ckpt is not None for ckpt in job_groups)

    print(f"[GPU {gpu_id}] Loading vLLM (enable_lora={has_lora}), {len(jobs)} jobs across {len(job_groups)} checkpoint(s)")

    puzzle_templates = load_puzzle_templates(puzzle_templates_json) if puzzle_templates_json else {}

    llm_kwargs = dict(model=base_model, enable_lora=has_lora, tensor_parallel_size=1, trust_remote_code=True, max_model_len=max_model_len)
    if has_lora:
        llm_kwargs['max_lora_rank'] = 128
    llm = LLM(**llm_kwargs)
    sampling_params = SamplingParams(temperature=temperature, max_tokens=max_new_tokens, stop=["<|im_end|>", "<|endoftext|>"], n=n_samples)

    lora_id_counter = 1
    for ckpt, datasets in job_groups.items():
        ckpt_key = ckpt if ckpt is not None else "baseline_no_lora"
        ckpt_display = os.path.basename(ckpt) if ckpt is not None else "baseline"

        lora_request = None
        if ckpt is not None:
            lora_request = LoRARequest("adapter", lora_id_counter, ckpt)
            lora_id_counter += 1
            print(f"[GPU {gpu_id}] Loaded LoRA: {ckpt_display}")
        else:
            print(f"[GPU {gpu_id}] Using base model (no LoRA)")

        for ds_name in datasets:
            prompts, ground_truth_records, dataset_format = load_eval_dataset(ds_name, max_samples, puzzle_templates, prompt_style)
            total_samples = len(prompts) * n_samples
            print(f"[GPU {gpu_id}] {ckpt_display} | {ds_name}: generating {total_samples} responses...")

            if lora_request is not None:
                outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
            else:
                outputs = llm.generate(prompts, sampling_params)

            ds_results = evaluate_responses(outputs, ground_truth_records, dataset_format, show_examples=0,
                                              prompts=prompts, return_details=dump_generations, n_samples=n_samples)

            gen_details = None
            if dump_generations and 'details' in ds_results:
                gen_details = ds_results.pop('details')

            acc, correct, total = ds_results['accuracy'], ds_results['correct'], ds_results['total']
            partial_str = f", partial: {ds_results['partial_correctness_avg']:.4f}" if 'partial_correctness_avg' in ds_results else ""
            if n_samples > 1:
                print(f"[GPU {gpu_id}] {ckpt_display} | {ds_name}: pass@1={acc:.4f} ({correct}/{total}){partial_str}")
            else:
                print(f"[GPU {gpu_id}] {ckpt_display} | {ds_name}: {acc:.4f} ({correct}/{total}){partial_str}")

            result_queue.put((ckpt_key, ds_name, ds_results, gen_details))


def evaluate_jobs_parallel(args, checkpoints, puzzle_templates_json):
    """Distribute M checkpoints x N datasets across K GPUs.

    Jobs are chunked (not round-robin) so same-checkpoint jobs stay on the same
    GPU, minimizing LoRA adapter switches. Each GPU worker loads vLLM once and
    processes its assigned jobs sequentially.

    Args:
        checkpoints: list of lora_path strings, or None for baseline.
    """
    import math

    n_samples = getattr(args, 'n_samples', 1)
    eval_datasets = args.eval_datasets

    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    if visible_devices:
        available_gpus = [int(d.strip()) for d in visible_devices.split(",")]
    else:
        available_gpus = list(range(args.num_gpus))
    num_gpus = min(args.num_gpus, len(available_gpus))

    # Build all jobs sorted by checkpoint (chunking keeps same-checkpoint jobs together)
    all_jobs = [(ckpt, ds) for ckpt in checkpoints for ds in eval_datasets]
    num_workers = min(num_gpus, len(all_jobs))

    ckpt_labels = [os.path.basename(c) if c else "baseline" for c in checkpoints]
    print(f"Parallel evaluation: {len(checkpoints)} checkpoint(s) x {len(eval_datasets)} dataset(s) = {len(all_jobs)} jobs on {num_workers} GPUs")
    print(f"  Checkpoints: {ckpt_labels}")
    print(f"  Datasets: {eval_datasets}")

    # Chunk jobs across GPUs
    chunk_size = math.ceil(len(all_jobs) / num_workers)
    gpu_jobs = [all_jobs[i * chunk_size : min((i + 1) * chunk_size, len(all_jobs))]
                for i in range(num_workers)]

    for i, jobs in enumerate(gpu_jobs):
        ckpts_on_gpu = list(dict.fromkeys((os.path.basename(c) if c else "baseline") for c, _ in jobs))
        print(f"  GPU {available_gpus[i]}: {len(jobs)} jobs ({', '.join(ckpts_on_gpu)})")

    tasks = [
        (available_gpus[i], args.base_model, jobs,
         args.max_new_tokens, args.max_samples, puzzle_templates_json,
         args.max_model_len, args.prompt_style, args.temperature,
         args.dump_generations, n_samples)
        for i, jobs in enumerate(gpu_jobs)
    ]

    mp_ctx = multiprocessing.get_context('spawn')
    result_queue = mp_ctx.Queue()

    processes = []
    for task in tasks:
        p = mp_ctx.Process(target=gpu_eval_worker, args=(task, result_queue))
        p.daemon = False
        processes.append(p)

    for p in processes:
        p.start()

    # Collect all results
    results = {}
    all_generations = {}
    for _ in range(len(all_jobs)):
        ckpt_key, ds_name, ds_results, gen_details = result_queue.get()
        results.setdefault(ckpt_key, {})[ds_name] = ds_results
        if gen_details is not None:
            all_generations.setdefault(ckpt_key, {})[ds_name] = gen_details

    for p in processes:
        p.join()

    if all_generations:
        results['_generations'] = all_generations

    return results


def evaluate_sequential(args, checkpoints, puzzle_templates):
    """Evaluate checkpoints sequentially on a single GPU."""
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest
    
    print(f"Loading vLLM with base model: {args.base_model}")
    
    llm = LLM(model=args.base_model, enable_lora=True, max_lora_rank=128, tensor_parallel_size=1, trust_remote_code=True, max_model_len=args.max_model_len)
    
    # Use n parameter for pass@k evaluation (Light-R1 paper uses n=64, temp=0.6)
    n_samples = getattr(args, 'n_samples', 1)
    sampling_params = SamplingParams(
        temperature=args.temperature, 
        max_tokens=args.max_new_tokens, 
        stop=["<|im_end|>", "<|endoftext|>"],
        n=n_samples  # Number of samples per prompt
    )
    
    if n_samples > 1:
        print(f"\nPass@1 evaluation mode: {n_samples} samples per query, temperature={args.temperature}")
    
    print(f"\nLoading {len(args.eval_datasets)} evaluation datasets...")
    eval_data = {}
    for ds_name in args.eval_datasets:
        prompts, ground_truth_records, dataset_format = load_eval_dataset(ds_name, args.max_samples, puzzle_templates, args.prompt_style)
        eval_data[ds_name] = (prompts, ground_truth_records, dataset_format)
    
    results = {}
    for lora_id, ckpt in enumerate(checkpoints, start=1):
        print(f"\n{'='*60}")
        print(f"Evaluating checkpoint {lora_id}/{len(checkpoints)}: {os.path.basename(ckpt)}")
        print(f"{'='*60}")
        
        lora_request = LoRARequest("adapter", lora_id, ckpt)
        
        results[ckpt] = {}
        
        for ds_name, (prompts, ground_truth_records, dataset_format) in eval_data.items():
            total_samples = len(prompts) * n_samples
            print(f"\n  Generating {total_samples} responses for {ds_name} ({len(prompts)} prompts x {n_samples} samples)...")
            outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
            
            ds_results = evaluate_responses(outputs, ground_truth_records, dataset_format, show_examples=3,
                                              prompts=prompts, return_details=args.dump_generations, n_samples=n_samples)

            # Extract details for dump_generations
            if args.dump_generations and 'details' in ds_results:
                results.setdefault('_generations', {}).setdefault(ckpt, {})[ds_name] = ds_results.pop('details')

            results[ckpt][ds_name] = ds_results

            acc, correct, total = ds_results['accuracy'], ds_results['correct'], ds_results['total']
            partial_str = f", partial: {ds_results['partial_correctness_avg']:.4f}" if 'partial_correctness_avg' in ds_results else ""
            if n_samples > 1:
                print(f"  {ds_name}: pass@1={acc:.4f} ({correct}/{total}){partial_str}")
            else:
                print(f"  {ds_name}: {acc:.4f} ({correct}/{total}){partial_str}")

    return results


# =============================================================================
# Main
# =============================================================================

def download_hf_lora(hf_lora_path: str, cache_dir: Optional[str] = None) -> str:
    """
    Download LoRA checkpoint from HuggingFace Hub if needed.
    
    Args:
        hf_lora_path: HuggingFace Hub path like "user/repo/subfolder" or "user/repo"
        cache_dir: Optional cache directory
    
    Returns:
        Local path to the LoRA checkpoint
    """
    from huggingface_hub import snapshot_download, hf_hub_download
    import tempfile
    
    # Parse the HF path - could be "user/repo" or "user/repo/subfolder"
    parts = hf_lora_path.split("/")
    
    if len(parts) >= 3:
        # Format: user/repo/subfolder or user/repo/sub/folder
        repo_id = f"{parts[0]}/{parts[1]}"
        subfolder = "/".join(parts[2:])
    else:
        # Format: user/repo
        repo_id = hf_lora_path
        subfolder = None
    
    print(f"Downloading LoRA from HuggingFace Hub:")
    print(f"  Repo: {repo_id}")
    if subfolder:
        print(f"  Subfolder: {subfolder}")
    
    cache_dir = cache_dir or os.path.join(tempfile.gettempdir(), "hf_lora_cache")
    os.makedirs(cache_dir, exist_ok=True)
    
    try:
        if subfolder:
            # Download only the subfolder
            local_dir = snapshot_download(
                repo_id=repo_id,
                allow_patterns=[f"{subfolder}/*", f"{subfolder}/**/*"],
                cache_dir=cache_dir,
                local_dir_use_symlinks=False
            )
            local_path = os.path.join(local_dir, subfolder)
        else:
            # Download entire repo
            local_path = snapshot_download(
                repo_id=repo_id,
                cache_dir=cache_dir,
                local_dir_use_symlinks=False
            )
        
        # Verify it's a valid LoRA checkpoint
        adapter_path = os.path.join(local_path, "adapter_model.safetensors")
        if not os.path.exists(adapter_path):
            # Try looking for adapter_model.bin as fallback
            adapter_path_bin = os.path.join(local_path, "adapter_model.bin")
            if not os.path.exists(adapter_path_bin):
                raise ValueError(
                    f"No adapter_model.safetensors or adapter_model.bin found in {local_path}. "
                    f"This doesn't appear to be a valid LoRA checkpoint."
                )
        
        print(f"  Downloaded to: {local_path}")
        return local_path
        
    except Exception as e:
        raise ValueError(f"Failed to download LoRA from {hf_lora_path}: {e}")


def evaluate_baseline(args, puzzle_templates):
    """Evaluate base model without any LoRA adapter."""
    from vllm import LLM, SamplingParams
    
    print(f"Loading vLLM with base model (no LoRA): {args.base_model}")
    
    n_samples = getattr(args, 'n_samples', 1)
    
    # Note: enable_lora=False for baseline
    llm = LLM(
        model=args.base_model,
        enable_lora=False,
        tensor_parallel_size=1,
        trust_remote_code=True,
        max_model_len=args.max_model_len
    )
    sampling_params = SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_new_tokens,
        stop=["<|im_end|>", "<|endoftext|>"],
        n=n_samples
    )
    
    print(f"\nLoading {len(args.eval_datasets)} evaluation datasets...")
    print(f"Temperature: {args.temperature}")
    if n_samples > 1:
        print(f"Pass@1 evaluation: {n_samples} samples per query")
    eval_data = {}
    for ds_name in args.eval_datasets:
        prompts, ground_truth_records, dataset_format = load_eval_dataset(
            ds_name, args.max_samples, puzzle_templates, args.prompt_style
        )
        eval_data[ds_name] = (prompts, ground_truth_records, dataset_format)
    
    results = {}
    all_generations = {}  # For dump_generations
    ckpt_name = "baseline_no_lora"
    
    print(f"\n{'='*60}")
    print(f"Evaluating baseline model (no LoRA)")
    print(f"{'='*60}")
    
    results[ckpt_name] = {}
    
    for ds_name, (prompts, ground_truth_records, dataset_format) in eval_data.items():
        total_samples = len(prompts) * n_samples
        print(f"\n  Generating {total_samples} responses for {ds_name} ({len(prompts)} prompts x {n_samples} samples)...")
        outputs = llm.generate(prompts, sampling_params)
        
        ds_results = evaluate_responses(
            outputs, ground_truth_records, dataset_format, 
            show_examples=3, prompts=prompts, 
            return_details=args.dump_generations,
            n_samples=n_samples
        )
        
        # Extract details before removing from results dict
        if args.dump_generations and 'details' in ds_results:
            all_generations[ds_name] = ds_results.pop('details')
        
        results[ckpt_name][ds_name] = ds_results
        
        acc = ds_results['accuracy']
        correct = ds_results['correct']
        total = ds_results['total']
        partial_str = ""
        if 'partial_correctness_avg' in ds_results:
            partial_str = f", partial: {ds_results['partial_correctness_avg']:.4f}"
        print(f"  {ds_name}: {acc:.4f} ({correct}/{total}){partial_str}")
    
    if args.dump_generations:
        results['_generations'] = {ckpt_name: all_generations}
    
    return results


def evaluate_single_hf_lora(args, lora_path: str, puzzle_templates):
    """Evaluate a single LoRA checkpoint (from HF Hub or local)."""
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest
    
    print(f"Loading vLLM with base model: {args.base_model}")
    
    n_samples = getattr(args, 'n_samples', 1)
    
    llm = LLM(
        model=args.base_model,
        enable_lora=True,
        max_lora_rank=128,
        tensor_parallel_size=1,
        trust_remote_code=True,
        max_model_len=args.max_model_len
    )
    sampling_params = SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_new_tokens,
        stop=["<|im_end|>", "<|endoftext|>"],
        n=n_samples
    )
    
    print(f"\nLoading {len(args.eval_datasets)} evaluation datasets...")
    print(f"Temperature: {args.temperature}")
    if n_samples > 1:
        print(f"Pass@1 evaluation: {n_samples} samples per query")
    eval_data = {}
    for ds_name in args.eval_datasets:
        prompts, ground_truth_records, dataset_format = load_eval_dataset(
            ds_name, args.max_samples, puzzle_templates, args.prompt_style
        )
        eval_data[ds_name] = (prompts, ground_truth_records, dataset_format)
    
    results = {}
    all_generations = {}  # For dump_generations
    ckpt_name = os.path.basename(lora_path)
    
    print(f"\n{'='*60}")
    print(f"Evaluating checkpoint: {ckpt_name}")
    print(f"{'='*60}")
    
    lora_request = LoRARequest("adapter", 1, lora_path)
    results[lora_path] = {}
    
    for ds_name, (prompts, ground_truth_records, dataset_format) in eval_data.items():
        total_samples = len(prompts) * n_samples
        print(f"\n  Generating {total_samples} responses for {ds_name} ({len(prompts)} prompts x {n_samples} samples)...")
        outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
        
        ds_results = evaluate_responses(
            outputs, ground_truth_records, dataset_format,
            show_examples=3, prompts=prompts,
            return_details=args.dump_generations,
            n_samples=n_samples
        )
        
        # Extract details before removing from results dict
        if args.dump_generations and 'details' in ds_results:
            all_generations[ds_name] = ds_results.pop('details')
        
        results[lora_path][ds_name] = ds_results
        
        acc = ds_results['accuracy']
        correct = ds_results['correct']
        total = ds_results['total']
        partial_str = ""
        if 'partial_correctness_avg' in ds_results:
            partial_str = f", partial: {ds_results['partial_correctness_avg']:.4f}"
        if n_samples > 1:
            print(f"  {ds_name}: pass@1={acc:.4f} ({correct}/{total}){partial_str}")
        else:
            print(f"  {ds_name}: {acc:.4f} ({correct}/{total}){partial_str}")
    
    if args.dump_generations:
        results['_generations'] = {lora_path: all_generations}
    
    return results


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate LoRA checkpoints using vLLM (puzzles and math)",
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument("--base_model", required=True, help="Base model name or path")
    parser.add_argument("--checkpoint_dir", default=None, help="Directory containing LoRA checkpoints")
    parser.add_argument("--hf_lora_path", default=None, 
                        help="HuggingFace Hub LoRA path (e.g., user/repo/subfolder)")
    parser.add_argument("--no_lora", action="store_true", 
                        help="Evaluate base model without LoRA (baseline)")
    parser.add_argument("--eval_datasets", nargs="+", required=True, help="HuggingFace dataset names")
    parser.add_argument("--max_new_tokens", type=int, default=8192, help="Maximum new tokens")
    parser.add_argument("--max_model_len", type=int, default=10000, help="Maximum model context length for vLLM")
    parser.add_argument("--max_samples", type=int, default=None, help="Maximum samples per dataset")
    parser.add_argument("--parallel", action="store_true", help="Parallel evaluation across GPUs")
    parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs for parallel mode")
    parser.add_argument("--upload_best_to_hf", default=None, help="HF repo to upload best checkpoint")
    parser.add_argument("--checkpoint_pattern", default=None, help="Glob pattern for checkpoints")
    parser.add_argument("--wandb_project", default=None, help="WandB project name")
    parser.add_argument("--wandb_run_name", default=None, help="WandB run name")
    parser.add_argument("--upload_results_to_hf", default=None, help="HF repo for eval_results.json")
    parser.add_argument("--puzzle_templates_json", default=None,
                        help='JSON mapping: \'{"bridges": "prompts/bridges.txt"}\'')
    parser.add_argument("--output_dir", default=None, 
                        help="Output directory for results (default: checkpoint_dir or current dir)")
    parser.add_argument("--prompt_style", default="rsft", choices=["rsft", "math", "math_cot", "none", "deepseek", "deepseek_r1", "qwen3"],
                        help="Prompt style: 'rsft' (system prompt with <answer> tags), 'math' (uses \\boxed{}), 'math_cot' (minimal CoT), 'none'/'deepseek' (Qwen2.5 format, no system), 'deepseek_r1' (native DeepSeek R1 format with <think>), 'qwen3' (Qwen3 thinking mode with <think> primer)")
    parser.add_argument("--temperature", type=float, default=0.0,
                        help="Sampling temperature (0.0 for greedy, >0 for sampling)")
    parser.add_argument("--n_samples", type=int, default=1,
                        help="Number of samples per query for pass@k evaluation (Light-R1 paper uses 64)")
    parser.add_argument("--dump_generations", action="store_true", default=True,
                        help="Save all generations to JSONL file (prompts, responses, ground truth)")
    parser.add_argument("--no_dump_generations", action="store_false", dest="dump_generations",
                        help="Disable saving generations to JSONL")
    args = parser.parse_args()
    
    # Validate arguments
    if not args.no_lora and not args.checkpoint_dir and not args.hf_lora_path:
        parser.error("Must specify --checkpoint_dir, --hf_lora_path, or --no_lora")
    
    # Load puzzle templates
    puzzle_templates = load_puzzle_templates(args.puzzle_templates_json) if args.puzzle_templates_json else {}
    
    # Determine output directory
    output_dir = args.output_dir or args.checkpoint_dir or "."
    os.makedirs(output_dir, exist_ok=True)
    
    # Mode 1: Baseline evaluation (no LoRA)
    if args.no_lora:
        print("="*60)
        print("BASELINE EVALUATION MODE (no LoRA)")
        print("="*60)
        if args.parallel:
            results = evaluate_jobs_parallel(args, [None], args.puzzle_templates_json)
        else:
            results = evaluate_baseline(args, puzzle_templates)
        best_ckpt = "baseline_no_lora"
        checkpoint_scores = {best_ckpt: sum(r["accuracy"] for r in results[best_ckpt].values()) / len(results[best_ckpt])}
        
    # Mode 2: Single HF Hub LoRA checkpoint
    elif args.hf_lora_path:
        print("="*60)
        print("HUGGINGFACE HUB LORA EVALUATION MODE")
        print("="*60)
        # Download from HF Hub if needed
        lora_path = download_hf_lora(args.hf_lora_path)
        if args.parallel:
            results = evaluate_jobs_parallel(args, [lora_path], args.puzzle_templates_json)
        else:
            results = evaluate_single_hf_lora(args, lora_path, puzzle_templates)
        best_ckpt = lora_path
        checkpoint_scores = {lora_path: sum(r["accuracy"] for r in results[lora_path].values()) / len(results[lora_path])}
        
    # Mode 3: Local checkpoint directory (original behavior)
    else:
        # Check if checkpoint_dir itself is a valid LoRA checkpoint (single checkpoint mode)
        single_checkpoint_path = os.path.join(args.checkpoint_dir, "adapter_model.safetensors")
        single_checkpoint_bin = os.path.join(args.checkpoint_dir, "adapter_model.bin")
        
        if os.path.isdir(args.checkpoint_dir) and (
            os.path.exists(single_checkpoint_path) or os.path.exists(single_checkpoint_bin)
        ):
            # Single checkpoint mode - evaluate just this checkpoint
            print(f"Single checkpoint mode: {args.checkpoint_dir}")
            checkpoints = [args.checkpoint_dir]
        else:
            # Find checkpoints in directory
            if args.checkpoint_pattern:
                checkpoints = sorted(glob(os.path.join(args.checkpoint_dir, args.checkpoint_pattern)))
            else:
                checkpoints = sorted(
                    glob(os.path.join(args.checkpoint_dir, "lora_*")) +
                    glob(os.path.join(args.checkpoint_dir, "val_best_*"))
                )
            
            valid_checkpoints = [
                ckpt for ckpt in checkpoints
                if os.path.isdir(ckpt) and os.path.exists(os.path.join(ckpt, "adapter_model.safetensors"))
            ]
            checkpoints = valid_checkpoints
        
        if not checkpoints:
            print(f"No valid checkpoints found in {args.checkpoint_dir}")
            print(f"  (Looking for adapter_model.safetensors or adapter_model.bin)")
            sys.exit(1)
        
        print(f"Found {len(checkpoints)} checkpoint(s) to evaluate:")
        for ckpt in checkpoints:
            print(f"  - {os.path.basename(ckpt)}")
        print()
        
        # Run evaluation
        if args.parallel:
            results = evaluate_jobs_parallel(args, checkpoints, args.puzzle_templates_json)
        else:
            results = evaluate_sequential(args, checkpoints, puzzle_templates)
        
        # Calculate checkpoint scores
        checkpoint_scores = {}
        for ckpt, ckpt_results in results.items():
            if ckpt == '_generations':
                continue  # Skip generations data
            avg_acc = sum(r["accuracy"] for r in ckpt_results.values()) / len(ckpt_results) if ckpt_results else 0
            checkpoint_scores[ckpt] = avg_acc
        best_ckpt = max(checkpoint_scores.keys(), key=lambda k: checkpoint_scores[k])
    
    # Print summary
    print(f"\n{'='*60}\nEVALUATION RESULTS SUMMARY\n{'='*60}")
    
    for ckpt, ckpt_results in sorted(results.items()):
        if ckpt == '_generations':
            continue  # Skip generations data
        ckpt_display = os.path.basename(ckpt) if ckpt != "baseline_no_lora" else "baseline_no_lora"
        print(f"\n{ckpt_display}:")
        for ds_name, ds_results in ckpt_results.items():
            partial_str = f" (partial: {ds_results['partial_correctness_avg']:.4f})" if 'partial_correctness_avg' in ds_results else ""
            print(f"  {ds_name}: {ds_results['accuracy']:.4f} ({ds_results['correct']}/{ds_results['total']}){partial_str}")
    
    best_ckpt_display = os.path.basename(best_ckpt) if best_ckpt != "baseline_no_lora" else "baseline_no_lora"
    print(f"\n{'='*60}\nBEST CHECKPOINT: {best_ckpt_display}")
    print(f"Average accuracy: {checkpoint_scores[best_ckpt]:.4f}")
    for ds_name, ds_results in results[best_ckpt].items():
        partial_str = f" (partial: {ds_results['partial_correctness_avg']:.4f})" if 'partial_correctness_avg' in ds_results else ""
        print(f"  {ds_name}: {ds_results['accuracy']:.4f}{partial_str}")
    print(f"{'='*60}")
    
    # Timestamp for unique filenames (never overwrite)
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Save generations if requested — per-dataset subdirectories
    if args.dump_generations and '_generations' in results:
        generations_data = results.pop('_generations')

        for ckpt_name, datasets in generations_data.items():
            for ds_name, details in datasets.items():
                # Create per-dataset subdirectory: evals/<dataset_slug>/
                ds_slug = ds_name.replace("/", "__")
                ds_dir = os.path.join(output_dir, "evals", ds_slug)
                os.makedirs(ds_dir, exist_ok=True)

                generations_path = os.path.join(ds_dir, f"generations_{timestamp}.jsonl")
                ckpt_base = os.path.basename(ckpt_name) if ckpt_name != "baseline_no_lora" else ckpt_name
                with open(generations_path, 'w') as f:
                    for detail in details:
                        record = {
                            'checkpoint': ckpt_base,
                            'dataset': ds_name,
                            **detail
                        }
                        f.write(json.dumps(record) + '\n')
                print(f"\nGenerations saved to: {generations_path}")

    # Save results — timestamped to never overwrite
    results_path = os.path.join(output_dir, f"eval_results_{timestamp}.json")
    serializable_results = {
        "checkpoints": {
            (os.path.basename(k) if k != "baseline_no_lora" else k): v
            for k, v in results.items()
            if k != '_generations'  # Exclude generations from main results
        },
        "best_checkpoint": best_ckpt_display,
        "best_average_accuracy": checkpoint_scores[best_ckpt],
        "eval_datasets": args.eval_datasets,
        "base_model": args.base_model,
        "mode": "baseline" if args.no_lora else ("hf_lora" if args.hf_lora_path else "local_checkpoints"),
    }
    if args.hf_lora_path:
        serializable_results["hf_lora_path"] = args.hf_lora_path
    if args.dump_generations:
        serializable_results["generations_file"] = f"evals/*/generations_{timestamp}.jsonl"
    with open(results_path, "w") as f:
        json.dump(serializable_results, f, indent=2)
    print(f"\nResults saved to: {results_path}")
    
    # WandB logging
    if args.wandb_project:
        _log_to_wandb(args, results, checkpoint_scores, best_ckpt, results_path)
    
    # HF uploads
    if args.upload_results_to_hf:
        _upload_results_to_hf(args, results_path)
    
    if args.upload_best_to_hf:
        _upload_best_checkpoint_to_hf(args, results, checkpoint_scores, best_ckpt)
    
    return results


def _log_to_wandb(args, results, checkpoint_scores, best_ckpt, results_path):
    """Log results to WandB."""
    try:
        import wandb
        run_name = args.wandb_run_name or f"{os.path.basename(args.checkpoint_dir)}_eval"
        print(f"\nLogging results to WandB: {args.wandb_project}/{run_name}")
        
        wandb.init(
            project=args.wandb_project, name=run_name,
            config={"base_model": args.base_model, "eval_datasets": args.eval_datasets},
            job_type="evaluation"
        )
        
        wandb.log({
            "best_checkpoint": os.path.basename(best_ckpt),
            "best_average_accuracy": checkpoint_scores[best_ckpt]
        })
        
        for ds_name, ds_results in results[best_ckpt].items():
            ds_short = ds_name.split("/")[-1].replace("-", "_").replace(":", "_")
            wandb.log({f"best/{ds_short}_accuracy": ds_results["accuracy"]})
            if 'partial_correctness_avg' in ds_results:
                wandb.log({f"best/{ds_short}_partial": ds_results["partial_correctness_avg"]})
        
        # Results table
        table_data = []
        for ckpt, ckpt_results in sorted(results.items()):
            row = {"checkpoint": os.path.basename(ckpt)}
            for ds_name, ds_results in ckpt_results.items():
                ds_short = ds_name.split("/")[-1].replace("-", "_").replace(":", "_")
                row[f"{ds_short}_accuracy"] = ds_results["accuracy"]
            row["average_accuracy"] = sum(r["accuracy"] for r in ckpt_results.values()) / len(ckpt_results) if ckpt_results else 0
            table_data.append(row)
        
        if table_data:
            columns = list(table_data[0].keys())
            wandb.log({"checkpoint_results": wandb.Table(columns=columns, data=[list(r.values()) for r in table_data])})
        
        artifact = wandb.Artifact(name=f"eval_results_{run_name}", type="evaluation_results")
        artifact.add_file(results_path)
        wandb.log_artifact(artifact)
        wandb.finish()
        print("  ✓ Logged to WandB")
    except ImportError:
        print("Warning: wandb not installed")
    except Exception as e:
        print(f"Warning: WandB logging failed: {e}")


def _upload_results_to_hf(args, results_path):
    """Upload eval_results.json to HuggingFace."""
    try:
        from huggingface_hub import HfApi, create_repo
        api = HfApi()
        try:
            create_repo(repo_id=args.upload_results_to_hf, private=True, exist_ok=True, repo_type="model")
        except Exception:
            pass
        api.upload_file(
            path_or_fileobj=results_path, path_in_repo="eval_results.json",
            repo_id=args.upload_results_to_hf, repo_type="model"
        )
        print(f"  ✓ Uploaded eval_results.json to {args.upload_results_to_hf}")
    except Exception as e:
        print(f"Error uploading to HF: {e}")


def _upload_best_checkpoint_to_hf(args, results, checkpoint_scores, best_ckpt):
    """Upload best checkpoint to HuggingFace."""
    try:
        from huggingface_hub import HfApi, create_repo
        api = HfApi()
        try:
            create_repo(repo_id=args.upload_best_to_hf, private=True, exist_ok=True, repo_type="model")
        except Exception:
            pass
        
        ds_short = args.eval_datasets[0].split("/")[-1].replace("-", "_").replace(":", "_")
        subfolder = f"best_checkpoint_{ds_short}"
        
        readme = f"# LoRA Adapter - Best Checkpoint ({ds_short})\n\n"
        readme += f"- **Base Model**: {args.base_model}\n"
        readme += f"- **Checkpoint**: {os.path.basename(best_ckpt)}\n"
        readme += f"- **Average Accuracy**: {checkpoint_scores[best_ckpt]:.4f}\n\n## Results\n"
        for ds_name, ds_results in results[best_ckpt].items():
            partial_str = f" (partial: {ds_results['partial_correctness_avg']:.4f})" if 'partial_correctness_avg' in ds_results else ""
            readme += f"- {ds_name}: {ds_results['accuracy']:.4f}{partial_str}\n"
        
        readme_path = os.path.join(best_ckpt, "README.md")
        with open(readme_path, "w") as f:
            f.write(readme)
        
        api.upload_folder(
            folder_path=best_ckpt, repo_id=args.upload_best_to_hf,
            repo_type="model", path_in_repo=subfolder,
            ignore_patterns=[".git*", "__pycache__", "*.pt"]
        )
        print(f"  ✓ Uploaded best checkpoint to {args.upload_best_to_hf}/{subfolder}")
    except Exception as e:
        print(f"Error uploading best checkpoint: {e}")


if __name__ == "__main__":
    main()
