#!/usr/bin/env python3
"""
Run direct online best-of-n evaluation with PRM
"""

import argparse
import os
import logging
import random
import gc
import numpy as np
import torch
from tqdm import tqdm
from datasets import load_dataset
from lm_polygraph import WhiteboxModel
import traceback
from online_bestofn.direct_online_bestofn_prm import DirectOnlineBestOfNPRM
from utils import parse_ans, parse_answer_by_dataset_type
from online_bestofn.deepseek_annotation import Annotator
from online_bestofn.run_direct_online_bestofn_multigpu import _is_correct_answer

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("direct_online_bon_prm")

from transformers import AutoTokenizer, AutoModelForCausalLM


def load_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.chat_template = None
    tokenizer.padding_side = 'left'  # Fix padding side for decoder-only models
    return tokenizer


def load_model(model_path: str, device_map: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_path, 
        device_map=device_map, 
        trust_remote_code=True
        # torch_dtype=torch.float16  # Use fp16 by default to save memory
    )
    return model


def get_parser():
    """Command line arguments"""
    parser = argparse.ArgumentParser(description="Direct online best-of-n with PRM")
    
    # Dataset arguments
    parser.add_argument("--dataset-path", type=str, required=True,
                        help="Dataset to evaluate on (HuggingFace name or local path)")
    parser.add_argument("--dataset-split", type=str, default="test", 
                        help="Dataset split to use")
    parser.add_argument("--subset", type=int, default=None,
                        help="Only process first N samples from dataset")
    
    # Model arguments
    parser.add_argument("--model-path", type=str, default="Qwen/Qwen3-1.7B",
                        help="Base model for generation")
    parser.add_argument("--prm-path", type=str, default="Qwen/Qwen2.5-Math-7B-PRM800K",
                        help="Path to PRM model")
    
    # Generation arguments
    parser.add_argument("--n", type=int, default=10,
                        help="Number of candidates per step")
    parser.add_argument("--temperature", type=float, default=0.7,
                        help="Generation temperature")
    parser.add_argument("--max-new-tokens", type=int, default=250,
                        help="Max tokens per step")
    parser.add_argument("--max-steps", type=int, default=30,
                        help="Maximum number of reasoning steps")
    
    # Memory optimization arguments
    parser.add_argument("--batch-size", type=int, default=None,
                        help="Batch size for candidate generation (default: same as --n)")
    parser.add_argument("--sequential-generation", action="store_true",
                        help="Generate candidates one by one to save memory")
    
    # Output arguments
    parser.add_argument("--save-dir", type=str, required=True,
                        help="Directory to save results (will create subdirectories)")
    parser.add_argument("--prompt-file", type=str, default=None,
                        help="Path to prompt template file (optional)")
    
    # System arguments
    parser.add_argument("--device", type=str, default="cuda:0",
                        help="Device to use for base model (default: cuda:0)")
    parser.add_argument("--prm-device", type=str, default="cuda:1",
                        help="Device to use for PRM model (default: cuda:1)")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed")
    parser.add_argument("--verbose", action="store_true",
                        help="Enable verbose logging")
    parser.add_argument("--hf-cache", type=str, default=None,
                        help="HuggingFace cache directory")
    parser.add_argument("--resume", action="store_true", default=True,
                        help="Resume from existing save file (default: True)")
    parser.add_argument("--no-resume", dest="resume", action="store_false",
                        help="Do not resume from existing save file")
    parser.add_argument("--correctness-mode", type=str, default="exact_match",
                        choices=["exact_match", "deepseek"],
                        help="Method for checking answer correctness (default: exact_match)")
    parser.add_argument("--n-threads", type=int, default=1,
                        help="Number of threads for DeepSeek verification (default: 1)")
    parser.add_argument("--annotation-prompt-type", type=str, default="non_unique",
                        choices=["unique", "non_unique"],
                        help="Type of annotation prompt for DeepSeek (default: non_unique)")
    
    return parser


def load_prompt_template(prompt_file: str) -> str:
    """Load prompt template from file"""
    if prompt_file and os.path.exists(prompt_file):
        with open(prompt_file, 'r') as f:
            return f.read().strip()
    else:
        # Default prompt template for PRM
        return "Question: {question}\n\nLet's solve this step by step.\n\n"


def prepare_dataset_with_prompts(dataset, prompt_template: str):
    """Add prompts to dataset questions"""
    
    def add_prompt(example):
        # Format prompt with question
        if "{question}" in prompt_template:
            example["question"] = prompt_template.format(question=example["question"])
        else:
            example["question"] = prompt_template + example["question"]
        return example
    
    return dataset.map(add_prompt)


def main(args):
    """Main evaluation function"""
    
    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    log.info(f"Set random seed to {args.seed}")
    
    # Extract dataset name and PRM name for directory structure
    dataset_name = args.dataset_path.split('/')[-1] if '/' in args.dataset_path else args.dataset_path
    prm_name = args.prm_path.split('/')[-1] if '/' in args.prm_path else args.prm_path
    
    # Create save path with directory structure
    save_path = os.path.join(args.save_dir, dataset_name, f"{prm_name}.pt")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    log.info(f"Results will be saved to: {save_path}")
    
    # Load dataset
    log.info(f"Loading dataset: {args.dataset_path} ({args.dataset_split})")
    dataset = load_dataset(
        args.dataset_path, 
        split=args.dataset_split,
        cache_dir=args.hf_cache
    )
    
    # Add prompts if provided
    # if args.prompt_file:
    #     prompt_template = load_prompt_template(args.prompt_file)
    #     dataset = prepare_dataset_with_prompts(dataset, prompt_template)
    #     log.info(f"Added prompts from {args.prompt_file}")
    
    # Load model
    log.info(f"Loading model: {args.model_path}")
    tokenizer = load_tokenizer(args.model_path)
    base_model = load_model(args.model_path, args.device)
    base_model.eval()
    model = WhiteboxModel(base_model, tokenizer)
    
    # Run direct online best-of-n with PRM
    log.info(f"Starting direct online best-of-n evaluation with PRM")
    log.info(f"  - PRM model: {args.prm_path}")
    log.info(f"  - Candidates per step: {args.n}")
    log.info(f"  - Temperature: {args.temperature}")
    log.info(f"  - Max tokens per step: {args.max_new_tokens}")
    log.info(f"  - Max steps: {args.max_steps}")
    

    
    # Load existing results if resuming
    results = []
    processed_indices = set()
    
    if args.resume and os.path.exists(save_path):
        log.info(f"Loading existing results from {save_path}")
        try:
            results = torch.load(save_path)
            processed_indices = {r["index"] for r in results}
            log.info(f"Loaded {len(results)} existing results")
            log.info(f"Already processed indices: {sorted(processed_indices)}")
            
            # Validate all existing results match current dataset
            log.info("Validating existing results against current dataset...")
            for result in results:
                idx = result["index"]
                if idx < len(dataset):
                    sample = dataset[idx]
                    if (result["question"] != sample["question"] or 
                        result["gold_answer"] != sample["answer"]):
                        raise ValueError(
                            f"Sample mismatch at index {idx}!\n"
                            f"Existing question: {result['question']}...\n"
                            f"Current question: {sample['question']}...\n"
                            f"Existing answer: {result['gold_answer']}\n"
                            f"Current answer: {sample['answer']}\n"
                            f"The saved results appear to be from a different dataset!"
                        )
            log.info("Validation passed - all existing results match current dataset")
            
        except Exception as e:
            if "Sample mismatch" in str(e):
                raise  # Re-raise validation errors
            log.warning(f"Failed to load existing results: {e}")
            results = []
            processed_indices = set()
    
    # Create generator with PRM
    log.info(f"Using device {args.device} for base model, {args.prm_device} for PRM")
    
    # Determine batch size for generation
    batch_size = args.batch_size if args.batch_size else args.n
    if args.sequential_generation:
        batch_size = 1
        log.info("Using sequential generation (batch_size=1) to save memory")
    elif batch_size < args.n:
        log.info(f"Using batch generation with batch_size={batch_size}")
    
    # Load prompt template if provided
    prompt_template = load_prompt_template(args.prompt_file) if args.prompt_file else None
    
    generator = DirectOnlineBestOfNPRM(
        model=model,
        prm_model_path=args.prm_path,
        candidates_per_step=args.n,
        max_steps=args.max_steps,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        device=args.device,
        prm_device=args.prm_device,
        verbose=args.verbose,
        generation_batch_size=batch_size,
        prompt_template=prompt_template
    )
    
    # Process dataset
    subset_size = min(args.subset, len(dataset)) if args.subset else len(dataset)
    
    # Phase 1: Generate trajectories (without checking correctness)
    log.info(f"\n{'='*60}")
    log.info("Phase 1: Generating trajectories")
    log.info(f"{'='*60}")
    
    for i in tqdm(range(subset_size), desc="Generating trajectories"):
        # Skip if already processed
        if i in processed_indices:
            if args.verbose:
                log.info(f"Skipping trajectory generation for sample {i} (already processed)")
            continue
            
        sample = dataset[i]
        
        if args.verbose:
            log.info(f"\n{'='*60}")
            log.info(f"Sample {i+1}/{subset_size}")
            log.info(f"Question: {sample['question'][:200]}...")
        
        try:
            # Generate trajectory
            result = generator.generate_trajectory(sample["question"])
            
            # Extract generated answer (but don't check correctness yet)
            generated_text = result["trajectory"]
            if sample["question"] in generated_text:
                generated_text = generated_text.replace(sample["question"], "").strip()
            
            # Store result WITHOUT correctness check
            results.append({
                "index": i,
                "question": sample["question"],
                "gold_answer": sample["answer"],
                "generated_trajectory": result["trajectory"],
                "generated_answer": generated_text,
                "steps": result["steps"],
                "step_scores": result["step_scores"],  # These are rewards (higher=better)
                "completed": result["completed"]
            })
            
            if args.verbose:
                log.info(f"Generated: {generated_text}")
                log.info(f"Num steps: {len(result['steps'])}")
                if result['step_scores']:
                    log.info(f"Avg step reward: {np.mean(result['step_scores']):.3f}")
            
        except Exception as e:
            log.error(f"Error processing sample {i}: {e}")
            traceback.print_exc()
            
            results.append({
                "index": i,
                "question": sample["question"],
                "gold_answer": sample["answer"],
                "error": str(e),
                "completed": False
            })
        
        # Save periodically
        if len(results) % 10 == 0:
            torch.save(results, save_path)
            log.info(f"Saved {len(results)} results to {save_path}")
            
            # Clean up memory periodically
            torch.cuda.empty_cache()
            gc.collect()
    
    # Final save after generation
    torch.save(results, save_path)
    log.info(f"Final save after generation: {len(results)} results to {save_path}")
    
    # Cleanup
    generator.cleanup()
    
    # Phase 2: Check correctness for all results (including resumed ones)
    log.info(f"\n{'='*60}")
    log.info("Phase 2: Checking correctness for ALL samples (including resumed)")
    log.info(f"{'='*60}")
    
    # For correctness checking, we need to check ALL results (not just new ones)
    # This ensures resumed results also get their correctness checked
    all_results_for_checking = results
    
    if args.correctness_mode == "exact_match":
        # Clear any existing correctness results to force recheck
        log.info("Clearing existing correctness results to force recheck...")
        for result in all_results_for_checking:
            if "is_correct_exact_match" in result:
                del result["is_correct_exact_match"]
        
        # Use exact match checking
        for i, result in enumerate(tqdm(all_results_for_checking, desc="Checking correctness (exact match)")):
            # Skip if this result has an error
            if "error" in result:
                continue
                
            try:
                # Determine dataset type from dataset path
                if 'science' in args.dataset_path:
                    dataset_type = 'scienceqa'
                elif 'strategy' in args.dataset_path:
                    dataset_type = 'strategyqa'
                elif 'gsm8k' in args.dataset_path:
                    dataset_type = 'maths'
                elif 'proofnet' in args.dataset_path:
                    dataset_type = 'maths'
                elif 'maths' in args.dataset_path:
                    dataset_type = 'maths'
                elif 'meeting' in args.dataset_path:
                    dataset_type = 'planning'
                elif 'trip' in args.dataset_path:
                    dataset_type = 'planning'
                elif 'calendar' in args.dataset_path:
                    dataset_type = 'planning'
                else:
                    # Default to generic parsing
                    dataset_type = 'maths'
                
                is_correct = _is_correct_answer(result["generated_answer"], result["gold_answer"], dataset_type)
                result["is_correct_exact_match"] = is_correct
                
                if args.verbose and i % 10 == 0:  # Log every 10th for less clutter
                    log.info(f"\nSample {result['index']}:")
                    log.info(f"Generated answer: {parse_answer_by_dataset_type(result['generated_answer'], dataset_type)}")
                    log.info(f"Gold answer: {parse_answer_by_dataset_type(result['gold_answer'], dataset_type)}")
                    log.info(f"Correct: {is_correct}")
            except Exception as e:
                log.error(f"Error checking correctness for sample {result['index']}: {e}")
                result["is_correct_exact_match"] = False
                
    elif args.correctness_mode == "deepseek":
        # Clear any existing correctness results to force recheck
        log.info("Clearing existing DeepSeek correctness results to force recheck...")
        for result in all_results_for_checking:
            if "is_correct_deepseek" in result:
                del result["is_correct_deepseek"]
        
        # Use DeepSeek verification
        log.info(f"Using DeepSeek verification with {args.n_threads} threads")
        
        # Load prompt template and ensure compatibility
        prompt_template = load_prompt_template(args.prompt_file) if args.prompt_file else "{q}"
        if "{question}" in prompt_template:
            prompt_template = prompt_template.replace("{question}", "{q}")
        
        # print(f'Using prompt template:\n{prompt_template}')
        # import pdb; pdb.set_trace()
        # Create annotator
        annotator = Annotator(
            prompt=prompt_template,
            n_threads=args.n_threads,
            cache_path="~/.cache",
            annotation_prompt_type=args.annotation_prompt_type
        )
        
        # Prepare data for batch processing
        problems = []
        solutions = []
        gold_answers = []
        result_indices = []
        
        # always process all results, since we have deepseek cache.
        for i, result in enumerate(all_results_for_checking):
            if "error" not in result:
                problems.append(result["question"])
                solutions.append(result["generated_answer"])
                gold_answers.append(result["gold_answer"])
                result_indices.append(i)
        # import pdb; pdb.set_trace()
        if problems:
            log.info(f"Verifying {len(problems)} solutions with DeepSeek ({args.annotation_prompt_type} prompt)...")
            
            # Get annotations from DeepSeek
            try:
                annotations = annotator(problems, solutions, gold_answers)
                
                # Update results with correctness
                for idx, annotation in zip(result_indices, annotations):
                    if np.isnan(annotation):
                        log.warning(f"DeepSeek returned unclear result for sample {all_results_for_checking[idx]['index']}, marking as incorrect")
                        all_results_for_checking[idx]["is_correct_deepseek"] = False
                    else:
                        all_results_for_checking[idx]["is_correct_deepseek"] = (annotation == 0)  # 0 = correct, 1 = incorrect
                    
                    if args.verbose and (idx - result_indices[0]) % 10 == 0:
                        log.info(f"\nSample {all_results_for_checking[idx]['index']}:")
                        log.info(f"DeepSeek annotation: {annotation}")
                        log.info(f"Correct: {all_results_for_checking[idx]['is_correct_deepseek']}")
                        
            except Exception as e:
                log.error(f"Error during DeepSeek verification: {e}")
                # Fall back to marking all as incorrect
                for idx in result_indices:
                    all_results_for_checking[idx]["is_correct_deepseek"] = False
    
    # Final save with correctness results
    # Save all results (which now includes correctness for all samples)
    torch.save(all_results_for_checking, save_path)
    log.info(f"Final save with correctness: {len(all_results_for_checking)} results to {save_path}")
    
    # Update results to point to all_results_for_checking for summary statistics
    results = all_results_for_checking
    
    # Print summary statistics
    # Use the appropriate correctness key based on the mode
    correctness_key = f"is_correct_{args.correctness_mode}"
    correct = sum(r.get(correctness_key, False) for r in results)
    completed = sum(r.get("completed", False) for r in results)
    errors = sum("error" in r for r in results)
    
    log.info(f"\n{'='*60}")
    log.info(f"Evaluation Summary:")
    log.info(f"  - Correctness mode: {args.correctness_mode}")
    log.info(f"  - Total samples: {len(results)}")
    log.info(f"  - Completed: {completed} ({completed/len(results):.1%})")
    log.info(f"  - Correct ({args.correctness_mode}): {correct} ({correct/len(results):.1%})")
    log.info(f"  - Errors: {errors}")
    
    if completed > 0:
        log.info(f"  - Accuracy (of completed): {correct/completed:.1%}")
    
    # Average statistics for rewards
    all_rewards = []
    all_steps = []
    for r in results:
        if "step_scores" in r and r["step_scores"]:
            all_rewards.extend(r["step_scores"])
            all_steps.append(len(r["steps"]))
    
    if all_rewards:
        log.info(f"\nStep Statistics:")
        log.info(f"  - Avg steps per trajectory: {np.mean(all_steps):.1f}")
        log.info(f"  - Avg step reward: {np.mean(all_rewards):.3f}")
        log.info(f"  - Min step reward: {np.min(all_rewards):.3f}")
        log.info(f"  - Max step reward: {np.max(all_rewards):.3f}")


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()
    main(args)