#!/usr/bin/env python3
"""Parallel branching processor implementing batch prompt branching generation"""

import torch
import os
import re
import json
import datetime
import numpy as np
from collections import defaultdict
import sys
from functools import lru_cache
import torch.nn.functional as F

from math_dapo import (
    last_boxed_only_string,
    remove_boxed,
    normalize_final_answer,
    is_correct_minerva,
    is_correct_strict_box,
    verify,
    compute_score
)

class ParallelBranchingProcessor:
    """Parallel branching processor"""
    
    def __init__(self, model, tokenizer, llm, branch_count, max_branch_depth, rollout_times, 
                 max_prob_threshold, global_file_manager, batch_size=100, strict_box_verify=True, logprobs_k=5):
        self.model = model
        self.tokenizer = tokenizer
        self.llm = llm
        self.branch_count = branch_count
        self.max_branch_depth = max_branch_depth
        self.rollout_times = rollout_times
        self.max_prob_threshold = max_prob_threshold
        self.global_file_manager = global_file_manager
        self.batch_size = batch_size
        self.strict_box_verify = strict_box_verify
        self.logprobs_k = logprobs_k
        
        print(f"Parallel Branching Processor initialized")
        print(f"   - Batch size: {batch_size}")
        print(f"   - Branch count: {branch_count}")
        print(f"   - Max depth: {max_branch_depth}")
        print(f"   - Max prob threshold: {max_prob_threshold}")
        print(f"   - Logprobs k: {logprobs_k}")
        print(f"   - Rollout times: {rollout_times}")
    
    def get_max_prob_from_logprobs(self, logprobs_data):
        """Get max probability from VLLM logprobs data"""
        try:
            if not logprobs_data or len(logprobs_data) == 0:
                return 1.0
            
            # take last token logprobs
            last_token_logprobs = logprobs_data[-1]
            
            if not last_token_logprobs:
                return 1.0
            
            # extract values
            logprobs_values = list(last_token_logprobs.values())
            
            # numerically stable probability computation
            logprobs_tensor = torch.tensor(logprobs_values, dtype=torch.float32)
            
            # check invalid values
            if torch.isnan(logprobs_tensor).any() or torch.isinf(logprobs_tensor).any():
                print(f"Warning: Invalid logprobs values detected, using fallback")
                return 0.5
            
            # use log_softmax to avoid overflow
            try:
                log_probs = torch.log_softmax(logprobs_tensor, dim=0)
                probs = torch.exp(log_probs)
                
                # validate probability values
                if torch.isnan(probs).any() or torch.isinf(probs).any():
                    print(f"Warning: Invalid probability values detected, using fallback")
                    return 0.5
                
                # get max probability
                max_prob = torch.max(probs).item()
                
                # clamp to [0,1]
                max_prob = max(0.0, min(1.0, max_prob))
                
            except Exception as e:
                print(f"Warning: Error in probability calculation: {e}, using fallback")
                return 0.5
            
            return max_prob
            
        except Exception as e:
            print(f"Error getting max prob from logprobs: {e}")
            return 1.0
    
    def should_branch_at_token(self, max_prob, current_depth, max_depth):
        """Decide whether to branch at current token"""
        if max_prob > self.max_prob_threshold:
            return False
        
        if current_depth >= max_depth:
            return False
        
        return True
    
    def process_parallel_branching(self, questions_batch, result_dir, model_name):
        """Parallel branching processing"""
        print(f"\nProcessing {len(questions_batch)} questions with parallel branching...")
        
        # init question states
        question_states = {}
        for i, question_data in enumerate(questions_batch):
            question_id = question_data['question_id']
            question_text = question_data['question']
            solution = question_data['solution']
            prompt = question_data['prompt']
            
            # convert initial prompt
            try:
                current_text = self.tokenizer.apply_chat_template(
                    prompt, tokenize=False, add_generation_prompt=True
                )
            except Exception as e:
                print(f"Error creating initial text for question {question_id}: {e}")
                current_text = str(prompt)
            
            question_states[question_id] = {
                'question_text': question_text,
                'solution': solution,
                'current_text': current_text,
                'initial_length': len(current_text),
                'branch_results': {},
                'log_content': [],
                'save_path': os.path.join(result_dir, f"{model_name}_{question_id}.txt"),
                'rollout_queue': []
            }
        
        # init all sequences
        all_sequences = []
        sequence_to_question = {}
        
        # create initial sequences per question
        for question_id, state in question_states.items():
            for i in range(self.branch_count):
                sequence_id = f"{question_id}_{i+1}"
                all_sequences.append({
                    'sequence_id': sequence_id,
                    'question_id': question_id,
                    'text': state['current_text'],
                    'path': f"{i+1}",
                    'branch_count': 0,
                    'is_active': True,
                    'initial_length': state['initial_length']
                })
                sequence_to_question[sequence_id] = question_id
        
        print(f"Initialized {len(all_sequences)} sequences for {len(questions_batch)} questions")
        
        # main loop: process all sequences in parallel
        iteration = 0
        while True:
            iteration += 1
            print(f"\nIteration {iteration}: Processing {len(all_sequences)} sequences")
            
            # split active and rollout sequences
            active_sequences = [seq for seq in all_sequences if seq['is_active']]
            
            if not active_sequences:
                print("No active sequences, breaking...")
                break
            
            print(f"Active sequences: {len(active_sequences)}")
            
            # debug samples (first 3)
            for seq in active_sequences[:3]:
                print(f"   - Active: {seq['sequence_id']}, Path: {seq['path']}, Branch Count: {seq['branch_count']}")
            
            # log
            for state in question_states.values():
                state['log_content'].append(f"=== Iteration {iteration} ===")
                state['log_content'].append(f"Total sequences: {len(all_sequences)}")
                state['log_content'].append(f"Active sequences: {len(active_sequences)}")
            
            # gather texts
            active_texts = [seq['text'] for seq in active_sequences]
            
            # batch generate next token (parallel)
            print(f"Batch generating {len(active_texts)} tokens with VLLM...")
            next_tokens, max_probabilities = self.batch_generate_with_probabilities(active_texts)
            print(f"Generated {len(next_tokens)} tokens successfully")
            
            # 处理每个序列的结果
            new_sequences = []
            rollout_sequences = []
            
            for i, (seq, next_token, max_prob) in enumerate(zip(active_sequences, next_tokens, max_probabilities)):
                # update text
                seq['text'] += next_token
                
                # decide branch
                should_branch = self.should_branch_at_token(
                    max_prob, 
                    seq['branch_count'], 
                    self.max_branch_depth
                )
                
                # debug
                print(f"   - Sequence {seq['sequence_id']}, Path: {seq['path']}, "
                      f"Branch Count: {seq['branch_count']}, Max Prob: {max_prob:.3f}, "
                      f"Threshold: {self.max_prob_threshold:.3f}, Should Branch: {should_branch}")
                
                if should_branch:
                    print(f"     WILL BRANCH: Prob {max_prob:.3f} <= Threshold {self.max_prob_threshold:.3f}")
                    
                    # max depth check
                    if seq['branch_count'] >= self.max_branch_depth:
                        print(f"     READY FOR ROLLOUT: Max depth reached")
                        seq['is_active'] = False
                        rollout_sequences.append(seq)
                    else:
                        # create branches
                        for j in range(self.branch_count):
                            new_seq = {
                                'sequence_id': f"{seq['sequence_id']}_{j+1}",
                                'question_id': seq['question_id'],
                                'text': seq['text'],
                                'path': f"{seq['path']}.{j+1}",
                                'branch_count': seq['branch_count'] + 1,
                                'is_active': True,
                                'initial_length': seq['initial_length']
                            }
                            new_sequences.append(new_seq)
                            sequence_to_question[new_seq['sequence_id']] = seq['question_id']
                        
                        # mark original as inactive
                        seq['is_active'] = False
                else:
                    if max_prob > self.max_prob_threshold:
                        print(f"     NO BRANCH: Prob {max_prob:.3f} > Threshold {self.max_prob_threshold:.3f}")
                    elif seq['branch_count'] >= self.max_branch_depth:
                        print(f"     NO BRANCH: Depth {seq['branch_count']} >= Max Depth {self.max_branch_depth}")
                    else:
                        print(f"     NO BRANCH: Unknown reason")
                    
                    # 继续当前序列
                    new_sequences.append(seq)
            
            # 更新序列列表
            all_sequences = new_sequences
            
            # batch rollout if needed
            if rollout_sequences:
                print(f"Processing {len(rollout_sequences)} sequences for rollout...")
                self.batch_rollout_sequences(rollout_sequences, question_states)
        
        # force rollout remaining active sequences
        remaining_active = [seq for seq in all_sequences if seq['is_active']]
        if remaining_active:
            print(f"Force rolling out {len(remaining_active)} remaining active sequences...")
            self.batch_rollout_sequences(remaining_active, question_states)
        
        # save results
        print(f"Saving results for {len(question_states)} questions...")
        for question_id, state in question_states.items():
            print(f"   - Saving question {question_id}...")
            self.save_question_results(question_id, state)
        
        print(f"All results saved successfully!")
        return question_states
    
    def batch_generate_with_probabilities(self, texts):
        """Batch-generate next token and get probabilities"""
        try:
            from vllm import SamplingParams
            
            # sampling params
            sampling_params = SamplingParams(
                n=1,
                temperature=0.7,
                top_p=1.0,
                max_tokens=1,
                logprobs=self.logprobs_k
            )
            
            # generate
            outputs = self.llm.generate(texts, sampling_params)
            
            next_tokens = []
            max_probabilities = []
            
            for i, output in enumerate(outputs):
                # generated token
                generated_text = output.outputs[0].text
                next_token = generated_text
                next_tokens.append(next_token)
                
                # get max prob
                if hasattr(output, 'logprobs') and output.logprobs:
                    max_prob = self.get_max_prob_from_logprobs(output.logprobs)
                else:
                    max_prob = 0.5
                
                max_probabilities.append(max_prob)
                
                # debug few
                if i < 3:
                    print(f"     Debug: Output {i}, Token: '{next_token}', Max Prob: {max_prob:.3f}")
            
            return next_tokens, max_probabilities
            
        except Exception as e:
            print(f"Error in batch_generate_with_probabilities: {e}")
            # 返回默认值
            return [""] * len(texts), [0.5] * len(texts)
    
    def batch_rollout_sequences(self, sequences, question_states):
        """Batch rollout sequences until termination"""
        try:
            from vllm import SamplingParams
            
            # prepare texts
            rollout_texts = [seq['text'] for seq in sequences]
            
            # sampling params
            sampling_params = SamplingParams(
                n=1,
                temperature=0.7,
                top_p=1.0,
                max_tokens=4096,
                logprobs=0
            )
            
            # generate
            outputs = self.llm.generate(rollout_texts, sampling_params)
            
            # handle results
            for i, (seq, output) in enumerate(zip(sequences, outputs)):
                full_text = output.outputs[0].text
                question_id = seq['question_id']
                path = seq['path']
                
                # verify answer
                state = question_states[question_id]
                is_correct, extracted_answer, status = self.verify_answer_correctness(
                    full_text, path, state['solution']
                )
                boxed_answer = self.extract_boxed_answer(full_text, state['initial_length'])
                
                # save results
                state['branch_results'][path] = {
                    'boxed_answer': boxed_answer,
                    'extracted_answer': extracted_answer,
                    'is_correct': is_correct,
                    'termination_reason': 'rollout_complete',
                    'status': status
                }
                
                # log
                final_content = full_text[state['initial_length']:].strip()
                completion_info = f" [BRANCH {path} COMPLETE - ROLLOUT] {status}"
                state['log_content'].append(f"{final_content}{completion_info}")
                
                print(f"   Rollout completed for {seq['sequence_id']}: {status}")
            
        except Exception as e:
            print(f"Error in batch_rollout_sequences: {e}")
    
    def verify_answer_correctness(self, text, branch_path, solution):
        """Verify branch answer correctness"""
        try:
            result = compute_score(
                solution_str=text,
                ground_truth=solution,
                strict_box_verify=self.strict_box_verify
            )
            
            is_correct = result['acc']
            extracted_answer = result['pred']
            
            if is_correct:
                status = "CORRECT"
            else:
                status = f"WRONG (got: {extracted_answer}, expected: {solution})"
                
            return bool(is_correct), extracted_answer, status
            
        except Exception as e:
            error_msg = f"VERIFICATION ERROR: {e}"
            return False, None, error_msg
    
    def extract_boxed_answer(self, text, initial_length):
        """Extract boxed answer"""
        try:
            new_content = text[initial_length:]
            boxed_match = re.search(r'\\boxed\{([^}]+)\}', new_content)
            if boxed_match:
                return boxed_match.group(1)
            return "[NO_BOXED_ANSWER]"
        except Exception as e:
            return f"[EXTRACTION_ERROR: {e}]"
    
    def save_question_results(self, question_id, state):
        """保存单个问题的结果"""
        try:
            # ensure dir exists
            import os
            os.makedirs(os.path.dirname(state['save_path']), exist_ok=True)
            
            # write detailed log
            with open(state['save_path'], 'w', encoding='utf-8') as f:
                f.write("Parallel Branching System Log\n")
                f.write("=" * 50 + "\n")
                f.write(f"Question ID: {question_id}\n")
                f.write(f"Question: {state['question_text']}\n")
                f.write(f"Solution: {state['solution']}\n")
                f.write(f"Max Prob Threshold: {self.max_prob_threshold}\n")
                f.write(f"Logprobs k: {self.logprobs_k}\n")
                f.write("=" * 50 + "\n\n")
                f.write("\n".join(state['log_content']))
            
            print(f"     Saved detailed log to: {state['save_path']}")
        except Exception as e:
            print(f"     Error saving detailed log: {e}")
        
        # write overview
        try:
            overview_path = state['save_path'].replace('.txt', '_overview.txt')
            with open(overview_path, 'w', encoding='utf-8') as f:
                f.write("Parallel Branching System Overview\n")
                f.write("=" * 60 + "\n")
                f.write(f"Question ID: {question_id}\n")
                f.write(f"Ground Truth: {state['solution']}\n")
                f.write(f"Max Prob Threshold: {self.max_prob_threshold}\n")
                f.write(f"Logprobs k: {self.logprobs_k}\n")
                f.write(f"Total Branches: {len(state['branch_results'])}\n")
                f.write("=" * 60 + "\n\n")
                
                correct_count = sum(1 for result in state['branch_results'].values() if result['is_correct'])
                total_count = len(state['branch_results'])
                
                f.write(f"Summary: {correct_count}/{total_count} branches correct\n")
                f.write(f"Final Accuracy: {correct_count/total_count*100:.2f}%\n\n")
                
                for branch_path, result in sorted(state['branch_results'].items()):
                    status_icon = "OK" if result['is_correct'] else "ERR"
                    f.write(f"{branch_path}: {status_icon} {result['boxed_answer']} - {result['termination_reason']}\n")
            
            print(f"     Saved overview to: {overview_path}")
        except Exception as e:
            print(f"     Error saving overview: {e}")
        
        # 添加到全局文件管理器
        if self.global_file_manager:
            log_content = "\n".join(state['log_content'])
            self.global_file_manager.add_question_result(
                question_id=question_id,
                question_text=state['question_text'],
                solution=state['solution'],
                branch_results=state['branch_results'],
                log_content=log_content
            ) 