import argparse
import numpy as np
import os
import torch
import json
from collections import Counter, defaultdict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
import re
import math

total_instances_dict = {
    "aime_Eurus2_7B_sft": 240,
    "amc_Eurus2_7B_sft": 664,
    "gpqa_Eurus2_7B_sft": 1584,
    "leetcode_Eurus2_7B_sft": 1440,
    "math_Eurus2_7B_sft": 4000,
    "minerva_math_Eurus2_7B_sft": 2176,
    "olympiadbench_Eurus2_7B_sft": 5400,
    "aime_Meta_Llama_3.1_70B_Instruct": 240,
    "amc_Meta_Llama_3.1_70B_Instruct": 664,
    "gpqa_Meta_Llama_3.1_70B_Instruct": 1584,
    "leetcode_Meta_Llama_3.1_70B_Instruct": 1440,
    "math_Meta_Llama_3.1_70B_Instruct": 4000,
    "minerva_math_Meta_Llama_3.1_70B_Instruct": 2176,
    "olympiadbench_Meta_Llama_3.1_70B_Instruct": 5400,
    "aime_Qwen2.5_7B_Instruct": 240,
    "amc_Qwen2.5_7B_Instruct": 664,
    "gpqa_Qwen2.5_7B_Instruct": 1584,
    "leetcode_Qwen2.5_7B_Instruct": 1440,
    "math_Qwen2.5_7B_Instruct": 4000,
    "minerva_math_Qwen2.5_7B_Instruct": 2176,
    "olympiadbench_Qwen2.5_7B_Instruct": 5400,
}

def extract_answer(solution_text: str):
    boxed_pattern = r'\\boxed\{([^}]*)\}'
    matches = re.findall(boxed_pattern, solution_text)
    if matches:
        return matches[-1].strip()
    return None

def apply_chat_templates(toker, messages):
    m = toker.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    results = []
    
    if "<|begin_of_text|>" in m:
        results.append(toker(m.split("<|begin_of_text|>")[1], add_special_tokens=False).input_ids)
    else:
        results.append(toker(m, add_special_tokens=False).input_ids)
    return results


def prepare_input_boxed(template, input_data, toker, step_idx):
    inputs = []
    checks = []
    for input_d in input_data:
        id_ = input_d['id']
        problem = input_d['problem']
        steps = input_d['steps']
        
        if step_idx >= len(steps):
            continue
            
        previous_steps = ''
        for sdx in range(step_idx+1):
            if sdx >= len(steps):
                break
                
            step = steps[sdx]
            current_step = f'<paragraph_{sdx}>\n{step}\n</paragraph_{sdx}>\n\n'
            if sdx < step_idx:
                previous_steps += current_step
                continue
                
            if sdx == 0:
                prompt = template.replace("[Previous Paragraph(s)]\n\n{tagged_response}\n\n", "").format(problem=problem, current_paragraph=current_step)
            else:
                prompt = template.format(problem=problem, tagged_response=previous_steps, current_paragraph=current_step)
            
            inputs.append(prompt.strip())
            checks.append({'id': id_})

    results = []
    for item in inputs:
        messages = [{'role': 'user', 'content': item}]
        results.extend(apply_chat_templates(toker, messages))

    return results, checks

def get_logprobs(generation, toker, use_softmax=False):
    # Define token IDs
    token_ids = {
        'zero': 15,      # Token ID for "0"
        'one': 16,       # Token ID for "1"
        'boxed': 79075,  # Token ID for "boxed"
        'left_brace': 90,  # Token ID for "{"
        'right_brace': 92,  # Token ID for "}",
        'new_right_brace': 31716,  # Another Token ID for "}",
        'another_right_brace': 532,
        "another_right_brace2": 7810
    }

    all_tokens = generation.outputs[0].token_ids
    
    # Look for valid boxed patterns
    valid_patterns = [
        [token_ids['boxed'], token_ids['left_brace'], token_ids['zero'], token_ids['right_brace']],
        [token_ids['boxed'], token_ids['left_brace'], token_ids['one'], token_ids['right_brace']],
        [token_ids['boxed'], token_ids['left_brace'], token_ids['zero'], token_ids['new_right_brace']],
        [token_ids['boxed'], token_ids['left_brace'], token_ids['one'], token_ids['new_right_brace']],
        [token_ids['boxed'], token_ids['left_brace'], token_ids['zero'], token_ids['another_right_brace']],
        [token_ids['boxed'], token_ids['left_brace'], token_ids['one'], token_ids['another_right_brace']],
        [token_ids['boxed'], token_ids['left_brace'], token_ids['zero'], token_ids['another_right_brace2']],
        [token_ids['boxed'], token_ids['left_brace'], token_ids['one'], token_ids['another_right_brace2']]
    ]
    
    # Find all occurrences of valid patterns
    pattern_positions = []
    
    for i in range(len(all_tokens)-3):  # Need at least 4 tokens
        for pattern in valid_patterns:
            if list(all_tokens[i:i+4]) == pattern:
                pattern_positions.append({
                    'index': i,
                    'answer_token': all_tokens[i+2]
                })

    # If no valid patterns found
    if not pattern_positions:
        if use_softmax:
            return 0, 0.0
        else:
            return 0, float('-inf')
            
    # Use the rightmost pattern
    rightmost = pattern_positions[-1]
    boxed_start_idx = rightmost['index'] + 2  # Points to the 0/1 position
    answer_token = rightmost['answer_token']
            
    # Get logprobs from generation output
    logprobs = generation.outputs[0].logprobs[boxed_start_idx]
    
    # Get logprobs for 0 and 1 tokens
    zero_logprob = logprobs[token_ids['zero']].logprob if token_ids['zero'] in logprobs else float('-inf')
    one_logprob = logprobs[token_ids['one']].logprob if token_ids['one'] in logprobs else float('-inf')

    if use_softmax:
        # Apply softmax
        max_logprob = max(zero_logprob, one_logprob)
        exp_zero = math.exp(zero_logprob - max_logprob)
        exp_one = math.exp(one_logprob - max_logprob)
        sum_exp = exp_zero + exp_one
        
        zero_prob = exp_zero / sum_exp
        one_prob = exp_one / sum_exp
        
        if answer_token == token_ids['one']:
            return 1, one_prob
        elif answer_token == token_ids['zero']:
            return 0, one_prob
        else:
            return 0, 0.0
    else:
        if answer_token == token_ids['one']:
            return 1, one_logprob
        elif answer_token == token_ids['zero']:
            return 0, one_logprob
        else:
            return 0, float('-inf')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--configs', type=str, nargs='+', default=None,
                        choices=["aime_Eurus2_7B_sft", "amc_Eurus2_7B_sft", "gpqa_Eurus2_7B_sft", "leetcode_Eurus2_7B_sft", "math_Eurus2_7B_sft", "minerva_math_Eurus2_7B_sft", "olympiadbench_Eurus2_7B_sft", "aime_Meta_Llama_3.1_70B_Instruct", "amc_Meta_Llama_3.1_70B_Instruct", "gpqa_Meta_Llama_3.1_70B_Instruct", "leetcode_Meta_Llama_3.1_70B_Instruct", "math_Meta_Llama_3.1_70B_Instruct", "minerva_math_Meta_Llama_3.1_70B_Instruct", "olympiadbench_Meta_Llama_3.1_70B_Instruct", "aime_Qwen2.5_7B_Instruct", "amc_Qwen2.5_7B_Instruct", "gpqa_Qwen2.5_7B_Instruct", "leetcode_Qwen2.5_7B_Instruct", "math_Qwen2.5_7B_Instruct", "minerva_math_Qwen2.5_7B_Instruct", "olympiadbench_Qwen2.5_7B_Instruct"])
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument("--output_dir", type=str, default='./bon_new_outputs')
    parser.add_argument("--model_cache_dir", type=str, default='')
    parser.add_argument('--use_voting', action='store_true')
    parser.add_argument('--voting_n', type=int, default=8)
    parser.add_argument('--resume', action='store_true', help='Resume from checkpoint if exists')
    parser.add_argument('--use_softmax', action='store_true')
    args = parser.parse_args()

    args.model_name = os.path.basename(args.model_path)

    toker = AutoTokenizer.from_pretrained(args.model_path)
    TEMPLATE = open('./templates/process_critique_correctness.txt').read().strip()

    llm = LLM(
        model=args.model_path, tokenizer=args.model_path,
        gpu_memory_utilization=0.9,
        tensor_parallel_size=torch.cuda.device_count(),
        enable_prefix_caching=True, swap_space=16,
        max_num_seqs=20,
        download_dir=args.model_cache_dir
    )
    
    # sampling_params = SamplingParams(temperature=0.6, logprobs=20, max_tokens=32768, seed=42)
    sampling_params = SamplingParams(temperature=0.6, logprobs=20, max_tokens=32768, seed=84)

    if args.configs is None:
        args.configs = ["aime_Eurus2_7B_sft", "amc_Eurus2_7B_sft", "gpqa_Eurus2_7B_sft", "leetcode_Eurus2_7B_sft", "math_Eurus2_7B_sft", "minerva_math_Eurus2_7B_sft", "olympiadbench_Eurus2_7B_sft", "aime_Meta_Llama_3.1_70B_Instruct", "amc_Meta_Llama_3.1_70B_Instruct", "gpqa_Meta_Llama_3.1_70B_Instruct", "leetcode_Meta_Llama_3.1_70B_Instruct", "math_Meta_Llama_3.1_70B_Instruct", "minerva_math_Meta_Llama_3.1_70B_Instruct", "olympiadbench_Meta_Llama_3.1_70B_Instruct", "aime_Qwen2.5_7B_Instruct", "amc_Qwen2.5_7B_Instruct", "gpqa_Qwen2.5_7B_Instruct", "leetcode_Qwen2.5_7B_Instruct", "math_Qwen2.5_7B_Instruct", "minerva_math_Qwen2.5_7B_Instruct", "olympiadbench_Qwen2.5_7B_Instruct"]

    for config in args.configs:
        output_dir = os.path.join(args.output_dir, args.model_name)
        os.makedirs(output_dir, exist_ok=True)
        
        checkpoint_file = os.path.join(output_dir, f'{config}_checkpoint.jsonl')
        
        # Load checkpoint if exists and resume flag is set
        processed_instances = {}
        if os.path.exists(checkpoint_file) and args.resume:
            with open(checkpoint_file) as f:
                for line in f:
                    data = json.loads(line)
                    instance_id = data['id'].rsplit('/', 1)[0]
                    processed_instances[instance_id] = data
                    
            # Load results from checkpoint
            results = {}
            for base_id, data in processed_instances.items():
                instance_idx = int(data['id'].rsplit('/', 1)[1])
                if base_id not in results:
                    results[base_id] = {
                        'critiques': [[] for _ in range(16)],  # Assuming max 16 instances per group
                        'logprobs': [[] for _ in range(16)],
                        'predictions': [[] for _ in range(16)],
                        'has_duplicate_steps': [False for _ in range(16)]
                    }
                results[base_id]['critiques'][instance_idx] = data['critiques']
                results[base_id]['logprobs'][instance_idx] = data['logprobs']
                results[base_id]['predictions'][instance_idx] = data['predictions']
                results[base_id]['has_duplicate_steps'][instance_idx] = data.get('has_duplicate_steps', False)

        input_data = load_dataset('prometheus-eval/filtered_bon_setting')[config]
        
        # Group instances by base id (before last /)
        instance_groups = defaultdict(list)
        for item in input_data:
            base_id = item['id'].rsplit('/', 1)[0]
            if base_id not in processed_instances:
                # Check for duplicate steps
                steps = item['steps']
                has_duplicates = len(steps) != len(set(steps))
                item['has_duplicate_steps'] = has_duplicates
                instance_groups[base_id].append(item)
                
        if not args.resume:
            results = {}
            
        # Initialize results for all instances
        for base_id, instances in instance_groups.items():
            if base_id not in results:
                results[base_id] = {
                    'critiques': [[] for _ in instances],
                    'logprobs': [[] for _ in instances],
                    'predictions': [[] for _ in instances],
                    'has_duplicate_steps': [inst['has_duplicate_steps'] for inst in instances]
                }

        # Process instances without duplicates
        active_instances = []
        for base_id, instances in instance_groups.items():
            active_instances.extend([inst for inst in instances if not inst['has_duplicate_steps']])

        # Find max number of steps across all instances
        max_steps = max(len(instance['steps']) for instance in active_instances)

        # Process step by step across all instances
        for step_idx in range(max_steps):
            # Get instances that need this step processed
            current_step_instances = []
            for instance in active_instances:
                base_id = instance['id'].rsplit('/', 1)[0]
                instance_idx = int(instance['id'].rsplit('/', 1)[1])
                num_steps = len(instance['steps'])
                current_critiques = len(results[base_id]['critiques'][instance_idx])
                
                # Only include if this step needs processing and hasn't been processed yet
                if step_idx < num_steps and current_critiques == step_idx:
                    current_step_instances.append(instance)
            
            if not current_step_instances:
                continue

            # Prepare input for all instances at current step
            prompt_token_ids_list = []
            checks_list = []
            instance_map = []  # Keep track of which instance each prompt corresponds to
            
            for instance in current_step_instances:
                base_id = instance['id'].rsplit('/', 1)[0]
                instance_idx = int(instance['id'].rsplit('/', 1)[1])
                prompt_ids, checks = prepare_input_boxed(TEMPLATE, [instance], toker, 
                                                       len(results[base_id]['critiques'][instance_idx]))
                if prompt_ids:
                    prompt_token_ids_list.extend(prompt_ids)
                    checks_list.append(checks[0])
                    instance_map.append(instance)

            if not prompt_token_ids_list:
                continue

            # Generate for all instances at once, with up to 3 retries for invalid outputs
            for retry in range(3):
                generations = llm.generate(prompt_token_ids=prompt_token_ids_list, 
                                         sampling_params=sampling_params)
                
                # Process results and track which need retry
                retry_indices = []
                retry_prompts = []
                retry_checks = []
                retry_instances = []
                
                for idx, (instance, generation, check) in enumerate(zip(instance_map, generations, checks_list)):
                    base_id = instance['id'].rsplit('/', 1)[0]
                    instance_idx = int(instance['id'].rsplit('/', 1)[1])
                    
                    prediction, logprob = get_logprobs(generation, toker, args.use_softmax)
                    
                    # Check if output is invalid
                    if (args.use_softmax and logprob == 0.0) or (not args.use_softmax and logprob == float('-inf')):
                        if retry < 2:  # Only add to retry if we have retries left
                            retry_indices.append(idx)
                            retry_prompts.append(prompt_token_ids_list[idx])
                            retry_checks.append(check)
                            retry_instances.append(instance)
                    else:
                        # Valid output - save it
                        results[base_id]['critiques'][instance_idx].append(generation.outputs[0].text)
                        results[base_id]['logprobs'][instance_idx].append(logprob)
                        results[base_id]['predictions'][instance_idx].append(prediction)
                
                # Update lists for retry
                if retry_indices:
                    prompt_token_ids_list = retry_prompts
                    checks_list = retry_checks
                    instance_map = retry_instances
                else:
                    break  # No more retries needed
                    
            # If we still have instances that failed after all retries, save them with invalid results
            for instance in instance_map:
                base_id = instance['id'].rsplit('/', 1)[0]
                instance_idx = int(instance['id'].rsplit('/', 1)[1])
                if len(results[base_id]['critiques'][instance_idx]) == step_idx:  # Still needs results
                    results[base_id]['critiques'][instance_idx].append(generations[0].outputs[0].text)
                    results[base_id]['logprobs'][instance_idx].append(0.0 if args.use_softmax else float('-inf'))
                    results[base_id]['predictions'][instance_idx].append(0)

            # Save checkpoint after processing this step for all instances
            with open(checkpoint_file, 'w') as f:
                for res_base_id, res in results.items():
                    for idx, (critiques, logprobs, predictions, has_duplicate) in enumerate(zip(
                        res['critiques'], res['logprobs'], res['predictions'], res['has_duplicate_steps'])):
                        instance = next(i for i in instance_groups[res_base_id] if i['id'].endswith(f'/{idx}'))
                        data = {
                            'id': instance['id'],
                            'critiques': critiques,
                            'logprobs': logprobs,
                            'predictions': predictions,
                            'has_duplicate_steps': has_duplicate
                        }
                        f.write(json.dumps(data) + '\n')

        # Verify all instances have correct number of steps processed
        for instance in active_instances:
            base_id = instance['id'].rsplit('/', 1)[0]
            instance_idx = int(instance['id'].rsplit('/', 1)[1])
            num_steps = len(instance['steps'])
            num_critiques = len(results[base_id]['critiques'][instance_idx])
            assert num_critiques == num_steps, f"Instance {instance['id']} has {num_critiques} critiques for {num_steps} steps"

        # Process final results for different N values after all steps are done
        for N in [1, 2, 4, 8]:
            final_results = []
            for base_id, res in results.items():
                instances = instance_groups[base_id][:N]  # Only consider first N instances

                # Find best response
                best_idx = -1
                best_score = float('-inf')

                # Helper function to get valid logprob score and count invalid logprobs
                def get_logprob_metrics(logprobs, use_softmax):
                    filtered_logprobs = []
                    invalid_count = 0
                    for lp in logprobs:
                        if use_softmax:
                            if lp != 0.0:
                                filtered_logprobs.append(lp)
                            else:
                                invalid_count += 1
                        else:
                            if lp != float('-inf'):
                                filtered_logprobs.append(lp)
                            else:
                                invalid_count += 1
                    return min(filtered_logprobs) if filtered_logprobs else float('-inf'), invalid_count

                # First find all responses with all correct steps among non-duplicate instances
                perfect_responses = []
                for idx in range(len(instances)):
                    if not instances[idx]['has_duplicate_steps']:
                        predictions = res['predictions'][idx]
                        logprobs = res['logprobs'][idx]
                        
                        # Filter out responses with any 0 logprobs
                        if args.use_softmax and any(lp == 0.0 for lp in logprobs):
                            continue
                            
                        # Filter out 0.0 values, take log of each component and sum
                        filtered_logprobs = [lp for lp in logprobs if lp != 0.0]
                        if not filtered_logprobs:
                            continue
                            
                        # Take log of each component and sum
                        log_product = sum(math.log(lp) for lp in filtered_logprobs)
                        
                        if all(p == 1 for p in predictions):
                            perfect_responses.append((idx, log_product))

                # If we found perfect responses, choose the one with highest log product
                if perfect_responses:
                    best_idx, _ = max(perfect_responses, key=lambda x: x[1])
                
                # Otherwise choose response with highest log product among all valid responses
                else:
                    best_min_logprob = float('-inf')
                    for idx in range(len(instances)):
                        if not instances[idx]['has_duplicate_steps']:
                            logprobs = res['logprobs'][idx]
                            
                            # Skip responses with any 0 logprobs when using softmax
                            if args.use_softmax and any(lp == 0.0 for lp in logprobs):
                                continue
                                
                            # Filter out 0.0 values, take log of each component and sum
                            filtered_logprobs = [lp for lp in logprobs if lp != 0.0 and lp != float('-inf')]
                            if not filtered_logprobs:
                                continue
                                
                            min_logprob = min(filtered_logprobs)
                            if min_logprob > best_min_logprob:
                                best_min_logprob = min_logprob
                                best_idx = idx

                # If still no response chosen, use index 0
                if best_idx == -1:
                    best_idx = 0

                # Save all responses with their metrics
                all_responses = []
                for idx, instance in enumerate(instances):
                    response_data = {
                        'instance': instance,
                        'generated_critique': res['critiques'][idx],
                        'logprobs': res['logprobs'][idx],
                        'predictions': res['predictions'][idx],
                        'is_best_response': (idx == best_idx)
                    }
                    all_responses.append(response_data)
                final_results.extend(all_responses)

            # Count correct final answers from best responses
            correct_count = sum(1 for instance in final_results if instance['is_best_response'] and instance['instance'].get('final_answer_correct', False))
            total_instances = total_instances_dict[config]/8
            score = correct_count / total_instances if total_instances > 0 else 0
            print(f"{config} N={N} Score: {score:.4f} ({correct_count}/{total_instances})")

            # Save all results for this N value
            with open(os.path.join(output_dir, f'{config}_results_N={N}_score={str(score)}.json'), 'w') as f:
                json.dump(final_results, f, indent=4)

        if os.path.exists(checkpoint_file):
            os.remove(checkpoint_file)

if __name__ == '__main__':
    main()