import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import re
import json
import yaml
from datasets import load_dataset, Dataset, concatenate_datasets
from tqdm import tqdm
import datetime
import argparse
import numpy as np
from collections import defaultdict
import sys
from functools import lru_cache
import torch.nn.functional as F

from global_file_manager import GlobalFileManager
from max_prob_based_branching import MaxProbBasedBatchProcessor

from math_dapo import (
    last_boxed_only_string,
    remove_boxed,
    normalize_final_answer,
    is_correct_minerva,
    is_correct_strict_box,
    verify,
    compute_score
)

parser = argparse.ArgumentParser(description='Max-Prob-Based Batch Parallel Token-Level Branching Evaluation')

# basic args
parser.add_argument("--model_path", type=str, required=True, help="Path to model directory")
parser.add_argument("--log_dir", type=str, default="eval_logs", help="Log directory")
parser.add_argument("--result_dir", type=str, help="Result directory")
parser.add_argument("--dataset_name", type=str, default="dapo_converted", help="Dataset name")
parser.add_argument("--dataset_split", type=str, default="test", help="Dataset split")
parser.add_argument("--math_subset", type=str, default="algebra", help="Math subset")

# max-prob branching args
parser.add_argument("--max_prob_threshold", type=float, default=0.7, help="Max probability threshold for branching")
parser.add_argument("--branch_count", type=int, default=3, help="Number of branches at each branching point")
parser.add_argument("--timestamp", type=str, help="Custom timestamp for result folder")
parser.add_argument("--max_branch_depth", type=int, default=5, help="Maximum branch depth")
parser.add_argument("--rollout_times", type=int, default=1, help="Number of rollouts at max branch depth")
parser.add_argument("--batch_size", type=int, default=100, help="Batch size for parallel processing")
parser.add_argument("--logprobs_k", type=int, default=5, help="Number of top-k logprobs to retrieve")

# sampling args
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top_p", type=float, default=1.0, help="Top-p sampling threshold")
parser.add_argument("--max_tokens", type=int, default=4096, help="Max generated tokens")
parser.add_argument("--strict_box", action="store_true", default=True, help="Use strict box verification")

# GPU args
parser.add_argument("--gpu_devices", type=str, default="0,1,2,3", help="GPU devices to use (comma-separated)")
parser.add_argument("--tensor_parallel_size", type=int, default=4, help="Tensor parallel size")

# sample range args
parser.add_argument("--start_index", type=int, default=None, help="Start index for sample range (0-based)")
parser.add_argument("--end_index", type=int, default=None, help="End index for sample range (exclusive, 0-based)")

args = parser.parse_args()

# init params
MODEL_PATH = args.model_path
LOG_DIR = args.log_dir
DATASET_NAME = args.dataset_name
DATASET_SPLIT = args.dataset_split
MATH_SUBSET = args.math_subset
STRICT_BOX_VERIFY = args.strict_box

# max-prob params
MAX_PROB_THRESHOLD = args.max_prob_threshold
BRANCH_COUNT = args.branch_count
LOGPROBS_K = args.logprobs_k
BATCH_SIZE = args.batch_size

# GPU参数
GPU_DEVICES = args.gpu_devices
TENSOR_PARALLEL_SIZE = args.tensor_parallel_size

# sample range vars
START_INDEX = args.start_index
END_INDEX = args.end_index

# misc
MODEL_NAME = MODEL_PATH.split("/")[-1]

# timestamp
if args.timestamp:
    TIMESTAMP = args.timestamp
else:
    TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# result dir
if args.result_dir:
    RESULT_DIR = args.result_dir
else:
    RESULT_DIR = os.path.join(LOG_DIR, f"max_prob_batch_vllm_run_{TIMESTAMP}")

os.makedirs(RESULT_DIR, exist_ok=True)

def get_simple_prompt(question):
    """Return a simple prompt structure"""
    return [
        {"role": "user", "content": question}
    ]

def extract_question_text(question_data):
    """Extract question text from mixed structures"""
    if isinstance(question_data, str):
        return question_data
    elif isinstance(question_data, dict):
        if 'question' in question_data:
            return question_data['question']
        elif 'text' in question_data:
            return question_data['text']
        elif 'prompt' in question_data:
            return question_data['prompt']
    return str(question_data)

def get_dapo_converted_questions(split="train", name='all') -> Dataset:
    """Load DAPO-Math-17k-converted dataset with the new format"""
    parquet_path = '/data_train/kaylhao/hzz/GRPO/datasets/DAPO-Math-17k-converted/all/train-00000-of-00001.parquet'
    
    try:
        import pandas as pd
        df = pd.read_parquet(parquet_path)
        data = Dataset.from_pandas(df)
        
        data = data.map(lambda x: {
            'prompt': get_simple_prompt(x['prompt'][0]['content']),
            'solution': x['reward_model']['ground_truth'],
            'question': x['prompt'][0]['content'],
            'clean_question': extract_question_text(x['prompt'][0]['content'])
        })
        
        print(f"✅ Successfully loaded DAPO-Math-17k-converted dataset from {parquet_path}")
        print(f"   - Total samples: {len(data)}")
        print(f"   - First question preview: {data[0]['question'][:100]}...")
        
        return data
        
    except Exception as e:
        print(f"❌ Error loading DAPO-Math-17k-converted dataset: {e}")
        path = '/data_train/kaylhao/hzz/GRPO/datasets/DAPO-Math-17k-converted'
        data = load_dataset(path, name)[split]
        data = data.map(lambda x: {
            'prompt': get_simple_prompt(x['prompt'][0]['content']),
            'solution': x['reward_model']['ground_truth'],
            'question': x['prompt'][0]['content'],
            'clean_question': extract_question_text(x['prompt'][0]['content'])
        })
        return data

def get_gsm8k_questions(split="train") -> Dataset:
    file_path = '/data_train/kaylhao/hzz/GRPO/datasets/gsm8k'

    def extract_hash_answer(text: str) -> str | None:
        if "####" not in text:
            return None
        return text.split("####")[1].strip().replace(",", "").replace("$", "")
    
    data = load_dataset(file_path, 'default')[split]
    data = data.map(lambda x: {
        'prompt': get_simple_prompt(x['question']),
        'solution': f"{extract_hash_answer(x['answer'])}",
        'question': x['question'],
        'clean_question': extract_question_text(x['question'])
    })
    return data

def get_math_questions(split="train", name='algebra') -> Dataset:
    file_path = '/data_train/kaylhao/hzz/GRPO/datasets/MATH'
    
    data = load_dataset(file_path, name)[split]
    data = data.map(lambda x: {
        'prompt': get_simple_prompt(x['problem']),
        'solution': x['solution'],
        'question': x['problem'],
        'clean_question': extract_question_text(x['problem'])
    })
    return data

def evaluate_model_with_max_prob_branching():
    """Main evaluation: max-prob based batch-parallel token-level branching"""
    
    # load dataset
    print("Loading dataset...")
    if DATASET_NAME == "gsm8k":
        dataset = get_gsm8k_questions(split=DATASET_SPLIT)
    elif DATASET_NAME == "math":
        dataset = get_math_questions(split=DATASET_SPLIT, name=MATH_SUBSET)
    elif DATASET_NAME == "dapo_converted":
        dataset = get_dapo_converted_questions(split=DATASET_SPLIT)
    else:
        raise ValueError(f"Unknown dataset: {DATASET_NAME}")
    
    print(f"Dataset loaded. Size: {len(dataset)}")
    
    # validate dataset not empty
    if len(dataset) == 0:
        raise ValueError("Dataset is empty! Please check dataset loading.")
    
    print(f"Max-Prob-Based Batch Parallel Token-Level Branching Parameters:")
    print(f"   - Max Prob Threshold: {MAX_PROB_THRESHOLD}")
    print(f"   - Branch Count (K): {BRANCH_COUNT}")
    print(f"   - Logprobs k: {LOGPROBS_K}")
    print(f"   - Batch Size: {BATCH_SIZE}")
    print(f"   - Max Branch Depth: {args.max_branch_depth}")
    print(f"   - Rollout Times: {args.rollout_times}")
    print(f"   - Using VLLM: True")
    
    # handle sample range
    if START_INDEX is not None or END_INDEX is not None:
        # defaults
        start_idx = START_INDEX if START_INDEX is not None else 0
        end_idx = END_INDEX if END_INDEX is not None else len(dataset)
        
        # validate range
        if start_idx < 0:
            print(f"⚠️ Warning: start_index {start_idx} < 0, setting to 0")
            start_idx = 0
        
        if end_idx > len(dataset):
            print(f"⚠️ Warning: end_index {end_idx} > dataset_size {len(dataset)}, setting to {len(dataset)}")
            end_idx = len(dataset)
        
        if start_idx >= end_idx:
            raise ValueError(f"Invalid sample range: start_index={start_idx}, end_index={end_idx}, dataset_size={len(dataset)}")
        
        actual_num_questions = end_idx - start_idx
        print(f"Processing sample range: {start_idx}-{end_idx-1} ({actual_num_questions} questions)")
        print(f"   - Sample range: [{start_idx}, {end_idx})")
    else:
        # process all if no range
        actual_num_questions = len(dataset)
        start_idx = 0
        end_idx = actual_num_questions
        print(f"\n🎯 Processing all {actual_num_questions} questions in batches of {BATCH_SIZE}:")
    
    # init VLLM
    print("Initializing VLLM...")
    try:
        from vllm import LLM, SamplingParams
        
        # set GPU devices
        os.environ["CUDA_VISIBLE_DEVICES"] = GPU_DEVICES
        
        # init VLLM
        llm = LLM(
            model=MODEL_PATH,
            tensor_parallel_size=TENSOR_PARALLEL_SIZE,
            trust_remote_code=True,
            gpu_memory_utilization=0.9
        )
        
        # get tokenizer
        tokenizer = llm.get_tokenizer()
        
        print(f"VLLM initialized successfully")
        print(f"   - Model: {MODEL_NAME}")
        print(f"   - Tensor parallel size: {TENSOR_PARALLEL_SIZE}")
        print(f"   - GPU devices: {GPU_DEVICES}")
        
    except Exception as e:
        print(f"❌ Error initializing VLLM: {e}")
        raise
    
    # init global file manager
    global_file_manager = GlobalFileManager(MODEL_NAME, RESULT_DIR)
    
    # init max-prob batch processor
    max_prob_processor = MaxProbBasedBatchProcessor(
        model=None,  # 不使用transformers模型
        tokenizer=tokenizer,
        llm=llm,
        branch_count=BRANCH_COUNT,
        max_branch_depth=args.max_branch_depth,
        rollout_times=args.rollout_times,
        max_prob_threshold=MAX_PROB_THRESHOLD,
        global_file_manager=global_file_manager,
        batch_size=BATCH_SIZE,
        strict_box_verify=STRICT_BOX_VERIFY,
        logprobs_k=LOGPROBS_K
    )
    
    all_results = []
    total_accuracy = 0.0
    
    # process in batches
    for batch_start in range(start_idx, end_idx, BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, end_idx)
        current_batch_size = batch_end - batch_start
        
        print(f"\n" + "="*60)
        print(f"🔄 Processing batch {batch_start//BATCH_SIZE + 1}: questions {batch_start+1}-{batch_end}")
        print("="*60)
        
        # prepare batch data
        questions_batch = []
        for i in range(batch_start, batch_end):
            item = dataset[i]
            questions_batch.append({
                'question_id': i + 1,
                'question': item['question'],
                'solution': item['solution'],
                'prompt': item['prompt']
            })
        
        # process batch with max-prob branching
        batch_results = max_prob_processor.process_batch_with_max_prob_branching(
            questions_batch, RESULT_DIR, MODEL_NAME
        )
        
        # handle batch results
        for question_id, state in batch_results.items():
            # compute question accuracy
            correct_branches = sum(1 for result in state['branch_results'].values() if result['is_correct'])
            total_branches = len(state['branch_results'])
            
            if total_branches > 0:
                question_accuracy = correct_branches / total_branches
            else:
                question_accuracy = 0.0
            
            total_accuracy += question_accuracy
            
            # store question result
            question_result = {
                "question_id": question_id,
                "question": state['question_text'],
                "ground_truth": state['solution'],
                "max_prob_branching_results": {
                    "max_prob_threshold": MAX_PROB_THRESHOLD,
                    "logprobs_k": LOGPROBS_K,
                    "total_branches": total_branches,
                    "correct_branches": correct_branches,
                    "question_accuracy": question_accuracy,
                    "branch_results": state['branch_results']
                }
            }
            
            all_results.append(question_result)
            
            print(f"   - Question {question_id}: {correct_branches}/{total_branches} branches correct ({question_accuracy*100:.1f}%)")
    
    # overall accuracy
    overall_accuracy = total_accuracy / actual_num_questions if actual_num_questions > 0 else 0.0
    
    print(f"\n" + "="*60)
    print(f"🎉 All {actual_num_questions} questions completed!")
    print(f"📊 Overall Accuracy: {overall_accuracy*100:.2f}% ({int(total_accuracy)}/{actual_num_questions})")
    print("="*60)
    
    # finalize global file manager
    global_stats = global_file_manager.finalize_results()
    
    # save summary
    summary_results = {
        "timestamp": TIMESTAMP,
        "model": MODEL_NAME,
        "model_path": MODEL_PATH,
        "dataset": DATASET_NAME,
        "num_questions_processed": actual_num_questions,
        "sample_range": {
            "start_index": start_idx,
            "end_index": end_idx,
            "range_size": actual_num_questions
        },
        "max_prob_branching_parameters": {
            "max_prob_threshold": MAX_PROB_THRESHOLD,
            "logprobs_k": LOGPROBS_K,
            "batch_size": BATCH_SIZE,
            "branch_count": BRANCH_COUNT,
            "max_branch_depth": args.max_branch_depth,
            "rollout_times": args.rollout_times
        },
        "overall_results": {
            "total_accuracy": total_accuracy,
            "overall_accuracy": overall_accuracy,
            "questions_processed": actual_num_questions,
            "global_stats": global_stats
        },
        "individual_results": all_results
    }
    
    # save to json
    summary_path = os.path.join(RESULT_DIR, f"{MODEL_NAME}_max_prob_batch_summary.json")
    with open(summary_path, 'w', encoding='utf-8') as f:
        json.dump(summary_results, f, indent=2, ensure_ascii=False)
    
    print(f"Summary results saved to: {summary_path}")
    
    return overall_accuracy

if __name__ == "__main__":
    print("=" * 60)
    print("Max-Prob-Based Batch Parallel Token-Level Branching Evaluation")
    print("=" * 60)
    print(f"  Model: {MODEL_NAME}")
    print(f"     Path: {MODEL_PATH}")
    print(f"  Dataset: {DATASET_NAME}/{DATASET_SPLIT}")
    print(f"  Max-Prob Branching Parameters:")
    print(f"     - Max Prob Threshold: {MAX_PROB_THRESHOLD}")
    print(f"     - Logprobs k: {LOGPROBS_K}")
    print(f"     - Batch Size: {BATCH_SIZE}")
    print(f"     - Branch Count (K): {BRANCH_COUNT}")
    print(f"     - Max Branch Depth: {args.max_branch_depth}")
    print(f"     - Rollout Times: {args.rollout_times}")
    if START_INDEX is not None or END_INDEX is not None:
        start_display = START_INDEX if START_INDEX is not None else 0
        end_display = END_INDEX if END_INDEX is not None else "end"
        print(f"     - Sample Range: [{start_display}, {end_display})")
        print(f"     - Questions to Process: {end_display - start_display}")
    else:
        print(f"     - Sample Range: All questions")
    print(f"  Model Config:")
    print(f"     - Model Type: VLLM")
    print(f"     - Temperature: {args.temperature}")
    print(f"     - Top-p: {args.top_p}")
    print(f"  📂 Output Directory: {RESULT_DIR}")
    print("=" * 60)
    
    try:
        final_accuracy = evaluate_model_with_max_prob_branching()
        
        print("\n" + "=" * 60)
        print("Max-Prob-Based Batch Parallel Token-Level Branching Evaluation Complete!")
        print("=" * 60)
        print(f"  Overall Accuracy: {final_accuracy*100:.2f}%")
        print(f"  Results saved to: {RESULT_DIR}")
        print(f"  Generated Files:")
        print(f"     - {MODEL_NAME}_all_results.txt (Main log)")
        print(f"     - {MODEL_NAME}_all_overview.txt (Main overview)")
        print(f"     - {MODEL_NAME}_max_prob_batch_summary.json (Summary)")
        print(f"     - Individual question files")
        print("=" * 60)
        
    except Exception as e:
        print(f"\nError during evaluation: {e}")
        import traceback
        traceback.print_exc()
        print(f"Partial results may be saved in: {RESULT_DIR}")