"""
Multi-GPU runner for direct online best-of-n implementation
Based on bestofn/run_bestofn_uhead_multigpu.py architecture
"""

import argparse
import os
import logging
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from typing import List, Dict, Tuple
import time
import gc
from tqdm import tqdm

from datasets import load_dataset, load_from_disk
from lm_polygraph import WhiteboxModel

from online_bestofn.direct_online_bestofn import DirectOnlineBestOfN
from utils import parse_answer_by_dataset_type
from configs.load_qwen import load_model as load_qwen_model, load_tokenizer as load_qwen_tokenizer
from online_bestofn.deepseek_annotation import Annotator

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
log = logging.getLogger("direct_online_bestofn_multigpu")


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset-path", type=str, required=True,
                        help="Dataset to evaluate on (HuggingFace name or local path)")
    parser.add_argument("--dataset-split", type=str, default="test", help="Dataset split")
    parser.add_argument("--save-dir", type=str, required=True, help="Directory to save results")
    parser.add_argument("--save-frequency", type=int, default=10, help="Save every n samples")
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen3-1.7B", help="Model name or path")
    parser.add_argument("--uhead-path", type=str, required=True, help="UHead model path")
    parser.add_argument("--device", type=str, default="auto", 
                        help="Device to use. Options: 'auto' (all GPUs), '0,1,2' (specific GPUs), 'cuda:0' (single)")
    parser.add_argument("--n", type=int, default=10, help="Number of candidates per step")
    parser.add_argument("--temperature", type=float, default=0.7, help="Generation temperature")
    parser.add_argument("--max-new-tokens", type=int, default=250, help="Max tokens per step")
    parser.add_argument("--max-steps", type=int, default=20, help="Max reasoning steps")
    parser.add_argument("--subset", type=int, default=None, help="Only process first N samples")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--generation-batch-size", type=int, default=None, 
                        help="Batch size for generation (default: same as n)")
    parser.add_argument("--feature-batch-size", type=int, default=1,
                        help="Batch size for feature extraction")
    parser.add_argument("--memory-efficient", action="store_true", default=True,
                        help="Enable memory efficient mode (default: True)")
    parser.add_argument("--no-memory-efficient", dest="memory_efficient", action="store_false",
                        help="Disable memory efficient mode")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
    parser.add_argument("--correctness-mode", type=str, default="exact_match",
                        choices=["exact_match", "deepseek"],
                        help="Method for checking answer correctness (default: exact_match)")
    parser.add_argument("--n-threads", type=int, default=1,
                        help="Number of threads for DeepSeek verification (default: 1)")
    parser.add_argument("--annotation-prompt-type", type=str, default="non_unique",
                        choices=["unique", "non_unique"],
                        help="Type of annotation prompt for DeepSeek (default: non_unique)")
    parser.add_argument("--prompt-file", type=str, default=None,
                        help="Path to prompt template file (optional)")
    parser.add_argument("--resume", action="store_true", default=False,
                        help="Resume from existing save file")
    parser.add_argument("--no-resume", dest="resume", action="store_false",
                        help="Do not resume from existing save file")
    parser.add_argument("--debug", action="store_true", help="Run in debug mode (single process)")
    parser.add_argument("--debug-sample", type=int, default=None, help="Debug specific sample index")
    return parser


def parse_device_arg(device_arg):
    """Parse device argument and return list of GPU IDs to use."""
    if device_arg == "auto":
        # Use all available GPUs
        num_gpus = torch.cuda.device_count()
        if num_gpus == 0:
            raise ValueError("No CUDA devices available.")
        return list(range(num_gpus))
    elif device_arg.startswith("cuda:"):
        # Single GPU specified
        gpu_id = int(device_arg.split(":")[1])
        return [gpu_id]
    elif "," in device_arg:
        # Multiple GPUs specified
        return [int(x.strip()) for x in device_arg.split(",")]
    else:
        # Try to parse as single GPU ID
        try:
            gpu_id = int(device_arg)
            return [gpu_id]
        except ValueError:
            raise ValueError(f"Invalid device argument: {device_arg}")


def get_physical_gpu_id(logical_id: int) -> int:
    """Map logical GPU ID to physical GPU ID based on CUDA_VISIBLE_DEVICES.
    Note: This function is kept for reference but not used in the current implementation.
    Workers use logical GPU IDs directly."""
    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    
    if visible_devices is None:
        return logical_id
    
    physical_ids = [int(x.strip()) for x in visible_devices.split(",")]
    
    if logical_id >= len(physical_ids):
        raise ValueError(f"Logical GPU ID {logical_id} exceeds available devices in CUDA_VISIBLE_DEVICES={visible_devices}")
    
    return physical_ids[logical_id]


def load_model(model_path, device):
    """Load model on specified device."""
    tokenizer = load_qwen_tokenizer(model_path)
    base_model = load_qwen_model(model_path, device)
    base_model.eval()
    return WhiteboxModel(base_model, tokenizer)


def split_dataset(dataset, num_splits):
    """Split dataset into roughly equal parts."""
    n_samples = len(dataset)
    samples_per_split = (n_samples + num_splits - 1) // num_splits
    
    splits = []
    for i in range(num_splits):
        start_idx = i * samples_per_split
        end_idx = min(start_idx + samples_per_split, n_samples)
        if start_idx < n_samples:
            split = dataset.select(range(start_idx, end_idx))
            splits.append((split, start_idx))  # Include start index for result ordering
    
    return splits


def setup_logger_for_worker(gpu_id: int):
    """Set up logger for worker process with GPU ID prefix."""
    # Create a custom formatter that adds GPU prefix at the beginning
    class GPUFormatter(logging.Formatter):
        def __init__(self):
            super().__init__()
            self.gpu_id = gpu_id
            
        def format(self, record):
            # Format: [GPU X] - module.name - message
            return f"[GPU {self.gpu_id}] - {record.name} - {record.getMessage()}"
    
    # Configure root logger to use our custom formatter
    root_logger = logging.getLogger()
    # Clear existing handlers and add new one with custom formatter
    root_logger.handlers.clear()
    handler = logging.StreamHandler()
    handler.setFormatter(GPUFormatter())
    root_logger.addHandler(handler)
    
    # Return the worker logger
    logger = logging.getLogger(f"worker_gpu_{gpu_id}")
    return logger


def worker_process(
    gpu_id: int, 
    dataset_split, 
    args, 
    start_idx: int, 
    result_queue: mp.Queue,
    save_path: str,
    processed_indices: set = None
):
    """
    Worker process for multi-GPU online best-of-n.
    Each worker independently processes its dataset split.
    """
    try:
        # Set up GPU-specific logger
        logger = setup_logger_for_worker(gpu_id)
        logger.info(f"Worker starting on GPU {gpu_id}")
        
        # Set CUDA device
        # gpu_id is the logical ID within the visible devices
        torch.cuda.set_device(gpu_id)
        device = f"cuda:{gpu_id}"
        logger.info(f"Using device: {device} (logical GPU {gpu_id})")
        
        # Set random seeds with GPU offset
        seed = args.seed + gpu_id
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
        # Load model on this GPU
        logger.info(f"Loading model {args.model_path}")
        model = load_model(args.model_path, device)
        
        # Initialize generator
        generator = DirectOnlineBestOfN(
            model=model,
            uhead_path=args.uhead_path,
            candidates_per_step=args.n,
            max_steps=args.max_steps,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            device=device,
            verbose=args.verbose,
            generation_batch_size=args.generation_batch_size,
            feature_batch_size=args.feature_batch_size,
            memory_efficient=args.memory_efficient
        )
        
        # Process dataset split
        results = []
        # Create temp file in the same directory as final output
        save_dir = os.path.dirname(args.save_path) if hasattr(args, 'save_path') else args.save_dir
        temp_save_path = os.path.join(save_dir, f"gpu{gpu_id}.tmp")
        
        # Count how many samples to actually process (excluding already processed)
        if processed_indices:
            samples_to_process = sum(1 for i in range(len(dataset_split)) 
                                   if (start_idx + i) not in processed_indices)
            logger.info(f"Processing {samples_to_process} samples (skipping {len(dataset_split) - samples_to_process} already processed)")
        else:
            samples_to_process = len(dataset_split)
            logger.info(f"Processing {samples_to_process} samples")
        
        # Create progress bar with GPU-specific description
        pbar = tqdm(
            enumerate(dataset_split),
            total=len(dataset_split),
            desc=f"GPU {gpu_id}",
            position=gpu_id,
            leave=True
        )
        
        for i, sample in pbar:
            global_idx = start_idx + i
            
            # Skip if already processed
            if processed_indices and global_idx in processed_indices:
                if args.verbose:
                    logger.info(f"Skipping trajectory generation for sample {global_idx} (already processed)")
                pbar.set_postfix(status="skipped")
                continue
                
            try:
                # Update progress bar with current sample info
                pbar.set_postfix(idx=global_idx, status="generating")
                
                if args.verbose:
                    logger.info(f"\nProcessing sample {i+1}/{len(dataset_split)} (global index: {global_idx})")
                    logger.info(f"Question: {sample['question']}")
                
                # Generate trajectory
                result = generator.generate_trajectory(sample["question"])
                
                # Extract answer from trajectory
                generated_text = result["trajectory"]
                if "trajectory" in result:
                    generated_text = result["trajectory"].replace(sample["question"], "").strip()
                
                # Don't check correctness here - will be done in phase 2
                
                # Store result
                result_dict = {
                    "index": global_idx,  # Global index (changed from "idx" to "index" for consistency)
                    "question": sample["question"],
                    "gold_answer": sample["answer"],
                    "generated_trajectory": result["trajectory"],
                    "generated_answer": generated_text,
                    "steps": result["steps"],
                    "step_scores": result["step_scores"],
                    "completed": result["completed"]
                }
                
                results.append(result_dict)
                
                # Update progress bar to show completion
                pbar.set_postfix(idx=global_idx, status="completed", steps=len(result["steps"]))
                
                if args.verbose:
                    logger.info(f"Generated answer: {generated_text[:100]}...")
                
                # Save periodically
                if (i + 1) % args.save_frequency == 0:
                    torch.save(results, temp_save_path)
                    logger.info(f"Saved {len(results)} results to {temp_save_path}")
                
                # Clean up memory
                if args.memory_efficient:
                    torch.cuda.empty_cache()
                    gc.collect()
                    
            except Exception as e:
                logger.error(f"Error processing sample {global_idx}: {e}")
                pbar.set_postfix(idx=global_idx, status="error")
                # Add error result
                result_dict = {
                    "index": global_idx,  # Changed from "idx" to "index" for consistency
                    "question": sample["question"],
                    "gold_answer": sample["answer"],
                    "error": str(e),
                    "completed": False
                }
                results.append(result_dict)
        
        # Final save
        torch.save(results, temp_save_path)
        logger.info(f"Final save: {len(results)} results to {temp_save_path}")
        
        # Clean up
        generator.cleanup()
        del model
        torch.cuda.empty_cache()
        gc.collect()
        
        # Send results to main process
        result_queue.put((gpu_id, start_idx, results, temp_save_path))
        logger.info(f"Worker finished successfully")
        
    except Exception as e:
        logger.error(f"Worker failed with error: {e}")
        result_queue.put((gpu_id, start_idx, None, str(e)))


def merge_results(all_results: List[Tuple[int, int, List[Dict], str]], save_path: str):
    """Merge results from all workers and save."""
    # Sort by start index to maintain order
    all_results.sort(key=lambda x: x[1])
    
    # Merge results
    merged_results = []
    for gpu_id, start_idx, results, temp_path in all_results:
        if results is not None:
            merged_results.extend(results)
        
        # Remove temporary file
        if os.path.exists(temp_path):
            os.remove(temp_path)
            log.info(f"Removed temporary file: {temp_path}")
    
    # Sort by original index
    merged_results.sort(key=lambda x: x["index"])
    
    # Save merged results
    torch.save(merged_results, save_path)
    log.info(f"Saved {len(merged_results)} total results to {save_path}")
    
    return merged_results


def _is_correct_answer(generated: str, gold: str, dataset_type: str) -> bool:
    """Check if generated answer matches gold answer."""
    try:
        pred = parse_answer_by_dataset_type(generated, dataset_type)
        gold_parsed = parse_answer_by_dataset_type(gold, dataset_type, is_gold=True)
        # import pdb; pdb.set_trace()
        if pred is None or gold_parsed is None:
            return False
        log.info(f"Generated answer: {pred}, Gold answer: {gold_parsed}")
        log.info(f"Generated answer type: {type(pred)}, Gold answer type: {type(gold_parsed)}")
        return pred == gold_parsed
    except:
        return False


def load_prompt_template(prompt_file: str) -> str:
    """Load prompt template from file"""
    if prompt_file and os.path.exists(prompt_file):
        with open(prompt_file, 'r') as f:
            return f.read().strip()
    else:
        # Default prompt template
        return "{q}"


def verify_with_deepseek(results: List[Dict], args) -> List[Dict]:
    """Verify results using DeepSeek API"""
    log.info(f"Verifying {len(results)} results with DeepSeek ({args.annotation_prompt_type} prompt)")
    
    # Load prompt template
    prompt_template = load_prompt_template(args.prompt_file) if args.prompt_file else "{q}"
    if "{question}" in prompt_template:
        prompt_template = prompt_template.replace("{question}", "{q}")
    
    # Create annotator
    annotator = Annotator(
        prompt=prompt_template,
        n_threads=args.n_threads,
        cache_path="~/.cache",
        annotation_prompt_type=args.annotation_prompt_type
    )
    
    # Prepare data for batch processing
    problems = []
    solutions = []
    gold_answers = []
    result_indices = []
    
    for i, result in enumerate(results):
        if "error" not in result and "generated_answer" in result:
            problems.append(result["question"])
            solutions.append(result["generated_answer"])
            gold_answers.append(result["gold_answer"])
            result_indices.append(i)
    
    if problems:
        try:
            # Get annotations from DeepSeek
            annotations = annotator(problems, solutions, gold_answers)
            
            # Update results with correctness
            for idx, annotation in zip(result_indices, annotations):
                if np.isnan(annotation):
                    log.warning(f"DeepSeek returned unclear result for sample {results[idx]['index']}, marking as incorrect")
                    results[idx]["is_correct_deepseek"] = False
                else:
                    results[idx]["is_correct_deepseek"] = (annotation == 0)  # 0 = correct, 1 = incorrect
                    
        except Exception as e:
            log.error(f"Error during DeepSeek verification: {e}")
            # Fall back to marking all as None
            for idx in result_indices:
                results[idx]["is_correct_deepseek"] = False
    
    return results


def main():
    args = get_parser().parse_args()
    
    # Parse GPU configuration
    gpu_ids = parse_device_arg(args.device)
    num_gpus = len(gpu_ids)
    log.info(f"Using {num_gpus} GPUs: {gpu_ids}")
    
    # Extract dataset name and uhead name for directory structure
    dataset_name = args.dataset_path.split('/')[-1] if '/' in args.dataset_path else args.dataset_path
    uhead_name = args.uhead_path.split('/')[-1] if '/' in args.uhead_path else args.uhead_path
    
    # Create save path with directory structure (same as single GPU version)
    save_path = os.path.join(args.save_dir, dataset_name, f"{uhead_name}.pt")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    log.info(f"Results will be saved to: {save_path}")
    
    # Load existing results if resuming
    existing_results = []
    processed_indices = set()
    
    if args.resume and os.path.exists(save_path):
        log.info(f"Resuming: Loading existing results from {save_path}")
        try:
            existing_results = torch.load(save_path)
            processed_indices = {r["index"] for r in existing_results}
            log.info(f"Loaded {len(existing_results)} existing results")
            log.info(f"Already processed indices: {sorted(processed_indices) if len(processed_indices) < 20 else f'{len(processed_indices)} indices'}")
        except Exception as e:
            log.warning(f"Failed to load existing results: {e}")
            existing_results = []
            processed_indices = set()
    
    # Load dataset
    log.info(f"Loading dataset from {args.dataset_path}")
    if os.path.exists(args.dataset_path):
        dataset = load_from_disk(args.dataset_path)
        if args.dataset_split and args.dataset_split in dataset:
            dataset = dataset[args.dataset_split]
    else:
        dataset = load_dataset(args.dataset_path, split=args.dataset_split)
    
    # Apply subset if specified
    if args.subset:
        dataset = dataset.select(range(min(args.subset, len(dataset))))
    
    log.info(f"Total samples: {len(dataset)}")
    
    # Validate existing results against dataset if resuming
    if existing_results:
        log.info("Validating existing results against current dataset...")
        for result in existing_results:
            idx = result["index"]
            if idx < len(dataset):
                sample = dataset[idx]
                if (result["question"] != sample["question"] or 
                    result["gold_answer"] != sample["answer"]):
                    raise ValueError(
                        f"Sample mismatch at index {idx}!\n"
                        f"Existing question: {result['question'][:100]}...\n"
                        f"Current question: {sample['question'][:100]}...\n"
                        f"Existing answer: {result['gold_answer']}\n"
                        f"Current answer: {sample['answer']}\n"
                        f"The saved results appear to be from a different dataset!"
                    )
        log.info("Validation passed - all existing results match current dataset")
    
    # If resuming, check how many samples are left to process
    if processed_indices:
        remaining_samples = sum(1 for i in range(len(dataset)) if i not in processed_indices)
        log.info(f"Samples remaining to process: {remaining_samples}")
        if remaining_samples == 0:
            log.info("All samples already processed - skipping to correctness checking")
            # Don't return early - proceed to Phase 2 to recheck correctness
            merged_results = existing_results
    
    # Skip generation if all samples are already processed
    if processed_indices and remaining_samples == 0:
        # merged_results already set above
        pass
    # Debug mode - run single process
    elif args.debug:
        log.info("=" * 60)
        log.info("Running in DEBUG MODE - single process")
        log.info("You can now use pdb.set_trace() for debugging")
        log.info("=" * 60)
        
        # Select sample to debug
        if args.debug_sample is not None:
            # Debug specific sample
            debug_dataset = dataset.select([args.debug_sample])
            start_idx = args.debug_sample
        else:
            # Debug first few samples  
            debug_dataset = dataset.select(range(min(5, len(dataset))))
            start_idx = 0
        
        # Run on first GPU
        gpu_id = gpu_ids[0] if gpu_ids else 0
        device = f"cuda:{gpu_id}"
        log.info(f"Using device: {device}")
        
        # Load model and run directly
        model = load_model(args.model_path, device)
        results = []
        
        # Initialize generator
        generator = DirectOnlineBestOfN(
            model=model,
            uhead_path=args.uhead_path,
            candidates_per_step=args.n,
            max_steps=args.max_steps,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            device=device,
            verbose=args.verbose,
            generation_batch_size=args.generation_batch_size,
            memory_efficient=args.memory_efficient
        )
        
        # Process samples with progress bar
        pbar = tqdm(enumerate(debug_dataset), total=len(debug_dataset), desc="Debug mode")
        for idx, sample in pbar:
            global_idx = start_idx + idx
            pbar.set_postfix(idx=global_idx)
            log.info(f"\nProcessing sample {idx+1}/{len(debug_dataset)} (global index: {global_idx})")
            
            # You can add pdb.set_trace() here or in any function
            # import pdb; pdb.set_trace()
            
            result = generator.generate_trajectory(sample["question"])
            
            # Extract answer
            generated_text = result["trajectory"]
            if "question" in sample and sample["question"] in generated_text:
                generated_text = generated_text.replace(sample["question"], "").strip()
            
            # Store result
            result_dict = {
                "index": global_idx,
                "question": sample["question"],
                "gold_answer": sample["answer"],
                "generated_trajectory": result["trajectory"],
                "generated_answer": generated_text,
                "steps": result["steps"],
                "step_scores": result["step_scores"],
                "completed": result["completed"]
            }
            results.append(result_dict)
            
            log.info(f"Generated answer: {parse_answer_by_dataset_type(generated_text, dataset_name)}")
            log.info(f"Gold answer: {sample['answer']}")
        
        # Save debug results
        debug_save_path = save_path.replace('.pt', '_debug.pt')
        torch.save(results, debug_save_path)
        log.info(f"Debug results saved to: {debug_save_path}")
        
        generator.cleanup()
        
        # Continue to Phase 2 in debug mode
        merged_results = results
    
    else:
        # Normal mode - multiprocessing
        # Split dataset across GPUs
        dataset_splits = split_dataset(dataset, num_gpus)
        log.info(f"Split dataset into {len(dataset_splits)} parts")
        
        # Set up multiprocessing
        mp.set_start_method('spawn', force=True)
        result_queue = mp.Queue()
        processes = []
        
        # Start worker processes
        for i, (split, start_idx) in enumerate(dataset_splits):
            # Use logical GPU ID (0, 1, 2...) for the worker
            logical_gpu_id = i
            p = mp.Process(
                target=worker_process,
                args=(logical_gpu_id, split, args, start_idx, result_queue, save_path, processed_indices)
            )
            p.start()
            processes.append(p)
            log.info(f"Started worker on logical GPU {logical_gpu_id} (physical GPU {gpu_ids[i] if i < len(gpu_ids) else 'auto'}) with {len(split)} samples")
        
        # Collect results
        all_results = []
        for _ in range(len(processes)):
            result = result_queue.get()
            all_results.append(result)
            gpu_id = result[0]
            log.info(f"Received results from GPU {gpu_id}")
        
        # Wait for all processes to finish
        for p in processes:
            p.join()
        
        # Merge and save results
        merged_results = merge_results(all_results, save_path)
        
        # If resuming, combine with existing results
        if existing_results:
            log.info(f"Combining {len(merged_results)} new results with {len(existing_results)} existing results")
            # Add existing results
            all_results_dict = {r["index"]: r for r in existing_results}
            # Update with new results (in case of any overwrites)
            for r in merged_results:
                all_results_dict[r["index"]] = r
            # Convert back to list and sort by index
            merged_results = sorted(all_results_dict.values(), key=lambda x: x["index"])
            # Save combined results
            torch.save(merged_results, save_path)
            log.info(f"Saved combined {len(merged_results)} results to {save_path}")
    
    # Phase 2: Check correctness for ALL results (including resumed ones)
    log.info(f"\n{'='*60}")
    log.info("Phase 2: Checking correctness for ALL samples (including resumed)")
    log.info(f"{'='*60}")
    print("ABOUT TO ENTER PDB - merged_results length:", len(merged_results))
    # import pdb; pdb.set_trace()
    if args.correctness_mode == "exact_match":
        # Clear any existing correctness results to force recheck
        log.info("Clearing existing correctness results to force recheck...")
        for result in merged_results:
            if "is_correct_exact_match" in result:
                del result["is_correct_exact_match"]
        
        # Check exact match for all results
        for i, result in enumerate(tqdm(merged_results, desc="Checking correctness (exact match)")):
            # Skip if has error
            if "error" in result:
                continue
                
            try:
                # Determine dataset type from dataset path
                if 'science' in args.dataset_path:
                    dataset_type = 'scienceqa'
                elif 'strategy' in args.dataset_path:
                    dataset_type = 'strategyqa'
                elif 'gsm8k' in args.dataset_path:
                    dataset_type = 'maths'
                elif 'proofnet' in args.dataset_path:
                    dataset_type = 'maths'
                elif 'maths' in args.dataset_path:
                    dataset_type = 'maths'
                elif 'meeting' in args.dataset_path:
                    dataset_type = 'planning'
                elif 'trip' in args.dataset_path:
                    dataset_type = 'planning'
                elif 'calendar' in args.dataset_path:
                    dataset_type = 'planning'
                else:
                    dataset_type = 'maths'  # Default
                
                is_correct = _is_correct_answer(result["generated_answer"], result["gold_answer"], dataset_type)
                # import pdb; pdb.set_trace()
                result["is_correct_exact_match"] = is_correct
                
            except Exception as e:
                log.error(f"Error checking correctness for sample {result['index']}: {e}")
                result["is_correct_exact_match"] = False
                
    elif args.correctness_mode == "deepseek":
        # Clear any existing correctness results to force recheck
        log.info("Clearing existing DeepSeek correctness results to force recheck...")
        for result in merged_results:
            if "is_correct_deepseek" in result:
                del result["is_correct_deepseek"]
        
        # Handle DeepSeek verification
        merged_results = verify_with_deepseek(merged_results, args)
    
    # Final save with correctness results
    torch.save(merged_results, save_path)
    log.info(f"Final save with correctness: {len(merged_results)} results to {save_path}")
    
    # Print summary based on correctness mode
    correctness_key = f"is_correct_{args.correctness_mode}"
    correct = sum(r.get(correctness_key, False) for r in merged_results)
    total = len(merged_results)
    accuracy = correct / total if total > 0 else 0
    
    log.info(f"\n{'='*60}")
    log.info(f"Evaluation complete!")
    log.info(f"Correctness mode: {args.correctness_mode}")
    log.info(f"Total samples: {total}")
    log.info(f"Correct: {correct}")
    log.info(f"Accuracy: {accuracy:.2%}")
    log.info(f"Results saved to: {save_path}")


if __name__ == "__main__":
    main()