import torch
import os
import re
import json
import datetime
import numpy as np
from collections import defaultdict
import sys
from functools import lru_cache
import torch.nn.functional as F
import math

from math_dapo import (
    last_boxed_only_string,
    remove_boxed,
    normalize_final_answer,
    is_correct_minerva,
    is_correct_strict_box,
    verify,
    compute_score
)

class MaxProbBasedBatchProcessor:
    """Max-prob-based batch-parallel processor"""
    
    def __init__(self, model, tokenizer, llm, branch_count, max_branch_depth, rollout_times, 
                 max_prob_threshold, global_file_manager, batch_size=100, strict_box_verify=True, logprobs_k=5):
        self.model = model
        self.tokenizer = tokenizer
        self.llm = llm
        self.branch_count = branch_count
        self.max_branch_depth = max_branch_depth
        self.rollout_times = rollout_times
        self.max_prob_threshold = max_prob_threshold
        self.global_file_manager = global_file_manager
        self.batch_size = batch_size
        self.strict_box_verify = strict_box_verify
        self.logprobs_k = logprobs_k
        
        print(f"Max-Prob-Based Batch Processor initialized")
        print(f"   - Batch size: {batch_size}")
        print(f"   - Branch count: {branch_count}")
        print(f"   - Max depth: {max_branch_depth}")
        print(f"   - Max prob threshold: {max_prob_threshold}")
        print(f"   - Logprobs k: {logprobs_k}")
        print(f"   - Rollout times: {rollout_times}")
    
    def get_max_prob_from_logprobs(self, logprobs_data):
        """Get max probability from VLLM logprobs data"""
        try:
            if not logprobs_data or len(logprobs_data) == 0:
                return 1.0
            
            # last token logprobs
            last_token_logprobs = logprobs_data[-1]
            
            if not last_token_logprobs:
                return 1.0
            
            # extract values
            logprobs_values = list(last_token_logprobs.values())
            
            # numerically stable probability computation
            logprobs_tensor = torch.tensor(logprobs_values, dtype=torch.float32)
            
            # use log_softmax to avoid overflow
            log_probs = torch.log_softmax(logprobs_tensor, dim=0)
            probs = torch.exp(log_probs)
            
            # get max probability
            max_prob = torch.max(probs).item()
            
            # clamp to [0,1]
            max_prob = max(0.0, min(1.0, max_prob))
            
            return max_prob
            
        except Exception as e:
            print(f"Error getting max prob from logprobs: {e}")
            return 1.0
    
    def should_branch_at_token(self, max_prob, current_depth, max_depth):
        """判断是否应该在当前token处进行分叉"""
        # 最大概率阈值判断：概率越低，越不确定，越适合分叉
        if max_prob > self.max_prob_threshold:
            return False  # 概率太高，确定性很强，不需要分叉
        
        # 深度限制
        if current_depth >= max_depth:
            return False
        
        return True
    
    def process_batch_with_max_prob_branching(self, questions_batch, result_dir, model_name):
        """使用基于最大概率的分叉机制处理批次"""
        print(f"\n🔄 Processing batch of {len(questions_batch)} questions with max-prob-based branching...")
        
        # 初始化每个问题的状态
        question_states = {}
        for i, question_data in enumerate(questions_batch):
            question_id = question_data['question_id']
            question_text = question_data['question']
            solution = question_data['solution']
            prompt = question_data['prompt']
            
            # 转换初始prompt
            try:
                current_text = self.tokenizer.apply_chat_template(
                    prompt, tokenize=False, add_generation_prompt=True
                )
            except Exception as e:
                print(f"❌ Error creating initial text for question {question_id}: {e}")
                current_text = str(prompt)
            
            question_states[question_id] = {
                'question_text': question_text,
                'solution': solution,
                'current_text': current_text,
                'initial_length': len(current_text),
                'current_layer': [(current_text, "")],  # (text, path) tuples
                'branch_results': {},
                'log_content': [],
                'save_path': os.path.join(result_dir, f"{model_name}_{question_id}.txt"),
                'max_prob_history': []  # 记录最大概率历史
            }
        
        # 迭代处理每一层
        for layer in range(self.max_branch_depth):
            print(f"\n🚀 Layer {layer}: Processing with max-prob-based branching")
            
            # 收集当前层所有问题的所有节点
            all_current_responses = []
            all_question_ids = []
            all_branch_paths = []
            
            for question_id, state in question_states.items():
                current_layer = state['current_layer']
                
                if not current_layer:
                    continue
                
                # 收集当前层的所有响应
                for response, path in current_layer:
                    all_current_responses.append(response)
                    all_question_ids.append(question_id)
                    all_branch_paths.append(path)
            
            if not all_current_responses:
                print(f"📝 No active nodes in layer {layer}")
                break
            
            print(f"📊 Processing {len(all_current_responses)} nodes in parallel...")
            print(f"   - Batch size for VLLM: {len(all_current_responses)}")
            print(f"   - Questions in batch: {set(all_question_ids)}")
            
            # 第一步：批量获取所有节点的最大概率值（并行）
            all_max_probabilities = self.batch_get_max_probabilities(all_current_responses)
            
            # 第二步：基于概率值决定分叉策略（保持与原始方案一致的分叉逻辑）
            expanded_responses = []
            expanded_question_ids = []
            expanded_branch_paths = []
            
            for i, (response, question_id, path, max_prob) in enumerate(
                zip(all_current_responses, all_question_ids, all_branch_paths, all_max_probabilities)
            ):
                should_branch = self.should_branch_at_token(
                    max_prob, 
                    len(path.split('.')) if path else 0,
                    self.max_branch_depth
                )
                
                # 添加调试信息
                current_depth = len(path.split('.')) if path else 0
                print(f"   - Question {question_id}, Path: {path}, Depth: {current_depth}, Max Prob: {max_prob:.3f}, Threshold: {self.max_prob_threshold:.3f}, Should Branch: {should_branch}")
                if should_branch:
                    print(f"     ✅ WILL BRANCH: Prob {max_prob:.3f} <= Threshold {self.max_prob_threshold:.3f}")
                else:
                    if max_prob > self.max_prob_threshold:
                        print(f"     ❌ NO BRANCH: Prob {max_prob:.3f} > Threshold {self.max_prob_threshold:.3f}")
                    elif current_depth >= self.max_branch_depth:
                        print(f"     ❌ NO BRANCH: Depth {current_depth} >= Max Depth {self.max_branch_depth}")
                    else:
                        print(f"     ❌ NO BRANCH: Unknown reason")
                
                # 基于概率决定是否分叉，但分叉后行为与原始方案一致
                if should_branch:
                    # 进行分叉：创建branch_count个分支
                    for j in range(self.branch_count):
                        expanded_responses.append(response)
                        expanded_question_ids.append(question_id)
                        
                        # 构建分支路径
                        if path == "":
                            new_path = f"{j + 1}"
                        else:
                            new_path = f"{path}.{j + 1}"
                        expanded_branch_paths.append(new_path)
                else:
                    # 不分叉：继续生成，只保留一个分支
                    expanded_responses.append(response)
                    expanded_question_ids.append(question_id)
                    expanded_branch_paths.append(path)
            
            if not expanded_responses:
                print(f"📝 No nodes to generate in layer {layer}")
                break
            
            print(f"📊 Batch generating {len(expanded_responses)} responses in parallel...")
            
            # 第三步：批量生成下一层（真正的并行）
            next_layer_responses = self.batch_generate_next_tokens(expanded_responses)
            
            # 第四步：处理结果
            self.process_batch_results(question_states, expanded_question_ids, expanded_branch_paths, 
                                    expanded_responses, next_layer_responses, layer)
        
        # 处理rollout队列和保存结果
        self.process_global_rollout_queue(question_states)
        
        for question_id, state in question_states.items():
            self.save_question_results(question_id, state)
        
        return question_states
    
    def get_max_prob_from_vllm(self, text):
        """使用VLLM获取单个文本的最大概率值（保留用于备用）"""
        try:
            from vllm import SamplingParams
            
            # 设置采样参数，启用logprobs
            sampling_params = SamplingParams(
                n=1,
                temperature=0.7,
                top_p=1.0,
                max_tokens=1,  # 只生成一个token来获取其概率分布
                logprobs=self.logprobs_k  # 获取top-k logprobs
            )
            
            # 生成一个token并获取logprobs
            outputs = self.llm.generate([text], sampling_params)
            
            if outputs and len(outputs) > 0:
                output = outputs[0]
                if hasattr(output, 'logprobs') and output.logprobs:
                    # 计算最大概率值
                    max_prob = self.get_max_prob_from_logprobs(output.logprobs)
                    return max_prob
            
            # 如果无法获取logprobs，使用启发式方法
            return self.calculate_heuristic_max_prob(text)
            
        except Exception as e:
            print(f"❌ Error getting max prob from VLLM: {e}")
            return self.calculate_heuristic_max_prob(text)
    
    def calculate_heuristic_max_prob(self, text):
        """计算启发式最大概率值（备用方法）"""
        try:
            # 简单的启发式：基于文本长度和内容复杂度
            if len(text) < 50:
                return 0.9  # 高概率，确定性很强
            
            # 检查是否包含数学符号
            math_symbols = ['+', '-', '*', '/', '=', '(', ')', '\\', '{', '}']
            math_count = sum(1 for char in text if char in math_symbols)
            
            # 检查是否包含数字
            digit_count = sum(1 for char in text if char.isdigit())
            
            # 计算复杂度分数
            complexity = (math_count + digit_count) / len(text)
            
            # 转换为最大概率值（0.1-1.0范围，复杂度越高，概率越低）
            max_prob = max(0.1, 1.0 - complexity * 2.0)
            
            return max_prob
            
        except Exception as e:
            return 0.5  # 默认中等概率值
    
    def batch_get_max_probabilities(self, texts):
        """批量获取多个文本的最大概率值（真正的并行）"""
        try:
            from vllm import SamplingParams
            
            print(f"🔍 Batch getting max probabilities for {len(texts)} texts...")
            
            # 批量获取概率分布
            sampling_params = SamplingParams(
                n=1,
                temperature=0.7,
                top_p=1.0,
                max_tokens=1,  # 只生成一个token来获取概率分布
                logprobs=self.logprobs_k  # 获取top-k logprobs
            )
            
            # 并行处理所有文本
            outputs = self.llm.generate(texts, sampling_params)
            
            max_probabilities = []
            for output in outputs:
                if hasattr(output, 'logprobs') and output.logprobs:
                    # 计算最大概率值
                    max_prob = self.get_max_prob_from_logprobs(output.logprobs)
                    max_probabilities.append(max_prob)
                else:
                    # 使用启发式方法
                    max_prob = self.calculate_heuristic_max_prob(texts[len(max_probabilities)])
                    max_probabilities.append(max_prob)
            
            print(f"✅ Got max probabilities: {len(max_probabilities)} values")
            return max_probabilities
            
        except Exception as e:
            print(f"❌ Error in batch_get_max_probabilities: {e}")
            # 回退到启发式方法
            return [self.calculate_heuristic_max_prob(text) for text in texts]
    
    def batch_generate_next_tokens(self, texts):
        """批量生成下一个token（真正的并行）"""
        try:
            from vllm import SamplingParams
            
            print(f"🚀 Batch generating next tokens for {len(texts)} texts...")
            
            # 批量生成下一个token
            sampling_params = SamplingParams(
                n=1,
                temperature=0.7,
                top_p=1.0,
                max_tokens=1,  # 只生成一个token
                logprobs=self.logprobs_k  # 获取top-k logprobs
            )
            
            # 并行生成所有文本的下一个token
            outputs = self.llm.generate(texts, sampling_params)
            
            next_tokens = []
            for output in outputs:
                new_text = output.outputs[0].text
                next_tokens.append(new_text)
            
            print(f"✅ Generated next tokens: {len(next_tokens)} responses")
            return next_tokens
            
        except Exception as e:
            print(f"❌ Error in batch_generate_next_tokens: {e}")
            return [""] * len(texts)
    
    def process_batch_results(self, question_states, all_question_ids, all_branch_paths, 
                            all_expanded_responses, next_layer_responses, layer):
        """处理批量生成的结果"""
        # 按问题分组处理结果
        for question_id, state in question_states.items():
            current_layer = state['current_layer']
            if not current_layer:
                continue
            
            # 计算该问题在all_expanded_responses中的位置
            question_start_idx = None
            question_end_idx = None
            
            for i, qid in enumerate(all_question_ids):
                if qid == question_id:
                    if question_start_idx is None:
                        question_start_idx = i
                    question_end_idx = i + 1
            
            if question_start_idx is not None:
                # 提取该问题的响应和路径
                question_responses = all_expanded_responses[question_start_idx:question_end_idx]
                question_paths = all_branch_paths[question_start_idx:question_end_idx]
                question_new_responses = next_layer_responses[question_start_idx:question_end_idx]
                
                # 处理该问题的结果
                terminated_branches = []
                non_terminated_branches = []
                
                for i, (original_response, path, new_text) in enumerate(zip(question_responses, question_paths, question_new_responses)):
                    full_text = original_response + new_text
                    should_terminate, termination_reason = self.check_branch_termination(full_text, state['initial_length'])
                    
                    if should_terminate:
                        # 处理已终止的分支
                        is_correct, extracted_answer, status = self.verify_answer_correctness(full_text, path, state['solution'])
                        boxed_answer = self.extract_boxed_answer(full_text, state['initial_length'])
                        
                        state['branch_results'][path] = {
                            'boxed_answer': boxed_answer,
                            'extracted_answer': extracted_answer,
                            'is_correct': is_correct,
                            'termination_reason': termination_reason,
                            'status': status
                        }
                        
                        # 记录日志
                        final_content = full_text[state['initial_length']:].strip()
                        completion_info = f" [BRANCH {path} COMPLETE - {termination_reason.upper()}] {status}"
                        state['log_content'].append(f"{final_content}{completion_info}")
                        
                        terminated_branches.append((full_text, path))
                    else:
                        non_terminated_branches.append((full_text, path))
                
                # 更新该问题的下一层
                question_states[question_id]['current_layer'] = non_terminated_branches
            
            # 如果达到最大深度，添加到rollout队列
            if layer == self.max_branch_depth - 1 and non_terminated_branches:
                for text, path in non_terminated_branches:
                    self.add_to_global_rollout_queue(text, path, question_id, state['solution'])
    
    def check_branch_termination(self, text, initial_length):
        """检查分支是否应该终止"""
        if len(text) <= initial_length:
            return False, "none"
        
        new_content = text[initial_length:]
        
        try:
            if text.endswith(self.tokenizer.eos_token):
                return True, "eos_token"
        except:
            pass
        
        return False, "none"
    
    def verify_answer_correctness(self, text, branch_path, solution):
        """验证分支答案的正确性"""
        try:
            result = compute_score(
                solution_str=text,
                ground_truth=solution,
                strict_box_verify=self.strict_box_verify
            )
            
            is_correct = result['acc']
            extracted_answer = result['pred']
            
            if is_correct:
                status = "✅ CORRECT"
            else:
                status = f"❌ WRONG (got: {extracted_answer}, expected: {solution})"
                
            return bool(is_correct), extracted_answer, status
            
        except Exception as e:
            error_msg = f"⚠️ VERIFICATION ERROR: {e}"
            return False, None, error_msg
    
    def extract_boxed_answer(self, text, initial_length):
        """提取文本中的\boxed{}答案"""
        try:
            new_content = text[initial_length:] if len(text) > initial_length else text
            
            if "\\boxed{" in new_content:
                import re
                pattern = r'\\boxed\{([^}]*)\}'
                matches = re.findall(pattern, new_content)
                if matches:
                    return matches[-1]
            
            return "Invalid"
        except Exception as e:
            return "Invalid"
    
    def add_to_global_rollout_queue(self, text, branch_path, question_id, solution):
        """添加到全局rollout队列"""
        if not hasattr(self, 'global_rollout_queue'):
            self.global_rollout_queue = []
        
        self.global_rollout_queue.append((text, branch_path, question_id, solution))
    
    def process_global_rollout_queue(self, question_states):
        """处理全局rollout队列"""
        if not hasattr(self, 'global_rollout_queue') or not self.global_rollout_queue:
            return
        
        print(f"🎲 Processing global rollout queue with {len(self.global_rollout_queue)} branches...")
        
        # 收集所有需要rollout的文本
        all_rollout_texts = []
        rollout_info = []
        
        for text, branch_path, question_id, solution in self.global_rollout_queue:
            all_rollout_texts.extend([text] * self.rollout_times)
            rollout_info.extend([(question_id, branch_path, solution)] * self.rollout_times)
        
        print(f"🎲 Batch rollout generating {len(all_rollout_texts)} responses...")
        
        # 批量生成所有rollout
        all_rollout_results = self.batch_rollout_vllm(all_rollout_texts)
        
        # 按问题分组处理结果
        for i, (text, branch_path, question_id, solution) in enumerate(self.global_rollout_queue):
            start_idx = i * self.rollout_times
            end_idx = start_idx + self.rollout_times
            branch_rollout_results = all_rollout_results[start_idx:end_idx]
            
            # 验证每个rollout结果
            validated_results = []
            for rollout_text, _ in branch_rollout_results:
                try:
                    result = compute_score(
                        solution_str=text + rollout_text,
                        ground_truth=solution,
                        strict_box_verify=self.strict_box_verify
                    )
                    is_correct = bool(result['acc'])
                except Exception as e:
                    is_correct = False
                validated_results.append((rollout_text, is_correct))
            
            # 计算准确率
            correct_count = sum(1 for _, is_correct in validated_results if is_correct)
            rollout_acc = correct_count / self.rollout_times if self.rollout_times > 0 else 0.0
            
            # 保存到对应问题的结果中
            question_states[question_id]['branch_results'][branch_path] = {
                'boxed_answer': '[ROLLOUT]',
                'extracted_answer': '[ROLLOUT]',
                'is_correct': rollout_acc > 0,
                'termination_reason': f'rollout_{self.rollout_times}',
                'status': f'ROLLOUT ACCURACY: {rollout_acc:.2f} ({int(rollout_acc*self.rollout_times)}/{self.rollout_times})',
                'rollout_results': validated_results
            }
        
        # 清空队列
        self.global_rollout_queue.clear()
    
    def batch_rollout_vllm(self, rollout_texts):
        """批量rollout"""
        try:
            from vllm import SamplingParams
            
            sampling_params = SamplingParams(
                n=1,
                temperature=0.7,
                top_p=1.0,
                max_tokens=3072
            )
            
            outputs = self.llm.generate(rollout_texts, sampling_params)
            
            all_results = []
            for output in outputs:
                new_content = output.outputs[0].text
                all_results.append((new_content.strip(), False))
            
            return all_results
            
        except Exception as e:
            print(f"❌ Error in batch_rollout_vllm: {e}")
            return [("ERROR", False)] * len(rollout_texts)
    
    def save_question_results(self, question_id, state):
        """保存单个问题的结果"""
        # 保存详细日志
        with open(state['save_path'], 'w', encoding='utf-8') as f:
            f.write("🚀 Max-Prob-Based Token-Level Branching Log\n")
            f.write("=" * 50 + "\n")
            f.write(f"Question ID: {question_id}\n")
            f.write(f"Question: {state['question_text']}\n")
            f.write(f"Solution: {state['solution']}\n")
            f.write(f"Max Prob Threshold: {self.max_prob_threshold}\n")
            f.write(f"Logprobs k: {self.logprobs_k}\n")
            f.write("=" * 50 + "\n\n")
            f.write("\n".join(state['log_content']))
        
        # 保存概览
        overview_path = state['save_path'].replace('.txt', '_overview.txt')
        with open(overview_path, 'w', encoding='utf-8') as f:
            f.write("🌳 Max-Prob-Based K-ary Tree Overview\n")
            f.write("=" * 60 + "\n")
            f.write(f"Question ID: {question_id}\n")
            f.write(f"Ground Truth: {state['solution']}\n")
            f.write(f"Max Prob Threshold: {self.max_prob_threshold}\n")
            f.write(f"Logprobs k: {self.logprobs_k}\n")
            f.write(f"Total Branches: {len(state['branch_results'])}\n")
            f.write("=" * 60 + "\n\n")
            
            correct_count = sum(1 for result in state['branch_results'].values() if result['is_correct'])
            total_count = len(state['branch_results'])
            
            f.write(f"📊 Summary: {correct_count}/{total_count} branches correct\n")
            f.write(f"📊 Final Accuracy: {correct_count/total_count*100:.2f}%\n\n")
            
            for branch_path, result in sorted(state['branch_results'].items()):
                status_icon = "✅" if result['is_correct'] else "❌"
                f.write(f"{branch_path}: {status_icon} {result['boxed_answer']} - {result['termination_reason']}\n")
        
        # 添加到全局文件管理器
        if self.global_file_manager:
            log_content = "\n".join(state['log_content'])
            self.global_file_manager.add_question_result(
                question_id=question_id,
                question_text=state['question_text'],
                solution=state['solution'],
                branch_results=state['branch_results'],
                log_content=log_content
            ) 