#!/usr/bin/env python3
"""
Convert Inspiration Retrieval data to MCQ (Multiple Choice Question) format.

MINIMAL CHANGES ONLY:
- Input: Add [A], [B], ..., [O] labels before candidate headers
- Output: Replace "**Selected Title starts:** XXX **Selected Title ends**" with "**Selected ID:** [X]"
- Everything else (reasoning trace, selection reason, etc.) remains UNCHANGED
"""

import json
import re
import argparse
from functools import partial
from multiprocessing import Pool
from tqdm import tqdm
from typing import List, Optional, Tuple
import multiprocessing

from common_utils import jaccard_similarity


# Label mapping for up to 26 candidates (full alphabet)
LABELS = [chr(ord('A') + i) for i in range(26)]  # ['A', 'B', ..., 'Z']


def num_to_letter(num: int) -> str:
    """Convert candidate number to letter label: 1 -> [A], 2 -> [B], ..., 15 -> [O]"""
    if 1 <= num <= 26:
        return f'[{chr(ord("A") + num - 1)}]'
    return f'{num}'  # Keep original if out of range


def replace_candidate_refs_in_reasoning(text: str, num_candidates: int = 15) -> str:
    """
    Replace candidate number references with letter labels in reasoning trace.
    
    Coverage (~99.9%+):
    - "Candidate 7", "Candidate #7", "candidate 7", "CANDIDATE 7"
    - "Candidates 4, 7, 8, and 10" (plural with list)
    - "candidates are 3, 5, 10" (are + number list)
    - "Candidate 5和11" (Chinese connectors)
    - "**Candidate 7**", "**Candidate 7:**"  
    - "Candidate 7's", "Candidate 7,"
    - "the 7th candidate", "7th candidate"
    
    Avoids (to prevent false positives):
    - "paper 7" (could refer to bibliography)
    - "a list of 15" (describing count, not candidate ID)
    - Isolated numbers without "candidate" context
    
    Args:
        text: Text to process (reasoning trace, selection reason, etc.)
        num_candidates: Number of candidates (default 15)
    
    Returns:
        Text with candidate numbers replaced by [A]-[O] letters
    """
    # Helper to replace a single number
    def replace_single_num(num: int) -> str:
        if 1 <= num <= num_candidates:
            return num_to_letter(num)
        return str(num)
    
    # Helper to replace all numbers in a string (including X-Y ranges)
    def replace_nums_in_str(s: str) -> str:
        # First, handle X-Y range patterns (preserve the dash)
        def range_replacer(m):
            num1 = int(m.group(1))
            dash = m.group(2)
            num2 = int(m.group(3))
            return replace_single_num(num1) + dash + replace_single_num(num2)
        s = re.sub(r'(\d{1,2})([-–—])(\d{1,2})', range_replacer, s)
        
        # Then, handle standalone numbers
        def num_replacer(n_match):
            num = int(n_match.group(0))
            return replace_single_num(num)
        return re.sub(r'\b(\d{1,2})\b', num_replacer, s)
    
    # Note: Use (?<![A-Za-z]) instead of \b at start to handle Chinese characters
    # Chinese chars don't trigger word boundary, so 但Candidate won't match with \b
    
    # Pattern 1d: "Candidates X to Y" (range format) - MUST run before Pattern 1a
    # to avoid partial matching of "Candidates X to Y" as "Candidates X"
    def replace_candidates_range(m):
        prefix = m.group(1)  # "Candidates "
        num1 = int(m.group(2))
        num2 = int(m.group(3))
        if 1 <= num1 <= num_candidates and 1 <= num2 <= num_candidates:
            return prefix + num_to_letter(num1) + " to " + num_to_letter(num2)
        return m.group(0)
    
    text = re.sub(
        r'((?<![A-Za-z])candidates?\s+)(\d{1,2})\s+to\s+(\d{1,2})(?![0-9])',
        replace_candidates_range,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 1e: "Candidates X-Y, X, X-Y" (mixed list with ranges and single numbers)
    # e.g., "Candidates 1-3, 5, 7-9, 11-15" or "Candidates 1-5, 9-13, 15"
    # Supports various dash characters: - – — and lists separated by commas/and/or
    def replace_candidates_mixed_list(m):
        prefix = m.group(1)  # "Candidates "
        items_str = m.group(2)  # "1-3, 5, 7-9" or "1-5, 9-13, 15"
        return prefix + replace_nums_in_str(items_str)
    
    # Match a list that contains at least one range (X-Y) and may include single numbers
    # Each item is either X-Y or X, separated by commas/and/or
    text = re.sub(
        r'((?<![A-Za-z])candidates?\s+)(\d{1,2}(?:[-–—]\d{1,2})?(?:\s*,\s*\d{1,2}(?:[-–—]\d{1,2})?)*(?:\s*,?\s*(?:and|or)\s+\d{1,2}(?:[-–—]\d{1,2})?)?)',
        replace_candidates_mixed_list,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 1f: Single "Candidates X-Y" (for cases not caught by 1e)
    def replace_candidates_hyphen_range(m):
        prefix = m.group(1)  # "Candidates "
        num1 = int(m.group(2))
        dash = m.group(3)  # preserve the original dash character
        num2 = int(m.group(4))
        if 1 <= num1 <= num_candidates and 1 <= num2 <= num_candidates:
            return prefix + num_to_letter(num1) + dash + num_to_letter(num2)
        return m.group(0)
    
    text = re.sub(
        r'((?<![A-Za-z])candidates?\s+)(\d{1,2})([-–—])(\d{1,2})(?![0-9])',
        replace_candidates_hyphen_range,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 1a: "Candidates X, Y, Z, and/or W" (plural with number list)
    def replace_candidates_list(m):
        prefix = m.group(1)
        numbers_str = m.group(2)
        return prefix + replace_nums_in_str(numbers_str)
    
    text = re.sub(
        r'((?<![A-Za-z])candidates\s+)(\d{1,2}(?:\s*,\s*\d{1,2})*(?:\s*,?\s*(?:and|or|和)\s+\d{1,2})?)',
        replace_candidates_list,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 1b: "candidates are/is #?X, Y, Z" (are/is + optional # + number list)
    # Note: # is removed for cleaner output
    def replace_candidates_are(m):
        prefix = m.group(1)  # "candidates are " or "candidate is "
        numbers_str = m.group(2)  # "#3, 5, 10" or "3, 5, 10"
        # Remove # before replacing numbers
        numbers_str = numbers_str.replace('#', '')
        return prefix + replace_nums_in_str(numbers_str)
    
    text = re.sub(
        r'((?<![A-Za-z])candidates?\s+(?:are|is)\s+)(#?\d{1,2}(?:\s*,\s*#?\d{1,2})*(?:\s*,?\s*(?:and|or|和)\s+#?\d{1,2})?)',
        replace_candidates_are,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 1c: "Candidates like X, Y, Z" (like + number list)
    def replace_candidates_like(m):
        prefix = m.group(1)  # "Candidates like " or "candidates like "
        numbers_str = m.group(2)
        return prefix + replace_nums_in_str(numbers_str)
    
    text = re.sub(
        r'((?<![A-Za-z])candidates?\s+like\s+)(\d{1,2}(?:\s*,\s*\d{1,2})*(?:\s*,?\s*(?:and|or|和)\s+\d{1,2})?)',
        replace_candidates_like,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 2a: "Candidate X和Y" or "Candidate X and Y" (singular with connectors)
    # Note: Capture whitespace around connector to preserve spacing
    def replace_candidate_connector(m):
        prefix = m.group(1)  # "Candidate "
        num1 = int(m.group(2))
        connector_with_spaces = m.group(3)  # " and " or " or " or "和" (includes surrounding spaces)
        num2 = int(m.group(4))
        return prefix + replace_single_num(num1) + connector_with_spaces + replace_single_num(num2)
    
    text = re.sub(
        r'((?<![A-Za-z])candidate\s*)(\d{1,2})(\s*(?:和|and|or)\s*)(\d{1,2})(?![0-9])',
        replace_candidate_connector,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 2b: "Candidate" (singular) + optional # + number
    # Handles normal cases and edge cases like "byCandidate 3" (missing space)
    def replace_candidate_ref(m):
        num = int(m.group(2))
        if 1 <= num <= num_candidates:
            return m.group(1) + num_to_letter(num)
        return m.group(0)
    
    text = re.sub(
        r'(candidate\s*)#?\s*(\d{1,2})(?![0-9])',
        replace_candidate_ref,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 2c: "paper X" referring to candidate (e.g., "Paper 9 is a randomized trial")
    # Replace "paper X" with "Candidate [X]" for consistency
    def replace_paper_ref(m):
        prefix = m.group(1)  # "paper " or "Paper "
        num = int(m.group(2))
        if 1 <= num <= num_candidates:
            # Replace "paper" with "Candidate" and number with letter
            return "Candidate " + num_to_letter(num)
        return m.group(0)
    
    text = re.sub(
        r'((?<![A-Za-z])paper\s+)(\d{1,2})(?![0-9])',
        replace_paper_ref,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 3: Ordinal + "candidate" (the 7th candidate, 7th candidate)
    def replace_ordinal(m):
        num = int(m.group(2))
        if 1 <= num <= num_candidates:
            prefix = m.group(1) if m.group(1) else ''
            return prefix + f'Candidate {num_to_letter(num)}'
        return m.group(0)
    
    text = re.sub(
        r'((?:^|[^A-Za-z])the\s+)?(\d{1,2})(?:st|nd|rd|th)\s+candidate(?![A-Za-z])',
        replace_ordinal,
        text,
        flags=re.IGNORECASE
    )
    
    # Pattern 4: Handle trailing numbers after already-replaced candidates
    # e.g., "Candidates [C], [H], [M], and maybe 10" -> "... and maybe [J]"
    # This must run AFTER the main replacements
    def replace_trailing_num(m):
        prefix = m.group(1)  # "and maybe " or "or "
        num = int(m.group(2))
        if 1 <= num <= num_candidates:
            return prefix + num_to_letter(num)
        return m.group(0)
    
    text = re.sub(
        r'(\[[A-O]\]\s*,?\s*(?:and|or)\s+(?:maybe\s+)?)(\d{1,2})(?![0-9])',
        replace_trailing_num,
        text
    )
    
    return text


def normalize_title(title: str) -> str:
    """Normalize title for matching: remove HTML tags, markdown, extra spaces."""
    s = title.strip().lower()
    # Remove markdown escapes
    s = s.replace('\\*', '*').replace('\\#', '#')
    # Remove HTML tags like <i>, </i>, <sub>, etc.
    s = re.sub(r'<[^>]+>', '', s)
    # Remove leading * or ** (markdown bold)
    s = re.sub(r'^\*+\s*', '', s)
    # Normalize whitespace
    s = re.sub(r'\s+', ' ', s).strip()
    return s


def find_label_for_title(title: str, candidate_titles: List[str]) -> Optional[str]:
    """
    Find the label (A-Z) for a given title by matching against candidates.
    
    Uses multiple matching strategies:
    1. Exact match (after normalization)
    2. Substring match
    3. Fuzzy match using Jaccard similarity (from common_utils)
    
    Args:
        title: The title to find
        candidate_titles: List of candidate titles (must be <= 26)
    
    Raises:
        ValueError: If candidate_titles has more than 26 items (exceeds alphabet)
    """
    if len(candidate_titles) > len(LABELS):
        raise ValueError(f"Too many candidates: {len(candidate_titles)} > {len(LABELS)} (alphabet limit)")
    
    title_clean = normalize_title(title)
    
    # Exact match first
    for i, cand in enumerate(candidate_titles):
        cand_clean = normalize_title(cand)
        if title_clean == cand_clean:
            return LABELS[i]
    
    # Substring match (at least 30 chars to avoid false positives)
    for i, cand in enumerate(candidate_titles):
        cand_clean = normalize_title(cand)
        if len(title_clean) >= 30 and len(cand_clean) >= 30:
            if title_clean[:50] in cand_clean or cand_clean[:50] in title_clean:
                return LABELS[i]
    
    # Fuzzy match using Jaccard similarity
    best_match_idx = -1
    best_similarity = 0.0
    
    for i, cand in enumerate(candidate_titles):
        cand_clean = normalize_title(cand)
        similarity = jaccard_similarity(title_clean, cand_clean)
        if similarity > best_similarity:
            best_similarity = similarity
            best_match_idx = i
    
    # Only accept if similarity is high enough (>= 0.5)
    if best_similarity >= 0.5 and best_match_idx >= 0:
        return LABELS[best_match_idx]
    
    return None


def add_labels_to_prompt(prompt: str, num_candidates: int = 15) -> str:
    """
    Add [A-O] labels to candidate headers in prompt.
    
    Changes: "### Candidate X" -> "### Candidate [Y]"
    
    This removes the number and uses only the letter label for cleaner output.
    The model only needs to output a single letter [A-O] as the selection.
    
    Args:
        prompt: The original prompt text
        num_candidates: Number of candidates (for validation, default 15)
    
    Returns:
        Modified prompt with letter labels
    """
    assert num_candidates <= len(LABELS), "Number of candidates must be less than or equal to length of alphabet"
    new_prompt = prompt
    
    # Validate we have the expected number of candidates
    actual_count = prompt.count("### Candidate ")
    if actual_count != num_candidates:
        # Still process but with actual count
        print(f"Warning: Expected {num_candidates} candidates, but found {actual_count} in prompt.")
        num_candidates = actual_count
    
    for i in range(num_candidates):
        label = LABELS[i]
        old = f"### Candidate {i+1}\n"
        new = f"### Candidate [{label}]\n"
        new_prompt = new_prompt.replace(old, new)
    
    return new_prompt


def replace_title_with_id(response: str, label: str) -> Optional[str]:
    """
    Replace selected title marker with selected ID in response.
    
    IMPORTANT: Only replaces the LAST occurrence to handle cases where
    the model made corrections and has multiple Selected Title sections.
    
    Returns None for malformed/incomplete data (e.g., missing Selection Reason).
    
    Standard format expected (malformed formats should be filtered beforehand):
        **Selected Title starts:** TITLE **Selected Title ends**
        **Selection Reason starts:** REASON **Selection Reason ends**
    
    Replacement:
        **Selected ID starts:** [X] **Selected ID ends**
        **Selection Reason starts:** REASON **Selection Reason ends**
    
    Everything else remains UNCHANGED.
    """
    replacement = f'**Selected ID starts:** [{label}] **Selected ID ends**'
    
    # Find all "**Selected Title starts:**" positions
    starts_pattern = r'\*\*Selected Title starts:\*\*'
    starts_matches = list(re.finditer(starts_pattern, response))
    
    if starts_matches:
        # Take the LAST "starts" position
        last_start = starts_matches[-1]
        start_pos = last_start.start()
        
        # Look for the end of this Selected Title section
        remaining = response[last_start.end():]
        
        # Find boundaries
        ends_match = re.search(r'\*\*Selected Title ends\*\*', remaining)
        selection_reason_match = re.search(r'\*\*Selection Reason starts:\*\*', remaining)
        
        # Determine end position based on what markers exist
        # Case 1: ends_match - Standard format with both start and end markers
        #   Input:  "**Selected Title starts:** TITLE **Selected Title ends**\n\n**Selection Reason..."
        #   Output: "**Selected ID starts:** [X] **Selected ID ends**\n\n**Selection Reason..."
        #   The title content is replaced by the ID, end marker is consumed
        if ends_match:
            end_pos = last_start.end() + ends_match.end()
            return response[:start_pos] + replacement + response[end_pos:]
        # Case 2: selection_reason_match - Missing end marker, but has Selection Reason
        #   Input:  "**Selected Title starts:** TITLE\n**Selection Reason..." (no ends marker)
        #   Output: "**Selected ID starts:** [X] **Selected ID ends**\n\n**Selection Reason..."
        #   The title content is replaced by the ID, double newline added before Selection Reason
        #   (Per prompt_store.py: standard format has empty line between Selected Title and Selection Reason)
        elif selection_reason_match:
            # No ends marker but has Selection Reason - replace up to Selection Reason
            end_pos = last_start.end() + selection_reason_match.start()
            # Always use double newline to match the standard format from prompt
            # (prompt specifies: **Selected Title ends**\n\n**Selection Reason starts:**)
            trailing_whitespace = '\n\n'
            return response[:start_pos] + replacement + trailing_whitespace + response[end_pos:]
        # Case 3: No ends marker AND no Selection Reason (rare, usually truncated response)
        #   This is incomplete/malformed data - missing Selection Reason
        #   Should be filtered out rather than converted
        else:
            return None
    
    # Fallback: try simpler patterns
    # Pattern: "Selected Title starts:" without ** (variant format)
    # e.g., "Selected Title starts: **title**\nSelected Title ends."
    pattern_variant = r'Selected Title starts:[^\n]*(?:\n(?!Selection)[^\n]*)*'
    matches = list(re.finditer(pattern_variant, response))
    if matches:
        last_match = matches[-1]
        # Also remove the "Selected Title ends." line if it follows
        end_pos = last_match.end()
        remaining = response[end_pos:]
        ends_match = re.match(r'\s*Selected Title ends\.?\s*', remaining)
        if ends_match:
            end_pos += ends_match.end()
        return response[:last_match.start()] + replacement + response[end_pos:]
    
    # Pattern: "**Selected Title:** XXX" or "**Selected Title**: XXX"
    pattern_simple = r'\*\*Selected Title\*?\*?:?\*?\s*[^\n]*'
    matches = list(re.finditer(pattern_simple, response))
    if matches:
        last_match = matches[-1]
        return response[:last_match.start()] + replacement + response[last_match.end():]
    
    # Pattern: "Selected Title:" or "Selected Title:" without asterisks (case insensitive)
    pattern_plain = r'(?i)Selected Title:?\s*[^\n]*'
    matches = list(re.finditer(pattern_plain, response))
    if matches:
        last_match = matches[-1]
        return response[:last_match.start()] + replacement + response[last_match.end():]
    
    # Pattern: "The selected paper is **"Title"** by Candidate X"
    pattern_natural = r'The selected paper is \*\*"[^"]+"\*\* by Candidate \d+\.?'
    matches = list(re.finditer(pattern_natural, response))
    if matches:
        last_match = matches[-1]
        return response[:last_match.start()] + replacement + response[last_match.end():]
    
    # Return original if no pattern matched
    return response


def extract_last_selected_title(response: str) -> Optional[str]:
    """
    Extract the LAST selected title from response.
    
    This is important because some samples have multiple selections (model self-correction),
    and we need the LAST one which is the final answer.
    """
    # Try pattern with starts/ends markers first
    matches = list(re.finditer(
        r'\*\*Selected Title starts:\*\*(.*?)(?:\*\*Selected Title ends\*\*|\*\*Selection Reason)',
        response, re.DOTALL
    ))
    if matches:
        return matches[-1].group(1).strip()
    
    # Try simpler patterns
    patterns = [
        r'\*\*Selected Title\*?\*?:?\*?\s*([^\n]+)',
        r'Selected Title:?\s*([^\n]+)',
        r'The selected paper is \*\*"([^"]+)"\*\*',
    ]
    for pattern in patterns:
        matches = list(re.finditer(pattern, response))
        if matches:
            return matches[-1].group(1).strip()
    
    return None


def normalize_response_format(response: str) -> str:
    """
    Normalize response format to handle minor variations.
    
    Normalizations:
    - "** Selected Title" -> "**Selected Title" (remove space after **)
    - "** Selection Reason" -> "**Selection Reason"
    """
    # Normalize space after ** for Selected Title
    response = re.sub(r'\*\*\s+Selected Title', '**Selected Title', response)
    # Normalize space after ** for Selection Reason
    response = re.sub(r'\*\*\s+Selection Reason', '**Selection Reason', response)
    return response


def count_selections(response: str) -> int:
    """
    Count the number of Selected Title blocks in a response.
    
    Considers multiple format variants to be robust.
    Note: Response should be normalized first using normalize_response_format().
    """
    # Primary format: **Selected Title starts:**
    count = len(re.findall(r'\*\*Selected Title starts:\*\*', response))
    
    # If no primary format found, check variant formats
    if count == 0:
        # Variant: **Selected Title:** or **Selected Title**:
        count = len(re.findall(r'\*\*Selected Title\*?\*?:', response))
    
    if count == 0:
        # Variant: "Selected Title:" without asterisks
        count = len(re.findall(r'(?<!\*)Selected Title:', response))
    
    if count == 0:
        # Variant: "The selected paper is **"
        count = len(re.findall(r'The selected paper is \*\*"', response))
    
    return count


def count_selection_reasons(response: str) -> int:
    """Count the number of Selection Reason blocks in a response."""
    return len(re.findall(r'\*\*Selection Reason starts:', response))


def has_nested_selection(response: str) -> bool:
    """
    Check if response has nested/malformed selection structure.
    
    Malformed formats that should be filtered:
    1. Selection Reason contains Selected Title (nested)
    2. Selection Reason is INSIDE Selected Title block (before Title ends)
    
    These indicate data quality issues that could cause incorrect conversions.
    """
    # Case 1: Check if Selection Reason contains Selected Title
    reason_match = re.search(
        r'\*\*Selection Reason starts:\*\*(.*?)\*\*Selection Reason ends\*\*',
        response, re.DOTALL
    )
    if reason_match:
        reason_content = reason_match.group(1)
        if '**Selected Title starts:' in reason_content:
            return True
    
    # Case 2: Check if Selection Reason is inside Selected Title block
    # i.e., Selection Reason starts appears BEFORE Selected Title ends
    title_starts = response.find('**Selected Title starts:**')
    title_ends = response.find('**Selected Title ends**')
    reason_starts = response.find('**Selection Reason starts:**')
    
    if title_starts >= 0 and title_ends >= 0 and reason_starts >= 0:
        # Selection Reason should be AFTER Title ends
        if title_starts < reason_starts < title_ends:
            return True
    
    return False




def process_single_sample(result: dict, skip_multi_selection: bool = False, 
                          skip_nested: bool = True) -> Optional[dict]:
    """
    Process a single sample, returning None if conversion fails.
    
    Args:
        result: The sample dictionary
        skip_multi_selection: If True, skip samples with multiple Selected Title blocks
                              (model self-corrections). Default False.
        skip_nested: If True, skip samples with nested/malformed selection structure.
                     Default True (recommended as these cause incorrect conversions).
    """
    candidate_titles = result.get('all_candidate_titles', [])
    num_candidates = len(candidate_titles)
    
    # Validate candidate count (must be between 1 and 26)
    if num_candidates == 0 or num_candidates > len(LABELS):
        return None
    
    # Normalize response format (handles space variants like "** Selected Title")
    response = normalize_response_format(result['response'])
    
    # Skip samples with nested or malformed selection structure (data quality issue)
    if skip_nested and (has_nested_selection(response) or count_selection_reasons(response) > 1):
        return None
    
    # Optionally skip samples with multiple selections (lower quality)
    num_selections = count_selections(response)
    if skip_multi_selection and num_selections > 1:
        return None
    
    # IMPORTANT: Extract the LAST selected title from response, not from metadata
    # This handles cases where model made multiple selections (self-correction)
    last_title = extract_last_selected_title(response)
    
    # Fallback to metadata if extraction fails
    if not last_title:
        last_title = result.get('selected_title_recovered') or result.get('selected_title_raw')
    
    if not last_title:
        return None
    
    # Find label for the LAST selected title
    label = find_label_for_title(last_title, candidate_titles)
    if label is None:
        return None
    
    # Convert prompt (add labels) - pass num_candidates for validation
    new_prompt = add_labels_to_prompt(result['prompt'], num_candidates)
    
    # Convert response (replace last title with ID)
    new_response = replace_title_with_id(response, label)
    
    # Verify conversion happened (replace_title_with_id returns None for malformed data)
    if new_response is None or '**Selected ID starts:**' not in new_response:
        return None
    
    # Replace candidate number references with letters in reasoning trace
    # This ensures consistency: prompt uses [A]-[O], reasoning uses [A]-[O]
    new_response = replace_candidate_refs_in_reasoning(new_response, num_candidates)
    
    # Preserve all original fields, only update prompt and response
    # This ensures compatibility with downstream scripts like inspiration_retrieval_prepare_sft_data_to_go.py
    converted = dict(result)  # Copy all original fields
    converted.update({
        'prompt': new_prompt,
        'response': new_response,
        # Add MCQ-specific fields
        'mcq_label': label,
        'num_selections': num_selections,  # Track how many selections were in original
    })
    return converted


def _process_entry(entry: dict, skip_multi_selection: bool, skip_nested: bool) -> Tuple[Optional[dict], dict]:
    """
    Process a single entry (for parallel processing).
    
    Returns:
        Tuple of (converted_entry or None, stats_dict)
    """
    stats = {
        'total': 0,
        'converted': 0,
        'skipped_multi': 0,
        'skipped_nested': 0,
    }
    
    if 'results' not in entry:
        return None, stats
    
    new_results = []
    for result in entry['results']:
        stats['total'] += 1
        # Normalize response format before checking
        response = normalize_response_format(result.get('response', ''))
        
        # Track skipped reasons
        if skip_nested and (has_nested_selection(response) or count_selection_reasons(response) > 1):
            stats['skipped_nested'] += 1
            continue
        if skip_multi_selection and count_selections(response) > 1:
            stats['skipped_multi'] += 1
            continue
        
        converted = process_single_sample(result, skip_multi_selection=skip_multi_selection, 
                                          skip_nested=skip_nested)
        if converted:
            new_results.append(converted)
            stats['converted'] += 1
    
    if new_results:
        # Preserve all original entry-level fields for compatibility with downstream scripts
        converted_entry = dict(entry)  # Copy all original fields (accuracy, correct_count, year_pmid, etc.)
        converted_entry['results'] = new_results  # Update with converted results
        return converted_entry, stats
    
    return None, stats


def process_data(input_path: str, output_path: str, skip_multi_selection: bool = False, 
                 skip_nested: bool = True, num_workers: int = 1, verbose: bool = True):
    """
    Process the entire dataset and convert to MCQ format.
    
    Args:
        input_path: Path to input JSON file
        output_path: Path to output JSON file
        skip_multi_selection: If True, skip samples with multiple selections (lower quality).
                              Recommended for training as these have ~11% accuracy vs ~25% for single.
        num_workers: Number of parallel workers (default 1 = single process).
                     Set to 0 or -1 to use all available CPUs.
        verbose: Print progress information
    
    Input Data Format (JSON):
    -------------------------
    {
        "all_results": [
            {
                "results": [
                    {
                        "prompt": str,                    # The input prompt with candidate papers
                        "response": str,                  # Model's reasoning + selected title + selection reason
                        "selected_title_raw": str,        # Raw extracted title from response
                        "selected_title_recovered": str,  # Cleaned up title (may match first selection, not last!)
                        "ground_truth_title": str,        # The correct answer title
                        "all_candidate_titles": List[str], # List of 15 candidate paper titles
                        "is_correct": bool,               # Whether selection matches ground truth
                        "background": str,                # Research background
                        "ground_truth_full": dict,        # Full ground truth paper info
                    },
                    ...  # Multiple rejection samples per entry
                ],
                "accuracy": float,
                "correct_count": int,
                "total_samples": int,
                "background": str,
                "ground_truth": dict,
            },
            ...
        ]
    }
    
    Note on "response" format:
    -------------------------
    The response typically follows this structure:
    
    [Reasoning text - may include <think>...</think> tags]
    
    **Selected Title starts:** TITLE **Selected Title ends**
    **Selection Reason starts:** REASON **Selection Reason ends**
    
    Some samples have MULTIPLE selections (model self-correction during generation).
    In such cases, we extract and convert only the LAST selection.
    Multi-selection samples have ~11% accuracy vs ~25% for single-selection samples.
    
    Output Data Format:
    ------------------
    Same structure but with:
    - prompt: Candidate headers changed from "### Candidate X" to "### Candidate [A-O]"
    - response: Last "**Selected Title starts:** XXX **Selected Title ends**" 
                replaced with "**Selected ID starts:** [A-O] **Selected ID ends**"
    """
    # Determine number of workers
    if num_workers <= 0:
        num_workers = multiprocessing.cpu_count()
    
    if verbose:
        print(f"Loading data from {input_path}...")
        if skip_multi_selection:
            print("  (Skipping multi-selection samples)")
        if skip_nested:
            print("  (Skipping nested/malformed samples)")
        if num_workers > 1:
            print(f"  (Using {num_workers} workers)")
    
    with open(input_path, 'r') as f:
        data = json.load(f)
    
    entries = data['all_results']
    converted_entries = []
    total_samples = 0
    converted_samples = 0
    skipped_multi = 0
    skipped_nested = 0
    
    # Process entries (single or multi-process based on num_workers)
    process_func = partial(_process_entry, 
                           skip_multi_selection=skip_multi_selection, 
                           skip_nested=skip_nested)
    chunksize = max(1, len(entries) // (num_workers * 10))
    
    with Pool(processes=num_workers) as pool:
        for converted_entry, stats in tqdm(
            pool.imap_unordered(process_func, entries, chunksize=chunksize),
            total=len(entries), desc="Processing", disable=not verbose
        ):
            total_samples += stats['total']
            converted_samples += stats['converted']
            skipped_multi += stats['skipped_multi']
            skipped_nested += stats['skipped_nested']
            if converted_entry:
                converted_entries.append(converted_entry)
    
    # Calculate mean_accuracy for compatibility with downstream scripts
    if converted_entries:
        accuracies = [e.get('accuracy', 0) for e in converted_entries if 'accuracy' in e]
        mean_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0
    else:
        mean_accuracy = 0.0
    
    output_data = {
        'all_results': converted_entries,
        'mean_accuracy': mean_accuracy,  # Required by inspiration_retrieval_prepare_sft_data_to_go.py
        'config': {
            'source_file': input_path,
            'total_samples': converted_samples,
            'failed_samples': total_samples - converted_samples,
            'skip_multi_selection': skip_multi_selection,
            'skip_nested': skip_nested,
            'skipped_multi_selection': skipped_multi,
            'skipped_nested': skipped_nested,
        }
    }
    
    if verbose:
        print(f"\nTotal samples: {total_samples}")
        print(f"Converted: {converted_samples} ({100*converted_samples/total_samples:.2f}%)")
        if skip_multi_selection:
            print(f"Skipped (multi-selection): {skipped_multi}")
        if skip_nested:
            print(f"Skipped (nested): {skipped_nested}")
        print(f"Saving to {output_path}...")
    
    with open(output_path, 'w') as f:
        json.dump(output_data, f, ensure_ascii=False, indent=2)
    
    if verbose:
        print("Done!")


def main():
    parser = argparse.ArgumentParser(description='Convert IR data to MCQ format')
    parser.add_argument('--input_path', type=str, required=True)
    parser.add_argument('--output_path', type=str, required=True)
    parser.add_argument('--skip_multi_selection', action='store_true',
                        help='Skip samples with multiple selections (recommended, they have lower quality)')
    parser.add_argument('--no_skip_nested', action='store_true',
                        help='Do NOT skip nested/malformed samples (default: skip them)')
    parser.add_argument('--num_workers', type=int, default=1,
                        help='Number of parallel workers (default 1). Use 0 or -1 for all CPUs.')
    parser.add_argument('--quiet', action='store_true')
    args = parser.parse_args()
    
    process_data(args.input_path, args.output_path, 
                 skip_multi_selection=args.skip_multi_selection,
                 skip_nested=not args.no_skip_nested,  # Default is True (skip nested)
                 num_workers=args.num_workers,
                 verbose=not args.quiet)


if __name__ == '__main__':
    main()

